from functools import wraps from types import NoneType, UnionType from typing import ClassVar, ForwardRef, Optional, Union, cast, get_args, get_origin from pydantic import BaseModel from sqlmodel import SQLModel, BIGINT, Field, select, func, column from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.main import SQLModelMetaclass, RelationshipInfo from .descriptors import EntityDescriptor, EntityField, EntityFieldDescriptor from .entity_metadata import EntityMetadata from . import session_dep class BotEntityMetaclass(SQLModelMetaclass): __future_references__ = {} def __new__(mcs, name, bases, namespace, **kwargs): bot_fields_descriptors = {} if bases: bot_entity_descriptor = bases[0].__dict__.get('bot_entity_descriptor') bot_fields_descriptors = {key: EntityFieldDescriptor(**value.__dict__.copy()) for key, value in bot_entity_descriptor.fields_descriptors.items()} if bot_entity_descriptor else {} if '__annotations__' in namespace: for annotation in namespace['__annotations__']: if annotation in ["bot_entity_descriptor", "entity_metadata"]: continue attribute_value = namespace.get(annotation) if isinstance(attribute_value, RelationshipInfo): continue descriptor_kwargs = {} descriptor_name = annotation if attribute_value: if isinstance(attribute_value, EntityField): descriptor_kwargs = attribute_value.__dict__.copy() sm_descriptor = descriptor_kwargs.pop("sm_descriptor", None) if sm_descriptor: namespace[annotation] = sm_descriptor else: namespace.pop(annotation) descriptor_name = descriptor_kwargs.pop("name") or annotation type_ = namespace['__annotations__'][annotation] type_origin = get_origin(type_) field_descriptor = EntityFieldDescriptor( name = descriptor_name, field_name = annotation, type_ = type_, type_base = type_, **descriptor_kwargs) is_list = False if type_origin == list: field_descriptor.is_list = is_list = True field_descriptor.type_base = type_ = get_args(type_)[0] if type_origin == Union and isinstance(get_args(type_)[0], ForwardRef): field_descriptor.is_optional = True field_descriptor.type_base = type_ = get_args(type_)[0].__forward_arg__ if type_origin == UnionType and get_args(type_)[1] == NoneType: field_descriptor.is_optional = True field_descriptor.type_base = type_ = get_args(type_)[0] if isinstance(type_, str): type_not_found = True for entity_descriptor in EntityMetadata().entity_descriptors.values(): if type_ == entity_descriptor.class_name: field_descriptor.type_ = (list[entity_descriptor.type_] if is_list else Optional[entity_descriptor.type_] if type_origin == Optional else entity_descriptor.type_ | None if (type_origin == UnionType and get_args(type_)[1] == NoneType) else entity_descriptor.type_) type_not_found = False break if type_not_found: if type_ in mcs.__future_references__: mcs.__future_references__[type_].append(field_descriptor) else: mcs.__future_references__[type_] = [field_descriptor] bot_fields_descriptors[descriptor_name] = field_descriptor descriptor_name = name if "bot_entity_descriptor" in namespace: entity_descriptor = namespace.pop("bot_entity_descriptor") descriptor_kwargs: dict = entity_descriptor.__dict__.copy() descriptor_name = descriptor_kwargs.pop("name", None) descriptor_name = descriptor_name or name.lower() descriptor_fields_sequence = descriptor_kwargs.pop("field_sequence", None) if not descriptor_fields_sequence: descriptor_fields_sequence = list(bot_fields_descriptors.keys()) descriptor_fields_sequence.remove("id") namespace["bot_entity_descriptor"] = EntityDescriptor( name = descriptor_name, class_name = name, type_ = name, fields_descriptors = bot_fields_descriptors, field_sequence = descriptor_fields_sequence, **descriptor_kwargs) else: descriptor_fields_sequence = list(bot_fields_descriptors.keys()) descriptor_fields_sequence.remove("id") descriptor_name = name.lower() namespace["bot_entity_descriptor"] = EntityDescriptor( name = descriptor_name, class_name = name, type_ = name, fields_descriptors = bot_fields_descriptors, field_sequence = descriptor_fields_sequence) for field_descriptor in bot_fields_descriptors.values(): field_descriptor.entity_descriptor = namespace["bot_entity_descriptor"] if "table" not in kwargs: kwargs["table"] = True if kwargs["table"] == True: entity_metadata = EntityMetadata() entity_metadata.entity_descriptors[descriptor_name] = namespace["bot_entity_descriptor"] if "__annotations__" in namespace: namespace["__annotations__"]["entity_metadata"] = ClassVar[EntityMetadata] else: namespace["__annotations__"] = {"entity_metadata": ClassVar[EntityMetadata]} namespace["entity_metadata"] = entity_metadata type_ = super().__new__(mcs, name, bases, namespace, **kwargs) if name in mcs.__future_references__: for field_descriptor in mcs.__future_references__[name]: type_origin = get_origin(field_descriptor.type_) field_descriptor.type_base = type_ field_descriptor.type_ = (list[type_] if get_origin(field_descriptor.type_) == list else Optional[type_] if type_origin == Union and isinstance(get_args(field_descriptor.type_)[0], ForwardRef) else type_ | None if type_origin == UnionType else type_) setattr(namespace["bot_entity_descriptor"], "type_", type_) return type_ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](SQLModel, metaclass = BotEntityMetaclass, table = False): bot_entity_descriptor: ClassVar[EntityDescriptor] entity_metadata: ClassVar[EntityMetadata] id: int = Field( primary_key = True, sa_type = BIGINT) name: str @classmethod @session_dep async def get(cls, *, session: AsyncSession | None = None, id: int): return await session.get(cls, id, populate_existing = True) @classmethod @session_dep async def get_count(cls, *, session: AsyncSession | None = None, filter: str = None) -> int: select_statement = select(func.count()).select_from(cls) if filter: select_statement = select_statement.where(column("name").ilike(f"%{filter}%")) return await session.scalar(select_statement) @classmethod @session_dep async def get_multi(cls, *, session: AsyncSession | None = None, order_by = None, filter:str = None, skip: int = 0, limit: int = None): select_statement = select(cls).offset(skip) if limit: select_statement = select_statement.limit(limit) if filter: select_statement = select_statement.where(column("name").ilike(f"%{filter}%")) if order_by: select_statement = select_statement.order_by(order_by) return (await session.exec(select_statement)).all() @classmethod @session_dep async def create(cls, *, session: AsyncSession | None = None, obj_in: CreateSchemaType, commit: bool = False): if isinstance(obj_in, cls): obj = obj_in else: obj = cls(**obj_in.model_dump()) session.add(obj) if commit: await session.commit() return obj @classmethod @session_dep async def update(cls, *, session: AsyncSession | None = None, id: int, obj_in: UpdateSchemaType, commit: bool = False): obj = await session.get(cls, id) if obj: obj_data = obj.model_dump() update_data = obj_in.model_dump(exclude_unset=True) for field in obj_data: if field in update_data: setattr(obj, field, update_data[field]) session.add(obj) if commit: await session.commit() return obj return None @classmethod @session_dep async def remove(cls, *, session: AsyncSession | None = None, id: int, commit: bool = False): obj = await session.get(cls, id) if obj: await session.delete(obj) if commit: await session.commit() return obj return None