uploader-bot/app/api/middleware.py

630 lines
23 KiB
Python

"""
Enhanced API middleware with security, rate limiting, and monitoring
"""
import asyncio
import time
import uuid
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, Callable
import json
from sanic import Request, HTTPResponse
from sanic.response import json as json_response, text as text_response
from sanic.exceptions import Unauthorized, Forbidden, BadRequest
# TooManyRequests может не существовать в этой версии Sanic, создадим собственное
class TooManyRequests(Exception):
"""Custom exception for rate limiting"""
pass
import structlog
from app.core.config import settings, SecurityConfig, CACHE_KEYS
from app.core.database import get_cache
from app.core.logging import request_id_var, user_id_var, operation_var, log_performance
from app.core.models.user import User
from app.core.models.base import BaseModel
logger = structlog.get_logger(__name__)
class SecurityMiddleware:
"""Security middleware for request validation and protection"""
@staticmethod
def add_security_headers(response: HTTPResponse) -> HTTPResponse:
"""Add security headers to response"""
# CORS headers
response.headers.update({
"Access-Control-Allow-Origin": "*", # Will be restricted based on request
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": (
"Origin, Content-Type, Accept, Authorization, "
"X-Requested-With, X-API-Key, X-Request-ID"
),
"Access-Control-Max-Age": "86400",
# Security headers
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
"Referrer-Policy": "strict-origin-when-cross-origin",
"Permissions-Policy": "geolocation=(), microphone=(), camera=()",
# Custom headers
"X-API-Version": settings.PROJECT_VERSION,
"X-Request-ID": getattr(getattr(Request, 'ctx', None), 'request_id', 'unknown')
})
# CSP header
csp_directives = "; ".join([
f"{directive} {' '.join(sources)}"
for directive, sources in SecurityConfig.CSP_DIRECTIVES.items()
])
response.headers["Content-Security-Policy"] = csp_directives
return response
@staticmethod
def validate_request_size(request: Request) -> None:
"""Validate request size limits"""
content_length = request.headers.get('content-length')
if content_length:
size = int(content_length)
if size > SecurityConfig.MAX_REQUEST_SIZE:
raise BadRequest(f"Request too large: {size} bytes")
@staticmethod
def validate_content_type(request: Request) -> None:
"""Validate content type for JSON requests"""
if request.method in ['POST', 'PUT', 'PATCH']:
content_type = request.headers.get('content-type', '')
if 'application/json' in content_type:
try:
# Validate JSON size
if hasattr(request, 'body') and len(request.body) > SecurityConfig.MAX_JSON_SIZE:
raise BadRequest("JSON payload too large")
except Exception:
raise BadRequest("Invalid JSON payload")
@staticmethod
def check_origin(request: Request) -> bool:
"""Check if request origin is allowed"""
origin = request.headers.get('origin')
if not origin:
return True # Allow requests without origin (direct API calls)
return any(
origin.startswith(allowed_origin.rstrip('/*'))
for allowed_origin in SecurityConfig.CORS_ORIGINS
)
class RateLimitMiddleware:
"""Rate limiting middleware using Redis"""
def __init__(self):
self.cache = None
async def get_cache(self):
"""Get cache instance"""
if not self.cache:
self.cache = await get_cache()
return self.cache
async def check_rate_limit(
self,
request: Request,
identifier: str,
pattern: str = "api"
) -> bool:
"""Check rate limit for identifier"""
try:
cache = await self.get_cache()
limits = SecurityConfig.RATE_LIMIT_PATTERNS.get(pattern, {
"requests": settings.RATE_LIMIT_REQUESTS,
"window": settings.RATE_LIMIT_WINDOW
})
cache_key = CACHE_KEYS["rate_limit"].format(
pattern=pattern,
identifier=identifier
)
# Get current count
current_count = await cache.get(cache_key)
if current_count is None:
# First request in window
await cache.set(cache_key, "1", ttl=limits["window"])
return True
current_count = int(current_count)
if current_count >= limits["requests"]:
# Rate limit exceeded
logger.warning(
"Rate limit exceeded",
identifier=identifier,
pattern=pattern,
count=current_count,
limit=limits["requests"]
)
return False
# Increment counter
await cache.incr(cache_key)
return True
except Exception as e:
logger.error("Rate limit check failed", error=str(e))
return True # Allow request if rate limiting fails
async def get_rate_limit_info(
self,
identifier: str,
pattern: str = "api"
) -> Dict[str, Any]:
"""Get rate limit information"""
try:
cache = await self.get_cache()
limits = SecurityConfig.RATE_LIMIT_PATTERNS.get(pattern, {
"requests": settings.RATE_LIMIT_REQUESTS,
"window": settings.RATE_LIMIT_WINDOW
})
cache_key = CACHE_KEYS["rate_limit"].format(
pattern=pattern,
identifier=identifier
)
current_count = await cache.get(cache_key) or "0"
ttl = await cache.redis.ttl(cache_key)
return {
"limit": limits["requests"],
"remaining": max(0, limits["requests"] - int(current_count)),
"reset_time": int(time.time()) + max(0, ttl),
"window": limits["window"]
}
except Exception as e:
logger.error("Failed to get rate limit info", error=str(e))
return {}
class AuthenticationMiddleware:
"""Authentication middleware for API access"""
@staticmethod
async def extract_token(request: Request) -> Optional[str]:
"""Extract authentication token from request"""
# Check Authorization header
auth_header = request.headers.get('authorization')
if auth_header and auth_header.startswith('Bearer '):
return auth_header[7:] # Remove 'Bearer ' prefix
# Check X-API-Key header
api_key = request.headers.get('x-api-key')
if api_key:
return api_key
# Check query parameter (less secure, for backward compatibility)
return request.args.get('token')
@staticmethod
async def validate_token(token: str, session) -> Optional[User]:
"""Validate authentication token and return user"""
if not token:
return None
try:
# For now, implement simple token validation
# In production, implement JWT or database token validation
# Example: if token format is user_id:hash
if ':' in token:
user_id_str, token_hash = token.split(':', 1)
try:
user_id = uuid.UUID(user_id_str)
user = await User.get_by_id(session, user_id)
if user and user.verify_token(token_hash): # Implement in User model
return user
except (ValueError, AttributeError):
pass
# Fallback: try to find user by API token
# This would require implementing token storage in User model
return None
except Exception as e:
logger.error("Token validation failed", token=token[:8] + "...", error=str(e))
return None
@staticmethod
async def check_permissions(user: User, request: Request) -> bool:
"""Check if user has required permissions for the endpoint"""
# Implement permission checking based on endpoint and user role
endpoint = request.path
method = request.method
# Admin endpoints
if '/admin/' in endpoint:
return user.is_admin
# Moderator endpoints
if '/mod/' in endpoint:
return user.is_moderator
# User-specific endpoints
if '/user/' in endpoint and method in ['POST', 'PUT', 'DELETE']:
return user.has_permission('user:write')
# Content upload endpoints
if '/upload' in endpoint or '/content' in endpoint and method == 'POST':
return user.can_upload_content()
# Default: allow read access for authenticated users
return True
class RequestContextMiddleware:
"""Request context middleware for tracking and logging"""
@staticmethod
def generate_request_id() -> str:
"""Generate unique request ID"""
return str(uuid.uuid4())
@staticmethod
async def add_request_context(request: Request) -> None:
"""Add request context for logging and tracking"""
# Generate and set request ID
request_id = RequestContextMiddleware.generate_request_id()
request.ctx.request_id = request_id
request_id_var.set(request_id)
# Set request start time
request.ctx.start_time = time.time()
# Extract client information
request.ctx.client_ip = RequestContextMiddleware.get_client_ip(request)
request.ctx.user_agent = request.headers.get('user-agent', 'Unknown')
# Initialize context
request.ctx.user = None
request.ctx.rate_limit_info = {}
logger.info(
"Request started",
method=request.method,
path=request.path,
client_ip=request.ctx.client_ip,
user_agent=request.ctx.user_agent
)
@staticmethod
def get_client_ip(request: Request) -> str:
"""Get real client IP address"""
# Check for forwarded headers
forwarded_for = request.headers.get('x-forwarded-for')
if forwarded_for:
return forwarded_for.split(',')[0].strip()
real_ip = request.headers.get('x-real-ip')
if real_ip:
return real_ip
# Fallback to request IP
return getattr(request, 'ip', '127.0.0.1')
@staticmethod
async def log_request_completion(request: Request, response: HTTPResponse) -> None:
"""Log request completion with metrics"""
duration = time.time() - getattr(request.ctx, 'start_time', time.time())
logger.info(
"Request completed",
method=request.method,
path=request.path,
status_code=response.status,
duration_ms=round(duration * 1000, 2),
response_size=len(response.body) if response.body else 0,
client_ip=getattr(request.ctx, 'client_ip', 'unknown'),
user_id=str(request.ctx.user.id) if request.ctx.user else None
)
# Initialize middleware instances
security_middleware = SecurityMiddleware()
rate_limit_middleware = RateLimitMiddleware()
auth_middleware = AuthenticationMiddleware()
context_middleware = RequestContextMiddleware()
async def request_middleware(request: Request):
"""Main request middleware pipeline"""
# Handle OPTIONS requests for CORS
if request.method == 'OPTIONS':
response = text_response('OK')
return security_middleware.add_security_headers(response)
# Add request context
await context_middleware.add_request_context(request)
# Security validations
try:
security_middleware.validate_request_size(request)
security_middleware.validate_content_type(request)
if not security_middleware.check_origin(request):
raise Forbidden("Origin not allowed")
except Exception as e:
logger.warning("Security validation failed", error=str(e))
response = json_response({"error": str(e)}, status=400)
return security_middleware.add_security_headers(response)
# Rate limiting
if settings.RATE_LIMIT_ENABLED:
client_identifier = context_middleware.get_client_ip(request)
pattern = "api"
# Determine rate limit pattern based on endpoint
if '/auth/' in request.path:
pattern = "auth"
elif '/upload' in request.path:
pattern = "upload"
elif '/admin/' in request.path:
pattern = "heavy"
if not await rate_limit_middleware.check_rate_limit(request, client_identifier, pattern):
rate_info = await rate_limit_middleware.get_rate_limit_info(client_identifier, pattern)
response = json_response(
{
"error": "Rate limit exceeded",
"rate_limit": rate_info
},
status=429
)
return security_middleware.add_security_headers(response)
# Store rate limit info for response headers
request.ctx.rate_limit_info = await rate_limit_middleware.get_rate_limit_info(
client_identifier, pattern
)
# Authentication (for protected endpoints)
if not request.path.startswith('/api/system') and request.path != '/':
from app.core.database import db_manager
async with db_manager.get_session() as session:
token = await auth_middleware.extract_token(request)
if token:
user = await auth_middleware.validate_token(token, session)
if user:
request.ctx.user = user
user_id_var.set(str(user.id))
# Check permissions
if not await auth_middleware.check_permissions(user, request):
response = json_response({"error": "Insufficient permissions"}, status=403)
return security_middleware.add_security_headers(response)
# Update user activity
user.update_activity()
await session.commit()
# Store session for request handlers
request.ctx.db_session = session
async def response_middleware(request: Request, response: HTTPResponse):
"""Main response middleware pipeline"""
# Add security headers
response = security_middleware.add_security_headers(response)
# Add rate limit headers
if hasattr(request.ctx, 'rate_limit_info') and request.ctx.rate_limit_info:
rate_info = request.ctx.rate_limit_info
response.headers.update({
"X-RateLimit-Limit": str(rate_info.get('limit', 0)),
"X-RateLimit-Remaining": str(rate_info.get('remaining', 0)),
"X-RateLimit-Reset": str(rate_info.get('reset_time', 0))
})
# Add request ID to response
if hasattr(request.ctx, 'request_id'):
response.headers["X-Request-ID"] = request.ctx.request_id
# Log request completion
await context_middleware.log_request_completion(request, response)
return response
async def exception_middleware(request: Request, exception: Exception):
"""Global exception handling middleware"""
error_id = str(uuid.uuid4())
# Log the exception
logger.error(
"Unhandled exception",
error_id=error_id,
exception_type=type(exception).__name__,
exception_message=str(exception),
path=request.path,
method=request.method,
user_id=str(request.ctx.user.id) if hasattr(request.ctx, 'user') and request.ctx.user else None
)
# Handle different exception types
if isinstance(exception, Unauthorized):
response_data = {"error": "Authentication required", "error_id": error_id}
status = 401
elif isinstance(exception, Forbidden):
response_data = {"error": "Access forbidden", "error_id": error_id}
status = 403
elif isinstance(exception, TooManyRequests):
response_data = {"error": "Rate limit exceeded", "error_id": error_id}
status = 429
elif isinstance(exception, BadRequest):
response_data = {"error": str(exception), "error_id": error_id}
status = 400
else:
# Generic server error
response_data = {
"error": "Internal server error",
"error_id": error_id
}
status = 500
if settings.DEBUG:
response_data["debug"] = {
"type": type(exception).__name__,
"message": str(exception)
}
response = json_response(response_data, status=status)
return security_middleware.add_security_headers(response)
# Maintenance mode middleware
async def maintenance_middleware(request: Request):
"""Check for maintenance mode"""
if settings.MAINTENANCE_MODE and not request.path.startswith('/api/system'):
response = json_response({
"error": "Service temporarily unavailable",
"message": settings.MAINTENANCE_MESSAGE
}, status=503)
return security_middleware.add_security_headers(response)
# Helper functions for route decorators
async def check_auth(request: Request) -> User:
"""Check authentication for endpoint"""
if not hasattr(request.ctx, 'user') or not request.ctx.user:
raise Unauthorized("Authentication required")
return request.ctx.user
async def validate_request_data(request: Request, schema: Optional[Any] = None) -> Dict[str, Any]:
"""Validate request data against schema"""
try:
if request.method in ['POST', 'PUT', 'PATCH']:
# Get JSON data
if hasattr(request, 'json') and request.json:
data = request.json
else:
data = {}
# Basic validation - can be extended with pydantic schemas
if schema:
# Here you would implement schema validation
# For now, just return the data
pass
return data
return {}
except Exception as e:
raise BadRequest(f"Invalid request data: {str(e)}")
async def check_rate_limit(request: Request, pattern: str = "api") -> bool:
"""Check rate limit for request"""
client_identifier = context_middleware.get_client_ip(request)
if not await rate_limit_middleware.check_rate_limit(request, client_identifier, pattern):
rate_info = await rate_limit_middleware.get_rate_limit_info(client_identifier, pattern)
raise TooManyRequests(f"Rate limit exceeded: {rate_info}")
return True
# Decorator functions for convenience
def auth_required(func):
"""Decorator to require authentication"""
async def auth_wrapper(request: Request, *args, **kwargs):
await check_auth(request)
return await func(request, *args, **kwargs)
auth_wrapper.__name__ = f"{func.__name__}_auth_required"
return auth_wrapper
def require_auth(permissions=None):
"""Decorator to require authentication and optional permissions"""
def decorator(func):
async def require_auth_wrapper(request: Request, *args, **kwargs):
user = await check_auth(request)
# Check permissions if specified
if permissions:
# This is a placeholder - implement proper permission checking
pass
return await func(request, *args, **kwargs)
require_auth_wrapper.__name__ = f"{func.__name__}_require_auth"
return require_auth_wrapper
return decorator
def validate_json(schema=None):
"""Decorator to validate JSON request"""
def decorator(func):
async def validate_json_wrapper(request: Request, *args, **kwargs):
await validate_request_data(request, schema)
return await func(request, *args, **kwargs)
validate_json_wrapper.__name__ = f"{func.__name__}_validate_json"
return validate_json_wrapper
return decorator
def validate_request(schema=None):
"""Decorator to validate request data against schema"""
def decorator(func):
async def validate_request_wrapper(request: Request, *args, **kwargs):
await validate_request_data(request, schema)
return await func(request, *args, **kwargs)
validate_request_wrapper.__name__ = f"{func.__name__}_validate_request"
return validate_request_wrapper
return decorator
def apply_rate_limit(pattern: str = "api", limit: Optional[int] = None, window: Optional[int] = None):
"""Decorator to apply rate limiting"""
def decorator(func):
async def rate_limit_wrapper(request: Request, *args, **kwargs):
# Use custom limits if provided
if limit and window:
client_identifier = context_middleware.get_client_ip(request)
cache = await rate_limit_middleware.get_cache()
cache_key = CACHE_KEYS["rate_limit"].format(
pattern=pattern,
identifier=client_identifier
)
# Get current count
current_count = await cache.get(cache_key)
if current_count is None:
await cache.set(cache_key, "1", ttl=window)
elif int(current_count) >= limit:
raise TooManyRequests(f"Rate limit exceeded: {limit} per {window}s")
else:
await cache.incr(cache_key)
else:
# Use default rate limiting
await check_rate_limit(request, pattern)
return await func(request, *args, **kwargs)
rate_limit_wrapper.__name__ = f"{func.__name__}_rate_limit"
return rate_limit_wrapper
return decorator
# Create compatibility alias for the decorator syntax used in auth_routes
def rate_limit(limit: Optional[int] = None, window: Optional[int] = None, pattern: str = "api"):
"""Compatibility decorator for rate limiting with limit/window parameters"""
return apply_rate_limit(pattern=pattern, limit=limit, window=window)