""" Comprehensive security module with encryption, JWT tokens, password hashing, and access control. Provides secure file encryption, token management, and authentication utilities. """ import hashlib import hmac import secrets from datetime import datetime, timedelta from typing import Dict, List, Optional, Any, Union from uuid import UUID import bcrypt import jwt from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC import base64 from app.core.config import get_settings from app.core.logging import get_logger logger = get_logger(__name__) settings = get_settings() class SecurityManager: """Main security manager for encryption, tokens, and authentication.""" def __init__(self): self.fernet_key = self._get_or_create_fernet_key() self.fernet = Fernet(self.fernet_key) def _get_or_create_fernet_key(self) -> bytes: """Get or create Fernet encryption key from settings.""" if hasattr(settings, 'ENCRYPTION_KEY') and settings.ENCRYPTION_KEY: # Derive key from settings kdf = PBKDF2HMAC( algorithm=hashes.SHA256(), length=32, salt=settings.SECRET_KEY.encode()[:16], iterations=100000, ) key = base64.urlsafe_b64encode(kdf.derive(settings.ENCRYPTION_KEY.encode())) return key else: # Generate random key (for development only) return Fernet.generate_key() # Global security manager instance _security_manager = SecurityManager() def hash_password(password: str) -> str: """ Hash password using bcrypt with salt. Args: password: Plain text password Returns: str: Hashed password """ try: salt = bcrypt.gensalt(rounds=12) hashed = bcrypt.hashpw(password.encode('utf-8'), salt) return hashed.decode('utf-8') except Exception as e: logger.error("Failed to hash password", error=str(e)) raise def verify_password(password: str, hashed_password: str) -> bool: """ Verify password against hash. Args: password: Plain text password hashed_password: Bcrypt hashed password Returns: bool: True if password matches """ try: return bcrypt.checkpw(password.encode('utf-8'), hashed_password.encode('utf-8')) except Exception as e: logger.error("Failed to verify password", error=str(e)) return False def generate_access_token( payload: Dict[str, Any], expires_in: int = 3600, token_type: str = "access" ) -> str: """ Generate JWT access token. Args: payload: Token payload data expires_in: Token expiration time in seconds token_type: Type of token (access, refresh, api) Returns: str: JWT token """ try: # Ensure expires_in is an integer if isinstance(expires_in, str): expires_in = int(expires_in) elif not isinstance(expires_in, int): expires_in = int(expires_in) now = datetime.utcnow() token_payload = { "iat": now, "exp": now + timedelta(seconds=expires_in), "type": token_type, "jti": secrets.token_urlsafe(16), # Unique token ID **payload } token = jwt.encode( token_payload, settings.SECRET_KEY, algorithm="HS256" ) logger.debug( "Access token generated", token_type=token_type, expires_in=expires_in, user_id=payload.get("user_id") ) return token except Exception as e: logger.error("Failed to generate access token", error=str(e)) raise def verify_access_token(token: str, token_type: str = "access") -> Optional[Dict[str, Any]]: """ Verify and decode JWT token. Args: token: JWT token string token_type: Expected token type Returns: Optional[Dict]: Decoded payload or None if invalid """ try: payload = jwt.decode( token, settings.SECRET_KEY, algorithms=["HS256"] ) # Verify token type if payload.get("type") != token_type: logger.warning("Token type mismatch", expected=token_type, actual=payload.get("type")) return None # Check expiration if datetime.utcnow() > datetime.fromtimestamp(payload["exp"]): logger.warning("Token expired", exp=payload["exp"]) return None return payload except jwt.ExpiredSignatureError: logger.warning("Token expired") return None except jwt.InvalidTokenError as e: logger.warning("Invalid token", error=str(e)) return None except Exception as e: logger.error("Failed to verify token", error=str(e)) return None def generate_refresh_token(user_id: UUID, device_id: Optional[str] = None) -> str: """ Generate long-lived refresh token. Args: user_id: User UUID device_id: Optional device identifier Returns: str: Refresh token """ payload = { "user_id": str(user_id), "device_id": device_id, "token_family": secrets.token_urlsafe(16) # For token rotation } return generate_access_token( payload, expires_in=settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 3600, token_type="refresh" ) def generate_api_key( user_id: UUID, permissions: List[str], name: str, expires_in: Optional[int] = None ) -> str: """ Generate API key with specific permissions. Args: user_id: User UUID permissions: List of permissions name: API key name expires_in: Optional expiration time in seconds Returns: str: API key token """ payload = { "user_id": str(user_id), "permissions": permissions, "name": name, "key_id": secrets.token_urlsafe(16) } expires = expires_in or (365 * 24 * 3600) # Default 1 year return generate_access_token(payload, expires_in=expires, token_type="api") def encrypt_data(data: Union[str, bytes], context: str = "") -> str: """ Encrypt data using Fernet symmetric encryption. Args: data: Data to encrypt context: Optional context for additional security Returns: str: Base64 encoded encrypted data """ try: if isinstance(data, str): data = data.encode('utf-8') # Add context to data for additional security if context: data = f"{context}:{len(data)}:".encode('utf-8') + data encrypted = _security_manager.fernet.encrypt(data) return base64.urlsafe_b64encode(encrypted).decode('utf-8') except Exception as e: logger.error("Failed to encrypt data", error=str(e)) raise def decrypt_data(encrypted_data: str, context: str = "") -> Union[str, bytes]: """ Decrypt data using Fernet symmetric encryption. Args: encrypted_data: Base64 encoded encrypted data context: Optional context for verification Returns: Union[str, bytes]: Decrypted data """ try: encrypted_bytes = base64.urlsafe_b64decode(encrypted_data.encode('utf-8')) decrypted = _security_manager.fernet.decrypt(encrypted_bytes) # Verify and remove context if provided if context: context_prefix = f"{context}:".encode('utf-8') if not decrypted.startswith(context_prefix): raise ValueError("Context mismatch during decryption") # Extract length and data remaining = decrypted[len(context_prefix):] length_end = remaining.find(b':') if length_end == -1: raise ValueError("Invalid encrypted data format") expected_length = int(remaining[:length_end].decode('utf-8')) data = remaining[length_end + 1:] if len(data) != expected_length: raise ValueError("Data length mismatch") return data return decrypted except Exception as e: logger.error("Failed to decrypt data", error=str(e)) raise def encrypt_file(file_data: bytes, file_id: str) -> bytes: """ Encrypt file data with file-specific context. Args: file_data: File bytes to encrypt file_id: Unique file identifier Returns: bytes: Encrypted file data """ try: encrypted_str = encrypt_data(file_data, context=f"file:{file_id}") return base64.urlsafe_b64decode(encrypted_str.encode('utf-8')) except Exception as e: logger.error("Failed to encrypt file", file_id=file_id, error=str(e)) raise def decrypt_file(encrypted_data: bytes, file_id: str) -> bytes: """ Decrypt file data with file-specific context. Args: encrypted_data: Encrypted file bytes file_id: Unique file identifier Returns: bytes: Decrypted file data """ try: encrypted_str = base64.urlsafe_b64encode(encrypted_data).decode('utf-8') decrypted = decrypt_data(encrypted_str, context=f"file:{file_id}") return decrypted if isinstance(decrypted, bytes) else decrypted.encode('utf-8') except Exception as e: logger.error("Failed to decrypt file", file_id=file_id, error=str(e)) raise def generate_secure_filename(original_filename: str, user_id: UUID) -> str: """ Generate secure filename to prevent path traversal and collisions. Args: original_filename: Original filename user_id: User UUID Returns: str: Secure filename """ # Extract extension parts = original_filename.rsplit('.', 1) extension = parts[1] if len(parts) > 1 else '' # Generate secure base name timestamp = datetime.utcnow().strftime('%Y%m%d_%H%M%S') random_part = secrets.token_urlsafe(8) user_hash = hashlib.sha256(str(user_id).encode()).hexdigest()[:8] secure_name = f"{timestamp}_{user_hash}_{random_part}" if extension: # Validate extension allowed_extensions = { 'txt', 'pdf', 'doc', 'docx', 'xls', 'xlsx', 'ppt', 'pptx', 'jpg', 'jpeg', 'png', 'gif', 'bmp', 'webp', 'svg', 'mp3', 'wav', 'flac', 'ogg', 'mp4', 'avi', 'mkv', 'webm', 'zip', 'rar', '7z', 'tar', 'gz', 'json', 'xml', 'csv' } clean_extension = extension.lower().strip() if clean_extension in allowed_extensions: secure_name += f".{clean_extension}" return secure_name def validate_file_signature(file_data: bytes, claimed_type: str) -> bool: """ Validate file signature against claimed MIME type. Args: file_data: File bytes to validate claimed_type: Claimed MIME type Returns: bool: True if signature matches type """ if len(file_data) < 8: return False # File signatures (magic numbers) signatures = { 'image/jpeg': [b'\xFF\xD8\xFF'], 'image/png': [b'\x89PNG\r\n\x1a\n'], 'image/gif': [b'GIF87a', b'GIF89a'], 'image/webp': [b'RIFF', b'WEBP'], 'application/pdf': [b'%PDF-'], 'application/zip': [b'PK\x03\x04', b'PK\x05\x06', b'PK\x07\x08'], 'audio/mpeg': [b'ID3', b'\xFF\xFB', b'\xFF\xF3', b'\xFF\xF2'], 'video/mp4': [b'\x00\x00\x00\x18ftypmp4', b'\x00\x00\x00\x20ftypmp4'], 'text/plain': [], # Text files don't have reliable signatures } expected_sigs = signatures.get(claimed_type, []) # If no signatures defined, allow (like text files) if not expected_sigs: return True # Check if file starts with any expected signature file_start = file_data[:32] # Check first 32 bytes for sig in expected_sigs: if file_start.startswith(sig): return True return False def generate_csrf_token(user_id: UUID, session_id: str) -> str: """ Generate CSRF token for form protection. Args: user_id: User UUID session_id: Session identifier Returns: str: CSRF token """ timestamp = str(int(datetime.utcnow().timestamp())) data = f"{user_id}:{session_id}:{timestamp}" signature = hmac.new( settings.SECRET_KEY.encode(), data.encode(), hashlib.sha256 ).hexdigest() token_data = f"{data}:{signature}" return base64.urlsafe_b64encode(token_data.encode()).decode() def verify_csrf_token(token: str, user_id: UUID, session_id: str, max_age: int = 3600) -> bool: """ Verify CSRF token. Args: token: CSRF token to verify user_id: User UUID session_id: Session identifier max_age: Maximum token age in seconds Returns: bool: True if token is valid """ try: token_data = base64.urlsafe_b64decode(token.encode()).decode() parts = token_data.split(':') if len(parts) != 4: return False token_user_id, token_session_id, timestamp, signature = parts # Verify components if token_user_id != str(user_id) or token_session_id != session_id: return False # Check age token_time = int(timestamp) current_time = int(datetime.utcnow().timestamp()) if current_time - token_time > max_age: return False # Verify signature data = f"{token_user_id}:{token_session_id}:{timestamp}" expected_signature = hmac.new( settings.SECRET_KEY.encode(), data.encode(), hashlib.sha256 ).hexdigest() return hmac.compare_digest(signature, expected_signature) except Exception as e: logger.warning("Failed to verify CSRF token", error=str(e)) return False def sanitize_input(input_data: str, max_length: int = 1000) -> str: """ Sanitize user input to prevent XSS and injection attacks. Args: input_data: Input string to sanitize max_length: Maximum allowed length Returns: str: Sanitized input """ if not input_data: return "" # Truncate if too long if len(input_data) > max_length: input_data = input_data[:max_length] # Remove/escape dangerous characters dangerous_chars = ['<', '>', '"', "'", '&', '\x00', '\r', '\n'] for char in dangerous_chars: if char in input_data: input_data = input_data.replace(char, '') # Strip whitespace return input_data.strip() def check_permission(user_permissions: List[str], required_permission: str) -> bool: """ Check if user has required permission. Args: user_permissions: List of user permissions required_permission: Required permission string Returns: bool: True if user has permission """ # Admin has all permissions if 'admin' in user_permissions: return True # Check exact permission if required_permission in user_permissions: return True # Check wildcard permissions permission_parts = required_permission.split('.') for i in range(len(permission_parts)): wildcard_perm = '.'.join(permission_parts[:i+1]) + '.*' if wildcard_perm in user_permissions: return True return False def rate_limit_key(identifier: str, action: str, window: str = "default") -> str: """ Generate rate limiting key. Args: identifier: User/IP identifier action: Action being rate limited window: Time window identifier Returns: str: Rate limit cache key """ key_data = f"rate_limit:{action}:{window}:{identifier}" return hashlib.sha256(key_data.encode()).hexdigest() def generate_otp(length: int = 6) -> str: """ Generate one-time password. Args: length: Length of OTP Returns: str: Numeric OTP """ return ''.join(secrets.choice('0123456789') for _ in range(length)) def constant_time_compare(a: str, b: str) -> bool: """ Constant time string comparison to prevent timing attacks. Args: a: First string b: Second string Returns: bool: True if strings are equal """ return hmac.compare_digest(a.encode('utf-8'), b.encode('utf-8')) # --- Added for optional auth compatibility --- from typing import Optional try: # If get_current_user already exists in this module, import it from app.core.security import get_current_user # type: ignore except Exception: # Fallback stub in case the project structure differs; will only be used if referenced directly def get_current_user(): raise RuntimeError("get_current_user is not available") def get_current_user_optional() -> Optional[object]: """ Return current user if authenticated, otherwise None. Designed to be used in dependencies for routes that allow anonymous access. """ try: return get_current_user() # type: ignore except Exception: return None # --- End added block ---