from __future__ import annotations import os import hmac import hashlib from typing import BinaryIO, Iterator from Crypto.Cipher import AES CHUNK_BYTES = int(os.getenv("CRYPTO_CHUNK_BYTES", "1048576")) # 1 MiB def _derive_nonce(salt: bytes, chunk_index: int) -> bytes: """Derive a 12-byte GCM nonce deterministically from per-file salt and chunk index.""" idx = chunk_index.to_bytes(8, 'big') digest = hmac.new(salt, idx, hashlib.sha256).digest() return digest[:12] def encrypt_stream_aesgcm(src: BinaryIO, key: bytes, salt: bytes) -> Iterator[bytes]: """ Read plaintext from src by CHUNK_BYTES, encrypt each chunk with AES-GCM using a deterministic nonce derived from (salt, index). Yields bytes in framing: [C_i][TAG_i]... Ciphertext length equals plaintext chunk length. Tag is 16 bytes. """ assert len(key) in (16, 24, 32) assert len(salt) >= 12 idx = 0 while True: block = src.read(CHUNK_BYTES) if not block: break nonce = _derive_nonce(salt, idx) cipher = AES.new(key, AES.MODE_GCM, nonce=nonce) ciphertext, tag = cipher.encrypt_and_digest(block) yield ciphertext yield tag idx += 1 def decrypt_stream_aesgcm_iter(byte_iter: Iterator[bytes], key: bytes, salt: bytes) -> Iterator[bytes]: """ Decrypt a stream that was produced by encrypt_stream_aesgcm. Frame format: concatenation of [C_i][TAG_i] for each i, where |C_i| = CHUNK_BYTES and |TAG_i|=16. We accept arbitrary chunking from the underlying iterator and reframe accordingly. """ assert len(key) in (16, 24, 32) buf = bytearray() idx = 0 TAG_LEN = 16 def _try_yield(): nonlocal idx out = [] while len(buf) >= CHUNK_BYTES + TAG_LEN: c = bytes(buf[:CHUNK_BYTES]) t = bytes(buf[CHUNK_BYTES:CHUNK_BYTES+TAG_LEN]) del buf[:CHUNK_BYTES+TAG_LEN] nonce = _derive_nonce(salt, idx) cipher = AES.new(key, AES.MODE_GCM, nonce=nonce) try: p = cipher.decrypt_and_verify(c, t) except Exception as e: raise ValueError(f"Decrypt failed at chunk {idx}: {e}") out.append(p) idx += 1 return out for chunk in byte_iter: if not chunk: continue buf.extend(chunk) for p in _try_yield(): yield p # At end, buffer must be empty if len(buf) != 0: raise ValueError("Trailing bytes in encrypted stream (incomplete frame)")