138 lines
4.3 KiB
Python
138 lines
4.3 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import struct
|
|
from typing import BinaryIO, Iterator, AsyncIterator
|
|
|
|
from Crypto.Cipher import SIV
|
|
from Crypto.Cipher import AES
|
|
|
|
|
|
MAGIC = b"ENCF"
|
|
VERSION = 1
|
|
|
|
# Scheme codes
|
|
SCHEME_AES_SIV = 0x02 # RFC5297 AES-SIV (CMAC-based)
|
|
|
|
|
|
def build_header(chunk_bytes: int, salt: bytes, scheme: int = SCHEME_AES_SIV) -> bytes:
|
|
assert 0 < chunk_bytes <= (1 << 31)
|
|
assert 1 <= len(salt) <= 255
|
|
# Layout: MAGIC(4) | version(1) | scheme(1) | chunk_bytes(4,BE) | salt_len(1) | salt(N) | reserved(5 zeros)
|
|
hdr = bytearray()
|
|
hdr += MAGIC
|
|
hdr += bytes([VERSION])
|
|
hdr += bytes([scheme])
|
|
hdr += struct.pack(">I", int(chunk_bytes))
|
|
hdr += bytes([len(salt)])
|
|
hdr += salt
|
|
hdr += b"\x00" * 5
|
|
return bytes(hdr)
|
|
|
|
|
|
def parse_header(buf: bytes) -> tuple[int, int, int, bytes, int]:
|
|
if len(buf) < 4 + 1 + 1 + 4 + 1:
|
|
raise ValueError("header too short")
|
|
if buf[:4] != MAGIC:
|
|
raise ValueError("bad magic")
|
|
version = buf[4]
|
|
scheme = buf[5]
|
|
chunk_bytes = struct.unpack(">I", buf[6:10])[0]
|
|
salt_len = buf[10]
|
|
needed = 4 + 1 + 1 + 4 + 1 + salt_len + 5
|
|
if len(buf) < needed:
|
|
raise ValueError("incomplete header")
|
|
salt = buf[11:11 + salt_len]
|
|
# reserved 5 bytes at the end ignored
|
|
return version, scheme, chunk_bytes, salt, needed
|
|
|
|
|
|
def _ad(salt: bytes, idx: int) -> bytes:
|
|
return salt + struct.pack(">Q", idx)
|
|
|
|
|
|
def encrypt_file_to_encf(src: BinaryIO, key: bytes, chunk_bytes: int, salt: bytes) -> Iterator[bytes]:
|
|
"""
|
|
Yield ENCF v1 stream bytes: [header] then for each chunk: [p_len:4][cipher][tag(16)].
|
|
Uses AES-SIV (RFC5297) with per-chunk associated data salt||index.
|
|
"""
|
|
yield build_header(chunk_bytes, salt, SCHEME_AES_SIV)
|
|
idx = 0
|
|
while True:
|
|
block = src.read(chunk_bytes)
|
|
if not block:
|
|
break
|
|
siv = SIV.new(key=key, ciphermod=AES) # new object per message
|
|
siv.update(_ad(salt, idx))
|
|
ciph, tag = siv.encrypt_and_digest(block)
|
|
yield struct.pack(">I", len(block))
|
|
yield ciph
|
|
yield tag
|
|
idx += 1
|
|
|
|
|
|
async def decrypt_encf_to_file(byte_iter: AsyncIterator[bytes], key: bytes, out_path: str) -> None:
|
|
"""
|
|
Parse ENCF v1 stream from async byte iterator and write plaintext to out_path.
|
|
"""
|
|
import aiofiles
|
|
from Crypto.Cipher import SIV as _SIV
|
|
from Crypto.Cipher import AES as _AES
|
|
|
|
buf = bytearray()
|
|
|
|
async def _fill(n: int):
|
|
"""Ensure at least n bytes in buffer (or EOF)."""
|
|
nonlocal buf
|
|
while len(buf) < n:
|
|
try:
|
|
chunk = await byte_iter.__anext__()
|
|
except StopAsyncIteration:
|
|
break
|
|
if chunk:
|
|
buf.extend(chunk)
|
|
|
|
# Read and parse header
|
|
await _fill(4 + 1 + 1 + 4 + 1) # minimal header
|
|
# Might still be incomplete if salt_len > 0; keep filling progressively
|
|
# First, get preliminary to know salt_len
|
|
if len(buf) < 11:
|
|
await _fill(11)
|
|
if buf[:4] != MAGIC:
|
|
raise ValueError("bad magic")
|
|
salt_len = buf[10]
|
|
hdr_len = 4 + 1 + 1 + 4 + 1 + salt_len + 5
|
|
await _fill(hdr_len)
|
|
version, scheme, chunk_bytes, salt, consumed = parse_header(bytes(buf))
|
|
del buf[:consumed]
|
|
if version != 1:
|
|
raise ValueError("unsupported ENCF version")
|
|
if scheme != SCHEME_AES_SIV:
|
|
raise ValueError("unsupported scheme")
|
|
|
|
async with aiofiles.open(out_path, 'wb') as out:
|
|
idx = 0
|
|
TAG_LEN = 16
|
|
while True:
|
|
# Need at least 4 bytes for p_len
|
|
await _fill(4)
|
|
if len(buf) == 0:
|
|
break # EOF exactly on boundary
|
|
if len(buf) < 4:
|
|
raise ValueError("truncated frame length")
|
|
p_len = struct.unpack(">I", bytes(buf[:4]))[0]
|
|
del buf[:4]
|
|
# Now need p_len + 16 bytes
|
|
await _fill(p_len + TAG_LEN)
|
|
if len(buf) < p_len + TAG_LEN:
|
|
raise ValueError("truncated cipher/tag")
|
|
c = bytes(buf[:p_len])
|
|
t = bytes(buf[p_len:p_len+TAG_LEN])
|
|
del buf[:p_len+TAG_LEN]
|
|
siv = _SIV.new(key=key, ciphermod=_AES)
|
|
siv.update(_ad(salt, idx))
|
|
p = siv.decrypt_and_verify(c, t)
|
|
await out.write(p)
|
|
idx += 1
|
|
|