uploader-bot/app/core/database_compatible.py

220 lines
7.2 KiB
Python

"""Compatible database configuration with MariaDB support."""
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional
from sqlalchemy import MetaData
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine
)
from sqlalchemy.pool import NullPool
from app.core.config import get_settings
logger = logging.getLogger(__name__)
# Global variables for database engine and session
_engine: Optional[AsyncEngine] = None
_async_session: Optional[async_sessionmaker[AsyncSession]] = None
# Naming convention for consistent constraint names
naming_convention = {
"ix": "ix_%(column_0_label)s",
"uq": "uq_%(table_name)s_%(column_0_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s"
}
metadata = MetaData(naming_convention=naming_convention)
def get_database_url() -> str:
"""Get database URL from settings."""
settings = get_settings()
# Support both new DATABASE_URL and legacy MariaDB settings
if hasattr(settings, 'database_url') and settings.database_url:
return settings.database_url
# Fallback to MariaDB configuration
mysql_host = getattr(settings, 'mysql_host', 'maria_db')
mysql_port = getattr(settings, 'mysql_port', 3306)
mysql_user = getattr(settings, 'mysql_user', 'myuploader')
mysql_password = getattr(settings, 'mysql_password', 'password')
mysql_database = getattr(settings, 'mysql_database', 'myuploader')
return f"mysql+aiomysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
async def init_database() -> None:
"""Initialize database connection."""
global _engine, _async_session
if _engine is not None:
logger.warning("Database already initialized")
return
try:
settings = get_settings()
database_url = get_database_url()
logger.info(f"Connecting to database: {database_url.split('@')[1] if '@' in database_url else 'unknown'}")
# Create async engine with MariaDB/MySQL optimizations
_engine = create_async_engine(
database_url,
echo=settings.debug if hasattr(settings, 'debug') else False,
pool_size=getattr(settings, 'database_pool_size', 20),
max_overflow=getattr(settings, 'database_max_overflow', 30),
pool_timeout=getattr(settings, 'database_pool_timeout', 30),
pool_recycle=getattr(settings, 'database_pool_recycle', 3600),
pool_pre_ping=True, # Verify connections before use
# MariaDB specific settings
connect_args={
"charset": "utf8mb4",
"use_unicode": True,
"autocommit": False,
}
)
# Create async session factory
_async_session = async_sessionmaker(
bind=_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=True,
autocommit=False
)
# Test the connection
async with _engine.begin() as conn:
await conn.execute("SELECT 1")
logger.info("Database connection established successfully")
except Exception as e:
logger.error(f"Failed to initialize database: {e}")
raise
async def close_database() -> None:
"""Close database connection."""
global _engine, _async_session
if _engine is not None:
logger.info("Closing database connection")
await _engine.dispose()
_engine = None
_async_session = None
logger.info("Database connection closed")
def get_engine() -> AsyncEngine:
"""Get database engine."""
if _engine is None:
raise RuntimeError("Database not initialized. Call init_database() first.")
return _engine
def get_session_factory() -> async_sessionmaker[AsyncSession]:
"""Get session factory."""
if _async_session is None:
raise RuntimeError("Database not initialized. Call init_database() first.")
return _async_session
@asynccontextmanager
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
"""Get async database session with automatic cleanup."""
if _async_session is None:
raise RuntimeError("Database not initialized. Call init_database() first.")
async with _async_session() as session:
try:
yield session
except Exception as e:
logger.error(f"Database session error: {e}")
await session.rollback()
raise
finally:
await session.close()
async def check_database_health() -> bool:
"""Check database connection health."""
try:
async with get_async_session() as session:
await session.execute("SELECT 1")
return True
except Exception as e:
logger.error(f"Database health check failed: {e}")
return False
async def get_database_info() -> dict:
"""Get database information."""
try:
async with get_async_session() as session:
# Get database version
result = await session.execute("SELECT VERSION() as version")
version_row = result.fetchone()
version = version_row[0] if version_row else "Unknown"
# Get connection count (MariaDB specific)
try:
result = await session.execute("SHOW STATUS LIKE 'Threads_connected'")
conn_row = result.fetchone()
connections = int(conn_row[1]) if conn_row else 0
except:
connections = 0
# Get database size
try:
result = await session.execute("""
SELECT
ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) as size_mb
FROM information_schema.tables
WHERE table_schema = DATABASE()
""")
size_row = result.fetchone()
size_mb = float(size_row[0]) if size_row and size_row[0] else 0
except:
size_mb = 0
return {
"version": version,
"connections": connections,
"size_mb": size_mb,
"engine_pool_size": _engine.pool.size() if _engine else 0,
"engine_checked_out": _engine.pool.checkedout() if _engine else 0,
}
except Exception as e:
logger.error(f"Failed to get database info: {e}")
return {"error": str(e)}
# Database session dependency for dependency injection
def get_db_session():
"""Database session dependency for API routes - returns async context manager"""
return get_async_session()
# Backward compatibility functions
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""Legacy function name for backward compatibility."""
async with get_async_session() as session:
yield session
# Transaction context manager
@asynccontextmanager
async def transaction():
"""Transaction context manager."""
async with get_async_session() as session:
async with session.begin():
yield session