495 lines
18 KiB
Python
495 lines
18 KiB
Python
"""
|
|
Enhanced API middleware with security, rate limiting, and monitoring
|
|
"""
|
|
import asyncio
|
|
import time
|
|
import uuid
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, Dict, Any, Callable
|
|
import json
|
|
|
|
from sanic import Request, HTTPResponse
|
|
from sanic.response import json as json_response, text as text_response
|
|
from sanic.exceptions import Unauthorized, Forbidden, TooManyRequests, BadRequest
|
|
import structlog
|
|
|
|
from app.core.config import settings, SecurityConfig, CACHE_KEYS
|
|
from app.core.database import get_db_session, get_cache
|
|
from app.core.logging import request_id_var, user_id_var, operation_var, log_performance
|
|
from app.core.models.user import User
|
|
from app.core.models.base import BaseModel
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
|
|
class SecurityMiddleware:
|
|
"""Security middleware for request validation and protection"""
|
|
|
|
@staticmethod
|
|
def add_security_headers(response: HTTPResponse) -> HTTPResponse:
|
|
"""Add security headers to response"""
|
|
# CORS headers
|
|
response.headers.update({
|
|
"Access-Control-Allow-Origin": "*", # Will be restricted based on request
|
|
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
|
"Access-Control-Allow-Headers": (
|
|
"Origin, Content-Type, Accept, Authorization, "
|
|
"X-Requested-With, X-API-Key, X-Request-ID"
|
|
),
|
|
"Access-Control-Max-Age": "86400",
|
|
|
|
# Security headers
|
|
"X-Content-Type-Options": "nosniff",
|
|
"X-Frame-Options": "DENY",
|
|
"X-XSS-Protection": "1; mode=block",
|
|
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
|
|
"Referrer-Policy": "strict-origin-when-cross-origin",
|
|
"Permissions-Policy": "geolocation=(), microphone=(), camera=()",
|
|
|
|
# Custom headers
|
|
"X-API-Version": settings.PROJECT_VERSION,
|
|
"X-Request-ID": getattr(getattr(Request, 'ctx', None), 'request_id', 'unknown')
|
|
})
|
|
|
|
# CSP header
|
|
csp_directives = "; ".join([
|
|
f"{directive} {' '.join(sources)}"
|
|
for directive, sources in SecurityConfig.CSP_DIRECTIVES.items()
|
|
])
|
|
response.headers["Content-Security-Policy"] = csp_directives
|
|
|
|
return response
|
|
|
|
@staticmethod
|
|
def validate_request_size(request: Request) -> None:
|
|
"""Validate request size limits"""
|
|
content_length = request.headers.get('content-length')
|
|
if content_length:
|
|
size = int(content_length)
|
|
if size > SecurityConfig.MAX_REQUEST_SIZE:
|
|
raise BadRequest(f"Request too large: {size} bytes")
|
|
|
|
@staticmethod
|
|
def validate_content_type(request: Request) -> None:
|
|
"""Validate content type for JSON requests"""
|
|
if request.method in ['POST', 'PUT', 'PATCH']:
|
|
content_type = request.headers.get('content-type', '')
|
|
if 'application/json' in content_type:
|
|
try:
|
|
# Validate JSON size
|
|
if hasattr(request, 'body') and len(request.body) > SecurityConfig.MAX_JSON_SIZE:
|
|
raise BadRequest("JSON payload too large")
|
|
except Exception:
|
|
raise BadRequest("Invalid JSON payload")
|
|
|
|
@staticmethod
|
|
def check_origin(request: Request) -> bool:
|
|
"""Check if request origin is allowed"""
|
|
origin = request.headers.get('origin')
|
|
if not origin:
|
|
return True # Allow requests without origin (direct API calls)
|
|
|
|
return any(
|
|
origin.startswith(allowed_origin.rstrip('/*'))
|
|
for allowed_origin in SecurityConfig.CORS_ORIGINS
|
|
)
|
|
|
|
|
|
class RateLimitMiddleware:
|
|
"""Rate limiting middleware using Redis"""
|
|
|
|
def __init__(self):
|
|
self.cache = None
|
|
|
|
async def get_cache(self):
|
|
"""Get cache instance"""
|
|
if not self.cache:
|
|
self.cache = await get_cache()
|
|
return self.cache
|
|
|
|
async def check_rate_limit(
|
|
self,
|
|
request: Request,
|
|
identifier: str,
|
|
pattern: str = "api"
|
|
) -> bool:
|
|
"""Check rate limit for identifier"""
|
|
try:
|
|
cache = await self.get_cache()
|
|
limits = SecurityConfig.RATE_LIMIT_PATTERNS.get(pattern, {
|
|
"requests": settings.RATE_LIMIT_REQUESTS,
|
|
"window": settings.RATE_LIMIT_WINDOW
|
|
})
|
|
|
|
cache_key = CACHE_KEYS["rate_limit"].format(
|
|
pattern=pattern,
|
|
identifier=identifier
|
|
)
|
|
|
|
# Get current count
|
|
current_count = await cache.get(cache_key)
|
|
if current_count is None:
|
|
# First request in window
|
|
await cache.set(cache_key, "1", ttl=limits["window"])
|
|
return True
|
|
|
|
current_count = int(current_count)
|
|
if current_count >= limits["requests"]:
|
|
# Rate limit exceeded
|
|
logger.warning(
|
|
"Rate limit exceeded",
|
|
identifier=identifier,
|
|
pattern=pattern,
|
|
count=current_count,
|
|
limit=limits["requests"]
|
|
)
|
|
return False
|
|
|
|
# Increment counter
|
|
await cache.incr(cache_key)
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error("Rate limit check failed", error=str(e))
|
|
return True # Allow request if rate limiting fails
|
|
|
|
async def get_rate_limit_info(
|
|
self,
|
|
identifier: str,
|
|
pattern: str = "api"
|
|
) -> Dict[str, Any]:
|
|
"""Get rate limit information"""
|
|
try:
|
|
cache = await self.get_cache()
|
|
limits = SecurityConfig.RATE_LIMIT_PATTERNS.get(pattern, {
|
|
"requests": settings.RATE_LIMIT_REQUESTS,
|
|
"window": settings.RATE_LIMIT_WINDOW
|
|
})
|
|
|
|
cache_key = CACHE_KEYS["rate_limit"].format(
|
|
pattern=pattern,
|
|
identifier=identifier
|
|
)
|
|
|
|
current_count = await cache.get(cache_key) or "0"
|
|
ttl = await cache.redis.ttl(cache_key)
|
|
|
|
return {
|
|
"limit": limits["requests"],
|
|
"remaining": max(0, limits["requests"] - int(current_count)),
|
|
"reset_time": int(time.time()) + max(0, ttl),
|
|
"window": limits["window"]
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get rate limit info", error=str(e))
|
|
return {}
|
|
|
|
|
|
class AuthenticationMiddleware:
|
|
"""Authentication middleware for API access"""
|
|
|
|
@staticmethod
|
|
async def extract_token(request: Request) -> Optional[str]:
|
|
"""Extract authentication token from request"""
|
|
# Check Authorization header
|
|
auth_header = request.headers.get('authorization')
|
|
if auth_header and auth_header.startswith('Bearer '):
|
|
return auth_header[7:] # Remove 'Bearer ' prefix
|
|
|
|
# Check X-API-Key header
|
|
api_key = request.headers.get('x-api-key')
|
|
if api_key:
|
|
return api_key
|
|
|
|
# Check query parameter (less secure, for backward compatibility)
|
|
return request.args.get('token')
|
|
|
|
@staticmethod
|
|
async def validate_token(token: str, session) -> Optional[User]:
|
|
"""Validate authentication token and return user"""
|
|
if not token:
|
|
return None
|
|
|
|
try:
|
|
# For now, implement simple token validation
|
|
# In production, implement JWT or database token validation
|
|
|
|
# Example: if token format is user_id:hash
|
|
if ':' in token:
|
|
user_id_str, token_hash = token.split(':', 1)
|
|
try:
|
|
user_id = uuid.UUID(user_id_str)
|
|
user = await User.get_by_id(session, user_id)
|
|
if user and user.verify_token(token_hash): # Implement in User model
|
|
return user
|
|
except (ValueError, AttributeError):
|
|
pass
|
|
|
|
# Fallback: try to find user by API token
|
|
# This would require implementing token storage in User model
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error("Token validation failed", token=token[:8] + "...", error=str(e))
|
|
return None
|
|
|
|
@staticmethod
|
|
async def check_permissions(user: User, request: Request) -> bool:
|
|
"""Check if user has required permissions for the endpoint"""
|
|
# Implement permission checking based on endpoint and user role
|
|
endpoint = request.path
|
|
method = request.method
|
|
|
|
# Admin endpoints
|
|
if '/admin/' in endpoint:
|
|
return user.is_admin
|
|
|
|
# Moderator endpoints
|
|
if '/mod/' in endpoint:
|
|
return user.is_moderator
|
|
|
|
# User-specific endpoints
|
|
if '/user/' in endpoint and method in ['POST', 'PUT', 'DELETE']:
|
|
return user.has_permission('user:write')
|
|
|
|
# Content upload endpoints
|
|
if '/upload' in endpoint or '/content' in endpoint and method == 'POST':
|
|
return user.can_upload_content()
|
|
|
|
# Default: allow read access for authenticated users
|
|
return True
|
|
|
|
|
|
class RequestContextMiddleware:
|
|
"""Request context middleware for tracking and logging"""
|
|
|
|
@staticmethod
|
|
def generate_request_id() -> str:
|
|
"""Generate unique request ID"""
|
|
return str(uuid.uuid4())
|
|
|
|
@staticmethod
|
|
async def add_request_context(request: Request) -> None:
|
|
"""Add request context for logging and tracking"""
|
|
# Generate and set request ID
|
|
request_id = RequestContextMiddleware.generate_request_id()
|
|
request.ctx.request_id = request_id
|
|
request_id_var.set(request_id)
|
|
|
|
# Set request start time
|
|
request.ctx.start_time = time.time()
|
|
|
|
# Extract client information
|
|
request.ctx.client_ip = RequestContextMiddleware.get_client_ip(request)
|
|
request.ctx.user_agent = request.headers.get('user-agent', 'Unknown')
|
|
|
|
# Initialize context
|
|
request.ctx.user = None
|
|
request.ctx.rate_limit_info = {}
|
|
|
|
logger.info(
|
|
"Request started",
|
|
method=request.method,
|
|
path=request.path,
|
|
client_ip=request.ctx.client_ip,
|
|
user_agent=request.ctx.user_agent
|
|
)
|
|
|
|
@staticmethod
|
|
def get_client_ip(request: Request) -> str:
|
|
"""Get real client IP address"""
|
|
# Check for forwarded headers
|
|
forwarded_for = request.headers.get('x-forwarded-for')
|
|
if forwarded_for:
|
|
return forwarded_for.split(',')[0].strip()
|
|
|
|
real_ip = request.headers.get('x-real-ip')
|
|
if real_ip:
|
|
return real_ip
|
|
|
|
# Fallback to request IP
|
|
return getattr(request, 'ip', '127.0.0.1')
|
|
|
|
@staticmethod
|
|
async def log_request_completion(request: Request, response: HTTPResponse) -> None:
|
|
"""Log request completion with metrics"""
|
|
duration = time.time() - getattr(request.ctx, 'start_time', time.time())
|
|
|
|
logger.info(
|
|
"Request completed",
|
|
method=request.method,
|
|
path=request.path,
|
|
status_code=response.status,
|
|
duration_ms=round(duration * 1000, 2),
|
|
response_size=len(response.body) if response.body else 0,
|
|
client_ip=getattr(request.ctx, 'client_ip', 'unknown'),
|
|
user_id=str(request.ctx.user.id) if request.ctx.user else None
|
|
)
|
|
|
|
|
|
# Initialize middleware instances
|
|
security_middleware = SecurityMiddleware()
|
|
rate_limit_middleware = RateLimitMiddleware()
|
|
auth_middleware = AuthenticationMiddleware()
|
|
context_middleware = RequestContextMiddleware()
|
|
|
|
|
|
async def request_middleware(request: Request):
|
|
"""Main request middleware pipeline"""
|
|
|
|
# Handle OPTIONS requests for CORS
|
|
if request.method == 'OPTIONS':
|
|
response = text_response('OK')
|
|
return security_middleware.add_security_headers(response)
|
|
|
|
# Add request context
|
|
await context_middleware.add_request_context(request)
|
|
|
|
# Security validations
|
|
try:
|
|
security_middleware.validate_request_size(request)
|
|
security_middleware.validate_content_type(request)
|
|
|
|
if not security_middleware.check_origin(request):
|
|
raise Forbidden("Origin not allowed")
|
|
|
|
except Exception as e:
|
|
logger.warning("Security validation failed", error=str(e))
|
|
response = json_response({"error": str(e)}, status=400)
|
|
return security_middleware.add_security_headers(response)
|
|
|
|
# Rate limiting
|
|
if settings.RATE_LIMIT_ENABLED:
|
|
client_identifier = context_middleware.get_client_ip(request)
|
|
pattern = "api"
|
|
|
|
# Determine rate limit pattern based on endpoint
|
|
if '/auth/' in request.path:
|
|
pattern = "auth"
|
|
elif '/upload' in request.path:
|
|
pattern = "upload"
|
|
elif '/admin/' in request.path:
|
|
pattern = "heavy"
|
|
|
|
if not await rate_limit_middleware.check_rate_limit(request, client_identifier, pattern):
|
|
rate_info = await rate_limit_middleware.get_rate_limit_info(client_identifier, pattern)
|
|
response = json_response(
|
|
{
|
|
"error": "Rate limit exceeded",
|
|
"rate_limit": rate_info
|
|
},
|
|
status=429
|
|
)
|
|
return security_middleware.add_security_headers(response)
|
|
|
|
# Store rate limit info for response headers
|
|
request.ctx.rate_limit_info = await rate_limit_middleware.get_rate_limit_info(
|
|
client_identifier, pattern
|
|
)
|
|
|
|
# Authentication (for protected endpoints)
|
|
if not request.path.startswith('/api/system') and request.path != '/':
|
|
async with get_db_session() as session:
|
|
token = await auth_middleware.extract_token(request)
|
|
if token:
|
|
user = await auth_middleware.validate_token(token, session)
|
|
if user:
|
|
request.ctx.user = user
|
|
user_id_var.set(str(user.id))
|
|
|
|
# Check permissions
|
|
if not await auth_middleware.check_permissions(user, request):
|
|
response = json_response({"error": "Insufficient permissions"}, status=403)
|
|
return security_middleware.add_security_headers(response)
|
|
|
|
# Update user activity
|
|
user.update_activity()
|
|
await session.commit()
|
|
|
|
# Store session for request handlers
|
|
request.ctx.db_session = session
|
|
|
|
|
|
async def response_middleware(request: Request, response: HTTPResponse):
|
|
"""Main response middleware pipeline"""
|
|
|
|
# Add security headers
|
|
response = security_middleware.add_security_headers(response)
|
|
|
|
# Add rate limit headers
|
|
if hasattr(request.ctx, 'rate_limit_info') and request.ctx.rate_limit_info:
|
|
rate_info = request.ctx.rate_limit_info
|
|
response.headers.update({
|
|
"X-RateLimit-Limit": str(rate_info.get('limit', 0)),
|
|
"X-RateLimit-Remaining": str(rate_info.get('remaining', 0)),
|
|
"X-RateLimit-Reset": str(rate_info.get('reset_time', 0))
|
|
})
|
|
|
|
# Add request ID to response
|
|
if hasattr(request.ctx, 'request_id'):
|
|
response.headers["X-Request-ID"] = request.ctx.request_id
|
|
|
|
# Log request completion
|
|
await context_middleware.log_request_completion(request, response)
|
|
|
|
return response
|
|
|
|
|
|
async def exception_middleware(request: Request, exception: Exception):
|
|
"""Global exception handling middleware"""
|
|
|
|
error_id = str(uuid.uuid4())
|
|
|
|
# Log the exception
|
|
logger.error(
|
|
"Unhandled exception",
|
|
error_id=error_id,
|
|
exception_type=type(exception).__name__,
|
|
exception_message=str(exception),
|
|
path=request.path,
|
|
method=request.method,
|
|
user_id=str(request.ctx.user.id) if hasattr(request.ctx, 'user') and request.ctx.user else None
|
|
)
|
|
|
|
# Handle different exception types
|
|
if isinstance(exception, Unauthorized):
|
|
response_data = {"error": "Authentication required", "error_id": error_id}
|
|
status = 401
|
|
elif isinstance(exception, Forbidden):
|
|
response_data = {"error": "Access forbidden", "error_id": error_id}
|
|
status = 403
|
|
elif isinstance(exception, TooManyRequests):
|
|
response_data = {"error": "Rate limit exceeded", "error_id": error_id}
|
|
status = 429
|
|
elif isinstance(exception, BadRequest):
|
|
response_data = {"error": str(exception), "error_id": error_id}
|
|
status = 400
|
|
else:
|
|
# Generic server error
|
|
response_data = {
|
|
"error": "Internal server error",
|
|
"error_id": error_id
|
|
}
|
|
status = 500
|
|
|
|
if settings.DEBUG:
|
|
response_data["debug"] = {
|
|
"type": type(exception).__name__,
|
|
"message": str(exception)
|
|
}
|
|
|
|
response = json_response(response_data, status=status)
|
|
return security_middleware.add_security_headers(response)
|
|
|
|
|
|
# Maintenance mode middleware
|
|
async def maintenance_middleware(request: Request):
|
|
"""Check for maintenance mode"""
|
|
if settings.MAINTENANCE_MODE and not request.path.startswith('/api/system'):
|
|
response = json_response({
|
|
"error": "Service temporarily unavailable",
|
|
"message": settings.MAINTENANCE_MESSAGE
|
|
}, status=503)
|
|
return security_middleware.add_security_headers(response)
|