78 lines
2.5 KiB
Python
78 lines
2.5 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import hmac
|
|
import hashlib
|
|
from typing import BinaryIO, Iterator
|
|
|
|
from Crypto.Cipher import AES
|
|
|
|
|
|
CHUNK_BYTES = int(os.getenv("CRYPTO_CHUNK_BYTES", "1048576")) # 1 MiB
|
|
|
|
|
|
def _derive_nonce(salt: bytes, chunk_index: int) -> bytes:
|
|
"""Derive a 12-byte GCM nonce deterministically from per-file salt and chunk index."""
|
|
idx = chunk_index.to_bytes(8, 'big')
|
|
digest = hmac.new(salt, idx, hashlib.sha256).digest()
|
|
return digest[:12]
|
|
|
|
|
|
def encrypt_stream_aesgcm(src: BinaryIO, key: bytes, salt: bytes) -> Iterator[bytes]:
|
|
"""
|
|
Read plaintext from src by CHUNK_BYTES, encrypt each chunk with AES-GCM using a
|
|
deterministic nonce derived from (salt, index). Yields bytes in framing: [C_i][TAG_i]...
|
|
Ciphertext length equals plaintext chunk length. Tag is 16 bytes.
|
|
"""
|
|
assert len(key) in (16, 24, 32)
|
|
assert len(salt) >= 12
|
|
idx = 0
|
|
while True:
|
|
block = src.read(CHUNK_BYTES)
|
|
if not block:
|
|
break
|
|
nonce = _derive_nonce(salt, idx)
|
|
cipher = AES.new(key, AES.MODE_GCM, nonce=nonce)
|
|
ciphertext, tag = cipher.encrypt_and_digest(block)
|
|
yield ciphertext
|
|
yield tag
|
|
idx += 1
|
|
|
|
|
|
def decrypt_stream_aesgcm_iter(byte_iter: Iterator[bytes], key: bytes, salt: bytes) -> Iterator[bytes]:
|
|
"""
|
|
Decrypt a stream that was produced by encrypt_stream_aesgcm.
|
|
Frame format: concatenation of [C_i][TAG_i] for each i, where |C_i| = CHUNK_BYTES and |TAG_i|=16.
|
|
We accept arbitrary chunking from the underlying iterator and reframe accordingly.
|
|
"""
|
|
assert len(key) in (16, 24, 32)
|
|
buf = bytearray()
|
|
idx = 0
|
|
TAG_LEN = 16
|
|
def _try_yield():
|
|
nonlocal idx
|
|
out = []
|
|
while len(buf) >= CHUNK_BYTES + TAG_LEN:
|
|
c = bytes(buf[:CHUNK_BYTES])
|
|
t = bytes(buf[CHUNK_BYTES:CHUNK_BYTES+TAG_LEN])
|
|
del buf[:CHUNK_BYTES+TAG_LEN]
|
|
nonce = _derive_nonce(salt, idx)
|
|
cipher = AES.new(key, AES.MODE_GCM, nonce=nonce)
|
|
try:
|
|
p = cipher.decrypt_and_verify(c, t)
|
|
except Exception as e:
|
|
raise ValueError(f"Decrypt failed at chunk {idx}: {e}")
|
|
out.append(p)
|
|
idx += 1
|
|
return out
|
|
for chunk in byte_iter:
|
|
if not chunk:
|
|
continue
|
|
buf.extend(chunk)
|
|
for p in _try_yield():
|
|
yield p
|
|
# At end, buffer must be empty
|
|
if len(buf) != 0:
|
|
raise ValueError("Trailing bytes in encrypted stream (incomplete frame)")
|
|
|