add command params

This commit is contained in:
Alexander Kalinovsky
2025-01-29 23:40:43 +01:00
parent b40e588379
commit f666bcfba3
33 changed files with 547 additions and 340 deletions

View File

@@ -11,13 +11,13 @@ from typing import (
TYPE_CHECKING,
)
from pydantic import BaseModel
from sqlmodel import SQLModel, BIGINT, Field, select, func, column
from sqlmodel import SQLModel, BigInteger, Field, select, func, column
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
from sqlmodel.main import SQLModelMetaclass, RelationshipInfo
from .descriptors import EntityDescriptor, EntityField, EntityFieldDescriptor, Filter
from .descriptors import EntityDescriptor, EntityField, FieldDescriptor, Filter
from .entity_metadata import EntityMetadata
from . import session_dep
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
class BotEntityMetaclass(SQLModelMetaclass):
__future_references__ = {}
_future_references = {}
def __new__(mcs, name, bases, namespace, **kwargs):
bot_fields_descriptors = {}
@@ -35,7 +35,7 @@ class BotEntityMetaclass(SQLModelMetaclass):
bot_entity_descriptor = bases[0].__dict__.get("bot_entity_descriptor")
bot_fields_descriptors = (
{
key: EntityFieldDescriptor(**value.__dict__.copy())
key: FieldDescriptor(**value.__dict__.copy())
for key, value in bot_entity_descriptor.fields_descriptors.items()
}
if bot_entity_descriptor
@@ -71,7 +71,7 @@ class BotEntityMetaclass(SQLModelMetaclass):
type_origin = get_origin(type_)
field_descriptor = EntityFieldDescriptor(
field_descriptor = FieldDescriptor(
name=descriptor_name,
field_name=annotation,
type_=type_,
@@ -80,18 +80,19 @@ class BotEntityMetaclass(SQLModelMetaclass):
)
is_list = False
is_optional = False
if type_origin is list:
field_descriptor.is_list = is_list = True
field_descriptor.type_base = type_ = get_args(type_)[0]
if type_origin == Union and isinstance(get_args(type_)[0], ForwardRef):
field_descriptor.is_optional = True
field_descriptor.is_optional = is_optional = True
field_descriptor.type_base = type_ = get_args(type_)[
0
].__forward_arg__
if type_origin == UnionType and get_args(type_)[1] == NoneType:
field_descriptor.is_optional = True
field_descriptor.is_optional = is_optional = True
field_descriptor.type_base = type_ = get_args(type_)[0]
if isinstance(type_, str):
@@ -100,18 +101,16 @@ class BotEntityMetaclass(SQLModelMetaclass):
entity_descriptor
) in EntityMetadata().entity_descriptors.values():
if type_ == entity_descriptor.class_name:
field_descriptor.type_base = entity_descriptor.type_
field_descriptor.type_ = (
list[entity_descriptor.type_]
if is_list
else (
Optional[entity_descriptor.type_]
if type_origin == Optional
if type_origin == Union and is_optional
else (
entity_descriptor.type_ | None
if (
type_origin == UnionType
and get_args(type_)[1] == NoneType
)
if (type_origin == UnionType and is_optional)
else entity_descriptor.type_
)
)
@@ -119,10 +118,10 @@ class BotEntityMetaclass(SQLModelMetaclass):
type_not_found = False
break
if type_not_found:
if type_ in mcs.__future_references__:
mcs.__future_references__[type_].append(field_descriptor)
if type_ in mcs._future_references:
mcs._future_references[type_].append(field_descriptor)
else:
mcs.__future_references__[type_] = [field_descriptor]
mcs._future_references[type_] = [field_descriptor]
bot_fields_descriptors[descriptor_name] = field_descriptor
@@ -191,14 +190,14 @@ class BotEntityMetaclass(SQLModelMetaclass):
type_ = super().__new__(mcs, name, bases, namespace, **kwargs)
if name in mcs.__future_references__:
for field_descriptor in mcs.__future_references__[name]:
if name in mcs._future_references:
for field_descriptor in mcs._future_references[name]:
type_origin = get_origin(field_descriptor.type_)
field_descriptor.type_base = type_
field_descriptor.type_ = (
list[type_]
if get_origin(field_descriptor.type_) is list
if type_origin is list
else (
Optional[type_]
if type_origin == Union
@@ -220,7 +219,9 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
bot_entity_descriptor: ClassVar[EntityDescriptor]
entity_metadata: ClassVar[EntityMetadata]
id: int = Field(primary_key=True, sa_type=BIGINT)
id: int = EntityField(
sm_descriptor=Field(primary_key=True, sa_type=BigInteger), is_visible=False
)
@classmethod
@session_dep
@@ -228,7 +229,7 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
return await session.get(cls, id, populate_existing=True)
@classmethod
def _static_fiter_condition(
def _static_filter_condition(
cls, select_statement: SelectOfScalar[Self], static_filter: list[Filter]
):
for sfilt in static_filter:
@@ -292,7 +293,7 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
select_statement = select(func.count()).select_from(cls)
if static_filter:
if isinstance(static_filter, list):
select_statement = cls._static_fiter_condition(
select_statement = cls._static_filter_condition(
select_statement, static_filter
)
else:
@@ -327,7 +328,7 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
select_statement = select_statement.limit(limit)
if static_filter:
if isinstance(static_filter, list):
select_statement = cls._static_fiter_condition(
select_statement = cls._static_filter_condition(
select_statement, static_filter
)
else:

View File

@@ -143,7 +143,13 @@ class BotEnum(EnumMember, metaclass=BotEnumMetaclass):
class EnumType(TypeDecorator):
impl = String(256)
impl = String(64)
cache_ok = True
# class comparator_factory(TypeDecorator.Comparator):
# def __eq__(self, other):
# expr = type_coerce(self.expr, String)
# return expr != other.value
def __init__(self, enum_type: BotEnum):
self._enum_type = enum_type

View File

@@ -2,7 +2,7 @@ from aiogram.types import Message, CallbackQuery
from aiogram.fsm.context import FSMContext
from aiogram.utils.i18n import I18n
from aiogram.utils.keyboard import InlineKeyboardBuilder
from typing import Any, Callable, TYPE_CHECKING, Literal
from typing import Any, Callable, TYPE_CHECKING, Literal, Union
from babel.support import LazyProxy
from dataclasses import dataclass, field
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -18,22 +18,21 @@ if TYPE_CHECKING:
EntityCaptionCallable = Callable[["EntityDescriptor"], str]
EntityItemCaptionCallable = Callable[["EntityDescriptor", Any], str]
EntityFieldCaptionCallable = Callable[["EntityFieldDescriptor", Any, Any], str]
EntityFieldCaptionCallable = Callable[["FieldDescriptor", Any, Any], str]
@dataclass
class FieldEditButton:
field_name: str
visibility: Callable[[Any], bool] | None = None
caption: str | LazyProxy | EntityFieldCaptionCallable | None = None
visibility: Callable[[Any], bool] | None = None
@dataclass
class CommandButton:
command: str
caption: str | LazyProxy | EntityItemCaptionCallable | None = None
command: ContextData | Callable[[ContextData, Any], ContextData] | str
caption: str | LazyProxy | EntityItemCaptionCallable
visibility: Callable[[Any], bool] | None = None
context_data: ContextData | Callable[[ContextData, Any], ContextData] | None = None
@dataclass
@@ -81,7 +80,7 @@ class EntityForm:
@dataclass(kw_only=True)
class _BaseEntityFieldDescriptor:
class _BaseFieldDescriptor:
icon: str = None
caption: str | LazyProxy | EntityFieldCaptionCallable | None = None
description: str | LazyProxy | EntityFieldCaptionCallable | None = None
@@ -99,18 +98,24 @@ class _BaseEntityFieldDescriptor:
@dataclass(kw_only=True)
class EntityField(_BaseEntityFieldDescriptor):
class EntityField(_BaseFieldDescriptor):
name: str | None = None
sm_descriptor: Any = None
@dataclass(kw_only=True)
class Setting(_BaseEntityFieldDescriptor):
class Setting(_BaseFieldDescriptor):
name: str | None = None
@dataclass(kw_only=True)
class EntityFieldDescriptor(_BaseEntityFieldDescriptor):
class FormField(_BaseFieldDescriptor):
name: str | None = None
type_: type
@dataclass(kw_only=True)
class FieldDescriptor(_BaseFieldDescriptor):
name: str
field_name: str
type_: type
@@ -118,6 +123,7 @@ class EntityFieldDescriptor(_BaseEntityFieldDescriptor):
is_list: bool = False
is_optional: bool = False
entity_descriptor: "EntityDescriptor" = None
command: "BotCommand" = None
def __hash__(self):
return self.name.__hash__()
@@ -162,7 +168,7 @@ class EntityDescriptor(_BaseEntityDescriptor):
name: str
class_name: str
type_: type["BotEntity"]
fields_descriptors: dict[str, EntityFieldDescriptor]
fields_descriptors: dict[str, FieldDescriptor]
@dataclass(kw_only=True)
@@ -179,24 +185,21 @@ class CommandCallbackContext[UT: UserBase]:
app: "QBotApp"
state_data: dict[str, Any]
state: FSMContext
form_data: dict[str, Any]
i18n: I18n
kwargs: dict[str, Any] = field(default_factory=dict)
@dataclass(kw_only=True)
class _BotCommand:
class BotCommand:
name: str
caption: str | dict[str, str] | None = None
pre_check: Callable[[Union[Message, CallbackQuery], Any], bool] | None = None
show_in_bot_commands: bool = False
register_navigation: bool = True
clear_navigation: bool = False
clear_state: bool = True
@dataclass(kw_only=True)
class BotCommand(_BotCommand):
param_form: dict[str, FieldDescriptor] | None = None
show_cancel_in_param_form: bool = True
show_back_in_param_form: bool = True
handler: Callable[[CommandCallbackContext], None]
@dataclass(kw_only=True)
class Command(_BotCommand): ...

View File

@@ -6,7 +6,7 @@ from typing import Any, get_args, get_origin
from ..db import async_session
from .role import RoleBase
from .descriptors import EntityFieldDescriptor, Setting
from .descriptors import FieldDescriptor, Setting
from ..utils.serialization import deserialize, serialize
import ujson as json
@@ -39,7 +39,7 @@ class SettingsMetaclass(type):
if isinstance(attr_value, Setting):
descriptor_kwargs = attr_value.__dict__.copy()
name = descriptor_kwargs.pop("name") or annotation
attributes[annotation] = EntityFieldDescriptor(
attributes[annotation] = FieldDescriptor(
name=name,
field_name=annotation,
type_=type_,
@@ -48,7 +48,7 @@ class SettingsMetaclass(type):
)
else:
attributes[annotation] = EntityFieldDescriptor(
attributes[annotation] = FieldDescriptor(
name=annotation,
field_name=annotation,
type_=type_,
@@ -83,7 +83,7 @@ class SettingsMetaclass(type):
class Settings(metaclass=SettingsMetaclass):
_cache: dict[str, Any] = dict[str, Any]()
_settings_descriptors: dict[str, EntityFieldDescriptor] = {}
_settings_descriptors: dict[str, FieldDescriptor] = {}
PAGE_SIZE: int = Setting(
default=10,
@@ -213,7 +213,7 @@ class Settings(metaclass=SettingsMetaclass):
return ret_val
@classmethod
async def load_param(cls, param: EntityFieldDescriptor) -> Any:
async def load_param(cls, param: FieldDescriptor) -> Any:
async with async_session() as session:
db_setting = (
await session.exec(
@@ -244,7 +244,7 @@ class Settings(metaclass=SettingsMetaclass):
db_settings = (await session.exec(select(DbSettings))).all()
for db_setting in db_settings:
if db_setting.name in cls.__dict__:
setting = cls.__dict__[db_setting.name] # type: EntityFieldDescriptor
setting = cls.__dict__[db_setting.name] # type: FieldDescriptor
cls._cache[db_setting.name] = await deserialize(
session=session,
type_=setting.type_,
@@ -255,7 +255,7 @@ class Settings(metaclass=SettingsMetaclass):
cls._loaded = True
@classmethod
async def set_param(cls, param: str | EntityFieldDescriptor, value) -> None:
async def set_param(cls, param: str | FieldDescriptor, value) -> None:
if isinstance(param, str):
param = cls._settings_descriptors[param]
ser_value = serialize(value, param)
@@ -273,11 +273,11 @@ class Settings(metaclass=SettingsMetaclass):
cls._cache[param.field_name] = value
@classmethod
def list_params(cls) -> dict[str, EntityFieldDescriptor]:
def list_params(cls) -> dict[str, FieldDescriptor]:
return cls._settings_descriptors
@classmethod
async def get_params(cls) -> dict[EntityFieldDescriptor, Any]:
async def get_params(cls) -> dict[FieldDescriptor, Any]:
params = cls.list_params()
return {
param: await cls.get(param, all_locales=True) for _, param in params.items()

View File

@@ -1,4 +1,4 @@
from sqlmodel import SQLModel, Field, BIGINT
from sqlmodel import SQLModel, Field, BigInteger
from sqlalchemy.ext.asyncio.session import AsyncSession
from . import session_dep
@@ -7,7 +7,7 @@ from . import session_dep
class ViewSetting(SQLModel, table=True):
__tablename__ = "view_setting"
user_id: int = Field(
sa_type=BIGINT, primary_key=True, foreign_key="user.id", ondelete="CASCADE"
sa_type=BigInteger, primary_key=True, foreign_key="user.id", ondelete="CASCADE"
)
entity_name: str = Field(primary_key=True)
filter: str | None = None