""" Enhanced API middleware with security, rate limiting, and monitoring """ import asyncio import time import uuid from datetime import datetime, timedelta from typing import Optional, Dict, Any, Callable import json from sanic import Request, HTTPResponse from sanic.response import json as json_response, text as text_response from sanic.exceptions import Unauthorized, Forbidden, BadRequest # TooManyRequests может не существовать в этой версии Sanic, создадим собственное class TooManyRequests(Exception): """Custom exception for rate limiting""" pass import structlog from app.core.config import settings, SecurityConfig, CACHE_KEYS from app.core.database import get_cache from app.core.logging import request_id_var, user_id_var, operation_var, log_performance from app.core.models.user import User from app.core.models.base import BaseModel logger = structlog.get_logger(__name__) class SecurityMiddleware: """Security middleware for request validation and protection""" @staticmethod def add_security_headers(response: HTTPResponse) -> HTTPResponse: """Add security headers to response""" # CORS headers response.headers.update({ "Access-Control-Allow-Origin": "*", # Will be restricted based on request "Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS", "Access-Control-Allow-Headers": ( "Origin, Content-Type, Accept, Authorization, " "X-Requested-With, X-API-Key, X-Request-ID" ), "Access-Control-Max-Age": "86400", # Security headers "X-Content-Type-Options": "nosniff", "X-Frame-Options": "DENY", "X-XSS-Protection": "1; mode=block", "Strict-Transport-Security": "max-age=31536000; includeSubDomains", "Referrer-Policy": "strict-origin-when-cross-origin", "Permissions-Policy": "geolocation=(), microphone=(), camera=()", # Custom headers "X-API-Version": settings.PROJECT_VERSION, "X-Request-ID": getattr(getattr(Request, 'ctx', None), 'request_id', 'unknown') }) # CSP header csp_directives = "; ".join([ f"{directive} {' '.join(sources)}" for directive, sources in SecurityConfig.CSP_DIRECTIVES.items() ]) response.headers["Content-Security-Policy"] = csp_directives return response @staticmethod def validate_request_size(request: Request) -> None: """Validate request size limits""" content_length = request.headers.get('content-length') if content_length: size = int(content_length) if size > SecurityConfig.MAX_REQUEST_SIZE: raise BadRequest(f"Request too large: {size} bytes") @staticmethod def validate_content_type(request: Request) -> None: """Validate content type for JSON requests""" if request.method in ['POST', 'PUT', 'PATCH']: content_type = request.headers.get('content-type', '') if 'application/json' in content_type: try: # Validate JSON size if hasattr(request, 'body') and len(request.body) > SecurityConfig.MAX_JSON_SIZE: raise BadRequest("JSON payload too large") except Exception: raise BadRequest("Invalid JSON payload") @staticmethod def check_origin(request: Request) -> bool: """Check if request origin is allowed""" origin = request.headers.get('origin') if not origin: return True # Allow requests without origin (direct API calls) return any( origin.startswith(allowed_origin.rstrip('/*')) for allowed_origin in SecurityConfig.CORS_ORIGINS ) class RateLimitMiddleware: """Rate limiting middleware using Redis""" def __init__(self): self.cache = None async def get_cache(self): """Get cache instance""" if not self.cache: self.cache = await get_cache() return self.cache async def check_rate_limit( self, request: Request, identifier: str, pattern: str = "api" ) -> bool: """Check rate limit for identifier""" try: cache = await self.get_cache() limits = SecurityConfig.RATE_LIMIT_PATTERNS.get(pattern, { "requests": settings.RATE_LIMIT_REQUESTS, "window": settings.RATE_LIMIT_WINDOW }) cache_key = CACHE_KEYS["rate_limit"].format( pattern=pattern, identifier=identifier ) # Get current count current_count = await cache.get(cache_key) if current_count is None: # First request in window await cache.set(cache_key, "1", ttl=limits["window"]) return True current_count = int(current_count) if current_count >= limits["requests"]: # Rate limit exceeded logger.warning( "Rate limit exceeded", identifier=identifier, pattern=pattern, count=current_count, limit=limits["requests"] ) return False # Increment counter await cache.incr(cache_key) return True except Exception as e: logger.error("Rate limit check failed", error=str(e)) return True # Allow request if rate limiting fails async def get_rate_limit_info( self, identifier: str, pattern: str = "api" ) -> Dict[str, Any]: """Get rate limit information""" try: cache = await self.get_cache() limits = SecurityConfig.RATE_LIMIT_PATTERNS.get(pattern, { "requests": settings.RATE_LIMIT_REQUESTS, "window": settings.RATE_LIMIT_WINDOW }) cache_key = CACHE_KEYS["rate_limit"].format( pattern=pattern, identifier=identifier ) current_count = await cache.get(cache_key) or "0" ttl = await cache.redis.ttl(cache_key) return { "limit": limits["requests"], "remaining": max(0, limits["requests"] - int(current_count)), "reset_time": int(time.time()) + max(0, ttl), "window": limits["window"] } except Exception as e: logger.error("Failed to get rate limit info", error=str(e)) return {} class AuthenticationMiddleware: """Authentication middleware for API access""" @staticmethod async def extract_token(request: Request) -> Optional[str]: """Extract authentication token from request""" # Check Authorization header auth_header = request.headers.get('authorization') if auth_header and auth_header.startswith('Bearer '): return auth_header[7:] # Remove 'Bearer ' prefix # Check X-API-Key header api_key = request.headers.get('x-api-key') if api_key: return api_key # Check query parameter (less secure, for backward compatibility) return request.args.get('token') @staticmethod async def validate_token(token: str, session) -> Optional[User]: """Validate authentication token and return user""" if not token: return None try: # For now, implement simple token validation # In production, implement JWT or database token validation # Example: if token format is user_id:hash if ':' in token: user_id_str, token_hash = token.split(':', 1) try: user_id = uuid.UUID(user_id_str) user = await User.get_by_id(session, user_id) if user and user.verify_token(token_hash): # Implement in User model return user except (ValueError, AttributeError): pass # Fallback: try to find user by API token # This would require implementing token storage in User model return None except Exception as e: logger.error("Token validation failed", token=token[:8] + "...", error=str(e)) return None @staticmethod async def check_permissions(user: User, request: Request) -> bool: """Check if user has required permissions for the endpoint""" # Implement permission checking based on endpoint and user role endpoint = request.path method = request.method # Admin endpoints if '/admin/' in endpoint: return user.is_admin # Moderator endpoints if '/mod/' in endpoint: return user.is_moderator # User-specific endpoints if '/user/' in endpoint and method in ['POST', 'PUT', 'DELETE']: return user.has_permission('user:write') # Content upload endpoints if '/upload' in endpoint or '/content' in endpoint and method == 'POST': return user.can_upload_content() # Default: allow read access for authenticated users return True class RequestContextMiddleware: """Request context middleware for tracking and logging""" @staticmethod def generate_request_id() -> str: """Generate unique request ID""" return str(uuid.uuid4()) @staticmethod async def add_request_context(request: Request) -> None: """Add request context for logging and tracking""" # Generate and set request ID request_id = RequestContextMiddleware.generate_request_id() request.ctx.request_id = request_id request_id_var.set(request_id) # Set request start time request.ctx.start_time = time.time() # Extract client information request.ctx.client_ip = RequestContextMiddleware.get_client_ip(request) request.ctx.user_agent = request.headers.get('user-agent', 'Unknown') # Initialize context request.ctx.user = None request.ctx.rate_limit_info = {} logger.info( "Request started", method=request.method, path=request.path, client_ip=request.ctx.client_ip, user_agent=request.ctx.user_agent ) @staticmethod def get_client_ip(request: Request) -> str: """Get real client IP address""" # Check for forwarded headers forwarded_for = request.headers.get('x-forwarded-for') if forwarded_for: return forwarded_for.split(',')[0].strip() real_ip = request.headers.get('x-real-ip') if real_ip: return real_ip # Fallback to request IP return getattr(request, 'ip', '127.0.0.1') @staticmethod async def log_request_completion(request: Request, response: HTTPResponse) -> None: """Log request completion with metrics""" duration = time.time() - getattr(request.ctx, 'start_time', time.time()) logger.info( "Request completed", method=request.method, path=request.path, status_code=response.status, duration_ms=round(duration * 1000, 2), response_size=len(response.body) if response.body else 0, client_ip=getattr(request.ctx, 'client_ip', 'unknown'), user_id=str(request.ctx.user.id) if request.ctx.user else None ) # Initialize middleware instances security_middleware = SecurityMiddleware() rate_limit_middleware = RateLimitMiddleware() auth_middleware = AuthenticationMiddleware() context_middleware = RequestContextMiddleware() async def request_middleware(request: Request): """Main request middleware pipeline""" # Handle OPTIONS requests for CORS if request.method == 'OPTIONS': response = text_response('OK') return security_middleware.add_security_headers(response) # Add request context await context_middleware.add_request_context(request) # Security validations try: security_middleware.validate_request_size(request) security_middleware.validate_content_type(request) if not security_middleware.check_origin(request): raise Forbidden("Origin not allowed") except Exception as e: logger.warning("Security validation failed", error=str(e)) response = json_response({"error": str(e)}, status=400) return security_middleware.add_security_headers(response) # Rate limiting if settings.RATE_LIMIT_ENABLED: client_identifier = context_middleware.get_client_ip(request) pattern = "api" # Determine rate limit pattern based on endpoint if '/auth/' in request.path: pattern = "auth" elif '/upload' in request.path: pattern = "upload" elif '/admin/' in request.path: pattern = "heavy" if not await rate_limit_middleware.check_rate_limit(request, client_identifier, pattern): rate_info = await rate_limit_middleware.get_rate_limit_info(client_identifier, pattern) response = json_response( { "error": "Rate limit exceeded", "rate_limit": rate_info }, status=429 ) return security_middleware.add_security_headers(response) # Store rate limit info for response headers request.ctx.rate_limit_info = await rate_limit_middleware.get_rate_limit_info( client_identifier, pattern ) # Authentication (for protected endpoints) if not request.path.startswith('/api/system') and request.path != '/': from app.core.database import db_manager async with db_manager.get_session() as session: token = await auth_middleware.extract_token(request) if token: user = await auth_middleware.validate_token(token, session) if user: request.ctx.user = user user_id_var.set(str(user.id)) # Check permissions if not await auth_middleware.check_permissions(user, request): response = json_response({"error": "Insufficient permissions"}, status=403) return security_middleware.add_security_headers(response) # Update user activity user.update_activity() await session.commit() # Store session for request handlers request.ctx.db_session = session async def response_middleware(request: Request, response: HTTPResponse): """Main response middleware pipeline""" # Add security headers response = security_middleware.add_security_headers(response) # Add rate limit headers if hasattr(request.ctx, 'rate_limit_info') and request.ctx.rate_limit_info: rate_info = request.ctx.rate_limit_info response.headers.update({ "X-RateLimit-Limit": str(rate_info.get('limit', 0)), "X-RateLimit-Remaining": str(rate_info.get('remaining', 0)), "X-RateLimit-Reset": str(rate_info.get('reset_time', 0)) }) # Add request ID to response if hasattr(request.ctx, 'request_id'): response.headers["X-Request-ID"] = request.ctx.request_id # Log request completion await context_middleware.log_request_completion(request, response) return response async def exception_middleware(request: Request, exception: Exception): """Global exception handling middleware""" error_id = str(uuid.uuid4()) # Log the exception logger.error( "Unhandled exception", error_id=error_id, exception_type=type(exception).__name__, exception_message=str(exception), path=request.path, method=request.method, user_id=str(request.ctx.user.id) if hasattr(request.ctx, 'user') and request.ctx.user else None ) # Handle different exception types if isinstance(exception, Unauthorized): response_data = {"error": "Authentication required", "error_id": error_id} status = 401 elif isinstance(exception, Forbidden): response_data = {"error": "Access forbidden", "error_id": error_id} status = 403 elif isinstance(exception, TooManyRequests): response_data = {"error": "Rate limit exceeded", "error_id": error_id} status = 429 elif isinstance(exception, BadRequest): response_data = {"error": str(exception), "error_id": error_id} status = 400 else: # Generic server error response_data = { "error": "Internal server error", "error_id": error_id } status = 500 if settings.DEBUG: response_data["debug"] = { "type": type(exception).__name__, "message": str(exception) } response = json_response(response_data, status=status) return security_middleware.add_security_headers(response) # Maintenance mode middleware async def maintenance_middleware(request: Request): """Check for maintenance mode""" if settings.MAINTENANCE_MODE and not request.path.startswith('/api/system'): response = json_response({ "error": "Service temporarily unavailable", "message": settings.MAINTENANCE_MESSAGE }, status=503) return security_middleware.add_security_headers(response) # Helper functions for route decorators async def check_auth(request: Request) -> User: """Check authentication for endpoint""" if not hasattr(request.ctx, 'user') or not request.ctx.user: raise Unauthorized("Authentication required") return request.ctx.user async def validate_request_data(request: Request, schema: Optional[Any] = None) -> Dict[str, Any]: """Validate request data against schema""" try: if request.method in ['POST', 'PUT', 'PATCH']: # Get JSON data if hasattr(request, 'json') and request.json: data = request.json else: data = {} # Basic validation - can be extended with pydantic schemas if schema: # Here you would implement schema validation # For now, just return the data pass return data return {} except Exception as e: raise BadRequest(f"Invalid request data: {str(e)}") async def check_rate_limit(request: Request, pattern: str = "api") -> bool: """Check rate limit for request""" client_identifier = context_middleware.get_client_ip(request) if not await rate_limit_middleware.check_rate_limit(request, client_identifier, pattern): rate_info = await rate_limit_middleware.get_rate_limit_info(client_identifier, pattern) raise TooManyRequests(f"Rate limit exceeded: {rate_info}") return True # Decorator functions for convenience def auth_required(func): """Decorator to require authentication""" async def auth_wrapper(request: Request, *args, **kwargs): await check_auth(request) return await func(request, *args, **kwargs) auth_wrapper.__name__ = f"{func.__name__}_auth_required" return auth_wrapper def require_auth(permissions=None): """Decorator to require authentication and optional permissions""" def decorator(func): async def require_auth_wrapper(request: Request, *args, **kwargs): user = await check_auth(request) # Check permissions if specified if permissions: # This is a placeholder - implement proper permission checking pass return await func(request, *args, **kwargs) require_auth_wrapper.__name__ = f"{func.__name__}_require_auth" return require_auth_wrapper return decorator def validate_json(schema=None): """Decorator to validate JSON request""" def decorator(func): async def validate_json_wrapper(request: Request, *args, **kwargs): await validate_request_data(request, schema) return await func(request, *args, **kwargs) validate_json_wrapper.__name__ = f"{func.__name__}_validate_json" return validate_json_wrapper return decorator def validate_request(schema=None): """Decorator to validate request data against schema""" def decorator(func): async def validate_request_wrapper(request: Request, *args, **kwargs): await validate_request_data(request, schema) return await func(request, *args, **kwargs) validate_request_wrapper.__name__ = f"{func.__name__}_validate_request" return validate_request_wrapper return decorator def apply_rate_limit(pattern: str = "api", limit: Optional[int] = None, window: Optional[int] = None): """Decorator to apply rate limiting""" def decorator(func): async def rate_limit_wrapper(request: Request, *args, **kwargs): # Use custom limits if provided if limit and window: client_identifier = context_middleware.get_client_ip(request) cache = await rate_limit_middleware.get_cache() cache_key = CACHE_KEYS["rate_limit"].format( pattern=pattern, identifier=client_identifier ) # Get current count current_count = await cache.get(cache_key) if current_count is None: await cache.set(cache_key, "1", ttl=window) elif int(current_count) >= limit: raise TooManyRequests(f"Rate limit exceeded: {limit} per {window}s") else: await cache.incr(cache_key) else: # Use default rate limiting await check_rate_limit(request, pattern) return await func(request, *args, **kwargs) rate_limit_wrapper.__name__ = f"{func.__name__}_rate_limit" return rate_limit_wrapper return decorator # Create compatibility alias for the decorator syntax used in auth_routes def rate_limit(limit: Optional[int] = None, window: Optional[int] = None, pattern: str = "api"): """Compatibility decorator for rate limiting with limit/window parameters""" return apply_rate_limit(pattern=pattern, limit=limit, window=window)