742 lines
28 KiB
Python
742 lines
28 KiB
Python
"""
|
||
Enhanced API middleware with security, rate limiting, monitoring and ed25519 signatures
|
||
"""
|
||
import asyncio
|
||
import time
|
||
import uuid
|
||
from datetime import datetime, timedelta
|
||
from typing import Optional, Dict, Any, Callable
|
||
import json
|
||
|
||
from sanic import Request, HTTPResponse
|
||
from sanic.response import json as json_response, text as text_response
|
||
from sanic.exceptions import Unauthorized, Forbidden, BadRequest
|
||
|
||
# TooManyRequests может не существовать в этой версии Sanic, создадим собственное
|
||
class TooManyRequests(Exception):
|
||
"""Custom exception for rate limiting"""
|
||
pass
|
||
import structlog
|
||
|
||
from app.core.config import settings, SecurityConfig, CACHE_KEYS
|
||
from app.core.database import get_cache
|
||
from app.core.logging import request_id_var, user_id_var, operation_var, log_performance
|
||
from app.core.models.user import User
|
||
from app.core.models.base import BaseModel
|
||
|
||
# Ed25519 криптографический модуль
|
||
try:
|
||
from app.core.crypto import get_ed25519_manager
|
||
CRYPTO_AVAILABLE = True
|
||
except ImportError:
|
||
CRYPTO_AVAILABLE = False
|
||
|
||
logger = structlog.get_logger(__name__)
|
||
|
||
|
||
class SecurityMiddleware:
|
||
"""Security middleware for request validation and protection"""
|
||
|
||
@staticmethod
|
||
def add_security_headers(response: HTTPResponse) -> HTTPResponse:
|
||
"""Add security headers to response"""
|
||
# CORS headers
|
||
response.headers.update({
|
||
"Access-Control-Allow-Origin": "*", # Will be restricted based on request
|
||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||
"Access-Control-Allow-Headers": (
|
||
"Origin, Content-Type, Accept, Authorization, "
|
||
"X-Requested-With, X-API-Key, X-Request-ID"
|
||
),
|
||
"Access-Control-Max-Age": "86400",
|
||
|
||
# Security headers
|
||
"X-Content-Type-Options": "nosniff",
|
||
"X-Frame-Options": "DENY",
|
||
"X-XSS-Protection": "1; mode=block",
|
||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
|
||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||
"Permissions-Policy": "geolocation=(), microphone=(), camera=()",
|
||
|
||
# Custom headers
|
||
"X-API-Version": settings.PROJECT_VERSION,
|
||
"X-Request-ID": getattr(getattr(Request, 'ctx', None), 'request_id', 'unknown')
|
||
})
|
||
|
||
# CSP header
|
||
csp_directives = "; ".join([
|
||
f"{directive} {' '.join(sources)}"
|
||
for directive, sources in SecurityConfig.CSP_DIRECTIVES.items()
|
||
])
|
||
response.headers["Content-Security-Policy"] = csp_directives
|
||
|
||
return response
|
||
|
||
@staticmethod
|
||
def validate_request_size(request: Request) -> None:
|
||
"""Validate request size limits"""
|
||
content_length = request.headers.get('content-length')
|
||
if content_length:
|
||
size = int(content_length)
|
||
if size > SecurityConfig.MAX_REQUEST_SIZE:
|
||
raise BadRequest(f"Request too large: {size} bytes")
|
||
|
||
@staticmethod
|
||
def validate_content_type(request: Request) -> None:
|
||
"""Validate content type for JSON requests"""
|
||
if request.method in ['POST', 'PUT', 'PATCH']:
|
||
content_type = request.headers.get('content-type', '')
|
||
if 'application/json' in content_type:
|
||
try:
|
||
# Validate JSON size
|
||
if hasattr(request, 'body') and len(request.body) > SecurityConfig.MAX_JSON_SIZE:
|
||
raise BadRequest("JSON payload too large")
|
||
except Exception:
|
||
raise BadRequest("Invalid JSON payload")
|
||
|
||
@staticmethod
|
||
def check_origin(request: Request) -> bool:
|
||
"""Check if request origin is allowed"""
|
||
origin = request.headers.get('origin')
|
||
if not origin:
|
||
return True # Allow requests without origin (direct API calls)
|
||
|
||
return any(
|
||
origin.startswith(allowed_origin.rstrip('/*'))
|
||
for allowed_origin in SecurityConfig.CORS_ORIGINS
|
||
)
|
||
|
||
|
||
class RateLimitMiddleware:
|
||
"""Rate limiting middleware using Redis"""
|
||
|
||
def __init__(self):
|
||
self.cache = None
|
||
|
||
async def get_cache(self):
|
||
"""Get cache instance"""
|
||
if not self.cache:
|
||
self.cache = await get_cache()
|
||
return self.cache
|
||
|
||
async def check_rate_limit(
|
||
self,
|
||
request: Request,
|
||
identifier: str,
|
||
pattern: str = "api"
|
||
) -> bool:
|
||
"""Check rate limit for identifier"""
|
||
try:
|
||
cache = await self.get_cache()
|
||
limits = SecurityConfig.RATE_LIMIT_PATTERNS.get(pattern, {
|
||
"requests": settings.RATE_LIMIT_REQUESTS,
|
||
"window": settings.RATE_LIMIT_WINDOW
|
||
})
|
||
|
||
cache_key = CACHE_KEYS["rate_limit"].format(
|
||
pattern=pattern,
|
||
identifier=identifier
|
||
)
|
||
|
||
# Get current count
|
||
current_count = await cache.get(cache_key)
|
||
if current_count is None:
|
||
# First request in window
|
||
await cache.set(cache_key, "1", ttl=limits["window"])
|
||
return True
|
||
|
||
current_count = int(current_count)
|
||
if current_count >= limits["requests"]:
|
||
# Rate limit exceeded
|
||
logger.warning(
|
||
"Rate limit exceeded",
|
||
identifier=identifier,
|
||
pattern=pattern,
|
||
count=current_count,
|
||
limit=limits["requests"]
|
||
)
|
||
return False
|
||
|
||
# Increment counter
|
||
await cache.incr(cache_key)
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error("Rate limit check failed", error=str(e))
|
||
return True # Allow request if rate limiting fails
|
||
|
||
async def get_rate_limit_info(
|
||
self,
|
||
identifier: str,
|
||
pattern: str = "api"
|
||
) -> Dict[str, Any]:
|
||
"""Get rate limit information"""
|
||
try:
|
||
cache = await self.get_cache()
|
||
limits = SecurityConfig.RATE_LIMIT_PATTERNS.get(pattern, {
|
||
"requests": settings.RATE_LIMIT_REQUESTS,
|
||
"window": settings.RATE_LIMIT_WINDOW
|
||
})
|
||
|
||
cache_key = CACHE_KEYS["rate_limit"].format(
|
||
pattern=pattern,
|
||
identifier=identifier
|
||
)
|
||
|
||
current_count = await cache.get(cache_key) or "0"
|
||
ttl = await cache.redis.ttl(cache_key)
|
||
|
||
return {
|
||
"limit": limits["requests"],
|
||
"remaining": max(0, limits["requests"] - int(current_count)),
|
||
"reset_time": int(time.time()) + max(0, ttl),
|
||
"window": limits["window"]
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error("Failed to get rate limit info", error=str(e))
|
||
return {}
|
||
|
||
|
||
class AuthenticationMiddleware:
|
||
"""Authentication middleware for API access"""
|
||
|
||
@staticmethod
|
||
async def extract_token(request: Request) -> Optional[str]:
|
||
"""Extract authentication token from request"""
|
||
# Check Authorization header
|
||
auth_header = request.headers.get('authorization')
|
||
if auth_header and auth_header.startswith('Bearer '):
|
||
return auth_header[7:] # Remove 'Bearer ' prefix
|
||
|
||
# Check X-API-Key header
|
||
api_key = request.headers.get('x-api-key')
|
||
if api_key:
|
||
return api_key
|
||
|
||
# Check query parameter (less secure, for backward compatibility)
|
||
return request.args.get('token')
|
||
|
||
@staticmethod
|
||
async def validate_token(token: str, session) -> Optional[User]:
|
||
"""Validate authentication token and return user"""
|
||
if not token:
|
||
return None
|
||
|
||
try:
|
||
# For now, implement simple token validation
|
||
# In production, implement JWT or database token validation
|
||
|
||
# Example: if token format is user_id:hash
|
||
if ':' in token:
|
||
user_id_str, token_hash = token.split(':', 1)
|
||
try:
|
||
user_id = uuid.UUID(user_id_str)
|
||
user = await User.get_by_id(session, user_id)
|
||
if user and user.verify_token(token_hash): # Implement in User model
|
||
return user
|
||
except (ValueError, AttributeError):
|
||
pass
|
||
|
||
# Fallback: try to find user by API token
|
||
# This would require implementing token storage in User model
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error("Token validation failed", token=token[:8] + "...", error=str(e))
|
||
return None
|
||
|
||
@staticmethod
|
||
async def check_permissions(user: User, request: Request) -> bool:
|
||
"""Check if user has required permissions for the endpoint"""
|
||
# Implement permission checking based on endpoint and user role
|
||
endpoint = request.path
|
||
method = request.method
|
||
|
||
# Admin endpoints
|
||
if '/admin/' in endpoint:
|
||
return user.is_admin
|
||
|
||
# Moderator endpoints
|
||
if '/mod/' in endpoint:
|
||
return user.is_moderator
|
||
|
||
# User-specific endpoints
|
||
if '/user/' in endpoint and method in ['POST', 'PUT', 'DELETE']:
|
||
return user.has_permission('user:write')
|
||
|
||
# Content upload endpoints
|
||
if '/upload' in endpoint or '/content' in endpoint and method == 'POST':
|
||
return user.can_upload_content()
|
||
|
||
# Default: allow read access for authenticated users
|
||
return True
|
||
|
||
|
||
class CryptographicMiddleware:
|
||
"""Ed25519 cryptographic middleware for inter-node communication"""
|
||
|
||
@staticmethod
|
||
async def verify_inter_node_signature(request: Request) -> bool:
|
||
"""Проверить ed25519 подпись для межузлового сообщения"""
|
||
if not CRYPTO_AVAILABLE:
|
||
logger.warning("Crypto module not available, skipping signature verification")
|
||
return True
|
||
|
||
# Проверяем, является ли это межузловым сообщением
|
||
if not request.headers.get("X-Node-Communication") == "true":
|
||
return True # Не межузловое сообщение, пропускаем проверку
|
||
|
||
try:
|
||
crypto_manager = get_ed25519_manager()
|
||
|
||
# Получаем необходимые заголовки
|
||
signature = request.headers.get("X-Node-Signature")
|
||
node_id = request.headers.get("X-Node-ID")
|
||
public_key = request.headers.get("X-Node-Public-Key")
|
||
|
||
if not all([signature, node_id, public_key]):
|
||
logger.warning("Missing cryptographic headers in inter-node request")
|
||
return False
|
||
|
||
# Читаем тело сообщения для проверки подписи
|
||
if hasattr(request, 'body') and request.body:
|
||
try:
|
||
message_data = json.loads(request.body.decode())
|
||
|
||
# Проверяем подпись
|
||
is_valid = crypto_manager.verify_signature(
|
||
message_data, signature, public_key
|
||
)
|
||
|
||
if is_valid:
|
||
logger.debug(f"Valid signature verified for node {node_id}")
|
||
# Сохраняем информацию о ноде в контексте
|
||
request.ctx.inter_node_communication = True
|
||
request.ctx.source_node_id = node_id
|
||
request.ctx.source_public_key = public_key
|
||
return True
|
||
else:
|
||
logger.warning(f"Invalid signature from node {node_id}")
|
||
return False
|
||
|
||
except json.JSONDecodeError:
|
||
logger.warning("Invalid JSON in inter-node request")
|
||
return False
|
||
else:
|
||
logger.warning("Empty body in inter-node request")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"Crypto verification error: {e}")
|
||
return False
|
||
|
||
@staticmethod
|
||
async def add_inter_node_headers(request: Request, response: HTTPResponse) -> HTTPResponse:
|
||
"""Добавить криптографические заголовки для межузловых ответов"""
|
||
if not CRYPTO_AVAILABLE:
|
||
return response
|
||
|
||
# Добавляем заголовки только для межузловых сообщений
|
||
if hasattr(request.ctx, 'inter_node_communication') and request.ctx.inter_node_communication:
|
||
try:
|
||
crypto_manager = get_ed25519_manager()
|
||
|
||
# Добавляем информацию о нашей ноде
|
||
response.headers.update({
|
||
"X-Node-ID": crypto_manager.node_id,
|
||
"X-Node-Public-Key": crypto_manager.public_key_hex,
|
||
"X-Node-Communication": "true"
|
||
})
|
||
|
||
# Если есть тело ответа, подписываем его
|
||
if response.body:
|
||
try:
|
||
response_data = json.loads(response.body.decode())
|
||
signature = crypto_manager.sign_message(response_data)
|
||
response.headers["X-Node-Signature"] = signature
|
||
except json.JSONDecodeError:
|
||
# Не JSON тело, пропускаем подпись
|
||
pass
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error adding inter-node headers: {e}")
|
||
|
||
return response
|
||
|
||
|
||
class RequestContextMiddleware:
|
||
"""Request context middleware for tracking and logging"""
|
||
|
||
@staticmethod
|
||
def generate_request_id() -> str:
|
||
"""Generate unique request ID"""
|
||
return str(uuid.uuid4())
|
||
|
||
@staticmethod
|
||
async def add_request_context(request: Request) -> None:
|
||
"""Add request context for logging and tracking"""
|
||
# Generate and set request ID
|
||
request_id = RequestContextMiddleware.generate_request_id()
|
||
request.ctx.request_id = request_id
|
||
request_id_var.set(request_id)
|
||
|
||
# Set request start time
|
||
request.ctx.start_time = time.time()
|
||
|
||
# Extract client information
|
||
request.ctx.client_ip = RequestContextMiddleware.get_client_ip(request)
|
||
request.ctx.user_agent = request.headers.get('user-agent', 'Unknown')
|
||
|
||
# Initialize context
|
||
request.ctx.user = None
|
||
request.ctx.rate_limit_info = {}
|
||
|
||
logger.info(
|
||
"Request started",
|
||
method=request.method,
|
||
path=request.path,
|
||
client_ip=request.ctx.client_ip,
|
||
user_agent=request.ctx.user_agent
|
||
)
|
||
|
||
@staticmethod
|
||
def get_client_ip(request: Request) -> str:
|
||
"""Get real client IP address"""
|
||
# Check for forwarded headers
|
||
forwarded_for = request.headers.get('x-forwarded-for')
|
||
if forwarded_for:
|
||
return forwarded_for.split(',')[0].strip()
|
||
|
||
real_ip = request.headers.get('x-real-ip')
|
||
if real_ip:
|
||
return real_ip
|
||
|
||
# Fallback to request IP
|
||
return getattr(request, 'ip', '127.0.0.1')
|
||
|
||
@staticmethod
|
||
async def log_request_completion(request: Request, response: HTTPResponse) -> None:
|
||
"""Log request completion with metrics"""
|
||
duration = time.time() - getattr(request.ctx, 'start_time', time.time())
|
||
|
||
logger.info(
|
||
"Request completed",
|
||
method=request.method,
|
||
path=request.path,
|
||
status_code=response.status,
|
||
duration_ms=round(duration * 1000, 2),
|
||
response_size=len(response.body) if response.body else 0,
|
||
client_ip=getattr(request.ctx, 'client_ip', 'unknown'),
|
||
user_id=str(request.ctx.user.id) if request.ctx.user else None
|
||
)
|
||
|
||
|
||
# Initialize middleware instances
|
||
security_middleware = SecurityMiddleware()
|
||
rate_limit_middleware = RateLimitMiddleware()
|
||
auth_middleware = AuthenticationMiddleware()
|
||
context_middleware = RequestContextMiddleware()
|
||
crypto_middleware = CryptographicMiddleware()
|
||
|
||
|
||
async def request_middleware(request: Request):
|
||
"""Main request middleware pipeline"""
|
||
|
||
# Handle OPTIONS requests for CORS
|
||
if request.method == 'OPTIONS':
|
||
response = text_response('OK')
|
||
return security_middleware.add_security_headers(response)
|
||
|
||
# Add request context
|
||
await context_middleware.add_request_context(request)
|
||
|
||
# Cryptographic signature verification for inter-node communication
|
||
if not await crypto_middleware.verify_inter_node_signature(request):
|
||
logger.warning("Inter-node signature verification failed")
|
||
response = json_response({
|
||
"error": "Invalid cryptographic signature",
|
||
"message": "Inter-node communication requires valid ed25519 signature"
|
||
}, status=403)
|
||
return security_middleware.add_security_headers(response)
|
||
|
||
# Security validations
|
||
try:
|
||
security_middleware.validate_request_size(request)
|
||
security_middleware.validate_content_type(request)
|
||
|
||
if not security_middleware.check_origin(request):
|
||
raise Forbidden("Origin not allowed")
|
||
|
||
except Exception as e:
|
||
logger.warning("Security validation failed", error=str(e))
|
||
response = json_response({"error": str(e)}, status=400)
|
||
return security_middleware.add_security_headers(response)
|
||
|
||
# Rate limiting
|
||
if settings.RATE_LIMIT_ENABLED:
|
||
client_identifier = context_middleware.get_client_ip(request)
|
||
pattern = "api"
|
||
|
||
# Determine rate limit pattern based on endpoint
|
||
if '/auth/' in request.path:
|
||
pattern = "auth"
|
||
elif '/upload' in request.path:
|
||
pattern = "upload"
|
||
elif '/admin/' in request.path:
|
||
pattern = "heavy"
|
||
|
||
if not await rate_limit_middleware.check_rate_limit(request, client_identifier, pattern):
|
||
rate_info = await rate_limit_middleware.get_rate_limit_info(client_identifier, pattern)
|
||
response = json_response(
|
||
{
|
||
"error": "Rate limit exceeded",
|
||
"rate_limit": rate_info
|
||
},
|
||
status=429
|
||
)
|
||
return security_middleware.add_security_headers(response)
|
||
|
||
# Store rate limit info for response headers
|
||
request.ctx.rate_limit_info = await rate_limit_middleware.get_rate_limit_info(
|
||
client_identifier, pattern
|
||
)
|
||
|
||
# Authentication (for protected endpoints)
|
||
if not request.path.startswith('/api/system') and request.path != '/':
|
||
from app.core.database import db_manager
|
||
async with db_manager.get_session() as session:
|
||
token = await auth_middleware.extract_token(request)
|
||
if token:
|
||
user = await auth_middleware.validate_token(token, session)
|
||
if user:
|
||
request.ctx.user = user
|
||
user_id_var.set(str(user.id))
|
||
|
||
# Check permissions
|
||
if not await auth_middleware.check_permissions(user, request):
|
||
response = json_response({"error": "Insufficient permissions"}, status=403)
|
||
return security_middleware.add_security_headers(response)
|
||
|
||
# Update user activity
|
||
user.update_activity()
|
||
await session.commit()
|
||
|
||
# Store session for request handlers
|
||
request.ctx.db_session = session
|
||
|
||
|
||
async def response_middleware(request: Request, response: HTTPResponse):
|
||
"""Main response middleware pipeline"""
|
||
|
||
# Add security headers
|
||
response = security_middleware.add_security_headers(response)
|
||
|
||
# Add cryptographic headers for inter-node communication
|
||
response = await crypto_middleware.add_inter_node_headers(request, response)
|
||
|
||
# Add rate limit headers
|
||
if hasattr(request.ctx, 'rate_limit_info') and request.ctx.rate_limit_info:
|
||
rate_info = request.ctx.rate_limit_info
|
||
response.headers.update({
|
||
"X-RateLimit-Limit": str(rate_info.get('limit', 0)),
|
||
"X-RateLimit-Remaining": str(rate_info.get('remaining', 0)),
|
||
"X-RateLimit-Reset": str(rate_info.get('reset_time', 0))
|
||
})
|
||
|
||
# Add request ID to response
|
||
if hasattr(request.ctx, 'request_id'):
|
||
response.headers["X-Request-ID"] = request.ctx.request_id
|
||
|
||
# Log request completion
|
||
await context_middleware.log_request_completion(request, response)
|
||
|
||
return response
|
||
|
||
|
||
async def exception_middleware(request: Request, exception: Exception):
|
||
"""Global exception handling middleware"""
|
||
|
||
error_id = str(uuid.uuid4())
|
||
|
||
# Log the exception
|
||
logger.error(
|
||
"Unhandled exception",
|
||
error_id=error_id,
|
||
exception_type=type(exception).__name__,
|
||
exception_message=str(exception),
|
||
path=request.path,
|
||
method=request.method,
|
||
user_id=str(request.ctx.user.id) if hasattr(request.ctx, 'user') and request.ctx.user else None
|
||
)
|
||
|
||
# Handle different exception types
|
||
if isinstance(exception, Unauthorized):
|
||
response_data = {"error": "Authentication required", "error_id": error_id}
|
||
status = 401
|
||
elif isinstance(exception, Forbidden):
|
||
response_data = {"error": "Access forbidden", "error_id": error_id}
|
||
status = 403
|
||
elif isinstance(exception, TooManyRequests):
|
||
response_data = {"error": "Rate limit exceeded", "error_id": error_id}
|
||
status = 429
|
||
elif isinstance(exception, BadRequest):
|
||
response_data = {"error": str(exception), "error_id": error_id}
|
||
status = 400
|
||
else:
|
||
# Generic server error
|
||
response_data = {
|
||
"error": "Internal server error",
|
||
"error_id": error_id
|
||
}
|
||
status = 500
|
||
|
||
if settings.DEBUG:
|
||
response_data["debug"] = {
|
||
"type": type(exception).__name__,
|
||
"message": str(exception)
|
||
}
|
||
|
||
response = json_response(response_data, status=status)
|
||
return security_middleware.add_security_headers(response)
|
||
|
||
|
||
# Maintenance mode middleware
|
||
async def maintenance_middleware(request: Request):
|
||
"""Check for maintenance mode"""
|
||
if settings.MAINTENANCE_MODE and not request.path.startswith('/api/system'):
|
||
response = json_response({
|
||
"error": "Service temporarily unavailable",
|
||
"message": settings.MAINTENANCE_MESSAGE
|
||
}, status=503)
|
||
return security_middleware.add_security_headers(response)
|
||
|
||
|
||
# Helper functions for route decorators
|
||
async def check_auth(request: Request) -> User:
|
||
"""Check authentication for endpoint"""
|
||
if not hasattr(request.ctx, 'user') or not request.ctx.user:
|
||
raise Unauthorized("Authentication required")
|
||
return request.ctx.user
|
||
|
||
|
||
async def validate_request_data(request: Request, schema: Optional[Any] = None) -> Dict[str, Any]:
|
||
"""Validate request data against schema"""
|
||
try:
|
||
if request.method in ['POST', 'PUT', 'PATCH']:
|
||
# Get JSON data
|
||
if hasattr(request, 'json') and request.json:
|
||
data = request.json
|
||
else:
|
||
data = {}
|
||
|
||
# Basic validation - can be extended with pydantic schemas
|
||
if schema:
|
||
# Here you would implement schema validation
|
||
# For now, just return the data
|
||
pass
|
||
|
||
return data
|
||
|
||
return {}
|
||
except Exception as e:
|
||
raise BadRequest(f"Invalid request data: {str(e)}")
|
||
|
||
|
||
async def check_rate_limit(request: Request, pattern: str = "api") -> bool:
|
||
"""Check rate limit for request"""
|
||
client_identifier = context_middleware.get_client_ip(request)
|
||
|
||
if not await rate_limit_middleware.check_rate_limit(request, client_identifier, pattern):
|
||
rate_info = await rate_limit_middleware.get_rate_limit_info(client_identifier, pattern)
|
||
raise TooManyRequests(f"Rate limit exceeded: {rate_info}")
|
||
|
||
return True
|
||
|
||
|
||
# Decorator functions for convenience
|
||
def auth_required(func):
|
||
"""Decorator to require authentication"""
|
||
async def auth_wrapper(request: Request, *args, **kwargs):
|
||
await check_auth(request)
|
||
return await func(request, *args, **kwargs)
|
||
auth_wrapper.__name__ = f"{func.__name__}_auth_required"
|
||
return auth_wrapper
|
||
|
||
|
||
def require_auth(permissions=None):
|
||
"""Decorator to require authentication and optional permissions"""
|
||
def decorator(func):
|
||
async def require_auth_wrapper(request: Request, *args, **kwargs):
|
||
user = await check_auth(request)
|
||
|
||
# Check permissions if specified
|
||
if permissions:
|
||
# This is a placeholder - implement proper permission checking
|
||
pass
|
||
|
||
return await func(request, *args, **kwargs)
|
||
require_auth_wrapper.__name__ = f"{func.__name__}_require_auth"
|
||
return require_auth_wrapper
|
||
return decorator
|
||
|
||
|
||
def validate_json(schema=None):
|
||
"""Decorator to validate JSON request"""
|
||
def decorator(func):
|
||
async def validate_json_wrapper(request: Request, *args, **kwargs):
|
||
await validate_request_data(request, schema)
|
||
return await func(request, *args, **kwargs)
|
||
validate_json_wrapper.__name__ = f"{func.__name__}_validate_json"
|
||
return validate_json_wrapper
|
||
return decorator
|
||
|
||
|
||
def validate_request(schema=None):
|
||
"""Decorator to validate request data against schema"""
|
||
def decorator(func):
|
||
async def validate_request_wrapper(request: Request, *args, **kwargs):
|
||
await validate_request_data(request, schema)
|
||
return await func(request, *args, **kwargs)
|
||
validate_request_wrapper.__name__ = f"{func.__name__}_validate_request"
|
||
return validate_request_wrapper
|
||
return decorator
|
||
|
||
|
||
def apply_rate_limit(pattern: str = "api", limit: Optional[int] = None, window: Optional[int] = None):
|
||
"""Decorator to apply rate limiting"""
|
||
def decorator(func):
|
||
async def rate_limit_wrapper(request: Request, *args, **kwargs):
|
||
# Use custom limits if provided
|
||
if limit and window:
|
||
client_identifier = context_middleware.get_client_ip(request)
|
||
cache = await rate_limit_middleware.get_cache()
|
||
|
||
cache_key = CACHE_KEYS["rate_limit"].format(
|
||
pattern=pattern,
|
||
identifier=client_identifier
|
||
)
|
||
|
||
# Get current count
|
||
current_count = await cache.get(cache_key)
|
||
if current_count is None:
|
||
await cache.set(cache_key, "1", ttl=window)
|
||
elif int(current_count) >= limit:
|
||
raise TooManyRequests(f"Rate limit exceeded: {limit} per {window}s")
|
||
else:
|
||
await cache.incr(cache_key)
|
||
else:
|
||
# Use default rate limiting
|
||
await check_rate_limit(request, pattern)
|
||
|
||
return await func(request, *args, **kwargs)
|
||
rate_limit_wrapper.__name__ = f"{func.__name__}_rate_limit"
|
||
return rate_limit_wrapper
|
||
return decorator
|
||
|
||
|
||
# Create compatibility alias for the decorator syntax used in auth_routes
|
||
def rate_limit(limit: Optional[int] = None, window: Optional[int] = None, pattern: str = "api"):
|
||
"""Compatibility decorator for rate limiting with limit/window parameters"""
|
||
return apply_rate_limit(pattern=pattern, limit=limit, window=window)
|