from functools import wraps from typing import ClassVar, cast, get_args, get_origin from pydantic import BaseModel from sqlmodel import SQLModel, BIGINT, Field, select, func from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.main import SQLModelMetaclass, RelationshipInfo from .descriptors import EntityDescriptor, EntityField, EntityFieldDescriptor from .entity_metadata import EntityMetadata from . import session_dep class BotEntityMetaclass(SQLModelMetaclass): __future_references__ = {} def __new__(mcs, name, bases, namespace, **kwargs): bot_fields_descriptors = {} if bases: bot_entity_descriptor = bases[0].__dict__.get('bot_entity_descriptor') bot_fields_descriptors = {key: EntityFieldDescriptor(**value.__dict__.copy()) for key, value in bot_entity_descriptor.fields_descriptors.items()} if bot_entity_descriptor else {} if '__annotations__' in namespace: for annotation in namespace['__annotations__']: if annotation in ["bot_entity_descriptor", "entity_metadata"]: continue attribute_value = namespace.get(annotation) if isinstance(attribute_value, RelationshipInfo): continue descriptor_kwargs = {} descriptor_name = annotation if attribute_value: if isinstance(attribute_value, EntityField): descriptor_kwargs = attribute_value.__dict__.copy() sm_descriptor = descriptor_kwargs.pop("sm_descriptor", None) if sm_descriptor: namespace[annotation] = sm_descriptor else: namespace.pop(annotation) descriptor_name = descriptor_kwargs.pop("name") or annotation type_ = namespace['__annotations__'][annotation] field_descriptor = EntityFieldDescriptor( name = descriptor_name, field_name = annotation, type_ = type_, **descriptor_kwargs) type_origin = get_origin(type_) is_list = False if type_origin == list: is_list = True type_ = get_args(type_)[0] if isinstance(type_, str): type_not_found = True for entity_descriptor in EntityMetadata().entity_descriptors.values(): if type_ == entity_descriptor.class_name: field_descriptor.type_ = list[entity_descriptor.type_] if is_list else entity_descriptor.type_ type_not_found = False break if type_not_found: if type_ in mcs.__future_references__: mcs.__future_references__[type_].append(field_descriptor) else: mcs.__future_references__[type_] = [field_descriptor] bot_fields_descriptors[descriptor_name] = field_descriptor descriptor_name = name if "bot_entity_descriptor" in namespace: entity_descriptor = namespace.pop("bot_entity_descriptor") descriptor_kwargs: dict = entity_descriptor.__dict__.copy() descriptor_name = descriptor_kwargs.pop("name", None) descriptor_name = descriptor_name or name.lower() descriptor_fields_sequence = descriptor_kwargs.pop("field_sequence", None) if not descriptor_fields_sequence: descriptor_fields_sequence = list(bot_fields_descriptors.keys()) descriptor_fields_sequence.remove("id") namespace["bot_entity_descriptor"] = EntityDescriptor( name = descriptor_name, class_name = name, type_ = name, fields_descriptors = bot_fields_descriptors, field_sequence = descriptor_fields_sequence, **descriptor_kwargs) else: descriptor_fields_sequence = list(bot_fields_descriptors.keys()) descriptor_fields_sequence.remove("id") descriptor_name = name.lower() namespace["bot_entity_descriptor"] = EntityDescriptor( name = descriptor_name, class_name = name, type_ = name, fields_descriptors = bot_fields_descriptors, field_sequence = descriptor_fields_sequence) for field_descriptor in bot_fields_descriptors.values(): field_descriptor.entity_descriptor = namespace["bot_entity_descriptor"] if "table" not in kwargs: kwargs["table"] = True if kwargs["table"] == True: entity_metadata = EntityMetadata() entity_metadata.entity_descriptors[descriptor_name] = namespace["bot_entity_descriptor"] if "__annotations__" in namespace: namespace["__annotations__"]["entity_metadata"] = ClassVar[EntityMetadata] else: namespace["__annotations__"] = {"entity_metadata": ClassVar[EntityMetadata]} namespace["entity_metadata"] = entity_metadata type_ = super().__new__(mcs, name, bases, namespace, **kwargs) if name in mcs.__future_references__: for field_descriptor in mcs.__future_references__[name]: field_descriptor.type_ = list[type_] if get_origin(field_descriptor.type_) == list else type_ a = field_descriptor setattr(namespace["bot_entity_descriptor"], "type_", type_) return type_ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](SQLModel, metaclass = BotEntityMetaclass, table = False): bot_entity_descriptor: ClassVar[EntityDescriptor] entity_metadata: ClassVar[EntityMetadata] id: int = Field( primary_key = True, sa_type = BIGINT) name: str @classmethod @session_dep async def get(cls, *, session: AsyncSession | None = None, id: int): return await session.get(cls, id) @classmethod @session_dep async def get_count(cls, *, session: AsyncSession | None = None) -> int: return await session.scalar(select(func.count()).select_from(cls)) @classmethod @session_dep async def get_multi(cls, *, session: AsyncSession | None = None, order_by = None, skip: int = 0, limit: int = None): select_statement = select(cls).offset(skip) if limit: select_statement = select_statement.limit(limit) if order_by: select_statement = select_statement.order_by(order_by) return (await session.exec(select_statement)).all() @classmethod @session_dep async def create(cls, *, session: AsyncSession | None = None, obj_in: CreateSchemaType, commit: bool = False): if isinstance(obj_in, cls): obj = obj_in else: obj = cls(**obj_in.model_dump()) session.add(obj) if commit: await session.commit() return obj @classmethod @session_dep async def update(cls, *, session: AsyncSession | None = None, id: int, obj_in: UpdateSchemaType, commit: bool = False): obj = await session.get(cls, id) if obj: obj_data = obj.model_dump() update_data = obj_in.model_dump(exclude_unset=True) for field in obj_data: if field in update_data: setattr(obj, field, update_data[field]) session.add(obj) if commit: await session.commit() return obj return None @classmethod @session_dep async def remove(cls, *, session: AsyncSession | None = None, id: int, commit: bool = False): obj = await session.get(cls, id) if obj: await session.delete(obj) if commit: await session.commit() return obj return None