from types import NoneType, UnionType from typing import ( Any, ClassVar, ForwardRef, Optional, Self, Union, get_args, get_origin, TYPE_CHECKING, dataclass_transform, ) from pydantic import BaseModel from sqlmodel import SQLModel, BigInteger, Field, select, func, column, col from sqlmodel.main import FieldInfo from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.sql.expression import SelectOfScalar from sqlmodel.main import SQLModelMetaclass, RelationshipInfo from .descriptors import EntityDescriptor, EntityField, FieldDescriptor, Filter from .entity_metadata import EntityMetadata from . import session_dep if TYPE_CHECKING: from .user import UserBase @dataclass_transform( kw_only_default=True, field_specifiers=(Field, FieldInfo, EntityField, FieldDescriptor), ) class BotEntityMetaclass(SQLModelMetaclass): _future_references = {} def __new__(mcs, name, bases, namespace, **kwargs): bot_fields_descriptors = {} if bases: bot_entity_descriptor = bases[0].__dict__.get("bot_entity_descriptor") bot_fields_descriptors = ( { key: FieldDescriptor(**value.__dict__.copy()) for key, value in bot_entity_descriptor.fields_descriptors.items() } if bot_entity_descriptor else {} ) if "__annotations__" in namespace: for annotation in namespace["__annotations__"]: if annotation in ["bot_entity_descriptor", "entity_metadata"]: continue attribute_value = namespace.get(annotation) if isinstance(attribute_value, RelationshipInfo): continue descriptor_kwargs = {} descriptor_name = annotation if attribute_value: if isinstance(attribute_value, EntityField): descriptor_kwargs = attribute_value.__dict__.copy() sm_descriptor = descriptor_kwargs.pop("sm_descriptor", None) if sm_descriptor: namespace[annotation] = sm_descriptor else: namespace.pop(annotation) descriptor_name = descriptor_kwargs.pop("name") or annotation type_ = namespace["__annotations__"][annotation] type_origin = get_origin(type_) field_descriptor = FieldDescriptor( name=descriptor_name, field_name=annotation, type_=type_, type_base=type_, **descriptor_kwargs, ) is_list = False is_optional = False if type_origin is list: field_descriptor.is_list = is_list = True field_descriptor.type_base = type_ = get_args(type_)[0] if type_origin is Union: args = get_args(type_) if isinstance(args[0], ForwardRef): field_descriptor.is_optional = is_optional = True field_descriptor.type_base = type_ = args[0].__forward_arg__ elif args[1] is NoneType: field_descriptor.is_optional = is_optional = True field_descriptor.type_base = type_ = args[0] if type_origin is UnionType and get_args(type_)[1] is NoneType: field_descriptor.is_optional = is_optional = True field_descriptor.type_base = type_ = get_args(type_)[0] if isinstance(type_, str): type_not_found = True for ( entity_descriptor ) in EntityMetadata().entity_descriptors.values(): if type_ == entity_descriptor.class_name: field_descriptor.type_base = entity_descriptor.type_ field_descriptor.type_ = ( list[entity_descriptor.type_] if is_list else ( Optional[entity_descriptor.type_] if type_origin == Union and is_optional else ( entity_descriptor.type_ | None if (type_origin == UnionType and is_optional) else entity_descriptor.type_ ) ) ) type_not_found = False break if type_not_found: if type_ in mcs._future_references: mcs._future_references[type_].append(field_descriptor) else: mcs._future_references[type_] = [field_descriptor] bot_fields_descriptors[descriptor_name] = field_descriptor descriptor_name = name if "bot_entity_descriptor" in namespace: entity_descriptor = namespace.pop("bot_entity_descriptor") descriptor_kwargs: dict = entity_descriptor.__dict__.copy() descriptor_name = descriptor_kwargs.pop("name", None) descriptor_name = descriptor_name or name.lower() namespace["bot_entity_descriptor"] = EntityDescriptor( name=descriptor_name, class_name=name, type_=name, fields_descriptors=bot_fields_descriptors, **descriptor_kwargs, ) else: descriptor_name = name.lower() namespace["bot_entity_descriptor"] = EntityDescriptor( name=descriptor_name, class_name=name, type_=name, fields_descriptors=bot_fields_descriptors, ) descriptor_fields_sequence = [ key for key, val in bot_fields_descriptors.items() if not (val.is_optional or val.name == "id") ] entity_descriptor: EntityDescriptor = namespace["bot_entity_descriptor"] if entity_descriptor.default_form.edit_field_sequence is None: entity_descriptor.default_form.edit_field_sequence = ( descriptor_fields_sequence ) for form in entity_descriptor.forms.values(): if form.edit_field_sequence is None: form.edit_field_sequence = descriptor_fields_sequence for field_descriptor in bot_fields_descriptors.values(): field_descriptor.entity_descriptor = namespace["bot_entity_descriptor"] if "table" not in kwargs: kwargs["table"] = True if kwargs["table"]: entity_metadata = EntityMetadata() entity_metadata.entity_descriptors[descriptor_name] = namespace[ "bot_entity_descriptor" ] if "__annotations__" in namespace: namespace["__annotations__"]["entity_metadata"] = ClassVar[ EntityMetadata ] else: namespace["__annotations__"] = { "entity_metadata": ClassVar[EntityMetadata] } namespace["entity_metadata"] = entity_metadata type_ = super().__new__(mcs, name, bases, namespace, **kwargs) if name in mcs._future_references: for field_descriptor in mcs._future_references[name]: type_origin = get_origin(field_descriptor.type_) field_descriptor.type_base = type_ field_descriptor.type_ = ( list[type_] if type_origin is list else ( Optional[type_] if type_origin == Union and isinstance(get_args(field_descriptor.type_)[0], ForwardRef) else type_ | None if type_origin == UnionType else type_ ) ) setattr(namespace["bot_entity_descriptor"], "type_", type_) return type_ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel]( SQLModel, metaclass=BotEntityMetaclass, table=False ): bot_entity_descriptor: ClassVar[EntityDescriptor] entity_metadata: ClassVar[EntityMetadata] id: int = EntityField( sm_descriptor=Field(primary_key=True, sa_type=BigInteger), is_visible=False ) @classmethod @session_dep async def get(cls, *, session: AsyncSession | None = None, id: int): return await session.get(cls, id, populate_existing=True) @classmethod def _static_filter_condition( cls, select_statement: SelectOfScalar[Self], static_filter: list[Filter] ): for sfilt in static_filter: column = getattr(cls, sfilt.field_name) if sfilt.operator == "==": condition = column.__eq__(sfilt.value) elif sfilt.operator == "!=": condition = column.__ne__(sfilt.value) elif sfilt.operator == "<": condition = column.__lt__(sfilt.value) elif sfilt.operator == "<=": condition = column.__le__(sfilt.value) elif sfilt.operator == ">": condition = column.__gt__(sfilt.value) elif sfilt.operator == ">=": condition = column.__ge__(sfilt.value) elif sfilt.operator == "ilike": condition = col(column).ilike(f"%{sfilt.value}%") elif sfilt.operator == "like": condition = col(column).like(f"%{sfilt.value}%") elif sfilt.operator == "in": condition = col(column).in_(sfilt.value) elif sfilt.operator == "not in": condition = col(column).notin_(sfilt.value) elif sfilt.operator == "is none": condition = col(column).is_(None) elif sfilt.operator == "is not none": condition = col(column).isnot(None) elif sfilt.operator == "contains": condition = sfilt.value == col(column).any_() else: condition = None if condition is not None: select_statement = select_statement.where(condition) return select_statement @classmethod def _filter_condition( cls, select_statement: SelectOfScalar[Self], filter: str, filter_fields: list[str], ): condition = None for field in filter_fields: if condition is not None: condition = condition | (column(field).ilike(f"%{filter}%")) else: condition = column(field).ilike(f"%{filter}%") return select_statement.where(condition) @classmethod @session_dep async def get_count( cls, *, session: AsyncSession | None = None, static_filter: list[Filter] | Any = None, filter: str = None, filter_fields: list[str] = None, ext_filter: Any = None, user: "UserBase" = None, ) -> int: select_statement = select(func.count()).select_from(cls) if static_filter: if isinstance(static_filter, list): select_statement = cls._static_filter_condition( select_statement, static_filter ) else: select_statement = select_statement.where(static_filter) if filter and filter_fields: select_statement = cls._filter_condition( select_statement, filter, filter_fields ) if ext_filter: select_statement = select_statement.where(ext_filter) if user: select_statement = cls._ownership_condition(select_statement, user) return await session.scalar(select_statement) @classmethod @session_dep async def get_multi( cls, *, session: AsyncSession | None = None, order_by=None, static_filter: list[Filter] | Any = None, filter: str = None, filter_fields: list[str] = None, ext_filter: Any = None, user: "UserBase" = None, skip: int = 0, limit: int = None, ): select_statement = select(cls).offset(skip) if limit: select_statement = select_statement.limit(limit) if static_filter is not None: if isinstance(static_filter, list): select_statement = cls._static_filter_condition( select_statement, static_filter ) else: select_statement = select_statement.where(static_filter) if filter and filter_fields: select_statement = cls._filter_condition( select_statement, filter, filter_fields ) if ext_filter is not None: select_statement = select_statement.where(ext_filter) if user: select_statement = cls._ownership_condition(select_statement, user) if order_by is not None: select_statement = select_statement.order_by(order_by) return (await session.exec(select_statement)).all() @classmethod def _ownership_condition( cls, select_statement: SelectOfScalar[Self], user: "UserBase" ): if cls.bot_entity_descriptor.ownership_fields: condition = None for role in user.roles: if role in cls.bot_entity_descriptor.ownership_fields: owner_col = column(cls.bot_entity_descriptor.ownership_fields[role]) if condition is not None: condition = condition | (owner_col == user.id) else: condition = owner_col == user.id else: condition = None break if condition is not None: return select_statement.where(condition) return select_statement @classmethod @session_dep async def create( cls, *, session: AsyncSession | None = None, obj_in: CreateSchemaType, commit: bool = False, ): if isinstance(obj_in, cls): obj = obj_in else: obj = cls(**obj_in.model_dump()) session.add(obj) if commit: await session.commit() return obj @classmethod @session_dep async def update( cls, *, session: AsyncSession | None = None, id: int, obj_in: UpdateSchemaType, commit: bool = False, ): obj = await session.get(cls, id) if obj: obj_data = obj.model_dump() update_data = obj_in.model_dump(exclude_unset=True) for field in obj_data: if field in update_data: setattr(obj, field, update_data[field]) session.add(obj) if commit: await session.commit() return obj return None @classmethod @session_dep async def remove( cls, *, session: AsyncSession | None = None, id: int, commit: bool = False ): obj = await session.get(cls, id) if obj: await session.delete(obj) if commit: await session.commit() return obj return None