"""Compatible database configuration with MariaDB support.""" import logging from contextlib import asynccontextmanager from typing import AsyncGenerator, Optional from sqlalchemy import MetaData from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine ) from sqlalchemy.pool import NullPool from app.core.config import get_settings logger = logging.getLogger(__name__) # Global variables for database engine and session _engine: Optional[AsyncEngine] = None _async_session: Optional[async_sessionmaker[AsyncSession]] = None # Naming convention for consistent constraint names naming_convention = { "ix": "ix_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s", "ck": "ck_%(table_name)s_%(constraint_name)s", "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", "pk": "pk_%(table_name)s" } metadata = MetaData(naming_convention=naming_convention) def get_database_url() -> str: """Get database URL from settings.""" settings = get_settings() # Support both new DATABASE_URL and legacy MariaDB settings if hasattr(settings, 'database_url') and settings.database_url: return settings.database_url # Fallback to MariaDB configuration mysql_host = getattr(settings, 'mysql_host', 'maria_db') mysql_port = getattr(settings, 'mysql_port', 3306) mysql_user = getattr(settings, 'mysql_user', 'myuploader') mysql_password = getattr(settings, 'mysql_password', 'password') mysql_database = getattr(settings, 'mysql_database', 'myuploader') return f"mysql+aiomysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}" async def init_database() -> None: """Initialize database connection.""" global _engine, _async_session if _engine is not None: logger.warning("Database already initialized") return try: settings = get_settings() database_url = get_database_url() logger.info(f"Connecting to database: {database_url.split('@')[1] if '@' in database_url else 'unknown'}") # Create async engine with MariaDB/MySQL optimizations _engine = create_async_engine( database_url, echo=settings.debug if hasattr(settings, 'debug') else False, pool_size=getattr(settings, 'database_pool_size', 20), max_overflow=getattr(settings, 'database_max_overflow', 30), pool_timeout=getattr(settings, 'database_pool_timeout', 30), pool_recycle=getattr(settings, 'database_pool_recycle', 3600), pool_pre_ping=True, # Verify connections before use # MariaDB specific settings connect_args={ "charset": "utf8mb4", "use_unicode": True, "autocommit": False, } ) # Create async session factory _async_session = async_sessionmaker( bind=_engine, class_=AsyncSession, expire_on_commit=False, autoflush=True, autocommit=False ) # Test the connection async with _engine.begin() as conn: await conn.execute("SELECT 1") logger.info("Database connection established successfully") except Exception as e: logger.error(f"Failed to initialize database: {e}") raise async def close_database() -> None: """Close database connection.""" global _engine, _async_session if _engine is not None: logger.info("Closing database connection") await _engine.dispose() _engine = None _async_session = None logger.info("Database connection closed") def get_engine() -> AsyncEngine: """Get database engine.""" if _engine is None: raise RuntimeError("Database not initialized. Call init_database() first.") return _engine def get_session_factory() -> async_sessionmaker[AsyncSession]: """Get session factory.""" if _async_session is None: raise RuntimeError("Database not initialized. Call init_database() first.") return _async_session @asynccontextmanager async def get_async_session() -> AsyncGenerator[AsyncSession, None]: """Get async database session with automatic cleanup.""" if _async_session is None: raise RuntimeError("Database not initialized. Call init_database() first.") async with _async_session() as session: try: yield session except Exception as e: logger.error(f"Database session error: {e}") await session.rollback() raise finally: await session.close() async def check_database_health() -> bool: """Check database connection health.""" try: async with get_async_session() as session: await session.execute("SELECT 1") return True except Exception as e: logger.error(f"Database health check failed: {e}") return False async def get_database_info() -> dict: """Get database information.""" try: async with get_async_session() as session: # Get database version result = await session.execute("SELECT VERSION() as version") version_row = result.fetchone() version = version_row[0] if version_row else "Unknown" # Get connection count (MariaDB specific) try: result = await session.execute("SHOW STATUS LIKE 'Threads_connected'") conn_row = result.fetchone() connections = int(conn_row[1]) if conn_row else 0 except: connections = 0 # Get database size try: result = await session.execute(""" SELECT ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) as size_mb FROM information_schema.tables WHERE table_schema = DATABASE() """) size_row = result.fetchone() size_mb = float(size_row[0]) if size_row and size_row[0] else 0 except: size_mb = 0 return { "version": version, "connections": connections, "size_mb": size_mb, "engine_pool_size": _engine.pool.size() if _engine else 0, "engine_checked_out": _engine.pool.checkedout() if _engine else 0, } except Exception as e: logger.error(f"Failed to get database info: {e}") return {"error": str(e)} # Database session dependency for dependency injection async def get_db_session() -> AsyncGenerator[AsyncSession, None]: """Database session dependency for API routes.""" async with get_async_session() as session: yield session # Backward compatibility functions async def get_db() -> AsyncGenerator[AsyncSession, None]: """Legacy function name for backward compatibility.""" async with get_async_session() as session: yield session # Transaction context manager @asynccontextmanager async def transaction(): """Transaction context manager.""" async with get_async_session() as session: async with session.begin(): yield session