107 lines
3.4 KiB
Python
107 lines
3.4 KiB
Python
from __future__ import annotations
|
|
|
|
import base64
|
|
import os
|
|
import threading
|
|
from typing import Optional
|
|
|
|
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
|
|
|
|
|
_VERSION = 1
|
|
_PREFIX_LEN = 1 # version byte
|
|
_NONCE_LEN = 12
|
|
_TAG_LEN = 16
|
|
_valid_key_lengths = {16, 24, 32}
|
|
_kek_lock = threading.Lock()
|
|
_cached_kek: Optional[bytes] = None
|
|
|
|
|
|
class KeyWrapError(RuntimeError):
|
|
"""Raised when KEK configuration or unwrap operations fail."""
|
|
|
|
|
|
def _normalize_base64(value: str) -> str:
|
|
v = value.strip()
|
|
missing = (-len(v)) % 4
|
|
if missing:
|
|
v += "=" * missing
|
|
return v
|
|
|
|
|
|
def _decode_key_material(value: str) -> bytes:
|
|
v = value.strip()
|
|
if v.startswith("0x") or v.startswith("0X"):
|
|
v = v[2:]
|
|
try:
|
|
raw = bytes.fromhex(v)
|
|
if len(raw) in _valid_key_lengths:
|
|
return raw
|
|
except ValueError:
|
|
pass
|
|
try:
|
|
raw = base64.b64decode(_normalize_base64(value), validate=False)
|
|
if len(raw) in _valid_key_lengths:
|
|
return raw
|
|
except Exception as exc: # noqa: BLE001 - we want to re-raise as KeyWrapError
|
|
raise KeyWrapError(f"invalid KEK encoding: {exc}") from exc
|
|
raise KeyWrapError("KEK must decode to 16/24/32 bytes")
|
|
|
|
|
|
def _load_kek() -> bytes:
|
|
global _cached_kek
|
|
if _cached_kek is not None:
|
|
return _cached_kek
|
|
with _kek_lock:
|
|
if _cached_kek is not None:
|
|
return _cached_kek
|
|
env = os.getenv("CONTENT_KEY_KEK_B64") or os.getenv("CONTENT_KEY_KEK_HEX")
|
|
if not env:
|
|
raise KeyWrapError("CONTENT_KEY_KEK_B64 or CONTENT_KEY_KEK_HEX must be set")
|
|
kek = _decode_key_material(env)
|
|
if len(kek) != 32:
|
|
# Force 256-bit KEK for uniform security properties
|
|
raise KeyWrapError("KEK must be 32 bytes (256-bit) for AES-256-GCM")
|
|
_cached_kek = kek
|
|
return _cached_kek
|
|
|
|
|
|
def wrap_dek(plaintext: bytes) -> str:
|
|
"""Wrap a DEK (plaintext bytes) with AES-256-GCM; return base64 string."""
|
|
if not isinstance(plaintext, (bytes, bytearray)):
|
|
raise TypeError("plaintext must be bytes")
|
|
kek = _load_kek()
|
|
nonce = os.urandom(_NONCE_LEN)
|
|
cipher = AESGCM(kek)
|
|
ct = cipher.encrypt(nonce, bytes(plaintext), associated_data=None)
|
|
blob = bytes([_VERSION]) + nonce + ct
|
|
return base64.b64encode(blob).decode()
|
|
|
|
|
|
def unwrap_dek(encoded: str) -> bytes:
|
|
"""Unwrap DEK from base64 string. Supports legacy (raw base64 key) values."""
|
|
if not encoded:
|
|
raise KeyWrapError("empty key payload")
|
|
try:
|
|
raw = base64.b64decode(_normalize_base64(encoded), validate=False)
|
|
except Exception as exc: # noqa: BLE001
|
|
raise KeyWrapError(f"invalid base64 payload: {exc}") from exc
|
|
if not raw:
|
|
raise KeyWrapError("decoded payload is empty")
|
|
version = raw[0]
|
|
if version == _VERSION:
|
|
if len(raw) < _PREFIX_LEN + _NONCE_LEN + _TAG_LEN + 1:
|
|
raise KeyWrapError("wrapped payload too short")
|
|
nonce = raw[_PREFIX_LEN:_PREFIX_LEN + _NONCE_LEN]
|
|
ciphertext = raw[_PREFIX_LEN + _NONCE_LEN:]
|
|
kek = _load_kek()
|
|
cipher = AESGCM(kek)
|
|
try:
|
|
return cipher.decrypt(nonce, ciphertext, associated_data=None)
|
|
except Exception as exc: # noqa: BLE001
|
|
raise KeyWrapError(f"unwrap failed: {exc}") from exc
|
|
# Legacy fallback: value is raw DEK (no version prefix)
|
|
if len(raw) in {16, 24, 32}:
|
|
return raw
|
|
raise KeyWrapError("unknown key payload format")
|