from __future__ import annotations import base64 import os import threading from typing import Optional from cryptography.hazmat.primitives.ciphers.aead import AESGCM _VERSION = 1 _PREFIX_LEN = 1 # version byte _NONCE_LEN = 12 _TAG_LEN = 16 _valid_key_lengths = {16, 24, 32} _kek_lock = threading.Lock() _cached_kek: Optional[bytes] = None class KeyWrapError(RuntimeError): """Raised when KEK configuration or unwrap operations fail.""" def _normalize_base64(value: str) -> str: v = value.strip() missing = (-len(v)) % 4 if missing: v += "=" * missing return v def _decode_key_material(value: str) -> bytes: v = value.strip() if v.startswith("0x") or v.startswith("0X"): v = v[2:] try: raw = bytes.fromhex(v) if len(raw) in _valid_key_lengths: return raw except ValueError: pass try: raw = base64.b64decode(_normalize_base64(value), validate=False) if len(raw) in _valid_key_lengths: return raw except Exception as exc: # noqa: BLE001 - we want to re-raise as KeyWrapError raise KeyWrapError(f"invalid KEK encoding: {exc}") from exc raise KeyWrapError("KEK must decode to 16/24/32 bytes") def _load_kek() -> bytes: global _cached_kek if _cached_kek is not None: return _cached_kek with _kek_lock: if _cached_kek is not None: return _cached_kek env = os.getenv("CONTENT_KEY_KEK_B64") or os.getenv("CONTENT_KEY_KEK_HEX") if not env: raise KeyWrapError("CONTENT_KEY_KEK_B64 or CONTENT_KEY_KEK_HEX must be set") kek = _decode_key_material(env) if len(kek) != 32: # Force 256-bit KEK for uniform security properties raise KeyWrapError("KEK must be 32 bytes (256-bit) for AES-256-GCM") _cached_kek = kek return _cached_kek def wrap_dek(plaintext: bytes) -> str: """Wrap a DEK (plaintext bytes) with AES-256-GCM; return base64 string.""" if not isinstance(plaintext, (bytes, bytearray)): raise TypeError("plaintext must be bytes") kek = _load_kek() nonce = os.urandom(_NONCE_LEN) cipher = AESGCM(kek) ct = cipher.encrypt(nonce, bytes(plaintext), associated_data=None) blob = bytes([_VERSION]) + nonce + ct return base64.b64encode(blob).decode() def unwrap_dek(encoded: str) -> bytes: """Unwrap DEK from base64 string. Supports legacy (raw base64 key) values.""" if not encoded: raise KeyWrapError("empty key payload") try: raw = base64.b64decode(_normalize_base64(encoded), validate=False) except Exception as exc: # noqa: BLE001 raise KeyWrapError(f"invalid base64 payload: {exc}") from exc if not raw: raise KeyWrapError("decoded payload is empty") version = raw[0] if version == _VERSION: if len(raw) < _PREFIX_LEN + _NONCE_LEN + _TAG_LEN + 1: raise KeyWrapError("wrapped payload too short") nonce = raw[_PREFIX_LEN:_PREFIX_LEN + _NONCE_LEN] ciphertext = raw[_PREFIX_LEN + _NONCE_LEN:] kek = _load_kek() cipher = AESGCM(kek) try: return cipher.decrypt(nonce, ciphertext, associated_data=None) except Exception as exc: # noqa: BLE001 raise KeyWrapError(f"unwrap failed: {exc}") from exc # Legacy fallback: value is raw DEK (no version prefix) if len(raw) in {16, 24, 32}: return raw raise KeyWrapError("unknown key payload format")