add ruff format, ruff check, time_picker, project structure and imports reorganized

This commit is contained in:
Alexander Kalinovsky
2025-01-21 23:50:19 +01:00
parent ced47ac993
commit 9dd0708a5b
58 changed files with 3690 additions and 2583 deletions

View File

@@ -1,33 +1,49 @@
from functools import wraps
from types import NoneType, UnionType
from typing import ClassVar, ForwardRef, Optional, Union, cast, get_args, get_origin
from typing import (
Any,
ClassVar,
ForwardRef,
Optional,
Self,
Union,
get_args,
get_origin,
TYPE_CHECKING,
)
from pydantic import BaseModel
from sqlmodel import SQLModel, BIGINT, Field, select, func, column
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
from sqlmodel.main import SQLModelMetaclass, RelationshipInfo
from .descriptors import EntityDescriptor, EntityField, EntityFieldDescriptor
from .descriptors import EntityDescriptor, EntityField, EntityFieldDescriptor, Filter
from .entity_metadata import EntityMetadata
from . import session_dep
if TYPE_CHECKING:
from .user import UserBase
class BotEntityMetaclass(SQLModelMetaclass):
__future_references__ = {}
def __new__(mcs, name, bases, namespace, **kwargs):
bot_fields_descriptors = {}
if bases:
bot_entity_descriptor = bases[0].__dict__.get('bot_entity_descriptor')
bot_fields_descriptors = {key: EntityFieldDescriptor(**value.__dict__.copy())
for key, value in bot_entity_descriptor.fields_descriptors.items()} if bot_entity_descriptor else {}
if '__annotations__' in namespace:
for annotation in namespace['__annotations__']:
bot_entity_descriptor = bases[0].__dict__.get("bot_entity_descriptor")
bot_fields_descriptors = (
{
key: EntityFieldDescriptor(**value.__dict__.copy())
for key, value in bot_entity_descriptor.fields_descriptors.items()
}
if bot_entity_descriptor
else {}
)
if "__annotations__" in namespace:
for annotation in namespace["__annotations__"]:
if annotation in ["bot_entity_descriptor", "entity_metadata"]:
continue
@@ -36,41 +52,43 @@ class BotEntityMetaclass(SQLModelMetaclass):
if isinstance(attribute_value, RelationshipInfo):
continue
descriptor_kwargs = {}
descriptor_kwargs = {}
descriptor_name = annotation
if attribute_value:
if isinstance(attribute_value, EntityField):
descriptor_kwargs = attribute_value.__dict__.copy()
sm_descriptor = descriptor_kwargs.pop("sm_descriptor", None)
if sm_descriptor:
namespace[annotation] = sm_descriptor
else:
namespace.pop(annotation)
namespace.pop(annotation)
descriptor_name = descriptor_kwargs.pop("name") or annotation
type_ = namespace['__annotations__'][annotation]
type_ = namespace["__annotations__"][annotation]
type_origin = get_origin(type_)
field_descriptor = EntityFieldDescriptor(
name = descriptor_name,
field_name = annotation,
type_ = type_,
type_base = type_,
**descriptor_kwargs)
name=descriptor_name,
field_name=annotation,
type_=type_,
type_base=type_,
**descriptor_kwargs,
)
is_list = False
if type_origin == list:
if type_origin is list:
field_descriptor.is_list = is_list = True
field_descriptor.type_base = type_ = get_args(type_)[0]
if type_origin == Union and isinstance(get_args(type_)[0], ForwardRef):
field_descriptor.is_optional = True
field_descriptor.type_base = type_ = get_args(type_)[0].__forward_arg__
field_descriptor.type_base = type_ = get_args(type_)[
0
].__forward_arg__
if type_origin == UnionType and get_args(type_)[1] == NoneType:
field_descriptor.is_optional = True
@@ -78,12 +96,26 @@ class BotEntityMetaclass(SQLModelMetaclass):
if isinstance(type_, str):
type_not_found = True
for entity_descriptor in EntityMetadata().entity_descriptors.values():
for (
entity_descriptor
) in EntityMetadata().entity_descriptors.values():
if type_ == entity_descriptor.class_name:
field_descriptor.type_ = (list[entity_descriptor.type_] if is_list
else Optional[entity_descriptor.type_] if type_origin == Optional
else entity_descriptor.type_ | None if (type_origin == UnionType and get_args(type_)[1] == NoneType)
else entity_descriptor.type_)
field_descriptor.type_ = (
list[entity_descriptor.type_]
if is_list
else (
Optional[entity_descriptor.type_]
if type_origin == Optional
else (
entity_descriptor.type_ | None
if (
type_origin == UnionType
and get_args(type_)[1] == NoneType
)
else entity_descriptor.type_
)
)
)
type_not_found = False
break
if type_not_found:
@@ -91,7 +123,7 @@ class BotEntityMetaclass(SQLModelMetaclass):
mcs.__future_references__[type_].append(field_descriptor)
else:
mcs.__future_references__[type_] = [field_descriptor]
bot_fields_descriptors[descriptor_name] = field_descriptor
descriptor_name = name
@@ -101,123 +133,246 @@ class BotEntityMetaclass(SQLModelMetaclass):
descriptor_kwargs: dict = entity_descriptor.__dict__.copy()
descriptor_name = descriptor_kwargs.pop("name", None)
descriptor_name = descriptor_name or name.lower()
descriptor_fields_sequence = descriptor_kwargs.pop("field_sequence", None)
if not descriptor_fields_sequence:
descriptor_fields_sequence = list(bot_fields_descriptors.keys())
descriptor_fields_sequence.remove("id")
namespace["bot_entity_descriptor"] = EntityDescriptor(
name = descriptor_name,
class_name = name,
type_ = name,
fields_descriptors = bot_fields_descriptors,
field_sequence = descriptor_fields_sequence,
**descriptor_kwargs)
name=descriptor_name,
class_name=name,
type_=name,
fields_descriptors=bot_fields_descriptors,
**descriptor_kwargs,
)
else:
descriptor_fields_sequence = list(bot_fields_descriptors.keys())
descriptor_fields_sequence.remove("id")
descriptor_name = name.lower()
namespace["bot_entity_descriptor"] = EntityDescriptor(
name = descriptor_name,
class_name = name,
type_ = name,
fields_descriptors = bot_fields_descriptors,
field_sequence = descriptor_fields_sequence)
name=descriptor_name,
class_name=name,
type_=name,
fields_descriptors=bot_fields_descriptors,
)
descriptor_fields_sequence = [
key
for key, val in bot_fields_descriptors.items()
if not (val.is_optional or val.name == "id")
]
entity_descriptor: EntityDescriptor = namespace["bot_entity_descriptor"]
if entity_descriptor.default_form.edit_field_sequence is None:
entity_descriptor.default_form.edit_field_sequence = (
descriptor_fields_sequence
)
for form in entity_descriptor.forms.values():
if form.edit_field_sequence is None:
form.edit_field_sequence = descriptor_fields_sequence
for field_descriptor in bot_fields_descriptors.values():
field_descriptor.entity_descriptor = namespace["bot_entity_descriptor"]
if "table" not in kwargs:
kwargs["table"] = True
if kwargs["table"] == True:
if kwargs["table"]:
entity_metadata = EntityMetadata()
entity_metadata.entity_descriptors[descriptor_name] = namespace["bot_entity_descriptor"]
entity_metadata.entity_descriptors[descriptor_name] = namespace[
"bot_entity_descriptor"
]
if "__annotations__" in namespace:
namespace["__annotations__"]["entity_metadata"] = ClassVar[EntityMetadata]
namespace["__annotations__"]["entity_metadata"] = ClassVar[
EntityMetadata
]
else:
namespace["__annotations__"] = {"entity_metadata": ClassVar[EntityMetadata]}
namespace["__annotations__"] = {
"entity_metadata": ClassVar[EntityMetadata]
}
namespace["entity_metadata"] = entity_metadata
type_ = super().__new__(mcs, name, bases, namespace, **kwargs)
if name in mcs.__future_references__:
for field_descriptor in mcs.__future_references__[name]:
type_origin = get_origin(field_descriptor.type_)
field_descriptor.type_base = type_
field_descriptor.type_ = (list[type_] if get_origin(field_descriptor.type_) == list else
Optional[type_] if type_origin == Union and isinstance(get_args(field_descriptor.type_)[0], ForwardRef) else
type_ | None if type_origin == UnionType else
type_)
field_descriptor.type_base = type_
field_descriptor.type_ = (
list[type_]
if get_origin(field_descriptor.type_) is list
else (
Optional[type_]
if type_origin == Union
and isinstance(get_args(field_descriptor.type_)[0], ForwardRef)
else type_ | None
if type_origin == UnionType
else type_
)
)
setattr(namespace["bot_entity_descriptor"], "type_", type_)
return type_
class BotEntity[CreateSchemaType: BaseModel,
UpdateSchemaType: BaseModel](SQLModel,
metaclass = BotEntityMetaclass,
table = False):
class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
SQLModel, metaclass=BotEntityMetaclass, table=False
):
bot_entity_descriptor: ClassVar[EntityDescriptor]
entity_metadata: ClassVar[EntityMetadata]
id: int = Field(
primary_key = True,
sa_type = BIGINT)
name: str
id: int = Field(primary_key=True, sa_type=BIGINT)
@classmethod
@session_dep
async def get(cls, *,
session: AsyncSession | None = None,
id: int):
return await session.get(cls, id, populate_existing = True)
async def get(cls, *, session: AsyncSession | None = None, id: int):
return await session.get(cls, id, populate_existing=True)
@classmethod
def _static_fiter_condition(
cls, select_statement: SelectOfScalar[Self], static_filter: list[Filter]
):
for sfilt in static_filter:
if sfilt.operator == "==":
condition = column(sfilt.field_name).__eq__(sfilt.value)
elif sfilt.operator == "!=":
condition = column(sfilt.field_name).__ne__(sfilt.value)
elif sfilt.operator == "<":
condition = column(sfilt.field_name).__lt__(sfilt.value)
elif sfilt.operator == "<=":
condition = column(sfilt.field_name).__le__(sfilt.value)
elif sfilt.operator == ">":
condition = column(sfilt.field_name).__gt__(sfilt.value)
elif sfilt.operator == ">=":
condition = column(sfilt.field_name).__ge__(sfilt.value)
elif sfilt.operator == "ilike":
condition = column(sfilt.field_name).ilike(f"%{sfilt.value}%")
elif sfilt.operator == "like":
condition = column(sfilt.field_name).like(f"%{sfilt.value}%")
elif sfilt.operator == "in":
condition = column(sfilt.field_name).in_(sfilt.value)
elif sfilt.operator == "not in":
condition = column(sfilt.field_name).notin_(sfilt.value)
elif sfilt.operator == "is":
condition = column(sfilt.field_name).is_(None)
elif sfilt.operator == "is not":
condition = column(sfilt.field_name).isnot(None)
else:
condition = None
if condition:
select_statement = select_statement.where(condition)
return select_statement
@classmethod
def _filter_condition(
cls,
select_statement: SelectOfScalar[Self],
filter: str,
filter_fields: list[str],
):
condition = None
for field in filter_fields:
if condition is not None:
condition = condition | (column(field).ilike(f"%{filter}%"))
else:
condition = column(field).ilike(f"%{filter}%")
return select_statement.where(condition)
@classmethod
@session_dep
async def get_count(cls, *,
session: AsyncSession | None = None,
filter: str = None) -> int:
async def get_count(
cls,
*,
session: AsyncSession | None = None,
static_filter: list[Filter] | Any = None,
filter: str = None,
filter_fields: list[str] = None,
ext_filter: Any = None,
user: "UserBase" = None,
) -> int:
select_statement = select(func.count()).select_from(cls)
if filter:
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
if static_filter:
if isinstance(static_filter, list):
select_statement = cls._static_fiter_condition(
select_statement, static_filter
)
else:
select_statement = select_statement.where(static_filter)
if filter and filter_fields:
select_statement = cls._filter_condition(
select_statement, filter, filter_fields
)
if ext_filter:
select_statement = select_statement.where(ext_filter)
if user:
select_statement = cls._ownership_condition(select_statement, user)
return await session.scalar(select_statement)
@classmethod
@session_dep
async def get_multi(cls, *,
session: AsyncSession | None = None,
order_by = None,
filter:str = None,
skip: int = 0,
limit: int = None):
async def get_multi(
cls,
*,
session: AsyncSession | None = None,
order_by=None,
static_filter: list[Filter] | Any = None,
filter: str = None,
filter_fields: list[str] = None,
ext_filter: Any = None,
user: "UserBase" = None,
skip: int = 0,
limit: int = None,
):
select_statement = select(cls).offset(skip)
if limit:
select_statement = select_statement.limit(limit)
if filter:
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
if static_filter:
if isinstance(static_filter, list):
select_statement = cls._static_fiter_condition(
select_statement, static_filter
)
else:
select_statement = select_statement.where(static_filter)
if filter and filter_fields:
select_statement = cls._filter_condition(
select_statement, filter, filter_fields
)
if ext_filter:
select_statement = select_statement.where(ext_filter)
if user:
select_statement = cls._ownership_condition(select_statement, user)
if order_by:
select_statement = select_statement.order_by(order_by)
return (await session.exec(select_statement)).all()
@classmethod
def _ownership_condition(
cls, select_statement: SelectOfScalar[Self], user: "UserBase"
):
if cls.bot_entity_descriptor.ownership_fields:
condition = None
for role in user.roles:
if role in cls.bot_entity_descriptor.ownership_fields:
owner_col = column(cls.bot_entity_descriptor.ownership_fields[role])
if condition is not None:
condition = condition | (owner_col == user.id)
else:
condition = owner_col == user.id
else:
condition = None
break
if condition is not None:
return select_statement.where(condition)
return select_statement
@classmethod
@session_dep
async def create(cls, *,
session: AsyncSession | None = None,
obj_in: CreateSchemaType,
commit: bool = False):
async def create(
cls,
*,
session: AsyncSession | None = None,
obj_in: CreateSchemaType,
commit: bool = False,
):
if isinstance(obj_in, cls):
obj = obj_in
else:
@@ -226,16 +381,17 @@ class BotEntity[CreateSchemaType: BaseModel,
if commit:
await session.commit()
return obj
@classmethod
@session_dep
async def update(cls, *,
session: AsyncSession | None = None,
id: int,
obj_in: UpdateSchemaType,
commit: bool = False):
async def update(
cls,
*,
session: AsyncSession | None = None,
id: int,
obj_in: UpdateSchemaType,
commit: bool = False,
):
obj = await session.get(cls, id)
if obj:
obj_data = obj.model_dump()
@@ -248,14 +404,12 @@ class BotEntity[CreateSchemaType: BaseModel,
await session.commit()
return obj
return None
@classmethod
@session_dep
async def remove(cls, *,
session: AsyncSession | None = None,
id: int,
commit: bool = False):
async def remove(
cls, *, session: AsyncSession | None = None, id: int, commit: bool = False
):
obj = await session.get(cls, id)
if obj:
await session.delete(obj)
@@ -263,4 +417,3 @@ class BotEntity[CreateSchemaType: BaseModel,
await session.commit()
return obj
return None