Files
quickbot/model/bot_entity.py
Alexander Kalinovsky 3898a333fa refactoring
2025-01-09 13:11:10 +01:00

267 lines
10 KiB
Python

from functools import wraps
from types import NoneType, UnionType
from typing import ClassVar, ForwardRef, Optional, Union, cast, get_args, get_origin
from pydantic import BaseModel
from sqlmodel import SQLModel, BIGINT, Field, select, func, column
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import SQLModelMetaclass, RelationshipInfo
from .descriptors import EntityDescriptor, EntityField, EntityFieldDescriptor
from .entity_metadata import EntityMetadata
from . import session_dep
class BotEntityMetaclass(SQLModelMetaclass):
__future_references__ = {}
def __new__(mcs, name, bases, namespace, **kwargs):
bot_fields_descriptors = {}
if bases:
bot_entity_descriptor = bases[0].__dict__.get('bot_entity_descriptor')
bot_fields_descriptors = {key: EntityFieldDescriptor(**value.__dict__.copy())
for key, value in bot_entity_descriptor.fields_descriptors.items()} if bot_entity_descriptor else {}
if '__annotations__' in namespace:
for annotation in namespace['__annotations__']:
if annotation in ["bot_entity_descriptor", "entity_metadata"]:
continue
attribute_value = namespace.get(annotation)
if isinstance(attribute_value, RelationshipInfo):
continue
descriptor_kwargs = {}
descriptor_name = annotation
if attribute_value:
if isinstance(attribute_value, EntityField):
descriptor_kwargs = attribute_value.__dict__.copy()
sm_descriptor = descriptor_kwargs.pop("sm_descriptor", None)
if sm_descriptor:
namespace[annotation] = sm_descriptor
else:
namespace.pop(annotation)
descriptor_name = descriptor_kwargs.pop("name") or annotation
type_ = namespace['__annotations__'][annotation]
type_origin = get_origin(type_)
field_descriptor = EntityFieldDescriptor(
name = descriptor_name,
field_name = annotation,
type_ = type_,
type_base = type_,
**descriptor_kwargs)
is_list = False
if type_origin == list:
field_descriptor.is_list = is_list = True
field_descriptor.type_base = type_ = get_args(type_)[0]
if type_origin == Union and isinstance(get_args(type_)[0], ForwardRef):
field_descriptor.is_optional = True
field_descriptor.type_base = type_ = get_args(type_)[0].__forward_arg__
if type_origin == UnionType and get_args(type_)[1] == NoneType:
field_descriptor.is_optional = True
field_descriptor.type_base = type_ = get_args(type_)[0]
if isinstance(type_, str):
type_not_found = True
for entity_descriptor in EntityMetadata().entity_descriptors.values():
if type_ == entity_descriptor.class_name:
field_descriptor.type_ = (list[entity_descriptor.type_] if is_list
else Optional[entity_descriptor.type_] if type_origin == Optional
else entity_descriptor.type_ | None if (type_origin == UnionType and get_args(type_)[1] == NoneType)
else entity_descriptor.type_)
type_not_found = False
break
if type_not_found:
if type_ in mcs.__future_references__:
mcs.__future_references__[type_].append(field_descriptor)
else:
mcs.__future_references__[type_] = [field_descriptor]
bot_fields_descriptors[descriptor_name] = field_descriptor
descriptor_name = name
if "bot_entity_descriptor" in namespace:
entity_descriptor = namespace.pop("bot_entity_descriptor")
descriptor_kwargs: dict = entity_descriptor.__dict__.copy()
descriptor_name = descriptor_kwargs.pop("name", None)
descriptor_name = descriptor_name or name.lower()
descriptor_fields_sequence = descriptor_kwargs.pop("field_sequence", None)
if not descriptor_fields_sequence:
descriptor_fields_sequence = list(bot_fields_descriptors.keys())
descriptor_fields_sequence.remove("id")
namespace["bot_entity_descriptor"] = EntityDescriptor(
name = descriptor_name,
class_name = name,
type_ = name,
fields_descriptors = bot_fields_descriptors,
field_sequence = descriptor_fields_sequence,
**descriptor_kwargs)
else:
descriptor_fields_sequence = list(bot_fields_descriptors.keys())
descriptor_fields_sequence.remove("id")
descriptor_name = name.lower()
namespace["bot_entity_descriptor"] = EntityDescriptor(
name = descriptor_name,
class_name = name,
type_ = name,
fields_descriptors = bot_fields_descriptors,
field_sequence = descriptor_fields_sequence)
for field_descriptor in bot_fields_descriptors.values():
field_descriptor.entity_descriptor = namespace["bot_entity_descriptor"]
if "table" not in kwargs:
kwargs["table"] = True
if kwargs["table"] == True:
entity_metadata = EntityMetadata()
entity_metadata.entity_descriptors[descriptor_name] = namespace["bot_entity_descriptor"]
if "__annotations__" in namespace:
namespace["__annotations__"]["entity_metadata"] = ClassVar[EntityMetadata]
else:
namespace["__annotations__"] = {"entity_metadata": ClassVar[EntityMetadata]}
namespace["entity_metadata"] = entity_metadata
type_ = super().__new__(mcs, name, bases, namespace, **kwargs)
if name in mcs.__future_references__:
for field_descriptor in mcs.__future_references__[name]:
type_origin = get_origin(field_descriptor.type_)
field_descriptor.type_base = type_
field_descriptor.type_ = (list[type_] if get_origin(field_descriptor.type_) == list else
Optional[type_] if type_origin == Union and isinstance(get_args(field_descriptor.type_)[0], ForwardRef) else
type_ | None if type_origin == UnionType else
type_)
setattr(namespace["bot_entity_descriptor"], "type_", type_)
return type_
class BotEntity[CreateSchemaType: BaseModel,
UpdateSchemaType: BaseModel](SQLModel,
metaclass = BotEntityMetaclass,
table = False):
bot_entity_descriptor: ClassVar[EntityDescriptor]
entity_metadata: ClassVar[EntityMetadata]
id: int = Field(
primary_key = True,
sa_type = BIGINT)
name: str
@classmethod
@session_dep
async def get(cls, *,
session: AsyncSession | None = None,
id: int):
return await session.get(cls, id, populate_existing = True)
@classmethod
@session_dep
async def get_count(cls, *,
session: AsyncSession | None = None,
filter: str = None) -> int:
select_statement = select(func.count()).select_from(cls)
if filter:
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
return await session.scalar(select_statement)
@classmethod
@session_dep
async def get_multi(cls, *,
session: AsyncSession | None = None,
order_by = None,
filter:str = None,
skip: int = 0,
limit: int = None):
select_statement = select(cls).offset(skip)
if limit:
select_statement = select_statement.limit(limit)
if filter:
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
if order_by:
select_statement = select_statement.order_by(order_by)
return (await session.exec(select_statement)).all()
@classmethod
@session_dep
async def create(cls, *,
session: AsyncSession | None = None,
obj_in: CreateSchemaType,
commit: bool = False):
if isinstance(obj_in, cls):
obj = obj_in
else:
obj = cls(**obj_in.model_dump())
session.add(obj)
if commit:
await session.commit()
return obj
@classmethod
@session_dep
async def update(cls, *,
session: AsyncSession | None = None,
id: int,
obj_in: UpdateSchemaType,
commit: bool = False):
obj = await session.get(cls, id)
if obj:
obj_data = obj.model_dump()
update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(obj, field, update_data[field])
session.add(obj)
if commit:
await session.commit()
return obj
return None
@classmethod
@session_dep
async def remove(cls, *,
session: AsyncSession | None = None,
id: int,
commit: bool = False):
obj = await session.get(cls, id)
if obj:
await session.delete(obj)
if commit:
await session.commit()
return obj
return None