uploader-bot/app/core/database.py

456 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Async SQLAlchemy configuration with connection pooling and Redis integration
"""
import asyncio
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional
from datetime import timedelta
from sqlalchemy.ext.asyncio import (
create_async_engine,
AsyncSession,
async_sessionmaker,
AsyncEngine
)
from sqlalchemy.pool import NullPool, QueuePool
from sqlalchemy.sql import text
import redis.asyncio as redis
from redis.asyncio.connection import ConnectionPool
import structlog
import os
from app.core.config import (
DATABASE_URL,
REDIS_URL,
DATABASE_POOL_SIZE,
DATABASE_MAX_OVERFLOW,
REDIS_POOL_SIZE
)
# Mock Redis для тестирования
class MockRedis:
def __init__(self):
self._data = {}
self._ttl_data = {} # Store TTL information
async def ping(self):
"""Ping redis server"""
return True
async def get(self, key):
"""Get value by key"""
try:
value = self._data.get(key)
return value if value is not None else None
except Exception as e:
logger.error("MockRedis get error", key=key, error=str(e))
return None
async def set(self, key, value, ex=None, nx=False):
"""Set key-value with optional expiration and nx flag"""
try:
if nx and key in self._data:
return False
# Convert value to string to match Redis behavior
if isinstance(value, dict):
import json
self._data[key] = json.dumps(value)
else:
self._data[key] = str(value) if value is not None else None
# Handle TTL
if ex:
import time
self._ttl_data[key] = time.time() + ex
return True
except Exception as e:
logger.error("MockRedis set error", key=key, error=str(e))
return False
async def delete(self, key):
"""Delete key"""
try:
existed = key in self._data
self._data.pop(key, None)
self._ttl_data.pop(key, None)
return 1 if existed else 0
except Exception as e:
logger.error("MockRedis delete error", key=key, error=str(e))
return 0
async def exists(self, key):
"""Check if key exists"""
try:
return 1 if key in self._data else 0
except Exception as e:
logger.error("MockRedis exists error", key=key, error=str(e))
return 0
async def incr(self, key, amount=1):
"""Increment counter"""
try:
current = int(self._data.get(key, 0))
new_value = current + amount
self._data[key] = str(new_value)
return new_value
except (ValueError, TypeError) as e:
logger.error("MockRedis incr error", key=key, error=str(e))
return 0
async def expire(self, key, ttl):
"""Set TTL for key"""
try:
if key in self._data:
import time
self._ttl_data[key] = time.time() + ttl
return True
return False
except Exception as e:
logger.error("MockRedis expire error", key=key, error=str(e))
return False
async def hget(self, name, key):
"""Get hash field value"""
try:
hash_data = self._data.get(name)
if not hash_data:
return None
# Try to parse as JSON if it's a string
if isinstance(hash_data, str):
try:
import json
hash_data = json.loads(hash_data)
except (json.JSONDecodeError, TypeError):
return None
if isinstance(hash_data, dict):
return hash_data.get(key)
return None
except Exception as e:
logger.error("MockRedis hget error", name=name, key=key, error=str(e))
return None
async def hset(self, name, key, value):
"""Set hash field value"""
try:
if name not in self._data:
self._data[name] = {}
# Ensure we have a dict
if not isinstance(self._data[name], dict):
self._data[name] = {}
self._data[name][key] = str(value)
return 1
except Exception as e:
logger.error("MockRedis hset error", name=name, key=key, error=str(e))
return 0
async def hdel(self, name, key):
"""Delete hash field"""
try:
if name in self._data and isinstance(self._data[name], dict):
existed = key in self._data[name]
self._data[name].pop(key, None)
return 1 if existed else 0
return 0
except Exception as e:
logger.error("MockRedis hdel error", name=name, key=key, error=str(e))
return 0
async def ttl(self, key):
"""Get TTL for key"""
try:
if key not in self._data:
return -2 # Key doesn't exist
if key not in self._ttl_data:
return -1 # Key exists but no TTL
import time
remaining = self._ttl_data[key] - time.time()
if remaining <= 0:
# Key expired, remove it
self._data.pop(key, None)
self._ttl_data.pop(key, None)
return -2
return int(remaining)
except Exception as e:
logger.error("MockRedis ttl error", key=key, error=str(e))
return -1
logger = structlog.get_logger(__name__)
class DatabaseManager:
"""Async database manager with connection pooling"""
def __init__(self):
self._engine: Optional[AsyncEngine] = None
self._session_factory: Optional[async_sessionmaker[AsyncSession]] = None
self._redis_pool: Optional[ConnectionPool] = None
self._redis: Optional[redis.Redis] = None
self._initialized = False
async def initialize(self) -> None:
"""Initialize database connections and Redis"""
if self._initialized:
return
# Initialize async SQLAlchemy engine
self._engine = create_async_engine(
DATABASE_URL,
poolclass=QueuePool,
pool_size=DATABASE_POOL_SIZE,
max_overflow=DATABASE_MAX_OVERFLOW,
pool_pre_ping=True,
pool_recycle=3600, # 1 hour
echo=False, # Set to True for SQL debugging
future=True,
json_serializer=lambda obj: obj,
json_deserializer=lambda obj: obj,
)
# Create session factory
self._session_factory = async_sessionmaker(
self._engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False
)
# Initialize Redis connection pool
use_mock_redis = (
os.getenv('MOCK_REDIS', '0') == '1' or
'mock' in REDIS_URL or
REDIS_URL.startswith('redis://mock')
)
if use_mock_redis:
logger.warning("Using MockRedis for testing")
self._redis = MockRedis()
self._redis_pool = None
else:
try:
self._redis_pool = ConnectionPool.from_url(
REDIS_URL,
max_connections=REDIS_POOL_SIZE,
retry_on_timeout=True,
health_check_interval=30
)
self._redis = redis.Redis(
connection_pool=self._redis_pool,
decode_responses=True
)
except Exception as e:
logger.warning(f"Failed to connect to Redis, using mock: {e}")
self._redis = MockRedis()
self._redis_pool = None
# Test connections
await self._test_connections()
self._initialized = True
logger.info("Database and Redis connections initialized")
async def _test_connections(self) -> None:
"""Test database and Redis connections"""
# Test database
async with self._engine.begin() as conn:
result = await conn.execute(text("SELECT 1"))
assert result.scalar() == 1
# Test Redis
await self._redis.ping()
logger.info("Database and Redis connections tested successfully")
async def close(self) -> None:
"""Close all connections gracefully"""
if self._engine:
await self._engine.dispose()
if self._redis_pool:
await self._redis_pool.disconnect()
self._initialized = False
logger.info("Database and Redis connections closed")
@asynccontextmanager
async def get_session(self) -> AsyncGenerator[AsyncSession, None]:
"""Get async database session with automatic cleanup"""
if not self._initialized:
await self.initialize()
async with self._session_factory() as session:
try:
yield session
except Exception as e:
await session.rollback()
logger.error("Database session error", error=str(e))
raise
finally:
await session.close()
@asynccontextmanager
async def get_transaction(self) -> AsyncGenerator[AsyncSession, None]:
"""Get async database session with automatic transaction management"""
async with self.get_session() as session:
async with session.begin():
yield session
async def get_redis(self) -> redis.Redis:
"""Get Redis client"""
if not self._initialized:
await self.initialize()
return self._redis
@property
def engine(self) -> AsyncEngine:
"""Get SQLAlchemy engine"""
if not self._engine:
raise RuntimeError("Database not initialized")
return self._engine
class CacheManager:
"""Redis-based cache manager with TTL and serialization"""
def __init__(self, redis_client: redis.Redis):
self.redis = redis_client
async def get(self, key: str, default=None):
"""Get value from cache"""
try:
value = await self.redis.get(key)
return value if value is not None else default
except Exception as e:
logger.error("Cache get error", key=key, error=str(e))
return default
async def set(
self,
key: str,
value: str,
ttl: Optional[int] = None,
nx: bool = False
) -> bool:
"""Set value in cache with optional TTL"""
try:
return await self.redis.set(key, value, ex=ttl, nx=nx)
except Exception as e:
logger.error("Cache set error", key=key, error=str(e))
return False
async def delete(self, key: str) -> bool:
"""Delete key from cache"""
try:
return bool(await self.redis.delete(key))
except Exception as e:
logger.error("Cache delete error", key=key, error=str(e))
return False
async def exists(self, key: str) -> bool:
"""Check if key exists in cache"""
try:
return bool(await self.redis.exists(key))
except Exception as e:
logger.error("Cache exists error", key=key, error=str(e))
return False
async def incr(self, key: str, amount: int = 1) -> int:
"""Increment counter in cache"""
try:
return await self.redis.incr(key, amount)
except Exception as e:
logger.error("Cache incr error", key=key, error=str(e))
return 0
async def increment(self, key: str, amount: int = 1, ttl: Optional[int] = None) -> int:
"""Increment counter in cache with optional TTL"""
try:
result = await self.redis.incr(key, amount)
# If this is the first increment and TTL is specified, set expiration
if ttl and result == amount:
await self.redis.expire(key, ttl)
return result
except Exception as e:
logger.error("Cache increment error", key=key, error=str(e))
return 0
async def expire(self, key: str, ttl: int) -> bool:
"""Set TTL for existing key"""
try:
return await self.redis.expire(key, ttl)
except Exception as e:
logger.error("Cache expire error", key=key, error=str(e))
return False
async def hget(self, name: str, key: str):
"""Get hash field value"""
try:
return await self.redis.hget(name, key)
except Exception as e:
logger.error("Cache hget error", name=name, key=key, error=str(e))
return None
async def hset(self, name: str, key: str, value: str) -> bool:
"""Set hash field value"""
try:
return bool(await self.redis.hset(name, key, value))
except Exception as e:
logger.error("Cache hset error", name=name, key=key, error=str(e))
return False
async def hdel(self, name: str, key: str) -> bool:
"""Delete hash field"""
try:
return bool(await self.redis.hdel(name, key))
except Exception as e:
logger.error("Cache hdel error", name=name, key=key, error=str(e))
return False
# Global instances
db_manager = DatabaseManager()
cache_manager: Optional[CacheManager] = None
def get_db_session():
"""Dependency for getting database session - returns async context manager"""
return db_manager.get_session()
async def get_cache() -> CacheManager:
"""Dependency for getting cache manager"""
global cache_manager
if not cache_manager:
redis_client = await db_manager.get_redis()
cache_manager = CacheManager(redis_client)
return cache_manager
async def init_database():
"""Initialize database connections"""
await db_manager.initialize()
async def close_database():
"""Close database connections"""
await db_manager.close()
# Алиасы для совместимости с существующим кодом
# УДАЛЁН: get_async_session() - вызывал ошибки context manager protocol
# Все места использования исправлены на db_manager.get_session()
async def get_cache_manager() -> CacheManager:
"""Alias for get_cache for compatibility"""
return await get_cache()