594 lines
24 KiB
Python
594 lines
24 KiB
Python
"""
|
||
FastAPI middleware адаптированный из Sanic middleware
|
||
Обеспечивает полную совместимость с существующей функциональностью
|
||
"""
|
||
|
||
import asyncio
|
||
import time
|
||
import uuid
|
||
import json
|
||
from datetime import datetime, timedelta
|
||
from typing import Optional, Dict, Any, Callable
|
||
from fastapi import Request, Response, HTTPException
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
from fastapi.responses import JSONResponse
|
||
import structlog
|
||
|
||
from app.core.config import settings, SecurityConfig, CACHE_KEYS
|
||
from app.core.database import get_cache
|
||
from app.core.logging import request_id_var, user_id_var, operation_var
|
||
from app.core.models.user import User
|
||
|
||
# Ed25519 криптографический модуль
|
||
try:
|
||
from app.core.crypto import get_ed25519_manager
|
||
CRYPTO_AVAILABLE = True
|
||
except ImportError:
|
||
CRYPTO_AVAILABLE = False
|
||
|
||
logger = structlog.get_logger(__name__)
|
||
|
||
|
||
class FastAPISecurityMiddleware(BaseHTTPMiddleware):
|
||
"""FastAPI Security middleware для валидации запросов и защиты"""
|
||
|
||
async def dispatch(self, request: Request, call_next):
|
||
# Handle OPTIONS requests for CORS
|
||
if request.method == 'OPTIONS':
|
||
response = Response(content='OK')
|
||
return self.add_security_headers(response)
|
||
|
||
# Security validations
|
||
try:
|
||
self.validate_request_size(request)
|
||
await self.validate_content_type(request)
|
||
|
||
if not self.check_origin(request):
|
||
raise HTTPException(status_code=403, detail="Origin not allowed")
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.warning("Security validation failed", error=str(e))
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
|
||
response = await call_next(request)
|
||
return self.add_security_headers(response)
|
||
|
||
def add_security_headers(self, response: Response) -> Response:
|
||
"""Add security headers to response"""
|
||
# CORS headers
|
||
response.headers.update({
|
||
"Access-Control-Allow-Origin": "*",
|
||
"Access-Control-Allow-Methods": "GET, POST, PUT, DELETE, OPTIONS",
|
||
"Access-Control-Allow-Headers": (
|
||
"Origin, Content-Type, Accept, Authorization, "
|
||
"X-Requested-With, X-API-Key, X-Request-ID, "
|
||
"X-Node-Communication, X-Node-ID, X-Node-Public-Key, X-Node-Signature"
|
||
),
|
||
"Access-Control-Max-Age": "86400",
|
||
|
||
# Security headers
|
||
"X-Content-Type-Options": "nosniff",
|
||
"X-Frame-Options": "DENY",
|
||
"X-XSS-Protection": "1; mode=block",
|
||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains",
|
||
"Referrer-Policy": "strict-origin-when-cross-origin",
|
||
"Permissions-Policy": "geolocation=(), microphone=(), camera=()",
|
||
|
||
# Custom headers
|
||
"X-API-Version": settings.PROJECT_VERSION,
|
||
})
|
||
|
||
# CSP header
|
||
csp_directives = "; ".join([
|
||
f"{directive} {' '.join(sources)}"
|
||
for directive, sources in SecurityConfig.CSP_DIRECTIVES.items()
|
||
])
|
||
response.headers["Content-Security-Policy"] = csp_directives
|
||
|
||
return response
|
||
|
||
def validate_request_size(self, request: Request) -> None:
|
||
"""Validate request size limits"""
|
||
content_length = request.headers.get('content-length')
|
||
if content_length:
|
||
size = int(content_length)
|
||
if size > SecurityConfig.MAX_REQUEST_SIZE:
|
||
raise HTTPException(status_code=413, detail=f"Request too large: {size} bytes")
|
||
|
||
async def validate_content_type(self, request: Request) -> None:
|
||
"""Validate content type for JSON requests"""
|
||
if request.method in ['POST', 'PUT', 'PATCH']:
|
||
content_type = request.headers.get('content-type', '')
|
||
if 'application/json' in content_type:
|
||
# Skip body reading here - it will be read by the route handler
|
||
# Just validate content-length header instead
|
||
content_length = request.headers.get('content-length')
|
||
if content_length and int(content_length) > SecurityConfig.MAX_JSON_SIZE:
|
||
raise HTTPException(status_code=413, detail="JSON payload too large")
|
||
|
||
def check_origin(self, request: Request) -> bool:
|
||
"""Check if request origin is allowed"""
|
||
origin = request.headers.get('origin')
|
||
if not origin:
|
||
return True # Allow requests without origin (direct API calls)
|
||
|
||
return any(
|
||
origin.startswith(allowed_origin.rstrip('/*'))
|
||
for allowed_origin in SecurityConfig.CORS_ORIGINS
|
||
)
|
||
|
||
|
||
class FastAPIRateLimitMiddleware(BaseHTTPMiddleware):
|
||
"""FastAPI Rate limiting middleware using Redis"""
|
||
|
||
def __init__(self, app):
|
||
super().__init__(app)
|
||
self.cache = None
|
||
|
||
async def get_cache(self):
|
||
"""Get cache instance"""
|
||
if not self.cache:
|
||
self.cache = await get_cache()
|
||
return self.cache
|
||
|
||
async def dispatch(self, request: Request, call_next):
|
||
if not settings.RATE_LIMIT_ENABLED:
|
||
return await call_next(request)
|
||
|
||
client_identifier = self.get_client_ip(request)
|
||
pattern = self.get_rate_limit_pattern(request)
|
||
|
||
if not await self.check_rate_limit(request, client_identifier, pattern):
|
||
rate_info = await self.get_rate_limit_info(client_identifier, pattern)
|
||
return JSONResponse(
|
||
content={
|
||
"error": "Rate limit exceeded",
|
||
"rate_limit": rate_info
|
||
},
|
||
status_code=429
|
||
)
|
||
|
||
# Store rate limit info for response headers
|
||
rate_info = await self.get_rate_limit_info(client_identifier, pattern)
|
||
|
||
response = await call_next(request)
|
||
|
||
# Add rate limit headers
|
||
if rate_info:
|
||
response.headers.update({
|
||
"X-RateLimit-Limit": str(rate_info.get('limit', 0)),
|
||
"X-RateLimit-Remaining": str(rate_info.get('remaining', 0)),
|
||
"X-RateLimit-Reset": str(rate_info.get('reset_time', 0))
|
||
})
|
||
|
||
return response
|
||
|
||
def get_client_ip(self, request: Request) -> str:
|
||
"""Get real client IP address"""
|
||
# Check for forwarded headers
|
||
forwarded_for = request.headers.get('x-forwarded-for')
|
||
if forwarded_for:
|
||
return forwarded_for.split(',')[0].strip()
|
||
|
||
real_ip = request.headers.get('x-real-ip')
|
||
if real_ip:
|
||
return real_ip
|
||
|
||
# Fallback to request IP
|
||
return getattr(request.client, 'host', '127.0.0.1')
|
||
|
||
def get_rate_limit_pattern(self, request: Request) -> str:
|
||
"""Determine rate limit pattern based on endpoint"""
|
||
path = request.url.path
|
||
|
||
if '/auth/' in path:
|
||
return "auth"
|
||
elif '/upload' in path:
|
||
return "upload"
|
||
elif '/admin/' in path:
|
||
return "heavy"
|
||
else:
|
||
return "api"
|
||
|
||
async def check_rate_limit(
|
||
self,
|
||
request: Request,
|
||
identifier: str,
|
||
pattern: str = "api"
|
||
) -> bool:
|
||
"""Check rate limit for identifier"""
|
||
try:
|
||
cache = await self.get_cache()
|
||
limits = SecurityConfig.RATE_LIMIT_PATTERNS.get(pattern, {
|
||
"requests": settings.RATE_LIMIT_REQUESTS,
|
||
"window": settings.RATE_LIMIT_WINDOW
|
||
})
|
||
|
||
cache_key = CACHE_KEYS["rate_limit"].format(
|
||
pattern=pattern,
|
||
identifier=identifier
|
||
)
|
||
|
||
# Get current count
|
||
current_count = await cache.get(cache_key)
|
||
if current_count is None:
|
||
# First request in window
|
||
await cache.set(cache_key, "1", ttl=limits["window"])
|
||
return True
|
||
|
||
current_count = int(current_count)
|
||
if current_count >= limits["requests"]:
|
||
# Rate limit exceeded
|
||
logger.warning(
|
||
"Rate limit exceeded",
|
||
identifier=identifier,
|
||
pattern=pattern,
|
||
count=current_count,
|
||
limit=limits["requests"]
|
||
)
|
||
return False
|
||
|
||
# Increment counter
|
||
await cache.incr(cache_key)
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error("Rate limit check failed", error=str(e))
|
||
return True # Allow request if rate limiting fails
|
||
|
||
async def get_rate_limit_info(
|
||
self,
|
||
identifier: str,
|
||
pattern: str = "api"
|
||
) -> Dict[str, Any]:
|
||
"""Get rate limit information"""
|
||
try:
|
||
cache = await self.get_cache()
|
||
limits = SecurityConfig.RATE_LIMIT_PATTERNS.get(pattern, {
|
||
"requests": settings.RATE_LIMIT_REQUESTS,
|
||
"window": settings.RATE_LIMIT_WINDOW
|
||
})
|
||
|
||
cache_key = CACHE_KEYS["rate_limit"].format(
|
||
pattern=pattern,
|
||
identifier=identifier
|
||
)
|
||
|
||
current_count = await cache.get(cache_key) or "0"
|
||
ttl = await cache.redis.ttl(cache_key)
|
||
|
||
return {
|
||
"limit": limits["requests"],
|
||
"remaining": max(0, limits["requests"] - int(current_count)),
|
||
"reset_time": int(time.time()) + max(0, ttl),
|
||
"window": limits["window"]
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error("Failed to get rate limit info", error=str(e))
|
||
return {}
|
||
|
||
|
||
class FastAPICryptographicMiddleware(BaseHTTPMiddleware):
|
||
"""FastAPI Ed25519 cryptographic middleware для межузлового общения"""
|
||
|
||
async def dispatch(self, request: Request, call_next):
|
||
# Проверяем ed25519 подпись для межузловых запросов
|
||
if not await self.verify_inter_node_signature(request):
|
||
logger.warning("Inter-node signature verification failed")
|
||
return JSONResponse(
|
||
content={
|
||
"error": "Invalid cryptographic signature",
|
||
"message": "Inter-node communication requires valid ed25519 signature"
|
||
},
|
||
status_code=403
|
||
)
|
||
|
||
response = await call_next(request)
|
||
|
||
# Добавляем криптографические заголовки для межузловых ответов
|
||
return await self.add_inter_node_headers(request, response)
|
||
|
||
async def verify_inter_node_signature(self, request: Request) -> bool:
|
||
"""Проверить ed25519 подпись для межузлового сообщения"""
|
||
if not CRYPTO_AVAILABLE:
|
||
logger.warning("Crypto module not available, skipping signature verification")
|
||
return True
|
||
|
||
# Проверяем, является ли это межузловым сообщением
|
||
if not request.headers.get("x-node-communication") == "true":
|
||
return True # Не межузловое сообщение, пропускаем проверку
|
||
|
||
try:
|
||
crypto_manager = get_ed25519_manager()
|
||
|
||
# Получаем необходимые заголовки
|
||
signature = request.headers.get("x-node-signature")
|
||
node_id = request.headers.get("x-node-id")
|
||
public_key = request.headers.get("x-node-public-key")
|
||
|
||
if not all([signature, node_id, public_key]):
|
||
logger.warning("Missing cryptographic headers in inter-node request")
|
||
return False
|
||
|
||
# SKIP body reading for now - this causes issues with FastAPI
|
||
# Inter-node communication signature verification disabled temporarily
|
||
logger.debug("Inter-node signature verification skipped (body reading conflict)")
|
||
request.state.inter_node_communication = True
|
||
request.state.source_node_id = node_id
|
||
request.state.source_public_key = public_key
|
||
return True
|
||
|
||
try:
|
||
message_data = json.loads(body.decode())
|
||
|
||
# Проверяем подпись
|
||
is_valid = crypto_manager.verify_signature(
|
||
message_data, signature, public_key
|
||
)
|
||
|
||
if is_valid:
|
||
logger.debug(f"Valid signature verified for node {node_id}")
|
||
# Сохраняем информацию о ноде в state
|
||
request.state.inter_node_communication = True
|
||
request.state.source_node_id = node_id
|
||
request.state.source_public_key = public_key
|
||
return True
|
||
else:
|
||
logger.warning(f"Invalid signature from node {node_id}")
|
||
return False
|
||
|
||
except json.JSONDecodeError:
|
||
logger.warning("Invalid JSON in inter-node request")
|
||
return False
|
||
|
||
except Exception as e:
|
||
logger.error(f"Crypto verification error: {e}")
|
||
return False
|
||
|
||
async def add_inter_node_headers(self, request: Request, response: Response) -> Response:
|
||
"""Добавить криптографические заголовки для межузловых ответов"""
|
||
if not CRYPTO_AVAILABLE:
|
||
return response
|
||
|
||
# Добавляем заголовки только для межузловых сообщений
|
||
if hasattr(request.state, 'inter_node_communication') and request.state.inter_node_communication:
|
||
try:
|
||
crypto_manager = get_ed25519_manager()
|
||
|
||
# Добавляем информацию о нашей ноде
|
||
response.headers.update({
|
||
"X-Node-ID": crypto_manager.node_id,
|
||
"X-Node-Public-Key": crypto_manager.public_key_hex,
|
||
"X-Node-Communication": "true"
|
||
})
|
||
|
||
# Если есть тело ответа, подписываем его
|
||
if hasattr(response, 'body') and response.body:
|
||
try:
|
||
response_data = json.loads(response.body.decode())
|
||
signature = crypto_manager.sign_message(response_data)
|
||
response.headers["X-Node-Signature"] = signature
|
||
except (json.JSONDecodeError, AttributeError):
|
||
# Не JSON тело или нет body, пропускаем подпись
|
||
pass
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error adding inter-node headers: {e}")
|
||
|
||
return response
|
||
|
||
|
||
class FastAPIRequestContextMiddleware(BaseHTTPMiddleware):
|
||
"""FastAPI Request context middleware для трекинга и логирования"""
|
||
|
||
async def dispatch(self, request: Request, call_next):
|
||
# Generate and set request ID
|
||
request_id = str(uuid.uuid4())
|
||
request.state.request_id = request_id
|
||
request_id_var.set(request_id)
|
||
|
||
# Set request start time
|
||
start_time = time.time()
|
||
request.state.start_time = start_time
|
||
|
||
# Extract client information
|
||
request.state.client_ip = self.get_client_ip(request)
|
||
request.state.user_agent = request.headers.get('user-agent', 'Unknown')
|
||
|
||
# Initialize context
|
||
request.state.user = None
|
||
|
||
logger.info(
|
||
"Request started",
|
||
method=request.method,
|
||
path=request.url.path,
|
||
client_ip=request.state.client_ip,
|
||
user_agent=request.state.user_agent
|
||
)
|
||
|
||
response = await call_next(request)
|
||
|
||
# Add request ID to response
|
||
response.headers["X-Request-ID"] = request_id
|
||
|
||
# Log request completion
|
||
duration = time.time() - start_time
|
||
|
||
logger.info(
|
||
"Request completed",
|
||
method=request.method,
|
||
path=request.url.path,
|
||
status_code=response.status_code,
|
||
duration_ms=round(duration * 1000, 2),
|
||
client_ip=request.state.client_ip,
|
||
user_id=str(request.state.user.id) if hasattr(request.state, 'user') and request.state.user else None
|
||
)
|
||
|
||
return response
|
||
|
||
def get_client_ip(self, request: Request) -> str:
|
||
"""Get real client IP address"""
|
||
# Check for forwarded headers
|
||
forwarded_for = request.headers.get('x-forwarded-for')
|
||
if forwarded_for:
|
||
return forwarded_for.split(',')[0].strip()
|
||
|
||
real_ip = request.headers.get('x-real-ip')
|
||
if real_ip:
|
||
return real_ip
|
||
|
||
# Fallback to request IP
|
||
return getattr(request.client, 'host', '127.0.0.1')
|
||
|
||
|
||
class FastAPIAuthenticationMiddleware(BaseHTTPMiddleware):
|
||
"""FastAPI Authentication middleware для API доступа"""
|
||
|
||
async def dispatch(self, request: Request, call_next):
|
||
# Skip authentication for system endpoints and root
|
||
if request.url.path.startswith('/api/system') or request.url.path == '/':
|
||
return await call_next(request)
|
||
|
||
# Extract and validate token
|
||
token = await self.extract_token(request)
|
||
if token:
|
||
from app.core.database import db_manager
|
||
async with db_manager.get_session() as session:
|
||
user = await self.validate_token(token, session)
|
||
if user:
|
||
request.state.user = user
|
||
user_id_var.set(str(user.id))
|
||
|
||
# Check permissions
|
||
if not await self.check_permissions(user, request):
|
||
return JSONResponse(
|
||
content={"error": "Insufficient permissions"},
|
||
status_code=403
|
||
)
|
||
|
||
# Update user activity
|
||
user.update_activity()
|
||
await session.commit()
|
||
|
||
return await call_next(request)
|
||
|
||
async def extract_token(self, request: Request) -> Optional[str]:
|
||
"""Extract authentication token from request"""
|
||
# Check Authorization header
|
||
auth_header = request.headers.get('authorization')
|
||
if auth_header and auth_header.startswith('Bearer '):
|
||
return auth_header[7:] # Remove 'Bearer ' prefix
|
||
|
||
# Check X-API-Key header
|
||
api_key = request.headers.get('x-api-key')
|
||
if api_key:
|
||
return api_key
|
||
|
||
# Check query parameter (less secure, for backward compatibility)
|
||
return request.query_params.get('token')
|
||
|
||
async def validate_token(self, token: str, session) -> Optional[User]:
|
||
"""Validate authentication token and return user"""
|
||
if not token:
|
||
return None
|
||
|
||
try:
|
||
# Импортируем функции безопасности
|
||
from app.core.security import verify_access_token
|
||
|
||
# Пытаемся декодировать как JWT токен (приоритет для auth.twa)
|
||
try:
|
||
payload = verify_access_token(token)
|
||
if payload and 'user_id' in payload:
|
||
user_id = uuid.UUID(payload['user_id'])
|
||
user = await User.get_by_id(session, user_id)
|
||
if user and user.is_active:
|
||
return user
|
||
except Exception as jwt_error:
|
||
logger.debug("JWT validation failed, trying legacy format", error=str(jwt_error))
|
||
|
||
# Fallback: Legacy token format (user_id:hash)
|
||
if ':' in token:
|
||
user_id_str, token_hash = token.split(':', 1)
|
||
try:
|
||
user_id = uuid.UUID(user_id_str)
|
||
user = await User.get_by_id(session, user_id)
|
||
if user and hasattr(user, 'verify_token') and user.verify_token(token_hash):
|
||
return user
|
||
except (ValueError, AttributeError):
|
||
pass
|
||
|
||
# Fallback: try to find user by API token in user model
|
||
# This would require implementing token storage in User model
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error("Token validation failed", token=token[:8] + "...", error=str(e))
|
||
return None
|
||
|
||
async def check_permissions(self, user: User, request: Request) -> bool:
|
||
"""Check if user has required permissions for the endpoint"""
|
||
# Implement permission checking based on endpoint and user role
|
||
endpoint = request.url.path
|
||
method = request.method
|
||
|
||
# Admin endpoints
|
||
if '/admin/' in endpoint:
|
||
return user.is_admin
|
||
|
||
# Moderator endpoints
|
||
if '/mod/' in endpoint:
|
||
return user.is_moderator
|
||
|
||
# User-specific endpoints
|
||
if '/user/' in endpoint and method in ['POST', 'PUT', 'DELETE']:
|
||
return user.has_permission('user:write')
|
||
|
||
# Content upload endpoints
|
||
if '/upload' in endpoint or '/content' in endpoint and method == 'POST':
|
||
return user.can_upload_content()
|
||
|
||
# Default: allow read access for authenticated users
|
||
return True
|
||
|
||
|
||
# FastAPI Dependencies для использования в роутах
|
||
from fastapi import Depends, HTTPException
|
||
|
||
async def get_current_user(request: Request) -> Optional[User]:
|
||
"""FastAPI dependency для получения текущего пользователя"""
|
||
if hasattr(request.state, 'user') and request.state.user:
|
||
return request.state.user
|
||
return None
|
||
|
||
async def require_auth(request: Request) -> User:
|
||
"""FastAPI dependency для требования аутентификации"""
|
||
user = await get_current_user(request)
|
||
if not user:
|
||
raise HTTPException(status_code=401, detail="Authentication required")
|
||
return user
|
||
|
||
async def check_permissions(permission: str):
|
||
"""FastAPI dependency для проверки разрешений"""
|
||
def permission_checker(user: User = Depends(require_auth)) -> User:
|
||
if not user.has_permission(permission):
|
||
raise HTTPException(status_code=403, detail=f"Permission required: {permission}")
|
||
return user
|
||
return permission_checker
|
||
|
||
async def require_admin(user: User = Depends(require_auth)) -> User:
|
||
"""FastAPI dependency для требования административных прав"""
|
||
if not hasattr(user, 'is_admin') or not user.is_admin:
|
||
raise HTTPException(status_code=403, detail="Administrative privileges required")
|
||
return user
|
||
|
||
async def check_rate_limit(pattern: str = "api"):
|
||
"""FastAPI dependency для проверки rate limit"""
|
||
def rate_limit_checker(request: Request) -> bool:
|
||
# Rate limiting уже проверяется в middleware
|
||
# Это dependency для дополнительных проверок если нужно
|
||
return True
|
||
return rate_limit_checker |