upd project structure
This commit is contained in:
432
src/qbot/model/bot_entity.py
Normal file
432
src/qbot/model/bot_entity.py
Normal file
@@ -0,0 +1,432 @@
|
||||
from types import NoneType, UnionType
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
ForwardRef,
|
||||
Optional,
|
||||
Self,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
TYPE_CHECKING,
|
||||
dataclass_transform,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import SQLModel, BigInteger, Field, select, func, column, col
|
||||
from sqlmodel.main import FieldInfo
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
from sqlmodel.main import SQLModelMetaclass, RelationshipInfo
|
||||
|
||||
|
||||
from .descriptors import EntityDescriptor, EntityField, FieldDescriptor, Filter
|
||||
from .entity_metadata import EntityMetadata
|
||||
from . import session_dep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .user import UserBase
|
||||
|
||||
|
||||
@dataclass_transform(
|
||||
kw_only_default=True,
|
||||
field_specifiers=(Field, FieldInfo, EntityField, FieldDescriptor),
|
||||
)
|
||||
class BotEntityMetaclass(SQLModelMetaclass):
|
||||
_future_references = {}
|
||||
|
||||
def __new__(mcs, name, bases, namespace, **kwargs):
|
||||
bot_fields_descriptors = {}
|
||||
|
||||
if bases:
|
||||
bot_entity_descriptor = bases[0].__dict__.get("bot_entity_descriptor")
|
||||
bot_fields_descriptors = (
|
||||
{
|
||||
key: FieldDescriptor(**value.__dict__.copy())
|
||||
for key, value in bot_entity_descriptor.fields_descriptors.items()
|
||||
}
|
||||
if bot_entity_descriptor
|
||||
else {}
|
||||
)
|
||||
|
||||
if "__annotations__" in namespace:
|
||||
for annotation in namespace["__annotations__"]:
|
||||
if annotation in ["bot_entity_descriptor", "entity_metadata"]:
|
||||
continue
|
||||
|
||||
attribute_value = namespace.get(annotation)
|
||||
|
||||
if isinstance(attribute_value, RelationshipInfo):
|
||||
continue
|
||||
|
||||
descriptor_kwargs = {}
|
||||
descriptor_name = annotation
|
||||
|
||||
if attribute_value:
|
||||
if isinstance(attribute_value, EntityField):
|
||||
descriptor_kwargs = attribute_value.__dict__.copy()
|
||||
sm_descriptor = descriptor_kwargs.pop("sm_descriptor", None)
|
||||
|
||||
if sm_descriptor:
|
||||
namespace[annotation] = sm_descriptor
|
||||
else:
|
||||
namespace.pop(annotation)
|
||||
|
||||
descriptor_name = descriptor_kwargs.pop("name") or annotation
|
||||
|
||||
type_ = namespace["__annotations__"][annotation]
|
||||
|
||||
type_origin = get_origin(type_)
|
||||
|
||||
field_descriptor = FieldDescriptor(
|
||||
name=descriptor_name,
|
||||
field_name=annotation,
|
||||
type_=type_,
|
||||
type_base=type_,
|
||||
**descriptor_kwargs,
|
||||
)
|
||||
|
||||
is_list = False
|
||||
is_optional = False
|
||||
if type_origin is list:
|
||||
field_descriptor.is_list = is_list = True
|
||||
field_descriptor.type_base = type_ = get_args(type_)[0]
|
||||
|
||||
if type_origin is Union:
|
||||
args = get_args(type_)
|
||||
if isinstance(args[0], ForwardRef):
|
||||
field_descriptor.is_optional = is_optional = True
|
||||
field_descriptor.type_base = type_ = args[0].__forward_arg__
|
||||
elif args[1] is NoneType:
|
||||
field_descriptor.is_optional = is_optional = True
|
||||
field_descriptor.type_base = type_ = args[0]
|
||||
|
||||
if type_origin is UnionType and get_args(type_)[1] is NoneType:
|
||||
field_descriptor.is_optional = is_optional = True
|
||||
field_descriptor.type_base = type_ = get_args(type_)[0]
|
||||
|
||||
if isinstance(type_, str):
|
||||
type_not_found = True
|
||||
for (
|
||||
entity_descriptor
|
||||
) in EntityMetadata().entity_descriptors.values():
|
||||
if type_ == entity_descriptor.class_name:
|
||||
field_descriptor.type_base = entity_descriptor.type_
|
||||
field_descriptor.type_ = (
|
||||
list[entity_descriptor.type_]
|
||||
if is_list
|
||||
else (
|
||||
Optional[entity_descriptor.type_]
|
||||
if type_origin == Union and is_optional
|
||||
else (
|
||||
entity_descriptor.type_ | None
|
||||
if (type_origin == UnionType and is_optional)
|
||||
else entity_descriptor.type_
|
||||
)
|
||||
)
|
||||
)
|
||||
type_not_found = False
|
||||
break
|
||||
if type_not_found:
|
||||
if type_ in mcs._future_references:
|
||||
mcs._future_references[type_].append(field_descriptor)
|
||||
else:
|
||||
mcs._future_references[type_] = [field_descriptor]
|
||||
|
||||
bot_fields_descriptors[descriptor_name] = field_descriptor
|
||||
|
||||
descriptor_name = name
|
||||
|
||||
if "bot_entity_descriptor" in namespace:
|
||||
entity_descriptor = namespace.pop("bot_entity_descriptor")
|
||||
descriptor_kwargs: dict = entity_descriptor.__dict__.copy()
|
||||
descriptor_name = descriptor_kwargs.pop("name", None)
|
||||
descriptor_name = descriptor_name or name.lower()
|
||||
namespace["bot_entity_descriptor"] = EntityDescriptor(
|
||||
name=descriptor_name,
|
||||
class_name=name,
|
||||
type_=name,
|
||||
fields_descriptors=bot_fields_descriptors,
|
||||
**descriptor_kwargs,
|
||||
)
|
||||
else:
|
||||
descriptor_name = name.lower()
|
||||
namespace["bot_entity_descriptor"] = EntityDescriptor(
|
||||
name=descriptor_name,
|
||||
class_name=name,
|
||||
type_=name,
|
||||
fields_descriptors=bot_fields_descriptors,
|
||||
)
|
||||
|
||||
descriptor_fields_sequence = [
|
||||
key
|
||||
for key, val in bot_fields_descriptors.items()
|
||||
if not (val.is_optional or val.name == "id")
|
||||
]
|
||||
|
||||
entity_descriptor: EntityDescriptor = namespace["bot_entity_descriptor"]
|
||||
|
||||
if entity_descriptor.default_form.edit_field_sequence is None:
|
||||
entity_descriptor.default_form.edit_field_sequence = (
|
||||
descriptor_fields_sequence
|
||||
)
|
||||
|
||||
for form in entity_descriptor.forms.values():
|
||||
if form.edit_field_sequence is None:
|
||||
form.edit_field_sequence = descriptor_fields_sequence
|
||||
|
||||
for field_descriptor in bot_fields_descriptors.values():
|
||||
field_descriptor.entity_descriptor = namespace["bot_entity_descriptor"]
|
||||
|
||||
if "table" not in kwargs:
|
||||
kwargs["table"] = True
|
||||
|
||||
if kwargs["table"]:
|
||||
entity_metadata = EntityMetadata()
|
||||
entity_metadata.entity_descriptors[descriptor_name] = namespace[
|
||||
"bot_entity_descriptor"
|
||||
]
|
||||
|
||||
if "__annotations__" in namespace:
|
||||
namespace["__annotations__"]["entity_metadata"] = ClassVar[
|
||||
EntityMetadata
|
||||
]
|
||||
else:
|
||||
namespace["__annotations__"] = {
|
||||
"entity_metadata": ClassVar[EntityMetadata]
|
||||
}
|
||||
|
||||
namespace["entity_metadata"] = entity_metadata
|
||||
|
||||
type_ = super().__new__(mcs, name, bases, namespace, **kwargs)
|
||||
|
||||
if name in mcs._future_references:
|
||||
for field_descriptor in mcs._future_references[name]:
|
||||
type_origin = get_origin(field_descriptor.type_)
|
||||
field_descriptor.type_base = type_
|
||||
|
||||
field_descriptor.type_ = (
|
||||
list[type_]
|
||||
if type_origin is list
|
||||
else (
|
||||
Optional[type_]
|
||||
if type_origin == Union
|
||||
and isinstance(get_args(field_descriptor.type_)[0], ForwardRef)
|
||||
else type_ | None
|
||||
if type_origin == UnionType
|
||||
else type_
|
||||
)
|
||||
)
|
||||
|
||||
setattr(namespace["bot_entity_descriptor"], "type_", type_)
|
||||
|
||||
return type_
|
||||
|
||||
|
||||
class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
|
||||
SQLModel, metaclass=BotEntityMetaclass, table=False
|
||||
):
|
||||
bot_entity_descriptor: ClassVar[EntityDescriptor]
|
||||
entity_metadata: ClassVar[EntityMetadata]
|
||||
|
||||
id: int = EntityField(
|
||||
sm_descriptor=Field(primary_key=True, sa_type=BigInteger), is_visible=False
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@session_dep
|
||||
async def get(cls, *, session: AsyncSession | None = None, id: int):
|
||||
return await session.get(cls, id, populate_existing=True)
|
||||
|
||||
@classmethod
|
||||
def _static_filter_condition(
|
||||
cls, select_statement: SelectOfScalar[Self], static_filter: list[Filter]
|
||||
):
|
||||
for sfilt in static_filter:
|
||||
column = getattr(cls, sfilt.field_name)
|
||||
if sfilt.operator == "==":
|
||||
condition = column.__eq__(sfilt.value)
|
||||
elif sfilt.operator == "!=":
|
||||
condition = column.__ne__(sfilt.value)
|
||||
elif sfilt.operator == "<":
|
||||
condition = column.__lt__(sfilt.value)
|
||||
elif sfilt.operator == "<=":
|
||||
condition = column.__le__(sfilt.value)
|
||||
elif sfilt.operator == ">":
|
||||
condition = column.__gt__(sfilt.value)
|
||||
elif sfilt.operator == ">=":
|
||||
condition = column.__ge__(sfilt.value)
|
||||
elif sfilt.operator == "ilike":
|
||||
condition = col(column).ilike(f"%{sfilt.value}%")
|
||||
elif sfilt.operator == "like":
|
||||
condition = col(column).like(f"%{sfilt.value}%")
|
||||
elif sfilt.operator == "in":
|
||||
condition = col(column).in_(sfilt.value)
|
||||
elif sfilt.operator == "not in":
|
||||
condition = col(column).notin_(sfilt.value)
|
||||
elif sfilt.operator == "is none":
|
||||
condition = col(column).is_(None)
|
||||
elif sfilt.operator == "is not none":
|
||||
condition = col(column).isnot(None)
|
||||
elif sfilt.operator == "contains":
|
||||
condition = sfilt.value == col(column).any_()
|
||||
else:
|
||||
condition = None
|
||||
if condition is not None:
|
||||
select_statement = select_statement.where(condition)
|
||||
return select_statement
|
||||
|
||||
@classmethod
|
||||
def _filter_condition(
|
||||
cls,
|
||||
select_statement: SelectOfScalar[Self],
|
||||
filter: str,
|
||||
filter_fields: list[str],
|
||||
):
|
||||
condition = None
|
||||
for field in filter_fields:
|
||||
if condition is not None:
|
||||
condition = condition | (column(field).ilike(f"%{filter}%"))
|
||||
else:
|
||||
condition = column(field).ilike(f"%{filter}%")
|
||||
return select_statement.where(condition)
|
||||
|
||||
@classmethod
|
||||
@session_dep
|
||||
async def get_count(
|
||||
cls,
|
||||
*,
|
||||
session: AsyncSession | None = None,
|
||||
static_filter: list[Filter] | Any = None,
|
||||
filter: str = None,
|
||||
filter_fields: list[str] = None,
|
||||
ext_filter: Any = None,
|
||||
user: "UserBase" = None,
|
||||
) -> int:
|
||||
select_statement = select(func.count()).select_from(cls)
|
||||
if static_filter:
|
||||
if isinstance(static_filter, list):
|
||||
select_statement = cls._static_filter_condition(
|
||||
select_statement, static_filter
|
||||
)
|
||||
else:
|
||||
select_statement = select_statement.where(static_filter)
|
||||
if filter and filter_fields:
|
||||
select_statement = cls._filter_condition(
|
||||
select_statement, filter, filter_fields
|
||||
)
|
||||
if ext_filter:
|
||||
select_statement = select_statement.where(ext_filter)
|
||||
if user:
|
||||
select_statement = cls._ownership_condition(select_statement, user)
|
||||
return await session.scalar(select_statement)
|
||||
|
||||
@classmethod
|
||||
@session_dep
|
||||
async def get_multi(
|
||||
cls,
|
||||
*,
|
||||
session: AsyncSession | None = None,
|
||||
order_by=None,
|
||||
static_filter: list[Filter] | Any = None,
|
||||
filter: str = None,
|
||||
filter_fields: list[str] = None,
|
||||
ext_filter: Any = None,
|
||||
user: "UserBase" = None,
|
||||
skip: int = 0,
|
||||
limit: int = None,
|
||||
):
|
||||
select_statement = select(cls).offset(skip)
|
||||
if limit:
|
||||
select_statement = select_statement.limit(limit)
|
||||
if static_filter is not None:
|
||||
if isinstance(static_filter, list):
|
||||
select_statement = cls._static_filter_condition(
|
||||
select_statement, static_filter
|
||||
)
|
||||
else:
|
||||
select_statement = select_statement.where(static_filter)
|
||||
if filter and filter_fields:
|
||||
select_statement = cls._filter_condition(
|
||||
select_statement, filter, filter_fields
|
||||
)
|
||||
if ext_filter is not None:
|
||||
select_statement = select_statement.where(ext_filter)
|
||||
if user:
|
||||
select_statement = cls._ownership_condition(select_statement, user)
|
||||
if order_by is not None:
|
||||
select_statement = select_statement.order_by(order_by)
|
||||
return (await session.exec(select_statement)).all()
|
||||
|
||||
@classmethod
|
||||
def _ownership_condition(
|
||||
cls, select_statement: SelectOfScalar[Self], user: "UserBase"
|
||||
):
|
||||
if cls.bot_entity_descriptor.ownership_fields:
|
||||
condition = None
|
||||
for role in user.roles:
|
||||
if role in cls.bot_entity_descriptor.ownership_fields:
|
||||
owner_col = column(cls.bot_entity_descriptor.ownership_fields[role])
|
||||
if condition is not None:
|
||||
condition = condition | (owner_col == user.id)
|
||||
else:
|
||||
condition = owner_col == user.id
|
||||
else:
|
||||
condition = None
|
||||
break
|
||||
if condition is not None:
|
||||
return select_statement.where(condition)
|
||||
return select_statement
|
||||
|
||||
@classmethod
|
||||
@session_dep
|
||||
async def create(
|
||||
cls,
|
||||
*,
|
||||
session: AsyncSession | None = None,
|
||||
obj_in: CreateSchemaType,
|
||||
commit: bool = False,
|
||||
):
|
||||
if isinstance(obj_in, cls):
|
||||
obj = obj_in
|
||||
else:
|
||||
obj = cls(**obj_in.model_dump())
|
||||
session.add(obj)
|
||||
if commit:
|
||||
await session.commit()
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
@session_dep
|
||||
async def update(
|
||||
cls,
|
||||
*,
|
||||
session: AsyncSession | None = None,
|
||||
id: int,
|
||||
obj_in: UpdateSchemaType,
|
||||
commit: bool = False,
|
||||
):
|
||||
obj = await session.get(cls, id)
|
||||
if obj:
|
||||
obj_data = obj.model_dump()
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
for field in obj_data:
|
||||
if field in update_data:
|
||||
setattr(obj, field, update_data[field])
|
||||
session.add(obj)
|
||||
if commit:
|
||||
await session.commit()
|
||||
return obj
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
@session_dep
|
||||
async def remove(
|
||||
cls, *, session: AsyncSession | None = None, id: int, commit: bool = False
|
||||
):
|
||||
obj = await session.get(cls, id)
|
||||
if obj:
|
||||
await session.delete(obj)
|
||||
if commit:
|
||||
await session.commit()
|
||||
return obj
|
||||
return None
|
||||
Reference in New Issue
Block a user