220 lines
7.2 KiB
Python
220 lines
7.2 KiB
Python
"""Compatible database configuration with MariaDB support."""
|
|
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from typing import AsyncGenerator, Optional
|
|
|
|
from sqlalchemy import MetaData
|
|
from sqlalchemy.ext.asyncio import (
|
|
AsyncEngine,
|
|
AsyncSession,
|
|
async_sessionmaker,
|
|
create_async_engine
|
|
)
|
|
from sqlalchemy.pool import NullPool
|
|
|
|
from app.core.config import get_settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global variables for database engine and session
|
|
_engine: Optional[AsyncEngine] = None
|
|
_async_session: Optional[async_sessionmaker[AsyncSession]] = None
|
|
|
|
# Naming convention for consistent constraint names
|
|
naming_convention = {
|
|
"ix": "ix_%(column_0_label)s",
|
|
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
|
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
|
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
|
"pk": "pk_%(table_name)s"
|
|
}
|
|
|
|
metadata = MetaData(naming_convention=naming_convention)
|
|
|
|
|
|
def get_database_url() -> str:
|
|
"""Get database URL from settings."""
|
|
settings = get_settings()
|
|
|
|
# Support both new DATABASE_URL and legacy MariaDB settings
|
|
if hasattr(settings, 'database_url') and settings.database_url:
|
|
return settings.database_url
|
|
|
|
# Fallback to MariaDB configuration
|
|
mysql_host = getattr(settings, 'mysql_host', 'maria_db')
|
|
mysql_port = getattr(settings, 'mysql_port', 3306)
|
|
mysql_user = getattr(settings, 'mysql_user', 'myuploader')
|
|
mysql_password = getattr(settings, 'mysql_password', 'password')
|
|
mysql_database = getattr(settings, 'mysql_database', 'myuploader')
|
|
|
|
return f"mysql+aiomysql://{mysql_user}:{mysql_password}@{mysql_host}:{mysql_port}/{mysql_database}"
|
|
|
|
|
|
async def init_database() -> None:
|
|
"""Initialize database connection."""
|
|
global _engine, _async_session
|
|
|
|
if _engine is not None:
|
|
logger.warning("Database already initialized")
|
|
return
|
|
|
|
try:
|
|
settings = get_settings()
|
|
database_url = get_database_url()
|
|
|
|
logger.info(f"Connecting to database: {database_url.split('@')[1] if '@' in database_url else 'unknown'}")
|
|
|
|
# Create async engine with MariaDB/MySQL optimizations
|
|
_engine = create_async_engine(
|
|
database_url,
|
|
echo=settings.debug if hasattr(settings, 'debug') else False,
|
|
pool_size=getattr(settings, 'database_pool_size', 20),
|
|
max_overflow=getattr(settings, 'database_max_overflow', 30),
|
|
pool_timeout=getattr(settings, 'database_pool_timeout', 30),
|
|
pool_recycle=getattr(settings, 'database_pool_recycle', 3600),
|
|
pool_pre_ping=True, # Verify connections before use
|
|
# MariaDB specific settings
|
|
connect_args={
|
|
"charset": "utf8mb4",
|
|
"use_unicode": True,
|
|
"autocommit": False,
|
|
}
|
|
)
|
|
|
|
# Create async session factory
|
|
_async_session = async_sessionmaker(
|
|
bind=_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=True,
|
|
autocommit=False
|
|
)
|
|
|
|
# Test the connection
|
|
async with _engine.begin() as conn:
|
|
await conn.execute("SELECT 1")
|
|
|
|
logger.info("Database connection established successfully")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize database: {e}")
|
|
raise
|
|
|
|
|
|
async def close_database() -> None:
|
|
"""Close database connection."""
|
|
global _engine, _async_session
|
|
|
|
if _engine is not None:
|
|
logger.info("Closing database connection")
|
|
await _engine.dispose()
|
|
_engine = None
|
|
_async_session = None
|
|
logger.info("Database connection closed")
|
|
|
|
|
|
def get_engine() -> AsyncEngine:
|
|
"""Get database engine."""
|
|
if _engine is None:
|
|
raise RuntimeError("Database not initialized. Call init_database() first.")
|
|
return _engine
|
|
|
|
|
|
def get_session_factory() -> async_sessionmaker[AsyncSession]:
|
|
"""Get session factory."""
|
|
if _async_session is None:
|
|
raise RuntimeError("Database not initialized. Call init_database() first.")
|
|
return _async_session
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
|
"""Get async database session with automatic cleanup."""
|
|
if _async_session is None:
|
|
raise RuntimeError("Database not initialized. Call init_database() first.")
|
|
|
|
async with _async_session() as session:
|
|
try:
|
|
yield session
|
|
except Exception as e:
|
|
logger.error(f"Database session error: {e}")
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|
|
|
|
|
|
async def check_database_health() -> bool:
|
|
"""Check database connection health."""
|
|
try:
|
|
async with get_async_session() as session:
|
|
await session.execute("SELECT 1")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Database health check failed: {e}")
|
|
return False
|
|
|
|
|
|
async def get_database_info() -> dict:
|
|
"""Get database information."""
|
|
try:
|
|
async with get_async_session() as session:
|
|
# Get database version
|
|
result = await session.execute("SELECT VERSION() as version")
|
|
version_row = result.fetchone()
|
|
version = version_row[0] if version_row else "Unknown"
|
|
|
|
# Get connection count (MariaDB specific)
|
|
try:
|
|
result = await session.execute("SHOW STATUS LIKE 'Threads_connected'")
|
|
conn_row = result.fetchone()
|
|
connections = int(conn_row[1]) if conn_row else 0
|
|
except:
|
|
connections = 0
|
|
|
|
# Get database size
|
|
try:
|
|
result = await session.execute("""
|
|
SELECT
|
|
ROUND(SUM(data_length + index_length) / 1024 / 1024, 2) as size_mb
|
|
FROM information_schema.tables
|
|
WHERE table_schema = DATABASE()
|
|
""")
|
|
size_row = result.fetchone()
|
|
size_mb = float(size_row[0]) if size_row and size_row[0] else 0
|
|
except:
|
|
size_mb = 0
|
|
|
|
return {
|
|
"version": version,
|
|
"connections": connections,
|
|
"size_mb": size_mb,
|
|
"engine_pool_size": _engine.pool.size() if _engine else 0,
|
|
"engine_checked_out": _engine.pool.checkedout() if _engine else 0,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Failed to get database info: {e}")
|
|
return {"error": str(e)}
|
|
|
|
|
|
# Database session dependency for dependency injection
|
|
def get_db_session():
|
|
"""Database session dependency for API routes - returns async context manager"""
|
|
return get_async_session()
|
|
|
|
|
|
# Backward compatibility functions
|
|
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
|
"""Legacy function name for backward compatibility."""
|
|
async with get_async_session() as session:
|
|
yield session
|
|
|
|
|
|
# Transaction context manager
|
|
@asynccontextmanager
|
|
async def transaction():
|
|
"""Transaction context manager."""
|
|
async with get_async_session() as session:
|
|
async with session.begin():
|
|
yield session |