uploader-bot/app/api/middleware.py

742 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Enhanced API middleware with security, rate limiting, monitoring and ed25519 signatures
"""
import asyncio
import time
import uuid
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, Callable
import json
from sanic import Request, HTTPResponse
from sanic.response import json as json_response, text as text_response
from sanic.exceptions import Unauthorized, Forbidden, BadRequest
# TooManyRequests может не существовать в этой версии Sanic, создадим собственное
class TooManyRequests(Exception):
"""Custom exception for rate limiting"""
pass
import structlog
from app.core.config import settings, SecurityConfig, CACHE_KEYS
from app.core.database import get_cache
from app.core.logging import request_id_var, user_id_var, operation_var, log_performance
from app.core.models.user import User
from app.core.models.base import BaseModel
# Ed25519 криптографический модуль
try:
from app.core.crypto import get_ed25519_manager
CRYPTO_AVAILABLE = True
except ImportError:
CRYPTO_AVAILABLE = False
logger = structlog.get_logger(__name__)
class SecurityMiddleware:
"""Security middleware for request validation and protection"""
@staticmethod
def add_security_headers(response: HTTPResponse) -> HTTPResponse:
"""Add security headers to response"""
# CORS headers
response.headers.update({
"Access-Control-Allow-Origin": "*", # Will be restricted based on request
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": (
"Origin, Content-Type, Accept, Authorization, "
"X-Requested-With, X-API-Key, X-Request-ID"
),
"Access-Control-Max-Age": "86400",
# Security headers
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
"Referrer-Policy": "strict-origin-when-cross-origin",
"Permissions-Policy": "geolocation=(), microphone=(), camera=()",
# Custom headers
"X-API-Version": settings.PROJECT_VERSION,
"X-Request-ID": getattr(getattr(Request, 'ctx', None), 'request_id', 'unknown')
})
# CSP header
csp_directives = "; ".join([
f"{directive} {' '.join(sources)}"
for directive, sources in SecurityConfig.CSP_DIRECTIVES.items()
])
response.headers["Content-Security-Policy"] = csp_directives
return response
@staticmethod
def validate_request_size(request: Request) -> None:
"""Validate request size limits"""
content_length = request.headers.get('content-length')
if content_length:
size = int(content_length)
if size > SecurityConfig.MAX_REQUEST_SIZE:
raise BadRequest(f"Request too large: {size} bytes")
@staticmethod
def validate_content_type(request: Request) -> None:
"""Validate content type for JSON requests"""
if request.method in ['POST', 'PUT', 'PATCH']:
content_type = request.headers.get('content-type', '')
if 'application/json' in content_type:
try:
# Validate JSON size
if hasattr(request, 'body') and len(request.body) > SecurityConfig.MAX_JSON_SIZE:
raise BadRequest("JSON payload too large")
except Exception:
raise BadRequest("Invalid JSON payload")
@staticmethod
def check_origin(request: Request) -> bool:
"""Check if request origin is allowed"""
origin = request.headers.get('origin')
if not origin:
return True # Allow requests without origin (direct API calls)
return any(
origin.startswith(allowed_origin.rstrip('/*'))
for allowed_origin in SecurityConfig.CORS_ORIGINS
)
class RateLimitMiddleware:
"""Rate limiting middleware using Redis"""
def __init__(self):
self.cache = None
async def get_cache(self):
"""Get cache instance"""
if not self.cache:
self.cache = await get_cache()
return self.cache
async def check_rate_limit(
self,
request: Request,
identifier: str,
pattern: str = "api"
) -> bool:
"""Check rate limit for identifier"""
try:
cache = await self.get_cache()
limits = SecurityConfig.RATE_LIMIT_PATTERNS.get(pattern, {
"requests": settings.RATE_LIMIT_REQUESTS,
"window": settings.RATE_LIMIT_WINDOW
})
cache_key = CACHE_KEYS["rate_limit"].format(
pattern=pattern,
identifier=identifier
)
# Get current count
current_count = await cache.get(cache_key)
if current_count is None:
# First request in window
await cache.set(cache_key, "1", ttl=limits["window"])
return True
current_count = int(current_count)
if current_count >= limits["requests"]:
# Rate limit exceeded
logger.warning(
"Rate limit exceeded",
identifier=identifier,
pattern=pattern,
count=current_count,
limit=limits["requests"]
)
return False
# Increment counter
await cache.incr(cache_key)
return True
except Exception as e:
logger.error("Rate limit check failed", error=str(e))
return True # Allow request if rate limiting fails
async def get_rate_limit_info(
self,
identifier: str,
pattern: str = "api"
) -> Dict[str, Any]:
"""Get rate limit information"""
try:
cache = await self.get_cache()
limits = SecurityConfig.RATE_LIMIT_PATTERNS.get(pattern, {
"requests": settings.RATE_LIMIT_REQUESTS,
"window": settings.RATE_LIMIT_WINDOW
})
cache_key = CACHE_KEYS["rate_limit"].format(
pattern=pattern,
identifier=identifier
)
current_count = await cache.get(cache_key) or "0"
ttl = await cache.redis.ttl(cache_key)
return {
"limit": limits["requests"],
"remaining": max(0, limits["requests"] - int(current_count)),
"reset_time": int(time.time()) + max(0, ttl),
"window": limits["window"]
}
except Exception as e:
logger.error("Failed to get rate limit info", error=str(e))
return {}
class AuthenticationMiddleware:
"""Authentication middleware for API access"""
@staticmethod
async def extract_token(request: Request) -> Optional[str]:
"""Extract authentication token from request"""
# Check Authorization header
auth_header = request.headers.get('authorization')
if auth_header and auth_header.startswith('Bearer '):
return auth_header[7:] # Remove 'Bearer ' prefix
# Check X-API-Key header
api_key = request.headers.get('x-api-key')
if api_key:
return api_key
# Check query parameter (less secure, for backward compatibility)
return request.args.get('token')
@staticmethod
async def validate_token(token: str, session) -> Optional[User]:
"""Validate authentication token and return user"""
if not token:
return None
try:
# For now, implement simple token validation
# In production, implement JWT or database token validation
# Example: if token format is user_id:hash
if ':' in token:
user_id_str, token_hash = token.split(':', 1)
try:
user_id = uuid.UUID(user_id_str)
user = await User.get_by_id(session, user_id)
if user and user.verify_token(token_hash): # Implement in User model
return user
except (ValueError, AttributeError):
pass
# Fallback: try to find user by API token
# This would require implementing token storage in User model
return None
except Exception as e:
logger.error("Token validation failed", token=token[:8] + "...", error=str(e))
return None
@staticmethod
async def check_permissions(user: User, request: Request) -> bool:
"""Check if user has required permissions for the endpoint"""
# Implement permission checking based on endpoint and user role
endpoint = request.path
method = request.method
# Admin endpoints
if '/admin/' in endpoint:
return user.is_admin
# Moderator endpoints
if '/mod/' in endpoint:
return user.is_moderator
# User-specific endpoints
if '/user/' in endpoint and method in ['POST', 'PUT', 'DELETE']:
return user.has_permission('user:write')
# Content upload endpoints
if '/upload' in endpoint or '/content' in endpoint and method == 'POST':
return user.can_upload_content()
# Default: allow read access for authenticated users
return True
class CryptographicMiddleware:
"""Ed25519 cryptographic middleware for inter-node communication"""
@staticmethod
async def verify_inter_node_signature(request: Request) -> bool:
"""Проверить ed25519 подпись для межузлового сообщения"""
if not CRYPTO_AVAILABLE:
logger.warning("Crypto module not available, skipping signature verification")
return True
# Проверяем, является ли это межузловым сообщением
if not request.headers.get("X-Node-Communication") == "true":
return True # Не межузловое сообщение, пропускаем проверку
try:
crypto_manager = get_ed25519_manager()
# Получаем необходимые заголовки
signature = request.headers.get("X-Node-Signature")
node_id = request.headers.get("X-Node-ID")
public_key = request.headers.get("X-Node-Public-Key")
if not all([signature, node_id, public_key]):
logger.warning("Missing cryptographic headers in inter-node request")
return False
# Читаем тело сообщения для проверки подписи
if hasattr(request, 'body') and request.body:
try:
message_data = json.loads(request.body.decode())
# Проверяем подпись
is_valid = crypto_manager.verify_signature(
message_data, signature, public_key
)
if is_valid:
logger.debug(f"Valid signature verified for node {node_id}")
# Сохраняем информацию о ноде в контексте
request.ctx.inter_node_communication = True
request.ctx.source_node_id = node_id
request.ctx.source_public_key = public_key
return True
else:
logger.warning(f"Invalid signature from node {node_id}")
return False
except json.JSONDecodeError:
logger.warning("Invalid JSON in inter-node request")
return False
else:
logger.warning("Empty body in inter-node request")
return False
except Exception as e:
logger.error(f"Crypto verification error: {e}")
return False
@staticmethod
async def add_inter_node_headers(request: Request, response: HTTPResponse) -> HTTPResponse:
"""Добавить криптографические заголовки для межузловых ответов"""
if not CRYPTO_AVAILABLE:
return response
# Добавляем заголовки только для межузловых сообщений
if hasattr(request.ctx, 'inter_node_communication') and request.ctx.inter_node_communication:
try:
crypto_manager = get_ed25519_manager()
# Добавляем информацию о нашей ноде
response.headers.update({
"X-Node-ID": crypto_manager.node_id,
"X-Node-Public-Key": crypto_manager.public_key_hex,
"X-Node-Communication": "true"
})
# Если есть тело ответа, подписываем его
if response.body:
try:
response_data = json.loads(response.body.decode())
signature = crypto_manager.sign_message(response_data)
response.headers["X-Node-Signature"] = signature
except json.JSONDecodeError:
# Не JSON тело, пропускаем подпись
pass
except Exception as e:
logger.error(f"Error adding inter-node headers: {e}")
return response
class RequestContextMiddleware:
"""Request context middleware for tracking and logging"""
@staticmethod
def generate_request_id() -> str:
"""Generate unique request ID"""
return str(uuid.uuid4())
@staticmethod
async def add_request_context(request: Request) -> None:
"""Add request context for logging and tracking"""
# Generate and set request ID
request_id = RequestContextMiddleware.generate_request_id()
request.ctx.request_id = request_id
request_id_var.set(request_id)
# Set request start time
request.ctx.start_time = time.time()
# Extract client information
request.ctx.client_ip = RequestContextMiddleware.get_client_ip(request)
request.ctx.user_agent = request.headers.get('user-agent', 'Unknown')
# Initialize context
request.ctx.user = None
request.ctx.rate_limit_info = {}
logger.info(
"Request started",
method=request.method,
path=request.path,
client_ip=request.ctx.client_ip,
user_agent=request.ctx.user_agent
)
@staticmethod
def get_client_ip(request: Request) -> str:
"""Get real client IP address"""
# Check for forwarded headers
forwarded_for = request.headers.get('x-forwarded-for')
if forwarded_for:
return forwarded_for.split(',')[0].strip()
real_ip = request.headers.get('x-real-ip')
if real_ip:
return real_ip
# Fallback to request IP
return getattr(request, 'ip', '127.0.0.1')
@staticmethod
async def log_request_completion(request: Request, response: HTTPResponse) -> None:
"""Log request completion with metrics"""
duration = time.time() - getattr(request.ctx, 'start_time', time.time())
logger.info(
"Request completed",
method=request.method,
path=request.path,
status_code=response.status,
duration_ms=round(duration * 1000, 2),
response_size=len(response.body) if response.body else 0,
client_ip=getattr(request.ctx, 'client_ip', 'unknown'),
user_id=str(request.ctx.user.id) if request.ctx.user else None
)
# Initialize middleware instances
security_middleware = SecurityMiddleware()
rate_limit_middleware = RateLimitMiddleware()
auth_middleware = AuthenticationMiddleware()
context_middleware = RequestContextMiddleware()
crypto_middleware = CryptographicMiddleware()
async def request_middleware(request: Request):
"""Main request middleware pipeline"""
# Handle OPTIONS requests for CORS
if request.method == 'OPTIONS':
response = text_response('OK')
return security_middleware.add_security_headers(response)
# Add request context
await context_middleware.add_request_context(request)
# Cryptographic signature verification for inter-node communication
if not await crypto_middleware.verify_inter_node_signature(request):
logger.warning("Inter-node signature verification failed")
response = json_response({
"error": "Invalid cryptographic signature",
"message": "Inter-node communication requires valid ed25519 signature"
}, status=403)
return security_middleware.add_security_headers(response)
# Security validations
try:
security_middleware.validate_request_size(request)
security_middleware.validate_content_type(request)
if not security_middleware.check_origin(request):
raise Forbidden("Origin not allowed")
except Exception as e:
logger.warning("Security validation failed", error=str(e))
response = json_response({"error": str(e)}, status=400)
return security_middleware.add_security_headers(response)
# Rate limiting
if settings.RATE_LIMIT_ENABLED:
client_identifier = context_middleware.get_client_ip(request)
pattern = "api"
# Determine rate limit pattern based on endpoint
if '/auth/' in request.path:
pattern = "auth"
elif '/upload' in request.path:
pattern = "upload"
elif '/admin/' in request.path:
pattern = "heavy"
if not await rate_limit_middleware.check_rate_limit(request, client_identifier, pattern):
rate_info = await rate_limit_middleware.get_rate_limit_info(client_identifier, pattern)
response = json_response(
{
"error": "Rate limit exceeded",
"rate_limit": rate_info
},
status=429
)
return security_middleware.add_security_headers(response)
# Store rate limit info for response headers
request.ctx.rate_limit_info = await rate_limit_middleware.get_rate_limit_info(
client_identifier, pattern
)
# Authentication (for protected endpoints)
if not request.path.startswith('/api/system') and request.path != '/':
from app.core.database import db_manager
async with db_manager.get_session() as session:
token = await auth_middleware.extract_token(request)
if token:
user = await auth_middleware.validate_token(token, session)
if user:
request.ctx.user = user
user_id_var.set(str(user.id))
# Check permissions
if not await auth_middleware.check_permissions(user, request):
response = json_response({"error": "Insufficient permissions"}, status=403)
return security_middleware.add_security_headers(response)
# Update user activity
user.update_activity()
await session.commit()
# Store session for request handlers
request.ctx.db_session = session
async def response_middleware(request: Request, response: HTTPResponse):
"""Main response middleware pipeline"""
# Add security headers
response = security_middleware.add_security_headers(response)
# Add cryptographic headers for inter-node communication
response = await crypto_middleware.add_inter_node_headers(request, response)
# Add rate limit headers
if hasattr(request.ctx, 'rate_limit_info') and request.ctx.rate_limit_info:
rate_info = request.ctx.rate_limit_info
response.headers.update({
"X-RateLimit-Limit": str(rate_info.get('limit', 0)),
"X-RateLimit-Remaining": str(rate_info.get('remaining', 0)),
"X-RateLimit-Reset": str(rate_info.get('reset_time', 0))
})
# Add request ID to response
if hasattr(request.ctx, 'request_id'):
response.headers["X-Request-ID"] = request.ctx.request_id
# Log request completion
await context_middleware.log_request_completion(request, response)
return response
async def exception_middleware(request: Request, exception: Exception):
"""Global exception handling middleware"""
error_id = str(uuid.uuid4())
# Log the exception
logger.error(
"Unhandled exception",
error_id=error_id,
exception_type=type(exception).__name__,
exception_message=str(exception),
path=request.path,
method=request.method,
user_id=str(request.ctx.user.id) if hasattr(request.ctx, 'user') and request.ctx.user else None
)
# Handle different exception types
if isinstance(exception, Unauthorized):
response_data = {"error": "Authentication required", "error_id": error_id}
status = 401
elif isinstance(exception, Forbidden):
response_data = {"error": "Access forbidden", "error_id": error_id}
status = 403
elif isinstance(exception, TooManyRequests):
response_data = {"error": "Rate limit exceeded", "error_id": error_id}
status = 429
elif isinstance(exception, BadRequest):
response_data = {"error": str(exception), "error_id": error_id}
status = 400
else:
# Generic server error
response_data = {
"error": "Internal server error",
"error_id": error_id
}
status = 500
if settings.DEBUG:
response_data["debug"] = {
"type": type(exception).__name__,
"message": str(exception)
}
response = json_response(response_data, status=status)
return security_middleware.add_security_headers(response)
# Maintenance mode middleware
async def maintenance_middleware(request: Request):
"""Check for maintenance mode"""
if settings.MAINTENANCE_MODE and not request.path.startswith('/api/system'):
response = json_response({
"error": "Service temporarily unavailable",
"message": settings.MAINTENANCE_MESSAGE
}, status=503)
return security_middleware.add_security_headers(response)
# Helper functions for route decorators
async def check_auth(request: Request) -> User:
"""Check authentication for endpoint"""
if not hasattr(request.ctx, 'user') or not request.ctx.user:
raise Unauthorized("Authentication required")
return request.ctx.user
async def validate_request_data(request: Request, schema: Optional[Any] = None) -> Dict[str, Any]:
"""Validate request data against schema"""
try:
if request.method in ['POST', 'PUT', 'PATCH']:
# Get JSON data
if hasattr(request, 'json') and request.json:
data = request.json
else:
data = {}
# Basic validation - can be extended with pydantic schemas
if schema:
# Here you would implement schema validation
# For now, just return the data
pass
return data
return {}
except Exception as e:
raise BadRequest(f"Invalid request data: {str(e)}")
async def check_rate_limit(request: Request, pattern: str = "api") -> bool:
"""Check rate limit for request"""
client_identifier = context_middleware.get_client_ip(request)
if not await rate_limit_middleware.check_rate_limit(request, client_identifier, pattern):
rate_info = await rate_limit_middleware.get_rate_limit_info(client_identifier, pattern)
raise TooManyRequests(f"Rate limit exceeded: {rate_info}")
return True
# Decorator functions for convenience
def auth_required(func):
"""Decorator to require authentication"""
async def auth_wrapper(request: Request, *args, **kwargs):
await check_auth(request)
return await func(request, *args, **kwargs)
auth_wrapper.__name__ = f"{func.__name__}_auth_required"
return auth_wrapper
def require_auth(permissions=None):
"""Decorator to require authentication and optional permissions"""
def decorator(func):
async def require_auth_wrapper(request: Request, *args, **kwargs):
user = await check_auth(request)
# Check permissions if specified
if permissions:
# This is a placeholder - implement proper permission checking
pass
return await func(request, *args, **kwargs)
require_auth_wrapper.__name__ = f"{func.__name__}_require_auth"
return require_auth_wrapper
return decorator
def validate_json(schema=None):
"""Decorator to validate JSON request"""
def decorator(func):
async def validate_json_wrapper(request: Request, *args, **kwargs):
await validate_request_data(request, schema)
return await func(request, *args, **kwargs)
validate_json_wrapper.__name__ = f"{func.__name__}_validate_json"
return validate_json_wrapper
return decorator
def validate_request(schema=None):
"""Decorator to validate request data against schema"""
def decorator(func):
async def validate_request_wrapper(request: Request, *args, **kwargs):
await validate_request_data(request, schema)
return await func(request, *args, **kwargs)
validate_request_wrapper.__name__ = f"{func.__name__}_validate_request"
return validate_request_wrapper
return decorator
def apply_rate_limit(pattern: str = "api", limit: Optional[int] = None, window: Optional[int] = None):
"""Decorator to apply rate limiting"""
def decorator(func):
async def rate_limit_wrapper(request: Request, *args, **kwargs):
# Use custom limits if provided
if limit and window:
client_identifier = context_middleware.get_client_ip(request)
cache = await rate_limit_middleware.get_cache()
cache_key = CACHE_KEYS["rate_limit"].format(
pattern=pattern,
identifier=client_identifier
)
# Get current count
current_count = await cache.get(cache_key)
if current_count is None:
await cache.set(cache_key, "1", ttl=window)
elif int(current_count) >= limit:
raise TooManyRequests(f"Rate limit exceeded: {limit} per {window}s")
else:
await cache.incr(cache_key)
else:
# Use default rate limiting
await check_rate_limit(request, pattern)
return await func(request, *args, **kwargs)
rate_limit_wrapper.__name__ = f"{func.__name__}_rate_limit"
return rate_limit_wrapper
return decorator
# Create compatibility alias for the decorator syntax used in auth_routes
def rate_limit(limit: Optional[int] = None, window: Optional[int] = None, pattern: str = "api"):
"""Compatibility decorator for rate limiting with limit/window parameters"""
return apply_rate_limit(pattern=pattern, limit=limit, window=window)