uploader-bot/app/core/crypto/keywrap.py

107 lines
3.4 KiB
Python

from __future__ import annotations
import base64
import os
import threading
from typing import Optional
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
_VERSION = 1
_PREFIX_LEN = 1 # version byte
_NONCE_LEN = 12
_TAG_LEN = 16
_valid_key_lengths = {16, 24, 32}
_kek_lock = threading.Lock()
_cached_kek: Optional[bytes] = None
class KeyWrapError(RuntimeError):
"""Raised when KEK configuration or unwrap operations fail."""
def _normalize_base64(value: str) -> str:
v = value.strip()
missing = (-len(v)) % 4
if missing:
v += "=" * missing
return v
def _decode_key_material(value: str) -> bytes:
v = value.strip()
if v.startswith("0x") or v.startswith("0X"):
v = v[2:]
try:
raw = bytes.fromhex(v)
if len(raw) in _valid_key_lengths:
return raw
except ValueError:
pass
try:
raw = base64.b64decode(_normalize_base64(value), validate=False)
if len(raw) in _valid_key_lengths:
return raw
except Exception as exc: # noqa: BLE001 - we want to re-raise as KeyWrapError
raise KeyWrapError(f"invalid KEK encoding: {exc}") from exc
raise KeyWrapError("KEK must decode to 16/24/32 bytes")
def _load_kek() -> bytes:
global _cached_kek
if _cached_kek is not None:
return _cached_kek
with _kek_lock:
if _cached_kek is not None:
return _cached_kek
env = os.getenv("CONTENT_KEY_KEK_B64") or os.getenv("CONTENT_KEY_KEK_HEX")
if not env:
raise KeyWrapError("CONTENT_KEY_KEK_B64 or CONTENT_KEY_KEK_HEX must be set")
kek = _decode_key_material(env)
if len(kek) != 32:
# Force 256-bit KEK for uniform security properties
raise KeyWrapError("KEK must be 32 bytes (256-bit) for AES-256-GCM")
_cached_kek = kek
return _cached_kek
def wrap_dek(plaintext: bytes) -> str:
"""Wrap a DEK (plaintext bytes) with AES-256-GCM; return base64 string."""
if not isinstance(plaintext, (bytes, bytearray)):
raise TypeError("plaintext must be bytes")
kek = _load_kek()
nonce = os.urandom(_NONCE_LEN)
cipher = AESGCM(kek)
ct = cipher.encrypt(nonce, bytes(plaintext), associated_data=None)
blob = bytes([_VERSION]) + nonce + ct
return base64.b64encode(blob).decode()
def unwrap_dek(encoded: str) -> bytes:
"""Unwrap DEK from base64 string. Supports legacy (raw base64 key) values."""
if not encoded:
raise KeyWrapError("empty key payload")
try:
raw = base64.b64decode(_normalize_base64(encoded), validate=False)
except Exception as exc: # noqa: BLE001
raise KeyWrapError(f"invalid base64 payload: {exc}") from exc
if not raw:
raise KeyWrapError("decoded payload is empty")
version = raw[0]
if version == _VERSION:
if len(raw) < _PREFIX_LEN + _NONCE_LEN + _TAG_LEN + 1:
raise KeyWrapError("wrapped payload too short")
nonce = raw[_PREFIX_LEN:_PREFIX_LEN + _NONCE_LEN]
ciphertext = raw[_PREFIX_LEN + _NONCE_LEN:]
kek = _load_kek()
cipher = AESGCM(kek)
try:
return cipher.decrypt(nonce, ciphertext, associated_data=None)
except Exception as exc: # noqa: BLE001
raise KeyWrapError(f"unwrap failed: {exc}") from exc
# Legacy fallback: value is raw DEK (no version prefix)
if len(raw) in {16, 24, 32}:
return raw
raise KeyWrapError("unknown key payload format")