"""Compatible SQLAlchemy base models for MariaDB.""" from datetime import datetime from typing import Optional, Dict, Any from sqlalchemy import Column, Integer, DateTime, text from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.orm import sessionmaker # Create base class Base = declarative_base() class TimestampMixin: """Mixin for adding timestamp fields.""" @declared_attr def created_at(cls): return Column( DateTime, nullable=False, default=datetime.utcnow, server_default=text('CURRENT_TIMESTAMP') ) @declared_attr def updated_at(cls): return Column( DateTime, nullable=False, default=datetime.utcnow, onupdate=datetime.utcnow, server_default=text('CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP') ) class BaseModel(Base, TimestampMixin): """Base model with common fields for all entities.""" __abstract__ = True id = Column(Integer, primary_key=True, autoincrement=True) def to_dict(self, exclude: Optional[set] = None) -> Dict[str, Any]: """Convert model instance to dictionary.""" exclude = exclude or set() result = {} for column in self.__table__.columns: if column.name not in exclude: value = getattr(self, column.name) # Handle datetime serialization if isinstance(value, datetime): result[column.name] = value.isoformat() else: result[column.name] = value return result def update_from_dict(self, data: Dict[str, Any], exclude: Optional[set] = None) -> None: """Update model instance from dictionary.""" exclude = exclude or {"id", "created_at", "updated_at"} for key, value in data.items(): if key not in exclude and hasattr(self, key): setattr(self, key, value) @classmethod def get_table_name(cls) -> str: """Get table name.""" return cls.__tablename__ @classmethod def get_columns(cls) -> list: """Get list of column names.""" return [column.name for column in cls.__table__.columns] def __repr__(self) -> str: """String representation of model.""" return f"<{self.__class__.__name__}(id={getattr(self, 'id', None)})>" # Legacy session factory for backward compatibility SessionLocal = sessionmaker() def get_session(): """Get database session (legacy function for compatibility).""" return SessionLocal()