uploader-bot/app/api/fastapi_middleware.py

594 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
FastAPI middleware адаптированный из Sanic middleware
Обеспечивает полную совместимость с существующей функциональностью
"""
import asyncio
import time
import uuid
import json
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, Callable
from fastapi import Request, Response, HTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from fastapi.responses import JSONResponse
import structlog
from app.core.config import settings, SecurityConfig, CACHE_KEYS
from app.core.database import get_cache
from app.core.logging import request_id_var, user_id_var, operation_var
from app.core.models.user import User
# Ed25519 криптографический модуль
try:
from app.core.crypto import get_ed25519_manager
CRYPTO_AVAILABLE = True
except ImportError:
CRYPTO_AVAILABLE = False
logger = structlog.get_logger(__name__)
class FastAPISecurityMiddleware(BaseHTTPMiddleware):
"""FastAPI Security middleware для валидации запросов и защиты"""
async def dispatch(self, request: Request, call_next):
# Handle OPTIONS requests for CORS
if request.method == 'OPTIONS':
response = Response(content='OK')
return self.add_security_headers(response)
# Security validations
try:
self.validate_request_size(request)
await self.validate_content_type(request)
if not self.check_origin(request):
raise HTTPException(status_code=403, detail="Origin not allowed")
except HTTPException:
raise
except Exception as e:
logger.warning("Security validation failed", error=str(e))
raise HTTPException(status_code=400, detail=str(e))
response = await call_next(request)
return self.add_security_headers(response)
def add_security_headers(self, response: Response) -> Response:
"""Add security headers to response"""
# CORS headers
response.headers.update({
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
"Access-Control-Allow-Headers": (
"Origin, Content-Type, Accept, Authorization, "
"X-Requested-With, X-API-Key, X-Request-ID, "
"X-Node-Communication, X-Node-ID, X-Node-Public-Key, X-Node-Signature"
),
"Access-Control-Max-Age": "86400",
# Security headers
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
"Referrer-Policy": "strict-origin-when-cross-origin",
"Permissions-Policy": "geolocation=(), microphone=(), camera=()",
# Custom headers
"X-API-Version": settings.PROJECT_VERSION,
})
# CSP header
csp_directives = "; ".join([
f"{directive} {' '.join(sources)}"
for directive, sources in SecurityConfig.CSP_DIRECTIVES.items()
])
response.headers["Content-Security-Policy"] = csp_directives
return response
def validate_request_size(self, request: Request) -> None:
"""Validate request size limits"""
content_length = request.headers.get('content-length')
if content_length:
size = int(content_length)
if size > SecurityConfig.MAX_REQUEST_SIZE:
raise HTTPException(status_code=413, detail=f"Request too large: {size} bytes")
async def validate_content_type(self, request: Request) -> None:
"""Validate content type for JSON requests"""
if request.method in ['POST', 'PUT', 'PATCH']:
content_type = request.headers.get('content-type', '')
if 'application/json' in content_type:
# Skip body reading here - it will be read by the route handler
# Just validate content-length header instead
content_length = request.headers.get('content-length')
if content_length and int(content_length) > SecurityConfig.MAX_JSON_SIZE:
raise HTTPException(status_code=413, detail="JSON payload too large")
def check_origin(self, request: Request) -> bool:
"""Check if request origin is allowed"""
origin = request.headers.get('origin')
if not origin:
return True # Allow requests without origin (direct API calls)
return any(
origin.startswith(allowed_origin.rstrip('/*'))
for allowed_origin in SecurityConfig.CORS_ORIGINS
)
class FastAPIRateLimitMiddleware(BaseHTTPMiddleware):
"""FastAPI Rate limiting middleware using Redis"""
def __init__(self, app):
super().__init__(app)
self.cache = None
async def get_cache(self):
"""Get cache instance"""
if not self.cache:
self.cache = await get_cache()
return self.cache
async def dispatch(self, request: Request, call_next):
if not settings.RATE_LIMIT_ENABLED:
return await call_next(request)
client_identifier = self.get_client_ip(request)
pattern = self.get_rate_limit_pattern(request)
if not await self.check_rate_limit(request, client_identifier, pattern):
rate_info = await self.get_rate_limit_info(client_identifier, pattern)
return JSONResponse(
content={
"error": "Rate limit exceeded",
"rate_limit": rate_info
},
status_code=429
)
# Store rate limit info for response headers
rate_info = await self.get_rate_limit_info(client_identifier, pattern)
response = await call_next(request)
# Add rate limit headers
if rate_info:
response.headers.update({
"X-RateLimit-Limit": str(rate_info.get('limit', 0)),
"X-RateLimit-Remaining": str(rate_info.get('remaining', 0)),
"X-RateLimit-Reset": str(rate_info.get('reset_time', 0))
})
return response
def get_client_ip(self, request: Request) -> str:
"""Get real client IP address"""
# Check for forwarded headers
forwarded_for = request.headers.get('x-forwarded-for')
if forwarded_for:
return forwarded_for.split(',')[0].strip()
real_ip = request.headers.get('x-real-ip')
if real_ip:
return real_ip
# Fallback to request IP
return getattr(request.client, 'host', '127.0.0.1')
def get_rate_limit_pattern(self, request: Request) -> str:
"""Determine rate limit pattern based on endpoint"""
path = request.url.path
if '/auth/' in path:
return "auth"
elif '/upload' in path:
return "upload"
elif '/admin/' in path:
return "heavy"
else:
return "api"
async def check_rate_limit(
self,
request: Request,
identifier: str,
pattern: str = "api"
) -> bool:
"""Check rate limit for identifier"""
try:
cache = await self.get_cache()
limits = SecurityConfig.RATE_LIMIT_PATTERNS.get(pattern, {
"requests": settings.RATE_LIMIT_REQUESTS,
"window": settings.RATE_LIMIT_WINDOW
})
cache_key = CACHE_KEYS["rate_limit"].format(
pattern=pattern,
identifier=identifier
)
# Get current count
current_count = await cache.get(cache_key)
if current_count is None:
# First request in window
await cache.set(cache_key, "1", ttl=limits["window"])
return True
current_count = int(current_count)
if current_count >= limits["requests"]:
# Rate limit exceeded
logger.warning(
"Rate limit exceeded",
identifier=identifier,
pattern=pattern,
count=current_count,
limit=limits["requests"]
)
return False
# Increment counter
await cache.incr(cache_key)
return True
except Exception as e:
logger.error("Rate limit check failed", error=str(e))
return True # Allow request if rate limiting fails
async def get_rate_limit_info(
self,
identifier: str,
pattern: str = "api"
) -> Dict[str, Any]:
"""Get rate limit information"""
try:
cache = await self.get_cache()
limits = SecurityConfig.RATE_LIMIT_PATTERNS.get(pattern, {
"requests": settings.RATE_LIMIT_REQUESTS,
"window": settings.RATE_LIMIT_WINDOW
})
cache_key = CACHE_KEYS["rate_limit"].format(
pattern=pattern,
identifier=identifier
)
current_count = await cache.get(cache_key) or "0"
ttl = await cache.redis.ttl(cache_key)
return {
"limit": limits["requests"],
"remaining": max(0, limits["requests"] - int(current_count)),
"reset_time": int(time.time()) + max(0, ttl),
"window": limits["window"]
}
except Exception as e:
logger.error("Failed to get rate limit info", error=str(e))
return {}
class FastAPICryptographicMiddleware(BaseHTTPMiddleware):
"""FastAPI Ed25519 cryptographic middleware для межузлового общения"""
async def dispatch(self, request: Request, call_next):
# Проверяем ed25519 подпись для межузловых запросов
if not await self.verify_inter_node_signature(request):
logger.warning("Inter-node signature verification failed")
return JSONResponse(
content={
"error": "Invalid cryptographic signature",
"message": "Inter-node communication requires valid ed25519 signature"
},
status_code=403
)
response = await call_next(request)
# Добавляем криптографические заголовки для межузловых ответов
return await self.add_inter_node_headers(request, response)
async def verify_inter_node_signature(self, request: Request) -> bool:
"""Проверить ed25519 подпись для межузлового сообщения"""
if not CRYPTO_AVAILABLE:
logger.warning("Crypto module not available, skipping signature verification")
return True
# Проверяем, является ли это межузловым сообщением
if not request.headers.get("x-node-communication") == "true":
return True # Не межузловое сообщение, пропускаем проверку
try:
crypto_manager = get_ed25519_manager()
# Получаем необходимые заголовки
signature = request.headers.get("x-node-signature")
node_id = request.headers.get("x-node-id")
public_key = request.headers.get("x-node-public-key")
if not all([signature, node_id, public_key]):
logger.warning("Missing cryptographic headers in inter-node request")
return False
# SKIP body reading for now - this causes issues with FastAPI
# Inter-node communication signature verification disabled temporarily
logger.debug("Inter-node signature verification skipped (body reading conflict)")
request.state.inter_node_communication = True
request.state.source_node_id = node_id
request.state.source_public_key = public_key
return True
try:
message_data = json.loads(body.decode())
# Проверяем подпись
is_valid = crypto_manager.verify_signature(
message_data, signature, public_key
)
if is_valid:
logger.debug(f"Valid signature verified for node {node_id}")
# Сохраняем информацию о ноде в state
request.state.inter_node_communication = True
request.state.source_node_id = node_id
request.state.source_public_key = public_key
return True
else:
logger.warning(f"Invalid signature from node {node_id}")
return False
except json.JSONDecodeError:
logger.warning("Invalid JSON in inter-node request")
return False
except Exception as e:
logger.error(f"Crypto verification error: {e}")
return False
async def add_inter_node_headers(self, request: Request, response: Response) -> Response:
"""Добавить криптографические заголовки для межузловых ответов"""
if not CRYPTO_AVAILABLE:
return response
# Добавляем заголовки только для межузловых сообщений
if hasattr(request.state, 'inter_node_communication') and request.state.inter_node_communication:
try:
crypto_manager = get_ed25519_manager()
# Добавляем информацию о нашей ноде
response.headers.update({
"X-Node-ID": crypto_manager.node_id,
"X-Node-Public-Key": crypto_manager.public_key_hex,
"X-Node-Communication": "true"
})
# Если есть тело ответа, подписываем его
if hasattr(response, 'body') and response.body:
try:
response_data = json.loads(response.body.decode())
signature = crypto_manager.sign_message(response_data)
response.headers["X-Node-Signature"] = signature
except (json.JSONDecodeError, AttributeError):
# Не JSON тело или нет body, пропускаем подпись
pass
except Exception as e:
logger.error(f"Error adding inter-node headers: {e}")
return response
class FastAPIRequestContextMiddleware(BaseHTTPMiddleware):
"""FastAPI Request context middleware для трекинга и логирования"""
async def dispatch(self, request: Request, call_next):
# Generate and set request ID
request_id = str(uuid.uuid4())
request.state.request_id = request_id
request_id_var.set(request_id)
# Set request start time
start_time = time.time()
request.state.start_time = start_time
# Extract client information
request.state.client_ip = self.get_client_ip(request)
request.state.user_agent = request.headers.get('user-agent', 'Unknown')
# Initialize context
request.state.user = None
logger.info(
"Request started",
method=request.method,
path=request.url.path,
client_ip=request.state.client_ip,
user_agent=request.state.user_agent
)
response = await call_next(request)
# Add request ID to response
response.headers["X-Request-ID"] = request_id
# Log request completion
duration = time.time() - start_time
logger.info(
"Request completed",
method=request.method,
path=request.url.path,
status_code=response.status_code,
duration_ms=round(duration * 1000, 2),
client_ip=request.state.client_ip,
user_id=str(request.state.user.id) if hasattr(request.state, 'user') and request.state.user else None
)
return response
def get_client_ip(self, request: Request) -> str:
"""Get real client IP address"""
# Check for forwarded headers
forwarded_for = request.headers.get('x-forwarded-for')
if forwarded_for:
return forwarded_for.split(',')[0].strip()
real_ip = request.headers.get('x-real-ip')
if real_ip:
return real_ip
# Fallback to request IP
return getattr(request.client, 'host', '127.0.0.1')
class FastAPIAuthenticationMiddleware(BaseHTTPMiddleware):
"""FastAPI Authentication middleware для API доступа"""
async def dispatch(self, request: Request, call_next):
# Skip authentication for system endpoints and root
if request.url.path.startswith('/api/system') or request.url.path == '/':
return await call_next(request)
# Extract and validate token
token = await self.extract_token(request)
if token:
from app.core.database import db_manager
async with db_manager.get_session() as session:
user = await self.validate_token(token, session)
if user:
request.state.user = user
user_id_var.set(str(user.id))
# Check permissions
if not await self.check_permissions(user, request):
return JSONResponse(
content={"error": "Insufficient permissions"},
status_code=403
)
# Update user activity
user.update_activity()
await session.commit()
return await call_next(request)
async def extract_token(self, request: Request) -> Optional[str]:
"""Extract authentication token from request"""
# Check Authorization header
auth_header = request.headers.get('authorization')
if auth_header and auth_header.startswith('Bearer '):
return auth_header[7:] # Remove 'Bearer ' prefix
# Check X-API-Key header
api_key = request.headers.get('x-api-key')
if api_key:
return api_key
# Check query parameter (less secure, for backward compatibility)
return request.query_params.get('token')
async def validate_token(self, token: str, session) -> Optional[User]:
"""Validate authentication token and return user"""
if not token:
return None
try:
# Импортируем функции безопасности
from app.core.security import verify_access_token
# Пытаемся декодировать как JWT токен (приоритет для auth.twa)
try:
payload = verify_access_token(token)
if payload and 'user_id' in payload:
user_id = uuid.UUID(payload['user_id'])
user = await User.get_by_id(session, user_id)
if user and user.is_active:
return user
except Exception as jwt_error:
logger.debug("JWT validation failed, trying legacy format", error=str(jwt_error))
# Fallback: Legacy token format (user_id:hash)
if ':' in token:
user_id_str, token_hash = token.split(':', 1)
try:
user_id = uuid.UUID(user_id_str)
user = await User.get_by_id(session, user_id)
if user and hasattr(user, 'verify_token') and user.verify_token(token_hash):
return user
except (ValueError, AttributeError):
pass
# Fallback: try to find user by API token in user model
# This would require implementing token storage in User model
return None
except Exception as e:
logger.error("Token validation failed", token=token[:8] + "...", error=str(e))
return None
async def check_permissions(self, user: User, request: Request) -> bool:
"""Check if user has required permissions for the endpoint"""
# Implement permission checking based on endpoint and user role
endpoint = request.url.path
method = request.method
# Admin endpoints
if '/admin/' in endpoint:
return user.is_admin
# Moderator endpoints
if '/mod/' in endpoint:
return user.is_moderator
# User-specific endpoints
if '/user/' in endpoint and method in ['POST', 'PUT', 'DELETE']:
return user.has_permission('user:write')
# Content upload endpoints
if '/upload' in endpoint or '/content' in endpoint and method == 'POST':
return user.can_upload_content()
# Default: allow read access for authenticated users
return True
# FastAPI Dependencies для использования в роутах
from fastapi import Depends, HTTPException
async def get_current_user(request: Request) -> Optional[User]:
"""FastAPI dependency для получения текущего пользователя"""
if hasattr(request.state, 'user') and request.state.user:
return request.state.user
return None
async def require_auth(request: Request) -> User:
"""FastAPI dependency для требования аутентификации"""
user = await get_current_user(request)
if not user:
raise HTTPException(status_code=401, detail="Authentication required")
return user
async def check_permissions(permission: str):
"""FastAPI dependency для проверки разрешений"""
def permission_checker(user: User = Depends(require_auth)) -> User:
if not user.has_permission(permission):
raise HTTPException(status_code=403, detail=f"Permission required: {permission}")
return user
return permission_checker
async def require_admin(user: User = Depends(require_auth)) -> User:
"""FastAPI dependency для требования административных прав"""
if not hasattr(user, 'is_admin') or not user.is_admin:
raise HTTPException(status_code=403, detail="Administrative privileges required")
return user
async def check_rate_limit(pattern: str = "api"):
"""FastAPI dependency для проверки rate limit"""
def rate_limit_checker(request: Request) -> bool:
# Rate limiting уже проверяется в middleware
# Это dependency для дополнительных проверок если нужно
return True
return rate_limit_checker