88 lines
2.6 KiB
Python
88 lines
2.6 KiB
Python
"""Compatible SQLAlchemy base models for MariaDB."""
|
|
|
|
from datetime import datetime
|
|
from typing import Optional, Dict, Any
|
|
from sqlalchemy import Column, Integer, DateTime, text
|
|
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
# Create base class
|
|
Base = declarative_base()
|
|
|
|
|
|
class TimestampMixin:
|
|
"""Mixin for adding timestamp fields."""
|
|
|
|
@declared_attr
|
|
def created_at(cls):
|
|
return Column(
|
|
DateTime,
|
|
nullable=False,
|
|
default=datetime.utcnow,
|
|
server_default=text('CURRENT_TIMESTAMP')
|
|
)
|
|
|
|
@declared_attr
|
|
def updated_at(cls):
|
|
return Column(
|
|
DateTime,
|
|
nullable=False,
|
|
default=datetime.utcnow,
|
|
onupdate=datetime.utcnow,
|
|
server_default=text('CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP')
|
|
)
|
|
|
|
|
|
class BaseModel(Base, TimestampMixin):
|
|
"""Base model with common fields for all entities."""
|
|
|
|
__abstract__ = True
|
|
|
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
|
|
def to_dict(self, exclude: Optional[set] = None) -> Dict[str, Any]:
|
|
"""Convert model instance to dictionary."""
|
|
exclude = exclude or set()
|
|
result = {}
|
|
|
|
for column in self.__table__.columns:
|
|
if column.name not in exclude:
|
|
value = getattr(self, column.name)
|
|
# Handle datetime serialization
|
|
if isinstance(value, datetime):
|
|
result[column.name] = value.isoformat()
|
|
else:
|
|
result[column.name] = value
|
|
|
|
return result
|
|
|
|
def update_from_dict(self, data: Dict[str, Any], exclude: Optional[set] = None) -> None:
|
|
"""Update model instance from dictionary."""
|
|
exclude = exclude or {"id", "created_at", "updated_at"}
|
|
|
|
for key, value in data.items():
|
|
if key not in exclude and hasattr(self, key):
|
|
setattr(self, key, value)
|
|
|
|
@classmethod
|
|
def get_table_name(cls) -> str:
|
|
"""Get table name."""
|
|
return cls.__tablename__
|
|
|
|
@classmethod
|
|
def get_columns(cls) -> list:
|
|
"""Get list of column names."""
|
|
return [column.name for column in cls.__table__.columns]
|
|
|
|
def __repr__(self) -> str:
|
|
"""String representation of model."""
|
|
return f"<{self.__class__.__name__}(id={getattr(self, 'id', None)})>"
|
|
|
|
|
|
# Legacy session factory for backward compatibility
|
|
SessionLocal = sessionmaker()
|
|
|
|
|
|
def get_session():
|
|
"""Get database session (legacy function for compatibility)."""
|
|
return SessionLocal() |