refactoring

This commit is contained in:
Alexander Kalinovsky
2025-01-09 13:11:10 +01:00
parent 7793a0cb77
commit 3898a333fa
29 changed files with 1065 additions and 381 deletions

View File

@@ -1,7 +1,8 @@
from functools import wraps
from typing import ClassVar, cast, get_args, get_origin
from types import NoneType, UnionType
from typing import ClassVar, ForwardRef, Optional, Union, cast, get_args, get_origin
from pydantic import BaseModel
from sqlmodel import SQLModel, BIGINT, Field, select, func
from sqlmodel import SQLModel, BIGINT, Field, select, func, column
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import SQLModelMetaclass, RelationshipInfo
@@ -53,24 +54,36 @@ class BotEntityMetaclass(SQLModelMetaclass):
type_ = namespace['__annotations__'][annotation]
type_origin = get_origin(type_)
field_descriptor = EntityFieldDescriptor(
name = descriptor_name,
field_name = annotation,
type_ = type_,
type_base = type_,
**descriptor_kwargs)
type_origin = get_origin(type_)
is_list = False
if type_origin == list:
is_list = True
type_ = get_args(type_)[0]
field_descriptor.is_list = is_list = True
field_descriptor.type_base = type_ = get_args(type_)[0]
if type_origin == Union and isinstance(get_args(type_)[0], ForwardRef):
field_descriptor.is_optional = True
field_descriptor.type_base = type_ = get_args(type_)[0].__forward_arg__
if type_origin == UnionType and get_args(type_)[1] == NoneType:
field_descriptor.is_optional = True
field_descriptor.type_base = type_ = get_args(type_)[0]
if isinstance(type_, str):
type_not_found = True
for entity_descriptor in EntityMetadata().entity_descriptors.values():
if type_ == entity_descriptor.class_name:
field_descriptor.type_ = list[entity_descriptor.type_] if is_list else entity_descriptor.type_
field_descriptor.type_ = (list[entity_descriptor.type_] if is_list
else Optional[entity_descriptor.type_] if type_origin == Optional
else entity_descriptor.type_ | None if (type_origin == UnionType and get_args(type_)[1] == NoneType)
else entity_descriptor.type_)
type_not_found = False
break
if type_not_found:
@@ -131,8 +144,12 @@ class BotEntityMetaclass(SQLModelMetaclass):
if name in mcs.__future_references__:
for field_descriptor in mcs.__future_references__[name]:
field_descriptor.type_ = list[type_] if get_origin(field_descriptor.type_) == list else type_
a = field_descriptor
type_origin = get_origin(field_descriptor.type_)
field_descriptor.type_base = type_
field_descriptor.type_ = (list[type_] if get_origin(field_descriptor.type_) == list else
Optional[type_] if type_origin == Union and isinstance(get_args(field_descriptor.type_)[0], ForwardRef) else
type_ | None if type_origin == UnionType else
type_)
setattr(namespace["bot_entity_descriptor"], "type_", type_)
@@ -160,15 +177,19 @@ class BotEntity[CreateSchemaType: BaseModel,
session: AsyncSession | None = None,
id: int):
return await session.get(cls, id)
return await session.get(cls, id, populate_existing = True)
@classmethod
@session_dep
async def get_count(cls, *,
session: AsyncSession | None = None) -> int:
session: AsyncSession | None = None,
filter: str = None) -> int:
return await session.scalar(select(func.count()).select_from(cls))
select_statement = select(func.count()).select_from(cls)
if filter:
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
return await session.scalar(select_statement)
@classmethod
@@ -176,12 +197,15 @@ class BotEntity[CreateSchemaType: BaseModel,
async def get_multi(cls, *,
session: AsyncSession | None = None,
order_by = None,
filter:str = None,
skip: int = 0,
limit: int = None):
select_statement = select(cls).offset(skip)
if limit:
select_statement = select_statement.limit(limit)
if filter:
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
if order_by:
select_statement = select_statement.order_by(order_by)
return (await session.exec(select_statement)).all()
@@ -238,4 +262,5 @@ class BotEntity[CreateSchemaType: BaseModel,
if commit:
await session.commit()
return obj
return None
return None

View File

@@ -1,10 +1,13 @@
from typing import Any, Callable
from typing import Any, Callable, TYPE_CHECKING
from babel.support import LazyProxy
from dataclasses import dataclass, field
from .role import RoleBase
from . import EntityPermission
if TYPE_CHECKING:
from .bot_entity import BotEntity
EntityCaptionCallable = Callable[["EntityDescriptor"], str]
EntityItemCaptionCallable = Callable[["EntityDescriptor", Any], str]
EntityFieldCaptionCallable = Callable[["EntityFieldDescriptor", Any, Any], str]
@@ -13,18 +16,14 @@ EntityFieldCaptionCallable = Callable[["EntityFieldDescriptor", Any, Any], str]
@dataclass(kw_only = True)
class _BaseEntityFieldDescriptor():
icon: str = None
caption_str: str | LazyProxy | EntityFieldCaptionCallable | None = None
caption_btn: str | LazyProxy | EntityFieldCaptionCallable | None = None
caption: str | LazyProxy | EntityFieldCaptionCallable | None = None
description: str | LazyProxy | EntityFieldCaptionCallable | None = None
edit_prompt: str | LazyProxy | EntityFieldCaptionCallable | None = None
caption_value_str: str | LazyProxy | EntityFieldCaptionCallable | None = None
caption_value_btn: str | LazyProxy | EntityFieldCaptionCallable | None = None
caption_value: EntityFieldCaptionCallable | None = None
is_visible: bool = True
localizable: bool = False
bool_false_value: str | LazyProxy = "no"
bool_false_value_btn: str | LazyProxy = "no"
bool_true_value: str | LazyProxy = "yes"
bool_true_value_btn: str | LazyProxy = "yes"
default: Any = None
@@ -44,6 +43,9 @@ class EntityFieldDescriptor(_BaseEntityFieldDescriptor):
name: str
field_name: str
type_: type
type_base: type = None
is_list: bool = False
is_optional: bool = False
entity_descriptor: "EntityDescriptor" = None
def __hash__(self):
@@ -54,14 +56,14 @@ class EntityFieldDescriptor(_BaseEntityFieldDescriptor):
class _BaseEntityDescriptor:
icon: str = "📘"
caption_msg: str | LazyProxy | EntityCaptionCallable | None = None
caption_btn: str | LazyProxy | EntityCaptionCallable | None = None
caption: str | LazyProxy | EntityCaptionCallable | None = None
caption_plural: str | LazyProxy | EntityCaptionCallable | None = None
description: str | LazyProxy | EntityCaptionCallable | None = None
item_caption_msg: EntityItemCaptionCallable | None = None
item_caption_btn: EntityItemCaptionCallable | None = None
item_caption: EntityItemCaptionCallable | None = None
show_in_entities_menu: bool = True
field_sequence: list[str] = None
edit_buttons: list[list[str]] = None
edit_button_visible: bool = True
edit_buttons: list[list[str | tuple[str, str | LazyProxy | EntityFieldCaptionCallable]]] = None
permissions: dict[EntityPermission, list[RoleBase]] = field(default_factory = lambda: {
EntityPermission.LIST: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.READ: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
@@ -87,5 +89,5 @@ class EntityDescriptor(_BaseEntityDescriptor):
name: str
class_name: str
type_: type
fields_descriptors: dict[str, EntityFieldDescriptor]
type_: type["BotEntity"]
fields_descriptors: dict[str, EntityFieldDescriptor]

26
model/menu.py Normal file
View File

@@ -0,0 +1,26 @@
# from aiogram.types import Message, CallbackQuery
# from aiogram.utils.keyboard import InlineKeyboardBuilder
# from typing import Any, Callable, Self, Union, overload
# from babel.support import LazyProxy
# from dataclasses import dataclass
# from ..bot.handlers.context import ContextData
# class Menu:
# @overload
# def __init__(self, description: str | LazyProxy): ...
# @overload
# def __init__(self, menu_factory: Callable[[InlineKeyboardBuilder, Union[Message, CallbackQuery], Any], str]): ...
# def __init__(self, description: str | LazyProxy = None,
# menu_factory: Callable[[InlineKeyboardBuilder, Union[Message, CallbackQuery], Any], str] = None) -> None:
# self.menu_factory = menu_factory
# self.description = description
# self.parent: Menu = None
# self.items: list[list[Menu]] = []

View File

@@ -1,4 +1,4 @@
from sqlmodel import BIGINT, Field, select, func
from sqlmodel import BIGINT, Field, select, func, column
from sqlmodel.ext.asyncio.session import AsyncSession
@@ -20,11 +20,14 @@ class OwnedBotEntity(BotEntity, table = False):
async def get_multi_by_user(cls, *,
session: AsyncSession | None = None,
user_id: int,
filter: str = None,
order_by = None,
skip: int = 0,
limit: int = None):
select_statement = select(cls).where(cls.user_id == user_id).offset(skip)
if filter:
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
if limit:
select_statement = select_statement.limit(limit)
if order_by:
@@ -36,9 +39,11 @@ class OwnedBotEntity(BotEntity, table = False):
@session_dep
async def get_count_by_user(cls, *,
session: AsyncSession | None = None,
user_id: int):
user_id: int,
filter: str = None) -> int:
return await session.scalar(
select(func.count()).
select_from(cls).
where(cls.user_id == user_id))
select_statement = select(func.count()).select_from(cls).where(cls.user_id == user_id)
if filter:
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
return await session.scalar(select_statement)

View File

@@ -1,12 +1,16 @@
from types import NoneType, UnionType
from aiogram.utils.i18n.context import get_i18n
from datetime import datetime
from sqlmodel import SQLModel, Field, select
from typing import Any, get_origin
from typing import Any, get_args, get_origin
from ..db import async_session
from .role import RoleBase
from .descriptors import EntityFieldDescriptor, Setting
from ..utils import deserialize, serialize
import ujson as json
class DbSettings(SQLModel, table = True):
__tablename__ = "settings"
@@ -31,23 +35,40 @@ class SettingsMetaclass(type):
attr_value = attributes.get(annotation)
name = annotation
type_ = attributes['__annotations__'][annotation]
if isinstance(attr_value, Setting):
descriptor_kwargs = attr_value.__dict__.copy()
name = descriptor_kwargs.pop("name") or annotation
attributes[annotation] = EntityFieldDescriptor(
name = name,
field_name = annotation,
type_ = attributes['__annotations__'][annotation],
type_ = type_,
type_base = type_,
**descriptor_kwargs)
else:
attributes[annotation] = EntityFieldDescriptor(
name = annotation,
field_name = annotation,
type_ = attributes['__annotations__'][annotation],
type_ = type_,
type_base = type_,
default = attr_value)
type_origin = get_origin(type_)
if type_origin == list:
attributes[annotation].is_list = True
attributes[annotation].type_base = type_ = get_args(type_)[0]
elif type_origin == UnionType and get_args(type_)[1] == NoneType:
attributes[annotation].is_optional = True
attributes[annotation].type_base = type_ = get_args(type_)[0]
settings_descriptors[name] = attributes[annotation]
if base_classes and base_classes[0].__name__ == "Settings" and hasattr(base_classes[0], annotation):
setattr(base_classes[0], annotation, attributes[annotation])
attributes["__annotations__"] = {}
attributes["_settings_descriptors"] = settings_descriptors
@@ -61,8 +82,7 @@ class Settings(metaclass = SettingsMetaclass):
_settings_descriptors: dict[str, EntityFieldDescriptor] = {}
PAGE_SIZE: int = Setting(default = 10, )
SECURITY_SETTINGS_ROLES: list[RoleBase] = [RoleBase.SUPER_USER]
SECURITY_PARAMETERS_ROLES: list[RoleBase] = Setting(name = "SECPARAMS_ROLES", default = [RoleBase.SUPER_USER], is_visible = False)
APP_STRINGS_WELCOME_P_NAME: str = Setting(name = "AS_WELCOME", default = "Welcome, {name}", is_visible = False)
APP_STRINGS_GREETING_P_NAME: str = Setting(name = "AS_GREETING", default = "Hello, {name}", is_visible = False)
@@ -90,6 +110,7 @@ class Settings(metaclass = SettingsMetaclass):
APP_STRINGS_YES_BTN: str = Setting(name = "AS_YES_BTN", default = "✅ Yes", is_visible = False)
APP_STRINGS_NO_BTN: str = Setting(name = "AS_NO_BTN", default = "❌ No", is_visible = False)
APP_STRINGS_CANCEL_BTN: str = Setting(name = "AS_CANCEL_BTN", default = "❌ Cancel", is_visible = False)
APP_STRINGS_CLEAR_BTN: str = Setting(name = "AS_CLEAR_BTN", default = "⌫ Clear", is_visible = False)
APP_STRINGS_DONE_BTN: str = Setting(name = "AS_DONE_BTN", default = "✅ Done", is_visible = False)
APP_STRINGS_SKIP_BTN: str = Setting(name = "AS_SKIP_BTN", default = "⏩️ Skip", is_visible = False)
APP_STRINGS_FIELD_EDIT_PROMPT_TEMPLATE_P_NAME_VALUE: str = Setting(
@@ -104,18 +125,30 @@ class Settings(metaclass = SettingsMetaclass):
name = "AS_STREDIT_LOC_TEMPLATE",
default = "string for \"{name}\"",
is_visible = False)
APP_STRINGS_VIEW_FILTER_EDIT_PROMPT: str = Setting(name = "AS_FILTEREDIT_PROMPT", default = "Enter filter value", is_visible = False)
APP_STRINGS_INVALID_INPUT: str = Setting(name = "AS_INVALID_INPUT", default = "Invalid input", is_visible = False)
@classmethod
async def get[T](cls, param: T) -> T:
async def get[T](cls, param: T, all_locales = False, locale: str = None) -> T:
name = param.field_name
if param.name not in cls._cache.keys():
if name not in cls._cache.keys():
cls._cache[name] = await cls.load_param(param)
return cls._cache[name]
ret_val = cls._cache[name]
if param.localizable and not all_locales:
if not locale:
locale = get_i18n().current_locale
try:
obj = json.loads(ret_val)
except:
return ret_val
return obj.get(locale, obj[list(obj.keys())[0]])
return ret_val
@classmethod
@@ -180,4 +213,4 @@ class Settings(metaclass = SettingsMetaclass):
async def get_params(cls) -> dict[EntityFieldDescriptor, Any]:
params = cls.list_params()
return {param: await cls.get(param) for _, param in params.items()}
return {param: await cls.get(param, all_locales = True) for _, param in params.items()}

View File

@@ -7,6 +7,7 @@ from .role import RoleBase
from .settings import DbSettings as DbSettings
from .fsm_storage import FSMStorage as FSMStorage
from .view_setting import ViewSetting as ViewSetting
class UserBase(BotEntity, table = False):

39
model/view_setting.py Normal file
View File

@@ -0,0 +1,39 @@
from sqlmodel import SQLModel, Field, BIGINT
from sqlalchemy.ext.asyncio.session import AsyncSession
from . import session_dep
class ViewSetting(SQLModel, table = True):
__tablename__ = "view_setting"
user_id: int = Field(sa_type = BIGINT, primary_key = True, foreign_key="user.id", ondelete="CASCADE")
entity_name: str = Field(primary_key = True)
filter: str | None = None
@classmethod
@session_dep
async def get_filter(cls, *,
session: AsyncSession | None = None,
user_id: int,
entity_name: str):
setting = await session.get(cls, (user_id, entity_name))
return setting.filter if setting else None
@classmethod
@session_dep
async def set_filter(cls, *,
session: AsyncSession | None = None,
user_id: int,
entity_name: str,
filter: str):
setting = await session.get(cls, (user_id, entity_name))
if setting:
setting.filter = filter
else:
setting = cls(user_id = user_id, entity_name = entity_name, filter = filter)
session.add(setting)
await session.commit()