uploader-bot/app/core/security.py

571 lines
16 KiB
Python

"""
Comprehensive security module with encryption, JWT tokens, password hashing, and access control.
Provides secure file encryption, token management, and authentication utilities.
"""
import hashlib
import hmac
import secrets
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Union
from uuid import UUID
import bcrypt
import jwt
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
import base64
from app.core.config import get_settings
from app.core.logging import get_logger
logger = get_logger(__name__)
settings = get_settings()
class SecurityManager:
"""Main security manager for encryption, tokens, and authentication."""
def __init__(self):
self.fernet_key = self._get_or_create_fernet_key()
self.fernet = Fernet(self.fernet_key)
def _get_or_create_fernet_key(self) -> bytes:
"""Get or create Fernet encryption key from settings."""
if hasattr(settings, 'ENCRYPTION_KEY') and settings.ENCRYPTION_KEY:
# Derive key from settings
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=32,
salt=settings.SECRET_KEY.encode()[:16],
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(settings.ENCRYPTION_KEY.encode()))
return key
else:
# Generate random key (for development only)
return Fernet.generate_key()
# Global security manager instance
_security_manager = SecurityManager()
def hash_password(password: str) -> str:
"""
Hash password using bcrypt with salt.
Args:
password: Plain text password
Returns:
str: Hashed password
"""
try:
salt = bcrypt.gensalt(rounds=12)
hashed = bcrypt.hashpw(password.encode('utf-8'), salt)
return hashed.decode('utf-8')
except Exception as e:
logger.error("Failed to hash password", error=str(e))
raise
def verify_password(password: str, hashed_password: str) -> bool:
"""
Verify password against hash.
Args:
password: Plain text password
hashed_password: Bcrypt hashed password
Returns:
bool: True if password matches
"""
try:
return bcrypt.checkpw(password.encode('utf-8'), hashed_password.encode('utf-8'))
except Exception as e:
logger.error("Failed to verify password", error=str(e))
return False
def generate_access_token(
payload: Dict[str, Any],
expires_in: int = 3600,
token_type: str = "access"
) -> str:
"""
Generate JWT access token.
Args:
payload: Token payload data
expires_in: Token expiration time in seconds
token_type: Type of token (access, refresh, api)
Returns:
str: JWT token
"""
try:
now = datetime.utcnow()
token_payload = {
"iat": now,
"exp": now + timedelta(seconds=expires_in),
"type": token_type,
"jti": secrets.token_urlsafe(16), # Unique token ID
**payload
}
token = jwt.encode(
token_payload,
settings.SECRET_KEY,
algorithm="HS256"
)
logger.debug(
"Access token generated",
token_type=token_type,
expires_in=expires_in,
user_id=payload.get("user_id")
)
return token
except Exception as e:
logger.error("Failed to generate access token", error=str(e))
raise
def verify_access_token(token: str, token_type: str = "access") -> Optional[Dict[str, Any]]:
"""
Verify and decode JWT token.
Args:
token: JWT token string
token_type: Expected token type
Returns:
Optional[Dict]: Decoded payload or None if invalid
"""
try:
payload = jwt.decode(
token,
settings.SECRET_KEY,
algorithms=["HS256"]
)
# Verify token type
if payload.get("type") != token_type:
logger.warning("Token type mismatch", expected=token_type, actual=payload.get("type"))
return None
# Check expiration
if datetime.utcnow() > datetime.fromtimestamp(payload["exp"]):
logger.warning("Token expired", exp=payload["exp"])
return None
return payload
except jwt.ExpiredSignatureError:
logger.warning("Token expired")
return None
except jwt.InvalidTokenError as e:
logger.warning("Invalid token", error=str(e))
return None
except Exception as e:
logger.error("Failed to verify token", error=str(e))
return None
def generate_refresh_token(user_id: UUID, device_id: Optional[str] = None) -> str:
"""
Generate long-lived refresh token.
Args:
user_id: User UUID
device_id: Optional device identifier
Returns:
str: Refresh token
"""
payload = {
"user_id": str(user_id),
"device_id": device_id,
"token_family": secrets.token_urlsafe(16) # For token rotation
}
return generate_access_token(
payload,
expires_in=settings.REFRESH_TOKEN_EXPIRE_DAYS * 24 * 3600,
token_type="refresh"
)
def generate_api_key(
user_id: UUID,
permissions: List[str],
name: str,
expires_in: Optional[int] = None
) -> str:
"""
Generate API key with specific permissions.
Args:
user_id: User UUID
permissions: List of permissions
name: API key name
expires_in: Optional expiration time in seconds
Returns:
str: API key token
"""
payload = {
"user_id": str(user_id),
"permissions": permissions,
"name": name,
"key_id": secrets.token_urlsafe(16)
}
expires = expires_in or (365 * 24 * 3600) # Default 1 year
return generate_access_token(payload, expires_in=expires, token_type="api")
def encrypt_data(data: Union[str, bytes], context: str = "") -> str:
"""
Encrypt data using Fernet symmetric encryption.
Args:
data: Data to encrypt
context: Optional context for additional security
Returns:
str: Base64 encoded encrypted data
"""
try:
if isinstance(data, str):
data = data.encode('utf-8')
# Add context to data for additional security
if context:
data = f"{context}:{len(data)}:".encode('utf-8') + data
encrypted = _security_manager.fernet.encrypt(data)
return base64.urlsafe_b64encode(encrypted).decode('utf-8')
except Exception as e:
logger.error("Failed to encrypt data", error=str(e))
raise
def decrypt_data(encrypted_data: str, context: str = "") -> Union[str, bytes]:
"""
Decrypt data using Fernet symmetric encryption.
Args:
encrypted_data: Base64 encoded encrypted data
context: Optional context for verification
Returns:
Union[str, bytes]: Decrypted data
"""
try:
encrypted_bytes = base64.urlsafe_b64decode(encrypted_data.encode('utf-8'))
decrypted = _security_manager.fernet.decrypt(encrypted_bytes)
# Verify and remove context if provided
if context:
context_prefix = f"{context}:".encode('utf-8')
if not decrypted.startswith(context_prefix):
raise ValueError("Context mismatch during decryption")
# Extract length and data
remaining = decrypted[len(context_prefix):]
length_end = remaining.find(b':')
if length_end == -1:
raise ValueError("Invalid encrypted data format")
expected_length = int(remaining[:length_end].decode('utf-8'))
data = remaining[length_end + 1:]
if len(data) != expected_length:
raise ValueError("Data length mismatch")
return data
return decrypted
except Exception as e:
logger.error("Failed to decrypt data", error=str(e))
raise
def encrypt_file(file_data: bytes, file_id: str) -> bytes:
"""
Encrypt file data with file-specific context.
Args:
file_data: File bytes to encrypt
file_id: Unique file identifier
Returns:
bytes: Encrypted file data
"""
try:
encrypted_str = encrypt_data(file_data, context=f"file:{file_id}")
return base64.urlsafe_b64decode(encrypted_str.encode('utf-8'))
except Exception as e:
logger.error("Failed to encrypt file", file_id=file_id, error=str(e))
raise
def decrypt_file(encrypted_data: bytes, file_id: str) -> bytes:
"""
Decrypt file data with file-specific context.
Args:
encrypted_data: Encrypted file bytes
file_id: Unique file identifier
Returns:
bytes: Decrypted file data
"""
try:
encrypted_str = base64.urlsafe_b64encode(encrypted_data).decode('utf-8')
decrypted = decrypt_data(encrypted_str, context=f"file:{file_id}")
return decrypted if isinstance(decrypted, bytes) else decrypted.encode('utf-8')
except Exception as e:
logger.error("Failed to decrypt file", file_id=file_id, error=str(e))
raise
def generate_secure_filename(original_filename: str, user_id: UUID) -> str:
"""
Generate secure filename to prevent path traversal and collisions.
Args:
original_filename: Original filename
user_id: User UUID
Returns:
str: Secure filename
"""
# Extract extension
parts = original_filename.rsplit('.', 1)
extension = parts[1] if len(parts) > 1 else ''
# Generate secure base name
timestamp = datetime.utcnow().strftime('%Y%m%d_%H%M%S')
random_part = secrets.token_urlsafe(8)
user_hash = hashlib.sha256(str(user_id).encode()).hexdigest()[:8]
secure_name = f"{timestamp}_{user_hash}_{random_part}"
if extension:
# Validate extension
allowed_extensions = {
'txt', 'pdf', 'doc', 'docx', 'xls', 'xlsx', 'ppt', 'pptx',
'jpg', 'jpeg', 'png', 'gif', 'bmp', 'webp', 'svg',
'mp3', 'wav', 'flac', 'ogg', 'mp4', 'avi', 'mkv', 'webm',
'zip', 'rar', '7z', 'tar', 'gz', 'json', 'xml', 'csv'
}
clean_extension = extension.lower().strip()
if clean_extension in allowed_extensions:
secure_name += f".{clean_extension}"
return secure_name
def validate_file_signature(file_data: bytes, claimed_type: str) -> bool:
"""
Validate file signature against claimed MIME type.
Args:
file_data: File bytes to validate
claimed_type: Claimed MIME type
Returns:
bool: True if signature matches type
"""
if len(file_data) < 8:
return False
# File signatures (magic numbers)
signatures = {
'image/jpeg': [b'\xFF\xD8\xFF'],
'image/png': [b'\x89PNG\r\n\x1a\n'],
'image/gif': [b'GIF87a', b'GIF89a'],
'image/webp': [b'RIFF', b'WEBP'],
'application/pdf': [b'%PDF-'],
'application/zip': [b'PK\x03\x04', b'PK\x05\x06', b'PK\x07\x08'],
'audio/mpeg': [b'ID3', b'\xFF\xFB', b'\xFF\xF3', b'\xFF\xF2'],
'video/mp4': [b'\x00\x00\x00\x18ftypmp4', b'\x00\x00\x00\x20ftypmp4'],
'text/plain': [], # Text files don't have reliable signatures
}
expected_sigs = signatures.get(claimed_type, [])
# If no signatures defined, allow (like text files)
if not expected_sigs:
return True
# Check if file starts with any expected signature
file_start = file_data[:32] # Check first 32 bytes
for sig in expected_sigs:
if file_start.startswith(sig):
return True
return False
def generate_csrf_token(user_id: UUID, session_id: str) -> str:
"""
Generate CSRF token for form protection.
Args:
user_id: User UUID
session_id: Session identifier
Returns:
str: CSRF token
"""
timestamp = str(int(datetime.utcnow().timestamp()))
data = f"{user_id}:{session_id}:{timestamp}"
signature = hmac.new(
settings.SECRET_KEY.encode(),
data.encode(),
hashlib.sha256
).hexdigest()
token_data = f"{data}:{signature}"
return base64.urlsafe_b64encode(token_data.encode()).decode()
def verify_csrf_token(token: str, user_id: UUID, session_id: str, max_age: int = 3600) -> bool:
"""
Verify CSRF token.
Args:
token: CSRF token to verify
user_id: User UUID
session_id: Session identifier
max_age: Maximum token age in seconds
Returns:
bool: True if token is valid
"""
try:
token_data = base64.urlsafe_b64decode(token.encode()).decode()
parts = token_data.split(':')
if len(parts) != 4:
return False
token_user_id, token_session_id, timestamp, signature = parts
# Verify components
if token_user_id != str(user_id) or token_session_id != session_id:
return False
# Check age
token_time = int(timestamp)
current_time = int(datetime.utcnow().timestamp())
if current_time - token_time > max_age:
return False
# Verify signature
data = f"{token_user_id}:{token_session_id}:{timestamp}"
expected_signature = hmac.new(
settings.SECRET_KEY.encode(),
data.encode(),
hashlib.sha256
).hexdigest()
return hmac.compare_digest(signature, expected_signature)
except Exception as e:
logger.warning("Failed to verify CSRF token", error=str(e))
return False
def sanitize_input(input_data: str, max_length: int = 1000) -> str:
"""
Sanitize user input to prevent XSS and injection attacks.
Args:
input_data: Input string to sanitize
max_length: Maximum allowed length
Returns:
str: Sanitized input
"""
if not input_data:
return ""
# Truncate if too long
if len(input_data) > max_length:
input_data = input_data[:max_length]
# Remove/escape dangerous characters
dangerous_chars = ['<', '>', '"', "'", '&', '\x00', '\r', '\n']
for char in dangerous_chars:
if char in input_data:
input_data = input_data.replace(char, '')
# Strip whitespace
return input_data.strip()
def check_permission(user_permissions: List[str], required_permission: str) -> bool:
"""
Check if user has required permission.
Args:
user_permissions: List of user permissions
required_permission: Required permission string
Returns:
bool: True if user has permission
"""
# Admin has all permissions
if 'admin' in user_permissions:
return True
# Check exact permission
if required_permission in user_permissions:
return True
# Check wildcard permissions
permission_parts = required_permission.split('.')
for i in range(len(permission_parts)):
wildcard_perm = '.'.join(permission_parts[:i+1]) + '.*'
if wildcard_perm in user_permissions:
return True
return False
def rate_limit_key(identifier: str, action: str, window: str = "default") -> str:
"""
Generate rate limiting key.
Args:
identifier: User/IP identifier
action: Action being rate limited
window: Time window identifier
Returns:
str: Rate limit cache key
"""
key_data = f"rate_limit:{action}:{window}:{identifier}"
return hashlib.sha256(key_data.encode()).hexdigest()
def generate_otp(length: int = 6) -> str:
"""
Generate one-time password.
Args:
length: Length of OTP
Returns:
str: Numeric OTP
"""
return ''.join(secrets.choice('0123456789') for _ in range(length))
def constant_time_compare(a: str, b: str) -> bool:
"""
Constant time string comparison to prevent timing attacks.
Args:
a: First string
b: Second string
Returns:
bool: True if strings are equal
"""
return hmac.compare_digest(a.encode('utf-8'), b.encode('utf-8'))