This commit is contained in:
Alexander Kalinovsky
2025-01-04 12:00:12 +01:00
commit 6dbe0536ca
94 changed files with 3467 additions and 0 deletions

241
model/bot_entity.py Normal file
View File

@@ -0,0 +1,241 @@
from functools import wraps
from typing import ClassVar, cast, get_args, get_origin
from pydantic import BaseModel
from sqlmodel import SQLModel, BIGINT, Field, select, func
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import SQLModelMetaclass, RelationshipInfo
from .descriptors import EntityDescriptor, EntityField, EntityFieldDescriptor
from .entity_metadata import EntityMetadata
from . import session_dep
class BotEntityMetaclass(SQLModelMetaclass):
__future_references__ = {}
def __new__(mcs, name, bases, namespace, **kwargs):
bot_fields_descriptors = {}
if bases:
bot_entity_descriptor = bases[0].__dict__.get('bot_entity_descriptor')
bot_fields_descriptors = {key: EntityFieldDescriptor(**value.__dict__.copy())
for key, value in bot_entity_descriptor.fields_descriptors.items()} if bot_entity_descriptor else {}
if '__annotations__' in namespace:
for annotation in namespace['__annotations__']:
if annotation in ["bot_entity_descriptor", "entity_metadata"]:
continue
attribute_value = namespace.get(annotation)
if isinstance(attribute_value, RelationshipInfo):
continue
descriptor_kwargs = {}
descriptor_name = annotation
if attribute_value:
if isinstance(attribute_value, EntityField):
descriptor_kwargs = attribute_value.__dict__.copy()
sm_descriptor = descriptor_kwargs.pop("sm_descriptor", None)
if sm_descriptor:
namespace[annotation] = sm_descriptor
else:
namespace.pop(annotation)
descriptor_name = descriptor_kwargs.pop("name") or annotation
type_ = namespace['__annotations__'][annotation]
field_descriptor = EntityFieldDescriptor(
name = descriptor_name,
field_name = annotation,
type_ = type_,
**descriptor_kwargs)
type_origin = get_origin(type_)
is_list = False
if type_origin == list:
is_list = True
type_ = get_args(type_)[0]
if isinstance(type_, str):
type_not_found = True
for entity_descriptor in EntityMetadata().entity_descriptors.values():
if type_ == entity_descriptor.class_name:
field_descriptor.type_ = list[entity_descriptor.type_] if is_list else entity_descriptor.type_
type_not_found = False
break
if type_not_found:
if type_ in mcs.__future_references__:
mcs.__future_references__[type_].append(field_descriptor)
else:
mcs.__future_references__[type_] = [field_descriptor]
bot_fields_descriptors[descriptor_name] = field_descriptor
descriptor_name = name
if "bot_entity_descriptor" in namespace:
entity_descriptor = namespace.pop("bot_entity_descriptor")
descriptor_kwargs: dict = entity_descriptor.__dict__.copy()
descriptor_name = descriptor_kwargs.pop("name", None)
descriptor_name = descriptor_name or name.lower()
descriptor_fields_sequence = descriptor_kwargs.pop("field_sequence", None)
if not descriptor_fields_sequence:
descriptor_fields_sequence = list(bot_fields_descriptors.keys())
descriptor_fields_sequence.remove("id")
namespace["bot_entity_descriptor"] = EntityDescriptor(
name = descriptor_name,
class_name = name,
type_ = name,
fields_descriptors = bot_fields_descriptors,
field_sequence = descriptor_fields_sequence,
**descriptor_kwargs)
else:
descriptor_fields_sequence = list(bot_fields_descriptors.keys())
descriptor_fields_sequence.remove("id")
descriptor_name = name.lower()
namespace["bot_entity_descriptor"] = EntityDescriptor(
name = descriptor_name,
class_name = name,
type_ = name,
fields_descriptors = bot_fields_descriptors,
field_sequence = descriptor_fields_sequence)
for field_descriptor in bot_fields_descriptors.values():
field_descriptor.entity_descriptor = namespace["bot_entity_descriptor"]
if "table" not in kwargs:
kwargs["table"] = True
if kwargs["table"] == True:
entity_metadata = EntityMetadata()
entity_metadata.entity_descriptors[descriptor_name] = namespace["bot_entity_descriptor"]
if "__annotations__" in namespace:
namespace["__annotations__"]["entity_metadata"] = ClassVar[EntityMetadata]
else:
namespace["__annotations__"] = {"entity_metadata": ClassVar[EntityMetadata]}
namespace["entity_metadata"] = entity_metadata
type_ = super().__new__(mcs, name, bases, namespace, **kwargs)
if name in mcs.__future_references__:
for field_descriptor in mcs.__future_references__[name]:
field_descriptor.type_ = list[type_] if get_origin(field_descriptor.type_) == list else type_
a = field_descriptor
setattr(namespace["bot_entity_descriptor"], "type_", type_)
return type_
class BotEntity[CreateSchemaType: BaseModel,
UpdateSchemaType: BaseModel](SQLModel,
metaclass = BotEntityMetaclass,
table = False):
bot_entity_descriptor: ClassVar[EntityDescriptor]
entity_metadata: ClassVar[EntityMetadata]
id: int = Field(
primary_key = True,
sa_type = BIGINT)
name: str
@classmethod
@session_dep
async def get(cls, *,
session: AsyncSession | None = None,
id: int):
return await session.get(cls, id)
@classmethod
@session_dep
async def get_count(cls, *,
session: AsyncSession | None = None) -> int:
return await session.scalar(select(func.count()).select_from(cls))
@classmethod
@session_dep
async def get_multi(cls, *,
session: AsyncSession | None = None,
order_by = None,
skip: int = 0,
limit: int = None):
select_statement = select(cls).offset(skip)
if limit:
select_statement = select_statement.limit(limit)
if order_by:
select_statement = select_statement.order_by(order_by)
return (await session.exec(select_statement)).all()
@classmethod
@session_dep
async def create(cls, *,
session: AsyncSession | None = None,
obj_in: CreateSchemaType,
commit: bool = False):
if isinstance(obj_in, cls):
obj = obj_in
else:
obj = cls(**obj_in.model_dump())
session.add(obj)
if commit:
await session.commit()
return obj
@classmethod
@session_dep
async def update(cls, *,
session: AsyncSession | None = None,
id: int,
obj_in: UpdateSchemaType,
commit: bool = False):
obj = await session.get(cls, id)
if obj:
obj_data = obj.model_dump()
update_data = obj_in.model_dump(exclude_unset=True)
for field in obj_data:
if field in update_data:
setattr(obj, field, update_data[field])
session.add(obj)
if commit:
await session.commit()
return obj
return None
@classmethod
@session_dep
async def remove(cls, *,
session: AsyncSession | None = None,
id: int,
commit: bool = False):
obj = await session.get(cls, id)
if obj:
await session.delete(obj)
if commit:
await session.commit()
return obj
return None