278 lines
7.7 KiB
Python
278 lines
7.7 KiB
Python
"""
|
|
Base model classes with async SQLAlchemy support
|
|
"""
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Any, Dict, Optional, Type, TypeVar, Union
|
|
|
|
from sqlalchemy import Column, DateTime, String, Boolean, Integer, Text, JSON
|
|
from sqlalchemy.dialects.postgresql import UUID
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.future import select
|
|
from sqlalchemy.orm import sessionmaker
|
|
from pydantic import BaseModel
|
|
import structlog
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
# Create declarative base
|
|
Base = declarative_base()
|
|
|
|
# Type variable for model classes
|
|
ModelType = TypeVar("ModelType", bound="BaseModel")
|
|
|
|
|
|
class TimestampMixin:
|
|
"""Mixin for automatic timestamp fields"""
|
|
|
|
created_at = Column(
|
|
DateTime,
|
|
nullable=False,
|
|
default=datetime.utcnow,
|
|
comment="Record creation timestamp"
|
|
)
|
|
updated_at = Column(
|
|
DateTime,
|
|
nullable=False,
|
|
default=datetime.utcnow,
|
|
onupdate=datetime.utcnow,
|
|
comment="Record last update timestamp"
|
|
)
|
|
|
|
|
|
class UUIDMixin:
|
|
"""Mixin for UUID primary key"""
|
|
|
|
id = Column(
|
|
UUID(as_uuid=True),
|
|
primary_key=True,
|
|
default=uuid.uuid4,
|
|
comment="Unique identifier"
|
|
)
|
|
|
|
|
|
class SoftDeleteMixin:
|
|
"""Mixin for soft delete functionality"""
|
|
|
|
deleted_at = Column(
|
|
DateTime,
|
|
nullable=True,
|
|
comment="Soft delete timestamp"
|
|
)
|
|
|
|
@property
|
|
def is_deleted(self) -> bool:
|
|
"""Check if record is soft deleted"""
|
|
return self.deleted_at is not None
|
|
|
|
def soft_delete(self):
|
|
"""Mark record as soft deleted"""
|
|
self.deleted_at = datetime.utcnow()
|
|
|
|
def restore(self):
|
|
"""Restore soft deleted record"""
|
|
self.deleted_at = None
|
|
|
|
|
|
class MetadataMixin:
|
|
"""Mixin for flexible metadata storage"""
|
|
|
|
metadata = Column(
|
|
JSON,
|
|
nullable=False,
|
|
default=dict,
|
|
comment="Flexible metadata storage"
|
|
)
|
|
|
|
def set_meta(self, key: str, value: Any) -> None:
|
|
"""Set metadata value"""
|
|
if self.metadata is None:
|
|
self.metadata = {}
|
|
self.metadata[key] = value
|
|
|
|
def get_meta(self, key: str, default: Any = None) -> Any:
|
|
"""Get metadata value"""
|
|
if self.metadata is None:
|
|
return default
|
|
return self.metadata.get(key, default)
|
|
|
|
def update_meta(self, updates: Dict[str, Any]) -> None:
|
|
"""Update multiple metadata values"""
|
|
if self.metadata is None:
|
|
self.metadata = {}
|
|
self.metadata.update(updates)
|
|
|
|
|
|
class StatusMixin:
|
|
"""Mixin for status tracking"""
|
|
|
|
status = Column(
|
|
String(64),
|
|
nullable=False,
|
|
default="active",
|
|
index=True,
|
|
comment="Record status"
|
|
)
|
|
|
|
def set_status(self, status: str, reason: Optional[str] = None):
|
|
"""Set status with optional reason"""
|
|
self.status = status
|
|
if reason:
|
|
self.set_meta("status_reason", reason)
|
|
self.set_meta("status_changed_at", datetime.utcnow().isoformat())
|
|
|
|
|
|
class BaseModelMixin:
|
|
"""Base mixin with common functionality"""
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert model to dictionary"""
|
|
result = {}
|
|
for column in self.__table__.columns:
|
|
value = getattr(self, column.name)
|
|
if isinstance(value, datetime):
|
|
value = value.isoformat()
|
|
elif hasattr(value, '__dict__'):
|
|
value = str(value)
|
|
result[column.name] = value
|
|
return result
|
|
|
|
def update_from_dict(self, data: Dict[str, Any]) -> None:
|
|
"""Update model from dictionary"""
|
|
for key, value in data.items():
|
|
if hasattr(self, key):
|
|
setattr(self, key, value)
|
|
|
|
@classmethod
|
|
async def get_by_id(
|
|
cls: Type[ModelType],
|
|
session: AsyncSession,
|
|
id_value: Union[int, str, uuid.UUID]
|
|
) -> Optional[ModelType]:
|
|
"""Get record by ID"""
|
|
try:
|
|
stmt = select(cls).where(cls.id == id_value)
|
|
result = await session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except Exception as e:
|
|
logger.error("Error getting record by ID", model=cls.__name__, id=id_value, error=str(e))
|
|
return None
|
|
|
|
@classmethod
|
|
async def get_all(
|
|
cls: Type[ModelType],
|
|
session: AsyncSession,
|
|
limit: Optional[int] = None,
|
|
offset: Optional[int] = None
|
|
) -> list[ModelType]:
|
|
"""Get all records with optional pagination"""
|
|
try:
|
|
stmt = select(cls)
|
|
if offset:
|
|
stmt = stmt.offset(offset)
|
|
if limit:
|
|
stmt = stmt.limit(limit)
|
|
result = await session.execute(stmt)
|
|
return result.scalars().all()
|
|
except Exception as e:
|
|
logger.error("Error getting all records", model=cls.__name__, error=str(e))
|
|
return []
|
|
|
|
@classmethod
|
|
async def count(cls: Type[ModelType], session: AsyncSession) -> int:
|
|
"""Get total count of records"""
|
|
try:
|
|
from sqlalchemy import func
|
|
stmt = select(func.count(cls.id))
|
|
result = await session.execute(stmt)
|
|
return result.scalar() or 0
|
|
except Exception as e:
|
|
logger.error("Error counting records", model=cls.__name__, error=str(e))
|
|
return 0
|
|
|
|
async def save(self, session: AsyncSession) -> None:
|
|
"""Save model to database"""
|
|
try:
|
|
session.add(self)
|
|
await session.commit()
|
|
await session.refresh(self)
|
|
except Exception as e:
|
|
await session.rollback()
|
|
logger.error("Error saving model", model=self.__class__.__name__, error=str(e))
|
|
raise
|
|
|
|
async def delete(self, session: AsyncSession) -> None:
|
|
"""Delete model from database"""
|
|
try:
|
|
await session.delete(self)
|
|
await session.commit()
|
|
except Exception as e:
|
|
await session.rollback()
|
|
logger.error("Error deleting model", model=self.__class__.__name__, error=str(e))
|
|
raise
|
|
|
|
|
|
class AuditMixin:
|
|
"""Mixin for audit trail"""
|
|
|
|
created_by = Column(
|
|
UUID(as_uuid=True),
|
|
nullable=True,
|
|
comment="User who created the record"
|
|
)
|
|
updated_by = Column(
|
|
UUID(as_uuid=True),
|
|
nullable=True,
|
|
comment="User who last updated the record"
|
|
)
|
|
|
|
def set_audit_info(self, user_id: Optional[uuid.UUID] = None):
|
|
"""Set audit information"""
|
|
if user_id:
|
|
if not hasattr(self, 'created_at') or not self.created_at:
|
|
self.created_by = user_id
|
|
self.updated_by = user_id
|
|
|
|
|
|
class CacheableMixin:
|
|
"""Mixin for cacheable models"""
|
|
|
|
@property
|
|
def cache_key(self) -> str:
|
|
"""Generate cache key for this model"""
|
|
return f"{self.__class__.__name__.lower()}:{self.id}"
|
|
|
|
@property
|
|
def cache_ttl(self) -> int:
|
|
"""Default cache TTL in seconds"""
|
|
return 3600 # 1 hour
|
|
|
|
def get_cache_data(self) -> Dict[str, Any]:
|
|
"""Get data for caching"""
|
|
return self.to_dict()
|
|
|
|
|
|
# Combined base model class
|
|
class BaseModel(
|
|
Base,
|
|
BaseModelMixin,
|
|
TimestampMixin,
|
|
UUIDMixin,
|
|
SoftDeleteMixin,
|
|
MetadataMixin,
|
|
StatusMixin,
|
|
AuditMixin,
|
|
CacheableMixin
|
|
):
|
|
"""Base model with all mixins"""
|
|
__abstract__ = True
|
|
|
|
def __repr__(self) -> str:
|
|
"""String representation of model"""
|
|
return f"<{self.__class__.__name__}(id={self.id})>"
|
|
|
|
|
|
# Compatibility with old model base
|
|
AlchemyBase = Base
|