267 lines
10 KiB
Python
267 lines
10 KiB
Python
from functools import wraps
|
|
from types import NoneType, UnionType
|
|
from typing import ClassVar, ForwardRef, Optional, Union, cast, get_args, get_origin
|
|
from pydantic import BaseModel
|
|
from sqlmodel import SQLModel, BIGINT, Field, select, func, column
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
from sqlmodel.main import SQLModelMetaclass, RelationshipInfo
|
|
|
|
from .descriptors import EntityDescriptor, EntityField, EntityFieldDescriptor
|
|
from .entity_metadata import EntityMetadata
|
|
from . import session_dep
|
|
|
|
|
|
class BotEntityMetaclass(SQLModelMetaclass):
|
|
|
|
__future_references__ = {}
|
|
|
|
def __new__(mcs, name, bases, namespace, **kwargs):
|
|
|
|
bot_fields_descriptors = {}
|
|
|
|
if bases:
|
|
bot_entity_descriptor = bases[0].__dict__.get('bot_entity_descriptor')
|
|
bot_fields_descriptors = {key: EntityFieldDescriptor(**value.__dict__.copy())
|
|
for key, value in bot_entity_descriptor.fields_descriptors.items()} if bot_entity_descriptor else {}
|
|
|
|
if '__annotations__' in namespace:
|
|
|
|
for annotation in namespace['__annotations__']:
|
|
|
|
if annotation in ["bot_entity_descriptor", "entity_metadata"]:
|
|
continue
|
|
|
|
attribute_value = namespace.get(annotation)
|
|
|
|
if isinstance(attribute_value, RelationshipInfo):
|
|
continue
|
|
|
|
descriptor_kwargs = {}
|
|
descriptor_name = annotation
|
|
|
|
if attribute_value:
|
|
|
|
if isinstance(attribute_value, EntityField):
|
|
descriptor_kwargs = attribute_value.__dict__.copy()
|
|
sm_descriptor = descriptor_kwargs.pop("sm_descriptor", None)
|
|
|
|
if sm_descriptor:
|
|
namespace[annotation] = sm_descriptor
|
|
else:
|
|
namespace.pop(annotation)
|
|
|
|
descriptor_name = descriptor_kwargs.pop("name") or annotation
|
|
|
|
type_ = namespace['__annotations__'][annotation]
|
|
|
|
type_origin = get_origin(type_)
|
|
|
|
field_descriptor = EntityFieldDescriptor(
|
|
name = descriptor_name,
|
|
field_name = annotation,
|
|
type_ = type_,
|
|
type_base = type_,
|
|
**descriptor_kwargs)
|
|
|
|
is_list = False
|
|
if type_origin == list:
|
|
field_descriptor.is_list = is_list = True
|
|
field_descriptor.type_base = type_ = get_args(type_)[0]
|
|
|
|
if type_origin == Union and isinstance(get_args(type_)[0], ForwardRef):
|
|
field_descriptor.is_optional = True
|
|
field_descriptor.type_base = type_ = get_args(type_)[0].__forward_arg__
|
|
|
|
if type_origin == UnionType and get_args(type_)[1] == NoneType:
|
|
field_descriptor.is_optional = True
|
|
field_descriptor.type_base = type_ = get_args(type_)[0]
|
|
|
|
if isinstance(type_, str):
|
|
type_not_found = True
|
|
for entity_descriptor in EntityMetadata().entity_descriptors.values():
|
|
if type_ == entity_descriptor.class_name:
|
|
field_descriptor.type_ = (list[entity_descriptor.type_] if is_list
|
|
else Optional[entity_descriptor.type_] if type_origin == Optional
|
|
else entity_descriptor.type_ | None if (type_origin == UnionType and get_args(type_)[1] == NoneType)
|
|
else entity_descriptor.type_)
|
|
type_not_found = False
|
|
break
|
|
if type_not_found:
|
|
if type_ in mcs.__future_references__:
|
|
mcs.__future_references__[type_].append(field_descriptor)
|
|
else:
|
|
mcs.__future_references__[type_] = [field_descriptor]
|
|
|
|
bot_fields_descriptors[descriptor_name] = field_descriptor
|
|
|
|
descriptor_name = name
|
|
|
|
if "bot_entity_descriptor" in namespace:
|
|
entity_descriptor = namespace.pop("bot_entity_descriptor")
|
|
descriptor_kwargs: dict = entity_descriptor.__dict__.copy()
|
|
descriptor_name = descriptor_kwargs.pop("name", None)
|
|
descriptor_name = descriptor_name or name.lower()
|
|
descriptor_fields_sequence = descriptor_kwargs.pop("field_sequence", None)
|
|
if not descriptor_fields_sequence:
|
|
descriptor_fields_sequence = list(bot_fields_descriptors.keys())
|
|
descriptor_fields_sequence.remove("id")
|
|
namespace["bot_entity_descriptor"] = EntityDescriptor(
|
|
name = descriptor_name,
|
|
class_name = name,
|
|
type_ = name,
|
|
fields_descriptors = bot_fields_descriptors,
|
|
field_sequence = descriptor_fields_sequence,
|
|
**descriptor_kwargs)
|
|
else:
|
|
descriptor_fields_sequence = list(bot_fields_descriptors.keys())
|
|
descriptor_fields_sequence.remove("id")
|
|
descriptor_name = name.lower()
|
|
namespace["bot_entity_descriptor"] = EntityDescriptor(
|
|
name = descriptor_name,
|
|
class_name = name,
|
|
type_ = name,
|
|
fields_descriptors = bot_fields_descriptors,
|
|
field_sequence = descriptor_fields_sequence)
|
|
|
|
for field_descriptor in bot_fields_descriptors.values():
|
|
field_descriptor.entity_descriptor = namespace["bot_entity_descriptor"]
|
|
|
|
if "table" not in kwargs:
|
|
kwargs["table"] = True
|
|
|
|
if kwargs["table"] == True:
|
|
entity_metadata = EntityMetadata()
|
|
entity_metadata.entity_descriptors[descriptor_name] = namespace["bot_entity_descriptor"]
|
|
|
|
if "__annotations__" in namespace:
|
|
namespace["__annotations__"]["entity_metadata"] = ClassVar[EntityMetadata]
|
|
else:
|
|
namespace["__annotations__"] = {"entity_metadata": ClassVar[EntityMetadata]}
|
|
|
|
namespace["entity_metadata"] = entity_metadata
|
|
|
|
type_ = super().__new__(mcs, name, bases, namespace, **kwargs)
|
|
|
|
if name in mcs.__future_references__:
|
|
for field_descriptor in mcs.__future_references__[name]:
|
|
type_origin = get_origin(field_descriptor.type_)
|
|
field_descriptor.type_base = type_
|
|
field_descriptor.type_ = (list[type_] if get_origin(field_descriptor.type_) == list else
|
|
Optional[type_] if type_origin == Union and isinstance(get_args(field_descriptor.type_)[0], ForwardRef) else
|
|
type_ | None if type_origin == UnionType else
|
|
type_)
|
|
|
|
setattr(namespace["bot_entity_descriptor"], "type_", type_)
|
|
|
|
return type_
|
|
|
|
|
|
class BotEntity[CreateSchemaType: BaseModel,
|
|
UpdateSchemaType: BaseModel](SQLModel,
|
|
metaclass = BotEntityMetaclass,
|
|
table = False):
|
|
|
|
bot_entity_descriptor: ClassVar[EntityDescriptor]
|
|
entity_metadata: ClassVar[EntityMetadata]
|
|
|
|
id: int = Field(
|
|
primary_key = True,
|
|
sa_type = BIGINT)
|
|
|
|
name: str
|
|
|
|
|
|
@classmethod
|
|
@session_dep
|
|
async def get(cls, *,
|
|
session: AsyncSession | None = None,
|
|
id: int):
|
|
|
|
return await session.get(cls, id, populate_existing = True)
|
|
|
|
|
|
@classmethod
|
|
@session_dep
|
|
async def get_count(cls, *,
|
|
session: AsyncSession | None = None,
|
|
filter: str = None) -> int:
|
|
|
|
select_statement = select(func.count()).select_from(cls)
|
|
if filter:
|
|
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
|
|
return await session.scalar(select_statement)
|
|
|
|
|
|
@classmethod
|
|
@session_dep
|
|
async def get_multi(cls, *,
|
|
session: AsyncSession | None = None,
|
|
order_by = None,
|
|
filter:str = None,
|
|
skip: int = 0,
|
|
limit: int = None):
|
|
|
|
select_statement = select(cls).offset(skip)
|
|
if limit:
|
|
select_statement = select_statement.limit(limit)
|
|
if filter:
|
|
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
|
|
if order_by:
|
|
select_statement = select_statement.order_by(order_by)
|
|
return (await session.exec(select_statement)).all()
|
|
|
|
|
|
@classmethod
|
|
@session_dep
|
|
async def create(cls, *,
|
|
session: AsyncSession | None = None,
|
|
obj_in: CreateSchemaType,
|
|
commit: bool = False):
|
|
|
|
if isinstance(obj_in, cls):
|
|
obj = obj_in
|
|
else:
|
|
obj = cls(**obj_in.model_dump())
|
|
session.add(obj)
|
|
if commit:
|
|
await session.commit()
|
|
return obj
|
|
|
|
|
|
@classmethod
|
|
@session_dep
|
|
async def update(cls, *,
|
|
session: AsyncSession | None = None,
|
|
id: int,
|
|
obj_in: UpdateSchemaType,
|
|
commit: bool = False):
|
|
|
|
obj = await session.get(cls, id)
|
|
if obj:
|
|
obj_data = obj.model_dump()
|
|
update_data = obj_in.model_dump(exclude_unset=True)
|
|
for field in obj_data:
|
|
if field in update_data:
|
|
setattr(obj, field, update_data[field])
|
|
session.add(obj)
|
|
if commit:
|
|
await session.commit()
|
|
return obj
|
|
return None
|
|
|
|
|
|
@classmethod
|
|
@session_dep
|
|
async def remove(cls, *,
|
|
session: AsyncSession | None = None,
|
|
id: int,
|
|
commit: bool = False):
|
|
obj = await session.get(cls, id)
|
|
if obj:
|
|
await session.delete(obj)
|
|
if commit:
|
|
await session.commit()
|
|
return obj
|
|
return None
|
|
|