243 lines
8.0 KiB
Python
243 lines
8.0 KiB
Python
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import func, and_, or_, desc, asc
|
||
from app.models.base import BaseModel
|
||
|
||
# Тип для модели SQLAlchemy
|
||
ModelType = TypeVar("ModelType", bound=BaseModel)
|
||
|
||
|
||
class BaseRepository(Generic[ModelType]):
|
||
"""
|
||
Базовый репозиторий с CRUD операциями.
|
||
Использует Generic для поддержки разных моделей.
|
||
"""
|
||
|
||
def __init__(self, db: Session, model: Type[ModelType]):
|
||
"""
|
||
Инициализация репозитория
|
||
|
||
Args:
|
||
db: Сессия SQLAlchemy
|
||
model: Класс модели
|
||
"""
|
||
self.db = db
|
||
self.model = model
|
||
|
||
def get_by_id(self, id: int) -> Optional[ModelType]:
|
||
"""Получить запись по ID"""
|
||
return self.db.query(self.model).filter(self.model.id == id).first()
|
||
|
||
def get_by_uuid(self, uuid: str) -> Optional[ModelType]:
|
||
"""Получить запись по UUID (если модель имеет поле uuid)"""
|
||
if hasattr(self.model, 'uuid'):
|
||
return self.db.query(self.model).filter(self.model.uuid == uuid).first()
|
||
raise AttributeError(f"Model {self.model.__name__} does not have 'uuid' field")
|
||
|
||
def get_all(
|
||
self,
|
||
skip: int = 0,
|
||
limit: int = 100,
|
||
order_by: Optional[str] = None,
|
||
descending: bool = False
|
||
) -> List[ModelType]:
|
||
"""Получить все записи с пагинацией"""
|
||
query = self.db.query(self.model)
|
||
|
||
if order_by and hasattr(self.model, order_by):
|
||
order_column = getattr(self.model, order_by)
|
||
if descending:
|
||
query = query.order_by(desc(order_column))
|
||
else:
|
||
query = query.order_by(asc(order_column))
|
||
elif hasattr(self.model, 'created_at'):
|
||
query = query.order_by(desc(self.model.created_at))
|
||
|
||
return query.offset(skip).limit(limit).all()
|
||
|
||
def create(self, obj_in: Dict[str, Any]) -> ModelType:
|
||
"""Создать новую запись"""
|
||
db_obj = self.model(**obj_in)
|
||
self.db.add(db_obj)
|
||
self.db.commit()
|
||
self.db.refresh(db_obj)
|
||
return db_obj
|
||
|
||
def create_many(self, objects_in: List[Dict[str, Any]]) -> List[ModelType]:
|
||
"""Создать несколько записей"""
|
||
db_objects = [self.model(**obj_in) for obj_in in objects_in]
|
||
self.db.add_all(db_objects)
|
||
self.db.commit()
|
||
for obj in db_objects:
|
||
self.db.refresh(obj)
|
||
return db_objects
|
||
|
||
def update(self, id: int, obj_in: Dict[str, Any]) -> Optional[ModelType]:
|
||
"""Обновить запись"""
|
||
db_obj = self.get_by_id(id)
|
||
if not db_obj:
|
||
return None
|
||
|
||
for field, value in obj_in.items():
|
||
if hasattr(db_obj, field):
|
||
setattr(db_obj, field, value)
|
||
|
||
self.db.commit()
|
||
self.db.refresh(db_obj)
|
||
return db_obj
|
||
|
||
def update_by_uuid(self, uuid: str, obj_in: Dict[str, Any]) -> Optional[ModelType]:
|
||
"""Обновить запись по UUID"""
|
||
db_obj = self.get_by_uuid(uuid)
|
||
if not db_obj:
|
||
return None
|
||
|
||
for field, value in obj_in.items():
|
||
if hasattr(db_obj, field):
|
||
setattr(db_obj, field, value)
|
||
|
||
self.db.commit()
|
||
self.db.refresh(db_obj)
|
||
return db_obj
|
||
|
||
def delete(self, id: int) -> bool:
|
||
"""Удалить запись"""
|
||
db_obj = self.get_by_id(id)
|
||
if not db_obj:
|
||
return False
|
||
|
||
self.db.delete(db_obj)
|
||
self.db.commit()
|
||
return True
|
||
|
||
def delete_by_uuid(self, uuid: str) -> bool:
|
||
"""Удалить запись по UUID"""
|
||
db_obj = self.get_by_uuid(uuid)
|
||
if not db_obj:
|
||
return False
|
||
|
||
self.db.delete(db_obj)
|
||
self.db.commit()
|
||
return True
|
||
|
||
def count(self, **filters) -> int:
|
||
"""Подсчитать количество записей с фильтрами"""
|
||
query = self.db.query(self.model)
|
||
|
||
for field, value in filters.items():
|
||
if hasattr(self.model, field):
|
||
query = query.filter(getattr(self.model, field) == value)
|
||
|
||
return query.count()
|
||
|
||
def exists(self, **filters) -> bool:
|
||
"""Проверить существование записи"""
|
||
return self.count(**filters) > 0
|
||
|
||
def filter_by(
|
||
self,
|
||
skip: int = 0,
|
||
limit: int = 100,
|
||
order_by: Optional[str] = None,
|
||
descending: bool = False,
|
||
**filters
|
||
) -> List[ModelType]:
|
||
"""Фильтрация записей по параметрам"""
|
||
query = self.db.query(self.model)
|
||
|
||
for field, value in filters.items():
|
||
if hasattr(self.model, field):
|
||
query = query.filter(getattr(self.model, field) == value)
|
||
|
||
if order_by and hasattr(self.model, order_by):
|
||
order_column = getattr(self.model, order_by)
|
||
if descending:
|
||
query = query.order_by(desc(order_column))
|
||
else:
|
||
query = query.order_by(asc(order_column))
|
||
|
||
return query.offset(skip).limit(limit).all()
|
||
|
||
def find_first(self, **filters) -> Optional[ModelType]:
|
||
"""Найти первую запись по фильтрам"""
|
||
query = self.db.query(self.model)
|
||
|
||
for field, value in filters.items():
|
||
if hasattr(self.model, field):
|
||
query = query.filter(getattr(self.model, field) == value)
|
||
|
||
return query.first()
|
||
|
||
def bulk_update(self, ids: List[int], obj_in: Dict[str, Any]) -> int:
|
||
"""Массовое обновление записей"""
|
||
updated_count = self.db.query(self.model).filter(
|
||
self.model.id.in_(ids)
|
||
).update(obj_in, synchronize_session=False)
|
||
|
||
self.db.commit()
|
||
return updated_count
|
||
|
||
def bulk_delete(self, ids: List[int]) -> int:
|
||
"""Массовое удаление записей"""
|
||
deleted_count = self.db.query(self.model).filter(
|
||
self.model.id.in_(ids)
|
||
).delete(synchronize_session=False)
|
||
|
||
self.db.commit()
|
||
return deleted_count
|
||
|
||
def get_or_create(self, defaults: Optional[Dict] = None, **filters) -> tuple[ModelType, bool]:
|
||
"""
|
||
Получить запись или создать новую
|
||
|
||
Returns:
|
||
tuple: (объект, создан_ли_новый)
|
||
"""
|
||
instance = self.find_first(**filters)
|
||
|
||
if instance:
|
||
return instance, False
|
||
|
||
if defaults:
|
||
filters.update(defaults)
|
||
|
||
return self.create(filters), True
|
||
|
||
def paginate(
|
||
self,
|
||
page: int = 1,
|
||
per_page: int = 20,
|
||
order_by: Optional[str] = None,
|
||
descending: bool = False,
|
||
**filters
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
Пагинированный список записей
|
||
|
||
Returns:
|
||
Dict с ключами: items, total, page, per_page, pages
|
||
"""
|
||
query = self.db.query(self.model)
|
||
|
||
for field, value in filters.items():
|
||
if hasattr(self.model, field):
|
||
query = query.filter(getattr(self.model, field) == value)
|
||
|
||
total = query.count()
|
||
|
||
if order_by and hasattr(self.model, order_by):
|
||
order_column = getattr(self.model, order_by)
|
||
if descending:
|
||
query = query.order_by(desc(order_column))
|
||
else:
|
||
query = query.order_by(asc(order_column))
|
||
|
||
items = query.offset((page - 1) * per_page).limit(per_page).all()
|
||
|
||
return {
|
||
"items": items,
|
||
"total": total,
|
||
"page": page,
|
||
"per_page": per_page,
|
||
"pages": (total + per_page - 1) // per_page
|
||
} |