598 lines
17 KiB
Python
598 lines
17 KiB
Python
"""
|
|
Comprehensive security module with encryption, JWT tokens, password hashing, and access control.
|
|
Provides secure file encryption, token management, and authentication utilities.
|
|
"""
|
|
|
|
import hashlib
|
|
import hmac
|
|
import secrets
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Any, Union
|
|
from uuid import UUID
|
|
|
|
import bcrypt
|
|
import jwt
|
|
from cryptography.fernet import Fernet
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
import base64
|
|
|
|
from app.core.config import get_settings
|
|
from app.core.logging import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
settings = get_settings()
|
|
|
|
class SecurityManager:
|
|
"""Main security manager for encryption, tokens, and authentication."""
|
|
|
|
def __init__(self):
|
|
self.fernet_key = self._get_or_create_fernet_key()
|
|
self.fernet = Fernet(self.fernet_key)
|
|
|
|
def _get_or_create_fernet_key(self) -> bytes:
|
|
"""Get or create Fernet encryption key from settings."""
|
|
if hasattr(settings, 'ENCRYPTION_KEY') and settings.ENCRYPTION_KEY:
|
|
# Derive key from settings
|
|
kdf = PBKDF2HMAC(
|
|
algorithm=hashes.SHA256(),
|
|
length=32,
|
|
salt=settings.SECRET_KEY.encode()[:16],
|
|
iterations=100000,
|
|
)
|
|
key = base64.urlsafe_b64encode(kdf.derive(settings.ENCRYPTION_KEY.encode()))
|
|
return key
|
|
else:
|
|
# Generate random key (for development only)
|
|
return Fernet.generate_key()
|
|
|
|
# Global security manager instance
|
|
_security_manager = SecurityManager()
|
|
|
|
def hash_password(password: str) -> str:
|
|
"""
|
|
Hash password using bcrypt with salt.
|
|
|
|
Args:
|
|
password: Plain text password
|
|
|
|
Returns:
|
|
str: Hashed password
|
|
"""
|
|
try:
|
|
salt = bcrypt.gensalt(rounds=12)
|
|
hashed = bcrypt.hashpw(password.encode('utf-8'), salt)
|
|
return hashed.decode('utf-8')
|
|
except Exception as e:
|
|
logger.error("Failed to hash password", error=str(e))
|
|
raise
|
|
|
|
def verify_password(password: str, hashed_password: str) -> bool:
|
|
"""
|
|
Verify password against hash.
|
|
|
|
Args:
|
|
password: Plain text password
|
|
hashed_password: Bcrypt hashed password
|
|
|
|
Returns:
|
|
bool: True if password matches
|
|
"""
|
|
try:
|
|
return bcrypt.checkpw(password.encode('utf-8'), hashed_password.encode('utf-8'))
|
|
except Exception as e:
|
|
logger.error("Failed to verify password", error=str(e))
|
|
return False
|
|
|
|
def generate_access_token(
|
|
payload: Dict[str, Any],
|
|
expires_in: int = 3600,
|
|
token_type: str = "access"
|
|
) -> str:
|
|
"""
|
|
Generate JWT access token.
|
|
|
|
Args:
|
|
payload: Token payload data
|
|
expires_in: Token expiration time in seconds
|
|
token_type: Type of token (access, refresh, api)
|
|
|
|
Returns:
|
|
str: JWT token
|
|
"""
|
|
try:
|
|
# Ensure expires_in is an integer
|
|
if isinstance(expires_in, str):
|
|
expires_in = int(expires_in)
|
|
elif not isinstance(expires_in, int):
|
|
expires_in = int(expires_in)
|
|
|
|
now = datetime.utcnow()
|
|
token_payload = {
|
|
"iat": now,
|
|
"exp": now + timedelta(seconds=expires_in),
|
|
"type": token_type,
|
|
"jti": secrets.token_urlsafe(16), # Unique token ID
|
|
**payload
|
|
}
|
|
|
|
token = jwt.encode(
|
|
token_payload,
|
|
settings.SECRET_KEY,
|
|
algorithm="HS256"
|
|
)
|
|
|
|
logger.debug(
|
|
"Access token generated",
|
|
token_type=token_type,
|
|
expires_in=expires_in,
|
|
user_id=payload.get("user_id")
|
|
)
|
|
|
|
return token
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to generate access token", error=str(e))
|
|
raise
|
|
|
|
def verify_access_token(token: str, token_type: str = "access") -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Verify and decode JWT token.
|
|
|
|
Args:
|
|
token: JWT token string
|
|
token_type: Expected token type
|
|
|
|
Returns:
|
|
Optional[Dict]: Decoded payload or None if invalid
|
|
"""
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
settings.SECRET_KEY,
|
|
algorithms=["HS256"]
|
|
)
|
|
|
|
# Verify token type
|
|
if payload.get("type") != token_type:
|
|
logger.warning("Token type mismatch", expected=token_type, actual=payload.get("type"))
|
|
return None
|
|
|
|
# Check expiration
|
|
if datetime.utcnow() > datetime.fromtimestamp(payload["exp"]):
|
|
logger.warning("Token expired", exp=payload["exp"])
|
|
return None
|
|
|
|
return payload
|
|
|
|
except jwt.ExpiredSignatureError:
|
|
logger.warning("Token expired")
|
|
return None
|
|
except jwt.InvalidTokenError as e:
|
|
logger.warning("Invalid token", error=str(e))
|
|
return None
|
|
except Exception as e:
|
|
logger.error("Failed to verify token", error=str(e))
|
|
return None
|
|
|
|
def generate_refresh_token(user_id: UUID, device_id: Optional[str] = None) -> str:
|
|
"""
|
|
Generate long-lived refresh token.
|
|
|
|
Args:
|
|
user_id: User UUID
|
|
device_id: Optional device identifier
|
|
|
|
Returns:
|
|
str: Refresh token
|
|
"""
|
|
payload = {
|
|
"user_id": str(user_id),
|
|
"device_id": device_id,
|
|
"token_family": secrets.token_urlsafe(16) # For token rotation
|
|
}
|
|
|
|
return generate_access_token(
|
|
payload,
|
|
expires_in=settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 3600,
|
|
token_type="refresh"
|
|
)
|
|
|
|
def generate_api_key(
|
|
user_id: UUID,
|
|
permissions: List[str],
|
|
name: str,
|
|
expires_in: Optional[int] = None
|
|
) -> str:
|
|
"""
|
|
Generate API key with specific permissions.
|
|
|
|
Args:
|
|
user_id: User UUID
|
|
permissions: List of permissions
|
|
name: API key name
|
|
expires_in: Optional expiration time in seconds
|
|
|
|
Returns:
|
|
str: API key token
|
|
"""
|
|
payload = {
|
|
"user_id": str(user_id),
|
|
"permissions": permissions,
|
|
"name": name,
|
|
"key_id": secrets.token_urlsafe(16)
|
|
}
|
|
|
|
expires = expires_in or (365 * 24 * 3600) # Default 1 year
|
|
|
|
return generate_access_token(payload, expires_in=expires, token_type="api")
|
|
|
|
def encrypt_data(data: Union[str, bytes], context: str = "") -> str:
|
|
"""
|
|
Encrypt data using Fernet symmetric encryption.
|
|
|
|
Args:
|
|
data: Data to encrypt
|
|
context: Optional context for additional security
|
|
|
|
Returns:
|
|
str: Base64 encoded encrypted data
|
|
"""
|
|
try:
|
|
if isinstance(data, str):
|
|
data = data.encode('utf-8')
|
|
|
|
# Add context to data for additional security
|
|
if context:
|
|
data = f"{context}:{len(data)}:".encode('utf-8') + data
|
|
|
|
encrypted = _security_manager.fernet.encrypt(data)
|
|
return base64.urlsafe_b64encode(encrypted).decode('utf-8')
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to encrypt data", error=str(e))
|
|
raise
|
|
|
|
def decrypt_data(encrypted_data: str, context: str = "") -> Union[str, bytes]:
|
|
"""
|
|
Decrypt data using Fernet symmetric encryption.
|
|
|
|
Args:
|
|
encrypted_data: Base64 encoded encrypted data
|
|
context: Optional context for verification
|
|
|
|
Returns:
|
|
Union[str, bytes]: Decrypted data
|
|
"""
|
|
try:
|
|
encrypted_bytes = base64.urlsafe_b64decode(encrypted_data.encode('utf-8'))
|
|
decrypted = _security_manager.fernet.decrypt(encrypted_bytes)
|
|
|
|
# Verify and remove context if provided
|
|
if context:
|
|
context_prefix = f"{context}:".encode('utf-8')
|
|
if not decrypted.startswith(context_prefix):
|
|
raise ValueError("Context mismatch during decryption")
|
|
|
|
# Extract length and data
|
|
remaining = decrypted[len(context_prefix):]
|
|
length_end = remaining.find(b':')
|
|
if length_end == -1:
|
|
raise ValueError("Invalid encrypted data format")
|
|
|
|
expected_length = int(remaining[:length_end].decode('utf-8'))
|
|
data = remaining[length_end + 1:]
|
|
|
|
if len(data) != expected_length:
|
|
raise ValueError("Data length mismatch")
|
|
|
|
return data
|
|
|
|
return decrypted
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to decrypt data", error=str(e))
|
|
raise
|
|
|
|
def encrypt_file(file_data: bytes, file_id: str) -> bytes:
|
|
"""
|
|
Encrypt file data with file-specific context.
|
|
|
|
Args:
|
|
file_data: File bytes to encrypt
|
|
file_id: Unique file identifier
|
|
|
|
Returns:
|
|
bytes: Encrypted file data
|
|
"""
|
|
try:
|
|
encrypted_str = encrypt_data(file_data, context=f"file:{file_id}")
|
|
return base64.urlsafe_b64decode(encrypted_str.encode('utf-8'))
|
|
except Exception as e:
|
|
logger.error("Failed to encrypt file", file_id=file_id, error=str(e))
|
|
raise
|
|
|
|
def decrypt_file(encrypted_data: bytes, file_id: str) -> bytes:
|
|
"""
|
|
Decrypt file data with file-specific context.
|
|
|
|
Args:
|
|
encrypted_data: Encrypted file bytes
|
|
file_id: Unique file identifier
|
|
|
|
Returns:
|
|
bytes: Decrypted file data
|
|
"""
|
|
try:
|
|
encrypted_str = base64.urlsafe_b64encode(encrypted_data).decode('utf-8')
|
|
decrypted = decrypt_data(encrypted_str, context=f"file:{file_id}")
|
|
return decrypted if isinstance(decrypted, bytes) else decrypted.encode('utf-8')
|
|
except Exception as e:
|
|
logger.error("Failed to decrypt file", file_id=file_id, error=str(e))
|
|
raise
|
|
|
|
def generate_secure_filename(original_filename: str, user_id: UUID) -> str:
|
|
"""
|
|
Generate secure filename to prevent path traversal and collisions.
|
|
|
|
Args:
|
|
original_filename: Original filename
|
|
user_id: User UUID
|
|
|
|
Returns:
|
|
str: Secure filename
|
|
"""
|
|
# Extract extension
|
|
parts = original_filename.rsplit('.', 1)
|
|
extension = parts[1] if len(parts) > 1 else ''
|
|
|
|
# Generate secure base name
|
|
timestamp = datetime.utcnow().strftime('%Y%m%d_%H%M%S')
|
|
random_part = secrets.token_urlsafe(8)
|
|
user_hash = hashlib.sha256(str(user_id).encode()).hexdigest()[:8]
|
|
|
|
secure_name = f"{timestamp}_{user_hash}_{random_part}"
|
|
|
|
if extension:
|
|
# Validate extension
|
|
allowed_extensions = {
|
|
'txt', 'pdf', 'doc', 'docx', 'xls', 'xlsx', 'ppt', 'pptx',
|
|
'jpg', 'jpeg', 'png', 'gif', 'bmp', 'webp', 'svg',
|
|
'mp3', 'wav', 'flac', 'ogg', 'mp4', 'avi', 'mkv', 'webm',
|
|
'zip', 'rar', '7z', 'tar', 'gz', 'json', 'xml', 'csv'
|
|
}
|
|
|
|
clean_extension = extension.lower().strip()
|
|
if clean_extension in allowed_extensions:
|
|
secure_name += f".{clean_extension}"
|
|
|
|
return secure_name
|
|
|
|
def validate_file_signature(file_data: bytes, claimed_type: str) -> bool:
|
|
"""
|
|
Validate file signature against claimed MIME type.
|
|
|
|
Args:
|
|
file_data: File bytes to validate
|
|
claimed_type: Claimed MIME type
|
|
|
|
Returns:
|
|
bool: True if signature matches type
|
|
"""
|
|
if len(file_data) < 8:
|
|
return False
|
|
|
|
# File signatures (magic numbers)
|
|
signatures = {
|
|
'image/jpeg': [b'\xFF\xD8\xFF'],
|
|
'image/png': [b'\x89PNG\r\n\x1a\n'],
|
|
'image/gif': [b'GIF87a', b'GIF89a'],
|
|
'image/webp': [b'RIFF', b'WEBP'],
|
|
'application/pdf': [b'%PDF-'],
|
|
'application/zip': [b'PK\x03\x04', b'PK\x05\x06', b'PK\x07\x08'],
|
|
'audio/mpeg': [b'ID3', b'\xFF\xFB', b'\xFF\xF3', b'\xFF\xF2'],
|
|
'video/mp4': [b'\x00\x00\x00\x18ftypmp4', b'\x00\x00\x00\x20ftypmp4'],
|
|
'text/plain': [], # Text files don't have reliable signatures
|
|
}
|
|
|
|
expected_sigs = signatures.get(claimed_type, [])
|
|
|
|
# If no signatures defined, allow (like text files)
|
|
if not expected_sigs:
|
|
return True
|
|
|
|
# Check if file starts with any expected signature
|
|
file_start = file_data[:32] # Check first 32 bytes
|
|
|
|
for sig in expected_sigs:
|
|
if file_start.startswith(sig):
|
|
return True
|
|
|
|
return False
|
|
|
|
def generate_csrf_token(user_id: UUID, session_id: str) -> str:
|
|
"""
|
|
Generate CSRF token for form protection.
|
|
|
|
Args:
|
|
user_id: User UUID
|
|
session_id: Session identifier
|
|
|
|
Returns:
|
|
str: CSRF token
|
|
"""
|
|
timestamp = str(int(datetime.utcnow().timestamp()))
|
|
data = f"{user_id}:{session_id}:{timestamp}"
|
|
|
|
signature = hmac.new(
|
|
settings.SECRET_KEY.encode(),
|
|
data.encode(),
|
|
hashlib.sha256
|
|
).hexdigest()
|
|
|
|
token_data = f"{data}:{signature}"
|
|
return base64.urlsafe_b64encode(token_data.encode()).decode()
|
|
|
|
def verify_csrf_token(token: str, user_id: UUID, session_id: str, max_age: int = 3600) -> bool:
|
|
"""
|
|
Verify CSRF token.
|
|
|
|
Args:
|
|
token: CSRF token to verify
|
|
user_id: User UUID
|
|
session_id: Session identifier
|
|
max_age: Maximum token age in seconds
|
|
|
|
Returns:
|
|
bool: True if token is valid
|
|
"""
|
|
try:
|
|
token_data = base64.urlsafe_b64decode(token.encode()).decode()
|
|
parts = token_data.split(':')
|
|
|
|
if len(parts) != 4:
|
|
return False
|
|
|
|
token_user_id, token_session_id, timestamp, signature = parts
|
|
|
|
# Verify components
|
|
if token_user_id != str(user_id) or token_session_id != session_id:
|
|
return False
|
|
|
|
# Check age
|
|
token_time = int(timestamp)
|
|
current_time = int(datetime.utcnow().timestamp())
|
|
|
|
if current_time - token_time > max_age:
|
|
return False
|
|
|
|
# Verify signature
|
|
data = f"{token_user_id}:{token_session_id}:{timestamp}"
|
|
expected_signature = hmac.new(
|
|
settings.SECRET_KEY.encode(),
|
|
data.encode(),
|
|
hashlib.sha256
|
|
).hexdigest()
|
|
|
|
return hmac.compare_digest(signature, expected_signature)
|
|
|
|
except Exception as e:
|
|
logger.warning("Failed to verify CSRF token", error=str(e))
|
|
return False
|
|
|
|
def sanitize_input(input_data: str, max_length: int = 1000) -> str:
|
|
"""
|
|
Sanitize user input to prevent XSS and injection attacks.
|
|
|
|
Args:
|
|
input_data: Input string to sanitize
|
|
max_length: Maximum allowed length
|
|
|
|
Returns:
|
|
str: Sanitized input
|
|
"""
|
|
if not input_data:
|
|
return ""
|
|
|
|
# Truncate if too long
|
|
if len(input_data) > max_length:
|
|
input_data = input_data[:max_length]
|
|
|
|
# Remove/escape dangerous characters
|
|
dangerous_chars = ['<', '>', '"', "'", '&', '\x00', '\r', '\n']
|
|
|
|
for char in dangerous_chars:
|
|
if char in input_data:
|
|
input_data = input_data.replace(char, '')
|
|
|
|
# Strip whitespace
|
|
return input_data.strip()
|
|
|
|
def check_permission(user_permissions: List[str], required_permission: str) -> bool:
|
|
"""
|
|
Check if user has required permission.
|
|
|
|
Args:
|
|
user_permissions: List of user permissions
|
|
required_permission: Required permission string
|
|
|
|
Returns:
|
|
bool: True if user has permission
|
|
"""
|
|
# Admin has all permissions
|
|
if 'admin' in user_permissions:
|
|
return True
|
|
|
|
# Check exact permission
|
|
if required_permission in user_permissions:
|
|
return True
|
|
|
|
# Check wildcard permissions
|
|
permission_parts = required_permission.split('.')
|
|
for i in range(len(permission_parts)):
|
|
wildcard_perm = '.'.join(permission_parts[:i+1]) + '.*'
|
|
if wildcard_perm in user_permissions:
|
|
return True
|
|
|
|
return False
|
|
|
|
def rate_limit_key(identifier: str, action: str, window: str = "default") -> str:
|
|
"""
|
|
Generate rate limiting key.
|
|
|
|
Args:
|
|
identifier: User/IP identifier
|
|
action: Action being rate limited
|
|
window: Time window identifier
|
|
|
|
Returns:
|
|
str: Rate limit cache key
|
|
"""
|
|
key_data = f"rate_limit:{action}:{window}:{identifier}"
|
|
return hashlib.sha256(key_data.encode()).hexdigest()
|
|
|
|
def generate_otp(length: int = 6) -> str:
|
|
"""
|
|
Generate one-time password.
|
|
|
|
Args:
|
|
length: Length of OTP
|
|
|
|
Returns:
|
|
str: Numeric OTP
|
|
"""
|
|
return ''.join(secrets.choice('0123456789') for _ in range(length))
|
|
|
|
def constant_time_compare(a: str, b: str) -> bool:
|
|
"""
|
|
Constant time string comparison to prevent timing attacks.
|
|
|
|
Args:
|
|
a: First string
|
|
b: Second string
|
|
|
|
Returns:
|
|
bool: True if strings are equal
|
|
"""
|
|
return hmac.compare_digest(a.encode('utf-8'), b.encode('utf-8'))
|
|
# --- Added for optional auth compatibility ---
|
|
from typing import Optional
|
|
|
|
try:
|
|
# If get_current_user already exists in this module, import it
|
|
from app.core.security import get_current_user # type: ignore
|
|
except Exception:
|
|
# Fallback stub in case the project structure differs; will only be used if referenced directly
|
|
def get_current_user():
|
|
raise RuntimeError("get_current_user is not available")
|
|
|
|
def get_current_user_optional() -> Optional[object]:
|
|
"""
|
|
Return current user if authenticated, otherwise None.
|
|
Designed to be used in dependencies for routes that allow anonymous access.
|
|
"""
|
|
try:
|
|
return get_current_user() # type: ignore
|
|
except Exception:
|
|
return None
|
|
# --- End added block --- |