""" Async SQLAlchemy configuration with connection pooling and Redis integration """ import asyncio import logging from contextlib import asynccontextmanager from typing import AsyncGenerator, Optional from datetime import timedelta from sqlalchemy.ext.asyncio import ( create_async_engine, AsyncSession, async_sessionmaker, AsyncEngine ) from sqlalchemy.pool import NullPool, QueuePool from sqlalchemy.sql import text import redis.asyncio as redis from redis.asyncio.connection import ConnectionPool import structlog import os from app.core.config import ( DATABASE_URL, REDIS_URL, DATABASE_POOL_SIZE, DATABASE_MAX_OVERFLOW, REDIS_POOL_SIZE ) # Mock Redis для тестирования class MockRedis: def __init__(self): self._data = {} self._ttl_data = {} # Store TTL information async def ping(self): """Ping redis server""" return True async def get(self, key): """Get value by key""" try: value = self._data.get(key) return value if value is not None else None except Exception as e: logger.error("MockRedis get error", key=key, error=str(e)) return None async def set(self, key, value, ex=None, nx=False): """Set key-value with optional expiration and nx flag""" try: if nx and key in self._data: return False # Convert value to string to match Redis behavior if isinstance(value, dict): import json self._data[key] = json.dumps(value) else: self._data[key] = str(value) if value is not None else None # Handle TTL if ex: import time self._ttl_data[key] = time.time() + ex return True except Exception as e: logger.error("MockRedis set error", key=key, error=str(e)) return False async def delete(self, key): """Delete key""" try: existed = key in self._data self._data.pop(key, None) self._ttl_data.pop(key, None) return 1 if existed else 0 except Exception as e: logger.error("MockRedis delete error", key=key, error=str(e)) return 0 async def exists(self, key): """Check if key exists""" try: return 1 if key in self._data else 0 except Exception as e: logger.error("MockRedis exists error", key=key, error=str(e)) return 0 async def incr(self, key, amount=1): """Increment counter""" try: current = int(self._data.get(key, 0)) new_value = current + amount self._data[key] = str(new_value) return new_value except (ValueError, TypeError) as e: logger.error("MockRedis incr error", key=key, error=str(e)) return 0 async def expire(self, key, ttl): """Set TTL for key""" try: if key in self._data: import time self._ttl_data[key] = time.time() + ttl return True return False except Exception as e: logger.error("MockRedis expire error", key=key, error=str(e)) return False async def hget(self, name, key): """Get hash field value""" try: hash_data = self._data.get(name) if not hash_data: return None # Try to parse as JSON if it's a string if isinstance(hash_data, str): try: import json hash_data = json.loads(hash_data) except (json.JSONDecodeError, TypeError): return None if isinstance(hash_data, dict): return hash_data.get(key) return None except Exception as e: logger.error("MockRedis hget error", name=name, key=key, error=str(e)) return None async def hset(self, name, key, value): """Set hash field value""" try: if name not in self._data: self._data[name] = {} # Ensure we have a dict if not isinstance(self._data[name], dict): self._data[name] = {} self._data[name][key] = str(value) return 1 except Exception as e: logger.error("MockRedis hset error", name=name, key=key, error=str(e)) return 0 async def hdel(self, name, key): """Delete hash field""" try: if name in self._data and isinstance(self._data[name], dict): existed = key in self._data[name] self._data[name].pop(key, None) return 1 if existed else 0 return 0 except Exception as e: logger.error("MockRedis hdel error", name=name, key=key, error=str(e)) return 0 async def ttl(self, key): """Get TTL for key""" try: if key not in self._data: return -2 # Key doesn't exist if key not in self._ttl_data: return -1 # Key exists but no TTL import time remaining = self._ttl_data[key] - time.time() if remaining <= 0: # Key expired, remove it self._data.pop(key, None) self._ttl_data.pop(key, None) return -2 return int(remaining) except Exception as e: logger.error("MockRedis ttl error", key=key, error=str(e)) return -1 logger = structlog.get_logger(__name__) class DatabaseManager: """Async database manager with connection pooling""" def __init__(self): self._engine: Optional[AsyncEngine] = None self._session_factory: Optional[async_sessionmaker[AsyncSession]] = None self._redis_pool: Optional[ConnectionPool] = None self._redis: Optional[redis.Redis] = None self._initialized = False async def initialize(self) -> None: """Initialize database connections and Redis""" if self._initialized: return # Initialize async SQLAlchemy engine self._engine = create_async_engine( DATABASE_URL, poolclass=QueuePool, pool_size=DATABASE_POOL_SIZE, max_overflow=DATABASE_MAX_OVERFLOW, pool_pre_ping=True, pool_recycle=3600, # 1 hour echo=False, # Set to True for SQL debugging future=True, json_serializer=lambda obj: obj, json_deserializer=lambda obj: obj, ) # Create session factory self._session_factory = async_sessionmaker( self._engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False ) # Initialize Redis connection pool use_mock_redis = ( os.getenv('MOCK_REDIS', '0') == '1' or 'mock' in REDIS_URL or REDIS_URL.startswith('redis://mock') ) if use_mock_redis: logger.warning("Using MockRedis for testing") self._redis = MockRedis() self._redis_pool = None else: try: self._redis_pool = ConnectionPool.from_url( REDIS_URL, max_connections=REDIS_POOL_SIZE, retry_on_timeout=True, health_check_interval=30 ) self._redis = redis.Redis( connection_pool=self._redis_pool, decode_responses=True ) except Exception as e: logger.warning(f"Failed to connect to Redis, using mock: {e}") self._redis = MockRedis() self._redis_pool = None # Test connections await self._test_connections() self._initialized = True logger.info("Database and Redis connections initialized") async def _test_connections(self) -> None: """Test database and Redis connections""" # Test database async with self._engine.begin() as conn: result = await conn.execute(text("SELECT 1")) assert result.scalar() == 1 # Test Redis await self._redis.ping() logger.info("Database and Redis connections tested successfully") async def close(self) -> None: """Close all connections gracefully""" if self._engine: await self._engine.dispose() if self._redis_pool: await self._redis_pool.disconnect() self._initialized = False logger.info("Database and Redis connections closed") @asynccontextmanager async def get_session(self) -> AsyncGenerator[AsyncSession, None]: """Get async database session with automatic cleanup""" if not self._initialized: await self.initialize() async with self._session_factory() as session: try: yield session except Exception as e: await session.rollback() logger.error("Database session error", error=str(e)) raise finally: await session.close() @asynccontextmanager async def get_transaction(self) -> AsyncGenerator[AsyncSession, None]: """Get async database session with automatic transaction management""" async with self.get_session() as session: async with session.begin(): yield session async def get_redis(self) -> redis.Redis: """Get Redis client""" if not self._initialized: await self.initialize() return self._redis @property def engine(self) -> AsyncEngine: """Get SQLAlchemy engine""" if not self._engine: raise RuntimeError("Database not initialized") return self._engine class CacheManager: """Redis-based cache manager with TTL and serialization""" def __init__(self, redis_client: redis.Redis): self.redis = redis_client async def get(self, key: str, default=None): """Get value from cache""" try: value = await self.redis.get(key) return value if value is not None else default except Exception as e: logger.error("Cache get error", key=key, error=str(e)) return default async def set( self, key: str, value: str, ttl: Optional[int] = None, nx: bool = False ) -> bool: """Set value in cache with optional TTL""" try: return await self.redis.set(key, value, ex=ttl, nx=nx) except Exception as e: logger.error("Cache set error", key=key, error=str(e)) return False async def delete(self, key: str) -> bool: """Delete key from cache""" try: return bool(await self.redis.delete(key)) except Exception as e: logger.error("Cache delete error", key=key, error=str(e)) return False async def exists(self, key: str) -> bool: """Check if key exists in cache""" try: return bool(await self.redis.exists(key)) except Exception as e: logger.error("Cache exists error", key=key, error=str(e)) return False async def incr(self, key: str, amount: int = 1) -> int: """Increment counter in cache""" try: return await self.redis.incr(key, amount) except Exception as e: logger.error("Cache incr error", key=key, error=str(e)) return 0 async def increment(self, key: str, amount: int = 1, ttl: Optional[int] = None) -> int: """Increment counter in cache with optional TTL""" try: result = await self.redis.incr(key, amount) # If this is the first increment and TTL is specified, set expiration if ttl and result == amount: await self.redis.expire(key, ttl) return result except Exception as e: logger.error("Cache increment error", key=key, error=str(e)) return 0 async def expire(self, key: str, ttl: int) -> bool: """Set TTL for existing key""" try: return await self.redis.expire(key, ttl) except Exception as e: logger.error("Cache expire error", key=key, error=str(e)) return False async def hget(self, name: str, key: str): """Get hash field value""" try: return await self.redis.hget(name, key) except Exception as e: logger.error("Cache hget error", name=name, key=key, error=str(e)) return None async def hset(self, name: str, key: str, value: str) -> bool: """Set hash field value""" try: return bool(await self.redis.hset(name, key, value)) except Exception as e: logger.error("Cache hset error", name=name, key=key, error=str(e)) return False async def hdel(self, name: str, key: str) -> bool: """Delete hash field""" try: return bool(await self.redis.hdel(name, key)) except Exception as e: logger.error("Cache hdel error", name=name, key=key, error=str(e)) return False # Global instances db_manager = DatabaseManager() cache_manager: Optional[CacheManager] = None def get_db_session(): """Dependency for getting database session - returns async context manager""" return db_manager.get_session() async def get_cache() -> CacheManager: """Dependency for getting cache manager""" global cache_manager if not cache_manager: redis_client = await db_manager.get_redis() cache_manager = CacheManager(redis_client) return cache_manager async def init_database(): """Initialize database connections""" await db_manager.initialize() async def close_database(): """Close database connections""" await db_manager.close() # Алиасы для совместимости с существующим кодом # УДАЛЁН: get_async_session() - вызывал ошибки context manager protocol # Все места использования исправлены на db_manager.get_session() async def get_cache_manager() -> CacheManager: """Alias for get_cache for compatibility""" return await get_cache()