add ruff format, ruff check, time_picker, project structure and imports reorganized

This commit is contained in:
Alexander Kalinovsky
2025-01-21 23:50:19 +01:00
parent ced47ac993
commit 9dd0708a5b
58 changed files with 3690 additions and 2583 deletions

View File

@@ -7,9 +7,7 @@ from .bot_enum import BotEnum, EnumMember
from ..db import async_session
class EntityPermission(BotEnum):
LIST = EnumMember("list")
READ = EnumMember("read")
CREATE = EnumMember("create")
@@ -23,24 +21,23 @@ class EntityPermission(BotEnum):
def session_dep(func):
@wraps(func)
async def wrapper(cls, *args, **kwargs):
if "session" in kwargs and kwargs["session"]:
return await func(cls, *args, **kwargs)
@wraps(func)
async def wrapper(cls, *args, **kwargs):
if "session" in kwargs and kwargs["session"]:
return await func(cls, *args, **kwargs)
_session = None
_session = None
state = cast(InstanceState, inspect(cls))
if hasattr(state, "async_session"):
_session = state.async_session
if not _session:
async with async_session() as session:
kwargs["session"] = session
return await func(cls, *args, **kwargs)
else:
kwargs["session"] = _session
state = cast(InstanceState, inspect(cls))
if hasattr(state, "async_session"):
_session = state.async_session
if not _session:
async with async_session() as session:
kwargs["session"] = session
return await func(cls, *args, **kwargs)
return wrapper
else:
kwargs["session"] = _session
return await func(cls, *args, **kwargs)
return wrapper

View File

@@ -1,6 +1,7 @@
class Singleton(type):
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]
return cls._instances[cls]

View File

@@ -1,33 +1,49 @@
from functools import wraps
from types import NoneType, UnionType
from typing import ClassVar, ForwardRef, Optional, Union, cast, get_args, get_origin
from typing import (
Any,
ClassVar,
ForwardRef,
Optional,
Self,
Union,
get_args,
get_origin,
TYPE_CHECKING,
)
from pydantic import BaseModel
from sqlmodel import SQLModel, BIGINT, Field, select, func, column
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
from sqlmodel.main import SQLModelMetaclass, RelationshipInfo
from .descriptors import EntityDescriptor, EntityField, EntityFieldDescriptor
from .descriptors import EntityDescriptor, EntityField, EntityFieldDescriptor, Filter
from .entity_metadata import EntityMetadata
from . import session_dep
if TYPE_CHECKING:
from .user import UserBase
class BotEntityMetaclass(SQLModelMetaclass):
__future_references__ = {}
def __new__(mcs, name, bases, namespace, **kwargs):
bot_fields_descriptors = {}
if bases:
bot_entity_descriptor = bases[0].__dict__.get('bot_entity_descriptor')
bot_fields_descriptors = {key: EntityFieldDescriptor(**value.__dict__.copy())
for key, value in bot_entity_descriptor.fields_descriptors.items()} if bot_entity_descriptor else {}
if '__annotations__' in namespace:
for annotation in namespace['__annotations__']:
bot_entity_descriptor = bases[0].__dict__.get("bot_entity_descriptor")
bot_fields_descriptors = (
{
key: EntityFieldDescriptor(**value.__dict__.copy())
for key, value in bot_entity_descriptor.fields_descriptors.items()
}
if bot_entity_descriptor
else {}
)
if "__annotations__" in namespace:
for annotation in namespace["__annotations__"]:
if annotation in ["bot_entity_descriptor", "entity_metadata"]:
continue
@@ -36,41 +52,43 @@ class BotEntityMetaclass(SQLModelMetaclass):
if isinstance(attribute_value, RelationshipInfo):
continue
descriptor_kwargs = {}
descriptor_kwargs = {}
descriptor_name = annotation
if attribute_value:
if isinstance(attribute_value, EntityField):
descriptor_kwargs = attribute_value.__dict__.copy()
sm_descriptor = descriptor_kwargs.pop("sm_descriptor", None)
if sm_descriptor:
namespace[annotation] = sm_descriptor
else:
namespace.pop(annotation)
namespace.pop(annotation)
descriptor_name = descriptor_kwargs.pop("name") or annotation
type_ = namespace['__annotations__'][annotation]
type_ = namespace["__annotations__"][annotation]
type_origin = get_origin(type_)
field_descriptor = EntityFieldDescriptor(
name = descriptor_name,
field_name = annotation,
type_ = type_,
type_base = type_,
**descriptor_kwargs)
name=descriptor_name,
field_name=annotation,
type_=type_,
type_base=type_,
**descriptor_kwargs,
)
is_list = False
if type_origin == list:
if type_origin is list:
field_descriptor.is_list = is_list = True
field_descriptor.type_base = type_ = get_args(type_)[0]
if type_origin == Union and isinstance(get_args(type_)[0], ForwardRef):
field_descriptor.is_optional = True
field_descriptor.type_base = type_ = get_args(type_)[0].__forward_arg__
field_descriptor.type_base = type_ = get_args(type_)[
0
].__forward_arg__
if type_origin == UnionType and get_args(type_)[1] == NoneType:
field_descriptor.is_optional = True
@@ -78,12 +96,26 @@ class BotEntityMetaclass(SQLModelMetaclass):
if isinstance(type_, str):
type_not_found = True
for entity_descriptor in EntityMetadata().entity_descriptors.values():
for (
entity_descriptor
) in EntityMetadata().entity_descriptors.values():
if type_ == entity_descriptor.class_name:
field_descriptor.type_ = (list[entity_descriptor.type_] if is_list
else Optional[entity_descriptor.type_] if type_origin == Optional
else entity_descriptor.type_ | None if (type_origin == UnionType and get_args(type_)[1] == NoneType)
else entity_descriptor.type_)
field_descriptor.type_ = (
list[entity_descriptor.type_]
if is_list
else (
Optional[entity_descriptor.type_]
if type_origin == Optional
else (
entity_descriptor.type_ | None
if (
type_origin == UnionType
and get_args(type_)[1] == NoneType
)
else entity_descriptor.type_
)
)
)
type_not_found = False
break
if type_not_found:
@@ -91,7 +123,7 @@ class BotEntityMetaclass(SQLModelMetaclass):
mcs.__future_references__[type_].append(field_descriptor)
else:
mcs.__future_references__[type_] = [field_descriptor]
bot_fields_descriptors[descriptor_name] = field_descriptor
descriptor_name = name
@@ -101,123 +133,246 @@ class BotEntityMetaclass(SQLModelMetaclass):
descriptor_kwargs: dict = entity_descriptor.__dict__.copy()
descriptor_name = descriptor_kwargs.pop("name", None)
descriptor_name = descriptor_name or name.lower()
descriptor_fields_sequence = descriptor_kwargs.pop("field_sequence", None)
if not descriptor_fields_sequence:
descriptor_fields_sequence = list(bot_fields_descriptors.keys())
descriptor_fields_sequence.remove("id")
namespace["bot_entity_descriptor"] = EntityDescriptor(
name = descriptor_name,
class_name = name,
type_ = name,
fields_descriptors = bot_fields_descriptors,
field_sequence = descriptor_fields_sequence,
**descriptor_kwargs)
name=descriptor_name,
class_name=name,
type_=name,
fields_descriptors=bot_fields_descriptors,
**descriptor_kwargs,
)
else:
descriptor_fields_sequence = list(bot_fields_descriptors.keys())
descriptor_fields_sequence.remove("id")
descriptor_name = name.lower()
namespace["bot_entity_descriptor"] = EntityDescriptor(
name = descriptor_name,
class_name = name,
type_ = name,
fields_descriptors = bot_fields_descriptors,
field_sequence = descriptor_fields_sequence)
name=descriptor_name,
class_name=name,
type_=name,
fields_descriptors=bot_fields_descriptors,
)
descriptor_fields_sequence = [
key
for key, val in bot_fields_descriptors.items()
if not (val.is_optional or val.name == "id")
]
entity_descriptor: EntityDescriptor = namespace["bot_entity_descriptor"]
if entity_descriptor.default_form.edit_field_sequence is None:
entity_descriptor.default_form.edit_field_sequence = (
descriptor_fields_sequence
)
for form in entity_descriptor.forms.values():
if form.edit_field_sequence is None:
form.edit_field_sequence = descriptor_fields_sequence
for field_descriptor in bot_fields_descriptors.values():
field_descriptor.entity_descriptor = namespace["bot_entity_descriptor"]
if "table" not in kwargs:
kwargs["table"] = True
if kwargs["table"] == True:
if kwargs["table"]:
entity_metadata = EntityMetadata()
entity_metadata.entity_descriptors[descriptor_name] = namespace["bot_entity_descriptor"]
entity_metadata.entity_descriptors[descriptor_name] = namespace[
"bot_entity_descriptor"
]
if "__annotations__" in namespace:
namespace["__annotations__"]["entity_metadata"] = ClassVar[EntityMetadata]
namespace["__annotations__"]["entity_metadata"] = ClassVar[
EntityMetadata
]
else:
namespace["__annotations__"] = {"entity_metadata": ClassVar[EntityMetadata]}
namespace["__annotations__"] = {
"entity_metadata": ClassVar[EntityMetadata]
}
namespace["entity_metadata"] = entity_metadata
type_ = super().__new__(mcs, name, bases, namespace, **kwargs)
if name in mcs.__future_references__:
for field_descriptor in mcs.__future_references__[name]:
type_origin = get_origin(field_descriptor.type_)
field_descriptor.type_base = type_
field_descriptor.type_ = (list[type_] if get_origin(field_descriptor.type_) == list else
Optional[type_] if type_origin == Union and isinstance(get_args(field_descriptor.type_)[0], ForwardRef) else
type_ | None if type_origin == UnionType else
type_)
field_descriptor.type_base = type_
field_descriptor.type_ = (
list[type_]
if get_origin(field_descriptor.type_) is list
else (
Optional[type_]
if type_origin == Union
and isinstance(get_args(field_descriptor.type_)[0], ForwardRef)
else type_ | None
if type_origin == UnionType
else type_
)
)
setattr(namespace["bot_entity_descriptor"], "type_", type_)
return type_
class BotEntity[CreateSchemaType: BaseModel,
UpdateSchemaType: BaseModel](SQLModel,
metaclass = BotEntityMetaclass,
table = False):
class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
SQLModel, metaclass=BotEntityMetaclass, table=False
):
bot_entity_descriptor: ClassVar[EntityDescriptor]
entity_metadata: ClassVar[EntityMetadata]
id: int = Field(
primary_key = True,
sa_type = BIGINT)
name: str
id: int = Field(primary_key=True, sa_type=BIGINT)
@classmethod
@session_dep
async def get(cls, *,
session: AsyncSession | None = None,
id: int):
return await session.get(cls, id, populate_existing = True)
async def get(cls, *, session: AsyncSession | None = None, id: int):
return await session.get(cls, id, populate_existing=True)
@classmethod
def _static_fiter_condition(
cls, select_statement: SelectOfScalar[Self], static_filter: list[Filter]
):
for sfilt in static_filter:
if sfilt.operator == "==":
condition = column(sfilt.field_name).__eq__(sfilt.value)
elif sfilt.operator == "!=":
condition = column(sfilt.field_name).__ne__(sfilt.value)
elif sfilt.operator == "<":
condition = column(sfilt.field_name).__lt__(sfilt.value)
elif sfilt.operator == "<=":
condition = column(sfilt.field_name).__le__(sfilt.value)
elif sfilt.operator == ">":
condition = column(sfilt.field_name).__gt__(sfilt.value)
elif sfilt.operator == ">=":
condition = column(sfilt.field_name).__ge__(sfilt.value)
elif sfilt.operator == "ilike":
condition = column(sfilt.field_name).ilike(f"%{sfilt.value}%")
elif sfilt.operator == "like":
condition = column(sfilt.field_name).like(f"%{sfilt.value}%")
elif sfilt.operator == "in":
condition = column(sfilt.field_name).in_(sfilt.value)
elif sfilt.operator == "not in":
condition = column(sfilt.field_name).notin_(sfilt.value)
elif sfilt.operator == "is":
condition = column(sfilt.field_name).is_(None)
elif sfilt.operator == "is not":
condition = column(sfilt.field_name).isnot(None)
else:
condition = None
if condition:
select_statement = select_statement.where(condition)
return select_statement
@classmethod
def _filter_condition(
cls,
select_statement: SelectOfScalar[Self],
filter: str,
filter_fields: list[str],
):
condition = None
for field in filter_fields:
if condition is not None:
condition = condition | (column(field).ilike(f"%{filter}%"))
else:
condition = column(field).ilike(f"%{filter}%")
return select_statement.where(condition)
@classmethod
@session_dep
async def get_count(cls, *,
session: AsyncSession | None = None,
filter: str = None) -> int:
async def get_count(
cls,
*,
session: AsyncSession | None = None,
static_filter: list[Filter] | Any = None,
filter: str = None,
filter_fields: list[str] = None,
ext_filter: Any = None,
user: "UserBase" = None,
) -> int:
select_statement = select(func.count()).select_from(cls)
if filter:
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
if static_filter:
if isinstance(static_filter, list):
select_statement = cls._static_fiter_condition(
select_statement, static_filter
)
else:
select_statement = select_statement.where(static_filter)
if filter and filter_fields:
select_statement = cls._filter_condition(
select_statement, filter, filter_fields
)
if ext_filter:
select_statement = select_statement.where(ext_filter)
if user:
select_statement = cls._ownership_condition(select_statement, user)
return await session.scalar(select_statement)
@classmethod
@session_dep
async def get_multi(cls, *,
session: AsyncSession | None = None,
order_by = None,
filter:str = None,
skip: int = 0,
limit: int = None):
async def get_multi(
cls,
*,
session: AsyncSession | None = None,
order_by=None,
static_filter: list[Filter] | Any = None,
filter: str = None,
filter_fields: list[str] = None,
ext_filter: Any = None,
user: "UserBase" = None,
skip: int = 0,
limit: int = None,
):
select_statement = select(cls).offset(skip)
if limit:
select_statement = select_statement.limit(limit)
if filter:
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
if static_filter:
if isinstance(static_filter, list):
select_statement = cls._static_fiter_condition(
select_statement, static_filter
)
else:
select_statement = select_statement.where(static_filter)
if filter and filter_fields:
select_statement = cls._filter_condition(
select_statement, filter, filter_fields
)
if ext_filter:
select_statement = select_statement.where(ext_filter)
if user:
select_statement = cls._ownership_condition(select_statement, user)
if order_by:
select_statement = select_statement.order_by(order_by)
return (await session.exec(select_statement)).all()
@classmethod
def _ownership_condition(
cls, select_statement: SelectOfScalar[Self], user: "UserBase"
):
if cls.bot_entity_descriptor.ownership_fields:
condition = None
for role in user.roles:
if role in cls.bot_entity_descriptor.ownership_fields:
owner_col = column(cls.bot_entity_descriptor.ownership_fields[role])
if condition is not None:
condition = condition | (owner_col == user.id)
else:
condition = owner_col == user.id
else:
condition = None
break
if condition is not None:
return select_statement.where(condition)
return select_statement
@classmethod
@session_dep
async def create(cls, *,
session: AsyncSession | None = None,
obj_in: CreateSchemaType,
commit: bool = False):
async def create(
cls,
*,
session: AsyncSession | None = None,
obj_in: CreateSchemaType,
commit: bool = False,
):
if isinstance(obj_in, cls):
obj = obj_in
else:
@@ -226,16 +381,17 @@ class BotEntity[CreateSchemaType: BaseModel,
if commit:
await session.commit()
return obj
@classmethod
@session_dep
async def update(cls, *,
session: AsyncSession | None = None,
id: int,
obj_in: UpdateSchemaType,
commit: bool = False):
async def update(
cls,
*,
session: AsyncSession | None = None,
id: int,
obj_in: UpdateSchemaType,
commit: bool = False,
):
obj = await session.get(cls, id)
if obj:
obj_data = obj.model_dump()
@@ -248,14 +404,12 @@ class BotEntity[CreateSchemaType: BaseModel,
await session.commit()
return obj
return None
@classmethod
@session_dep
async def remove(cls, *,
session: AsyncSession | None = None,
id: int,
commit: bool = False):
async def remove(
cls, *, session: AsyncSession | None = None, id: int, commit: bool = False
):
obj = await session.get(cls, id)
if obj:
await session.delete(obj)
@@ -263,4 +417,3 @@ class BotEntity[CreateSchemaType: BaseModel,
await session.commit()
return obj
return None

View File

@@ -4,36 +4,46 @@ from typing import Any, Self, overload
class BotEnumMetaclass(type):
def __new__(cls, name: str, bases: tuple[type], namespace: dict[str, Any]):
all_members = {}
if bases and bases[0].__name__ != "BotEnum" and "all_members" in bases[0].__dict__:
if (
bases
and bases[0].__name__ != "BotEnum"
and "all_members" in bases[0].__dict__
):
all_members = bases[0].__dict__["all_members"]
annotations = {}
for key, value in namespace.items():
if (key.isupper() and
not key.startswith("__") and
not key.endswith("__")):
if key.isupper() and not key.startswith("__") and not key.endswith("__"):
if not isinstance(value, EnumMember):
value = EnumMember(value, None)
if key in all_members.keys() and all_members[key].value != value.value:
raise ValueError(f"Enum member {key} already exists with different value. Use same value to extend it.")
if (value.value in [member.value for member in all_members.values()] and
key not in all_members.keys()):
raise ValueError(
f"Enum member {key} already exists with different value. Use same value to extend it."
)
if (
value.value in [member.value for member in all_members.values()]
and key not in all_members.keys()
):
raise ValueError(f"Duplicate enum value {value[0]}")
member = EnumMember(value = value.value, loc_obj = value.loc_obj, parent = None, name = key, casting = False)
member = EnumMember(
value=value.value,
loc_obj=value.loc_obj,
parent=None,
name=key,
casting=False,
)
namespace[key] = member
all_members[key] = member
annotations[key] = type(member)
namespace["__annotations__"] = annotations
namespace["__annotations__"] = annotations
namespace["all_members"] = all_members
type_ = super().__new__(cls, name, bases, namespace)
@@ -46,22 +56,23 @@ class BotEnumMetaclass(type):
class EnumMember(object):
@overload
def __init__(self, value: str) -> "EnumMember": ...
@overload
def __init__(self, value: str) -> "EnumMember":...
def __init__(self, value: "EnumMember") -> "EnumMember": ...
@overload
def __init__(self, value: "EnumMember") -> "EnumMember":...
def __init__(self, value: str, loc_obj: dict[str, str]) -> "EnumMember": ...
@overload
def __init__(self, value: str, loc_obj: dict[str, str]) -> "EnumMember":...
def __init__(self,
value: str = None,
loc_obj: dict[str, str] = None,
parent: type = None,
name: str = None,
casting: bool = True) -> "EnumMember":
def __init__(
self,
value: str = None,
loc_obj: dict[str, str] = None,
parent: type = None,
name: str = None,
casting: bool = True,
) -> "EnumMember":
if not casting:
self._parent = parent
self._name = name
@@ -69,10 +80,9 @@ class EnumMember(object):
self.loc_obj = loc_obj
@overload
def __new__(cls: Self, *args, **kwargs) -> "EnumMember":...
def __new__(cls: Self, *args, **kwargs) -> "EnumMember": ...
def __new__(cls, *args, casting: bool = True, **kwargs) -> "EnumMember":
if (cls.__name__ == "EnumMember") or not casting:
obj = super().__new__(cls)
kwargs["casting"] = False
@@ -80,56 +90,59 @@ class EnumMember(object):
return obj
if args.__len__() == 0:
return list(cls.all_members.values())[0]
if args.__len__() == 1 and isinstance(args[0], str):
return {member.value: member for key, member in cls.all_members.items()}[args[0]]
if args.__len__() == 1 and isinstance(args[0], str):
return {member.value: member for key, member in cls.all_members.items()}[
args[0]
]
elif args.__len__() == 1:
return {member.value: member for key, member in cls.all_members.items()}[args[0].value]
return {member.value: member for key, member in cls.all_members.items()}[
args[0].value
]
else:
return args[0]
def __get_pydantic_core_schema__(cls, *args, **kwargs):
return str_schema()
def __get__(self, instance, owner) -> Self:
# return {member.value: member for key, member in owner.all_members.items()}[self.value]
return {member.value: member for key, member in self._parent.all_members.items()}[self.value]
return {
member.value: member for key, member in self._parent.all_members.items()
}[self.value]
def __set__(self, instance, value):
instance.__dict__[self] = value
def __repr__(self):
return f"<{self._parent.__name__ if self._parent else "EnumMember"}.{self._name}: '{self.value}'>"
return f"<{self._parent.__name__ if self._parent else 'EnumMember'}.{self._name}: '{self.value}'>"
def __str__(self):
return self.value
def __eq__(self, other : Self | str) -> bool:
if other is None:
return False
if isinstance(other, str):
return self.value == other
return self.value == other.value
def __eq__(self, other: Self | str) -> bool:
if other is None:
return False
if isinstance(other, str):
return self.value == other
return self.value == other.value
def __hash__(self):
return hash(self.value)
def localized(self, lang: str = None) -> str:
if self.loc_obj and len(self.loc_obj) > 0:
if lang and lang in self.loc_obj.keys():
return self.loc_obj[lang]
else:
return self.loc_obj[list(self.loc_obj.keys())[0]]
return self.value
class BotEnum(EnumMember, metaclass = BotEnumMetaclass):
class BotEnum(EnumMember, metaclass=BotEnumMetaclass):
all_members: dict[str, EnumMember]
class EnumType(TypeDecorator):
impl = String(256)
def __init__(self, enum_type: BotEnum):
@@ -140,8 +153,8 @@ class EnumType(TypeDecorator):
if value and isinstance(value, EnumMember):
return value.value
return None
def process_result_value(self, value, dialect):
if value:
return self._enum_type(value)
return None
return None

View File

@@ -1,4 +1,4 @@
from .user import UserBase
class DefaultUser(UserBase): ...
class DefaultUser(UserBase): ...

View File

@@ -1,20 +1,85 @@
from typing import Any, Callable, TYPE_CHECKING
from aiogram.types import Message, CallbackQuery
from aiogram.fsm.context import FSMContext
from aiogram.utils.i18n import I18n
from aiogram.utils.keyboard import InlineKeyboardBuilder
from typing import Any, Callable, TYPE_CHECKING, Literal
from babel.support import LazyProxy
from dataclasses import dataclass, field
from sqlmodel.ext.asyncio.session import AsyncSession
from .role import RoleBase
from . import EntityPermission
from ..bot.handlers.context import ContextData
if TYPE_CHECKING:
from .bot_entity import BotEntity
from ..main import QBotApp
from .user import UserBase
EntityCaptionCallable = Callable[["EntityDescriptor"], str]
EntityItemCaptionCallable = Callable[["EntityDescriptor", Any], str]
EntityFieldCaptionCallable = Callable[["EntityFieldDescriptor", Any, Any], str]
@dataclass(kw_only = True)
class _BaseEntityFieldDescriptor():
@dataclass
class FieldEditButton:
field_name: str
caption: str | LazyProxy | EntityFieldCaptionCallable | None = None
@dataclass
class CommandButton:
command: str
caption: str | LazyProxy | EntityItemCaptionCallable | None = None
context_data: ContextData | None = None
@dataclass
class Filter:
field_name: str
operator: Literal[
"==",
"!=",
">",
"<",
">=",
"<=",
"in",
"not in",
"like",
"ilike",
"is",
"is not",
]
value_type: Literal["const", "param"]
value: Any | None = None
param_index: int | None = None
@dataclass
class EntityList:
caption: str | LazyProxy | EntityCaptionCallable | None = None
item_repr: EntityItemCaptionCallable | None = None
show_add_new_button: bool = True
item_form: str | None = None
pagination: bool = True
static_filters: list[Filter] | Any = None
filtering: bool = True
filtering_fields: list[str] = None
order_by: str | Any | None = None
@dataclass
class EntityForm:
item_repr: EntityItemCaptionCallable | None = None
edit_field_sequence: list[str] = None
form_buttons: list[list[FieldEditButton | CommandButton]] = None
show_edit_button: bool = True
show_delete_button: bool = True
@dataclass(kw_only=True)
class _BaseEntityFieldDescriptor:
icon: str = None
caption: str | LazyProxy | EntityFieldCaptionCallable | None = None
description: str | LazyProxy | EntityFieldCaptionCallable | None = None
@@ -24,21 +89,25 @@ class _BaseEntityFieldDescriptor():
localizable: bool = False
bool_false_value: str | LazyProxy = "no"
bool_true_value: str | LazyProxy = "yes"
ep_form: str | None = None
ep_parent_field: str | None = None
ep_child_field: str | None = None
dt_type: Literal["date", "datetime"] = "date"
default: Any = None
@dataclass(kw_only = True)
@dataclass(kw_only=True)
class EntityField(_BaseEntityFieldDescriptor):
name: str | None = None
sm_descriptor: Any = None
@dataclass(kw_only = True)
@dataclass(kw_only=True)
class Setting(_BaseEntityFieldDescriptor):
name: str | None = None
@dataclass(kw_only = True)
@dataclass(kw_only=True)
class EntityFieldDescriptor(_BaseEntityFieldDescriptor):
name: str
field_name: str
@@ -52,42 +121,80 @@ class EntityFieldDescriptor(_BaseEntityFieldDescriptor):
return self.name.__hash__()
@dataclass(kw_only = True)
@dataclass(kw_only=True)
class _BaseEntityDescriptor:
icon: str = "📘"
caption: str | LazyProxy | EntityCaptionCallable | None = None
caption_plural: str | LazyProxy | EntityCaptionCallable | None = None
full_name: str | LazyProxy | EntityCaptionCallable | None = None
full_name_plural: str | LazyProxy | EntityCaptionCallable | None = None
description: str | LazyProxy | EntityCaptionCallable | None = None
item_caption: EntityItemCaptionCallable | None = None
item_repr: EntityItemCaptionCallable | None = None
default_list: EntityList = field(default_factory=EntityList)
default_form: EntityForm = field(default_factory=EntityForm)
lists: dict[str, EntityList] = field(default_factory=dict[str, EntityList])
forms: dict[str, EntityForm] = field(default_factory=dict[str, EntityForm])
show_in_entities_menu: bool = True
field_sequence: list[str] = None
edit_button_visible: bool = True
edit_buttons: list[list[str | tuple[str, str | LazyProxy | EntityFieldCaptionCallable]]] = None
permissions: dict[EntityPermission, list[RoleBase]] = field(default_factory = lambda: {
EntityPermission.LIST: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.READ: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.CREATE: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.UPDATE: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.DELETE: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.LIST_ALL: [RoleBase.SUPER_USER],
EntityPermission.READ_ALL: [RoleBase.SUPER_USER],
EntityPermission.CREATE_ALL: [RoleBase.SUPER_USER],
EntityPermission.UPDATE_ALL: [RoleBase.SUPER_USER],
EntityPermission.DELETE_ALL: [RoleBase.SUPER_USER]
})
ownership_fields: dict[RoleBase, str] = field(default_factory=dict[RoleBase, str])
permissions: dict[EntityPermission, list[RoleBase]] = field(
default_factory=lambda: {
EntityPermission.LIST: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.READ: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.CREATE: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.UPDATE: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.DELETE: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.LIST_ALL: [RoleBase.SUPER_USER],
EntityPermission.READ_ALL: [RoleBase.SUPER_USER],
EntityPermission.CREATE_ALL: [RoleBase.SUPER_USER],
EntityPermission.UPDATE_ALL: [RoleBase.SUPER_USER],
EntityPermission.DELETE_ALL: [RoleBase.SUPER_USER],
}
)
@dataclass(kw_only = True)
@dataclass(kw_only=True)
class Entity(_BaseEntityDescriptor):
name: str | None = None
@dataclass
class EntityDescriptor(_BaseEntityDescriptor):
name: str
class_name: str
type_: type["BotEntity"]
fields_descriptors: dict[str, EntityFieldDescriptor]
@dataclass(kw_only=True)
class CommandCallbackContext[UT: UserBase]:
keyboard_builder: InlineKeyboardBuilder = field(
default_factory=InlineKeyboardBuilder
)
message_text: str | None = None
register_navigation: bool = True
message: Message | CallbackQuery
callback_data: ContextData
db_session: AsyncSession
user: UT
app: "QBotApp"
state_data: dict[str, Any]
state: FSMContext
i18n: I18n
kwargs: dict[str, Any] = field(default_factory=dict)
@dataclass(kw_only=True)
class _BotCommand:
name: str
caption: str | dict[str, str] | None = None
show_in_bot_commands: bool = False
register_navigation: bool = True
clear_navigation: bool = False
clear_state: bool = True
@dataclass(kw_only=True)
class BotCommand(_BotCommand):
handler: Callable[[CommandCallbackContext], None]
@dataclass(kw_only=True)
class Command(_BotCommand): ...

View File

@@ -2,7 +2,6 @@ from .descriptors import EntityDescriptor
from ._singleton import Singleton
class EntityMetadata(metaclass = Singleton):
class EntityMetadata(metaclass=Singleton):
def __init__(self):
self.entity_descriptors: dict[str, EntityDescriptor] = {}

View File

@@ -1,45 +0,0 @@
from dataclasses import dataclass
from babel.support import LazyProxy
from typing import TypeVar
from .bot_entity import BotEntity
@dataclass
class FieldType: ...
class LocStr(str): ...
class String(FieldType):
localizable: bool = False
class Integer(FieldType):
pass
@dataclass
class Decimal:
precision: int = 0
@dataclass
class Boolean:
true_value: str | LazyProxy = "true"
false_value: str | LazyProxy = "false"
@dataclass
class DateTime:
pass
EntityType = TypeVar('EntityType', bound = BotEntity)
@dataclass
class EntityReference:
entity_type: type[EntityType]

View File

@@ -1,8 +1,7 @@
from sqlmodel import SQLModel, Field
class FSMStorage(SQLModel, table = True):
class FSMStorage(SQLModel, table=True):
__tablename__ = "fsm_storage"
key: str = Field(primary_key = True)
value: str | None = None
key: str = Field(primary_key=True)
value: str | None = None

View File

@@ -2,5 +2,4 @@ from .bot_enum import BotEnum, EnumMember
class LanguageBase(BotEnum):
EN = EnumMember("en", {"en": "🇬🇧 english"})
EN = EnumMember("en", {"en": "🇬🇧 english"})

View File

@@ -1,26 +0,0 @@
# from aiogram.types import Message, CallbackQuery
# from aiogram.utils.keyboard import InlineKeyboardBuilder
# from typing import Any, Callable, Self, Union, overload
# from babel.support import LazyProxy
# from dataclasses import dataclass
# from ..bot.handlers.context import ContextData
# class Menu:
# @overload
# def __init__(self, description: str | LazyProxy): ...
# @overload
# def __init__(self, menu_factory: Callable[[InlineKeyboardBuilder, Union[Message, CallbackQuery], Any], str]): ...
# def __init__(self, description: str | LazyProxy = None,
# menu_factory: Callable[[InlineKeyboardBuilder, Union[Message, CallbackQuery], Any], str] = None) -> None:
# self.menu_factory = menu_factory
# self.description = description
# self.parent: Menu = None
# self.items: list[list[Menu]] = []

View File

@@ -1,49 +0,0 @@
from sqlmodel import BIGINT, Field, select, func, column
from sqlmodel.ext.asyncio.session import AsyncSession
from .bot_entity import BotEntity
from .descriptors import EntityField
from .user import UserBase
from . import session_dep
class OwnedBotEntity(BotEntity, table = False):
user_id: int | None = EntityField(
sm_descriptor = Field(sa_type = BIGINT, foreign_key = "user.id", ondelete="SET NULL"),
is_visible = False)
@classmethod
@session_dep
async def get_multi_by_user(cls, *,
session: AsyncSession | None = None,
user_id: int,
filter: str = None,
order_by = None,
skip: int = 0,
limit: int = None):
select_statement = select(cls).where(cls.user_id == user_id).offset(skip)
if filter:
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
if limit:
select_statement = select_statement.limit(limit)
if order_by:
select_statement = select_statement.order_by(order_by)
return (await session.exec(select_statement)).all()
@classmethod
@session_dep
async def get_count_by_user(cls, *,
session: AsyncSession | None = None,
user_id: int,
filter: str = None) -> int:
select_statement = select(func.count()).select_from(cls).where(cls.user_id == user_id)
if filter:
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
return await session.scalar(select_statement)

View File

@@ -2,6 +2,5 @@ from .bot_enum import BotEnum, EnumMember
class RoleBase(BotEnum):
SUPER_USER = EnumMember("super_user")
DEFAULT_USER = EnumMember("default_user")
DEFAULT_USER = EnumMember("default_user")

View File

@@ -7,131 +7,193 @@ from typing import Any, get_args, get_origin
from ..db import async_session
from .role import RoleBase
from .descriptors import EntityFieldDescriptor, Setting
from ..utils import deserialize, serialize
from ..utils.serialization import deserialize, serialize
import ujson as json
class DbSettings(SQLModel, table = True):
class DbSettings(SQLModel, table=True):
__tablename__ = "settings"
name: str = Field(primary_key = True)
name: str = Field(primary_key=True)
value: str
class SettingsMetaclass(type):
def __new__(cls, class_name, base_classes, attributes):
settings_descriptors = {}
if base_classes:
settings_descriptors = base_classes[0].__dict__.get("_settings_descriptors", {})
settings_descriptors = base_classes[0].__dict__.get(
"_settings_descriptors", {}
)
for annotation in attributes.get('__annotations__', {}):
for annotation in attributes.get("__annotations__", {}):
if annotation in ["_settings_descriptors", "_cache", "_cached_settings"]:
continue
attr_value = attributes.get(annotation)
name = annotation
type_ = attributes['__annotations__'][annotation]
type_ = attributes["__annotations__"][annotation]
if isinstance(attr_value, Setting):
descriptor_kwargs = attr_value.__dict__.copy()
name = descriptor_kwargs.pop("name") or annotation
attributes[annotation] = EntityFieldDescriptor(
name = name,
field_name = annotation,
type_ = type_,
type_base = type_,
**descriptor_kwargs)
name=name,
field_name=annotation,
type_=type_,
type_base=type_,
**descriptor_kwargs,
)
else:
attributes[annotation] = EntityFieldDescriptor(
name = annotation,
field_name = annotation,
type_ = type_,
type_base = type_,
default = attr_value)
name=annotation,
field_name=annotation,
type_=type_,
type_base=type_,
default=attr_value,
)
type_origin = get_origin(type_)
if type_origin == list:
if type_origin is list:
attributes[annotation].is_list = True
attributes[annotation].type_base = type_ = get_args(type_)[0]
elif type_origin == UnionType and get_args(type_)[1] == NoneType:
attributes[annotation].is_optional = True
attributes[annotation].type_base = type_ = get_args(type_)[0]
settings_descriptors[name] = attributes[annotation]
if base_classes and base_classes[0].__name__ == "Settings" and hasattr(base_classes[0], annotation):
if (
base_classes
and base_classes[0].__name__ == "Settings"
and hasattr(base_classes[0], annotation)
):
setattr(base_classes[0], annotation, attributes[annotation])
attributes["__annotations__"] = {}
attributes["_settings_descriptors"] = settings_descriptors
return super().__new__(cls, class_name, base_classes, attributes)
class Settings(metaclass = SettingsMetaclass):
class Settings(metaclass=SettingsMetaclass):
_cache: dict[str, Any] = dict[str, Any]()
_settings_descriptors: dict[str, EntityFieldDescriptor] = {}
PAGE_SIZE: int = Setting(default = 10, )
SECURITY_PARAMETERS_ROLES: list[RoleBase] = Setting(name = "SECPARAMS_ROLES", default = [RoleBase.SUPER_USER], is_visible = False)
APP_STRINGS_WELCOME_P_NAME: str = Setting(name = "AS_WELCOME", default = "Welcome, {name}", is_visible = False)
APP_STRINGS_GREETING_P_NAME: str = Setting(name = "AS_GREETING", default = "Hello, {name}", is_visible = False)
APP_STRINGS_INTERNAL_ERROR_P_ERROR: str = Setting(name = "AS_INTERNAL_ERROR", default = "Internal error\n{error}", is_visible = False)
APP_STRINGS_USER_BLOCKED_P_NAME: str = Setting(name = "AS_USER_BLOCKED", default = "User {name} is blocked", is_visible = False)
APP_STRINGS_FORBIDDEN: str = Setting(name = "AS_FORBIDDEN", default = "Forbidden", is_visible = False)
APP_STRINGS_NOT_FOUND: str = Setting(name = "AS_NOT_FOUND", default = "Object not found", is_visible = False)
APP_STRINGS_MAIN_NENU: str = Setting(name = "AS_MAIN_MENU", default = "Main menu", is_visible = False)
APP_STRINGS_REFERENCES: str = Setting(name = "AS_REFERENCES", default = "References", is_visible = False)
APP_STRINGS_REFERENCES_BTN: str = Setting(name = "AS_REFERENCES_BTN", default = "📚 References", is_visible = False)
APP_STRINGS_SETTINGS: str = Setting(name = "AS_SETTINGS", default = "Settings", is_visible = False)
APP_STRINGS_SETTINGS_BTN: str = Setting(name = "AS_SETTINGS_BTN", default = "⚙️ Settings", is_visible = False)
APP_STRINGS_PARAMETERS: str = Setting(name = "AS_PARAMETERS", default = "Parameters", is_visible = False)
APP_STRINGS_PARAMETERS_BTN: str = Setting(name = "AS_PARAMETERS_BTN", default = "🎛️ Parameters", is_visible = False)
APP_STRINGS_LANGUAGE: str = Setting(name = "AS_LANGUAGE", default = "Language", is_visible = False)
APP_STRINGS_LANGUAGE_BTN: str = Setting(name = "AS_LANGUAGE_BTN", default = "🗣️ Language", is_visible = False)
APP_STRINGS_BACK_BTN: str = Setting(name = "AS_BACK_BTN", default = "⬅️ Back", is_visible = False)
APP_STRINGS_DELETE_BTN: str = Setting(name = "AS_DELETE_BTN", default = "🗑️ Delete", is_visible = False)
PAGE_SIZE: int = Setting(
default=10,
)
SECURITY_PARAMETERS_ROLES: list[RoleBase] = Setting(
name="SECPARAMS_ROLES", default=[RoleBase.SUPER_USER], is_visible=False
)
APP_STRINGS_WELCOME_P_NAME: str = Setting(
name="AS_WELCOME", default="Welcome, {name}", is_visible=False
)
APP_STRINGS_GREETING_P_NAME: str = Setting(
name="AS_GREETING", default="Hello, {name}", is_visible=False
)
APP_STRINGS_INTERNAL_ERROR_P_ERROR: str = Setting(
name="AS_INTERNAL_ERROR", default="Internal error\n{error}", is_visible=False
)
APP_STRINGS_USER_BLOCKED_P_NAME: str = Setting(
name="AS_USER_BLOCKED", default="User {name} is blocked", is_visible=False
)
APP_STRINGS_FORBIDDEN: str = Setting(
name="AS_FORBIDDEN", default="Forbidden", is_visible=False
)
APP_STRINGS_NOT_FOUND: str = Setting(
name="AS_NOT_FOUND", default="Object not found", is_visible=False
)
APP_STRINGS_MAIN_NENU: str = Setting(
name="AS_MAIN_MENU", default="Main menu", is_visible=False
)
APP_STRINGS_REFERENCES: str = Setting(
name="AS_REFERENCES", default="References", is_visible=False
)
APP_STRINGS_REFERENCES_BTN: str = Setting(
name="AS_REFERENCES_BTN", default="📚 References", is_visible=False
)
APP_STRINGS_SETTINGS: str = Setting(
name="AS_SETTINGS", default="Settings", is_visible=False
)
APP_STRINGS_SETTINGS_BTN: str = Setting(
name="AS_SETTINGS_BTN", default="⚙️ Settings", is_visible=False
)
APP_STRINGS_PARAMETERS: str = Setting(
name="AS_PARAMETERS", default="Parameters", is_visible=False
)
APP_STRINGS_PARAMETERS_BTN: str = Setting(
name="AS_PARAMETERS_BTN", default="🎛️ Parameters", is_visible=False
)
APP_STRINGS_LANGUAGE: str = Setting(
name="AS_LANGUAGE", default="Language", is_visible=False
)
APP_STRINGS_LANGUAGE_BTN: str = Setting(
name="AS_LANGUAGE_BTN", default="🗣️ Language", is_visible=False
)
APP_STRINGS_BACK_BTN: str = Setting(
name="AS_BACK_BTN", default="⬅️ Back", is_visible=False
)
APP_STRINGS_DELETE_BTN: str = Setting(
name="AS_DELETE_BTN", default="🗑️ Delete", is_visible=False
)
APP_STRINGS_CONFIRM_DELETE_P_NAME: str = Setting(
name = "AS_CONFIRM_DEL",
default = "Are you sure you want to delete \"{name}\"?",
is_visible = False)
APP_STRINGS_EDIT_BTN: str = Setting(name = "AS_EDIT_BTN", default = "✏️ Edit", is_visible = False)
APP_STRINGS_ADD_BTN: str = Setting(name = "AS_ADD_BTN", default = " Add", is_visible = False)
APP_STRINGS_YES_BTN: str = Setting(name = "AS_YES_BTN", default = "✅ Yes", is_visible = False)
APP_STRINGS_NO_BTN: str = Setting(name = "AS_NO_BTN", default = "❌ No", is_visible = False)
APP_STRINGS_CANCEL_BTN: str = Setting(name = "AS_CANCEL_BTN", default = "❌ Cancel", is_visible = False)
APP_STRINGS_CLEAR_BTN: str = Setting(name = "AS_CLEAR_BTN", default = "⌫ Clear", is_visible = False)
APP_STRINGS_DONE_BTN: str = Setting(name = "AS_DONE_BTN", default = "✅ Done", is_visible = False)
APP_STRINGS_SKIP_BTN: str = Setting(name = "AS_SKIP_BTN", default = "⏩️ Skip", is_visible = False)
name="AS_CONFIRM_DEL",
default='Are you sure you want to delete "{name}"?',
is_visible=False,
)
APP_STRINGS_EDIT_BTN: str = Setting(
name="AS_EDIT_BTN", default="✏️ Edit", is_visible=False
)
APP_STRINGS_ADD_BTN: str = Setting(
name="AS_ADD_BTN", default=" Add", is_visible=False
)
APP_STRINGS_YES_BTN: str = Setting(
name="AS_YES_BTN", default="✅ Yes", is_visible=False
)
APP_STRINGS_NO_BTN: str = Setting(
name="AS_NO_BTN", default="❌ No", is_visible=False
)
APP_STRINGS_CANCEL_BTN: str = Setting(
name="AS_CANCEL_BTN", default="❌ Cancel", is_visible=False
)
APP_STRINGS_CLEAR_BTN: str = Setting(
name="AS_CLEAR_BTN", default="⌫ Clear", is_visible=False
)
APP_STRINGS_DONE_BTN: str = Setting(
name="AS_DONE_BTN", default="✅ Done", is_visible=False
)
APP_STRINGS_SKIP_BTN: str = Setting(
name="AS_SKIP_BTN", default="⏩️ Skip", is_visible=False
)
APP_STRINGS_FIELD_EDIT_PROMPT_TEMPLATE_P_NAME_VALUE: str = Setting(
name = "AS_FIELDEDIT_PROMPT",
default = "Enter new value for \"{name}\" (current value: {value})",
is_visible = False)
name="AS_FIELDEDIT_PROMPT",
default='Enter new value for "{name}" (current value: {value})',
is_visible=False,
)
APP_STRINGS_FIELD_CREATE_PROMPT_TEMPLATE_P_NAME: str = Setting(
name = "AS_FIELDCREATE_PROMPT",
default = "Enter new value for \"{name}\"",
is_visible = False)
name="AS_FIELDCREATE_PROMPT",
default='Enter new value for "{name}"',
is_visible=False,
)
APP_STRINGS_STRING_EDITOR_LOCALE_TEMPLATE_P_NAME: str = Setting(
name = "AS_STREDIT_LOC_TEMPLATE",
default = "string for \"{name}\"",
is_visible = False)
APP_STRINGS_VIEW_FILTER_EDIT_PROMPT: str = Setting(name = "AS_FILTEREDIT_PROMPT", default = "Enter filter value", is_visible = False)
APP_STRINGS_INVALID_INPUT: str = Setting(name = "AS_INVALID_INPUT", default = "Invalid input", is_visible = False)
name="AS_STREDIT_LOC_TEMPLATE", default='string for "{name}"', is_visible=False
)
APP_STRINGS_VIEW_FILTER_EDIT_PROMPT: str = Setting(
name="AS_FILTEREDIT_PROMPT", default="Enter filter value", is_visible=False
)
APP_STRINGS_INVALID_INPUT: str = Setting(
name="AS_INVALID_INPUT", default="Invalid input", is_visible=False
)
@classmethod
async def get[T](cls, param: T, all_locales = False, locale: str = None) -> T:
async def get[T](cls, param: T, all_locales=False, locale: str = None) -> T:
name = param.field_name
if name not in cls._cache.keys():
@@ -144,73 +206,79 @@ class Settings(metaclass = SettingsMetaclass):
locale = get_i18n().current_locale
try:
obj = json.loads(ret_val)
except:
except Exception:
return ret_val
return obj.get(locale, obj[list(obj.keys())[0]])
return ret_val
@classmethod
async def load_param(cls, param: EntityFieldDescriptor) -> Any:
async with async_session() as session:
db_setting = (await session.exec(
select(DbSettings)
.where(DbSettings.name == param.field_name))).first()
if db_setting:
return await deserialize(session = session,
type_ = param.type_,
value = db_setting.value)
return (param.default if param.default else
[] if (get_origin(param.type_) is list or param.type_ == list) else
datetime(2000, 1, 1) if param.type_ == datetime else
param.type_())
db_setting = (
await session.exec(
select(DbSettings).where(DbSettings.name == param.field_name)
)
).first()
if db_setting:
return await deserialize(
session=session, type_=param.type_, value=db_setting.value
)
return (
param.default
if param.default
else (
[]
if (get_origin(param.type_) is list or param.type_ is list)
else datetime(2000, 1, 1)
if param.type_ == datetime
else param.type_()
)
)
@classmethod
async def load_params(cls):
async with async_session() as session:
db_settings = (await session.exec(select(DbSettings))).all()
for db_setting in db_settings:
if db_setting.name in cls.__dict__:
setting = cls.__dict__[db_setting.name]
cls._cache[db_setting.name] = await deserialize(session = session,
type_ = setting.type_,
value = db_setting.value,
default = setting.default)
setting = cls.__dict__[db_setting.name] # type: EntityFieldDescriptor
cls._cache[db_setting.name] = await deserialize(
session=session,
type_=setting.type_,
value=db_setting.value,
default=setting.default,
)
cls._loaded = True
@classmethod
async def set_param(cls, param: str | EntityFieldDescriptor, value) -> None:
if isinstance(param, str):
param = cls._settings_descriptors[param]
param = cls._settings_descriptors[param]
ser_value = serialize(value, param)
async with async_session() as session:
db_setting = (await session.exec(
select(DbSettings)
.where(DbSettings.name == param.field_name))).first()
db_setting = (
await session.exec(
select(DbSettings).where(DbSettings.name == param.field_name)
)
).first()
if db_setting is None:
db_setting = DbSettings(name = param.field_name)
db_setting = DbSettings(name=param.field_name)
db_setting.value = str(ser_value)
session.add(db_setting)
await session.commit()
cls._cache[param.field_name] = value
@classmethod
def list_params(cls) -> dict[str, EntityFieldDescriptor]:
return cls._settings_descriptors
@classmethod
async def get_params(cls) -> dict[EntityFieldDescriptor, Any]:
params = cls.list_params()
return {param: await cls.get(param, all_locales = True) for _, param in params.items()}
return {
param: await cls.get(param, all_locales=True) for _, param in params.items()
}

View File

@@ -10,11 +10,12 @@ from .fsm_storage import FSMStorage as FSMStorage
from .view_setting import ViewSetting as ViewSetting
class UserBase(BotEntity, table = False):
class UserBase(BotEntity, table=False):
__tablename__ = "user"
lang: LanguageBase = Field(sa_type = EnumType(LanguageBase), default = LanguageBase.EN)
lang: LanguageBase = Field(sa_type=EnumType(LanguageBase), default=LanguageBase.EN)
is_active: bool = True
roles: list[RoleBase] = Field(sa_type=ARRAY(EnumType(RoleBase)), default = [RoleBase.DEFAULT_USER])
roles: list[RoleBase] = Field(
sa_type=ARRAY(EnumType(RoleBase)), default=[RoleBase.DEFAULT_USER]
)

View File

@@ -4,36 +4,36 @@ from sqlalchemy.ext.asyncio.session import AsyncSession
from . import session_dep
class ViewSetting(SQLModel, table = True):
class ViewSetting(SQLModel, table=True):
__tablename__ = "view_setting"
user_id: int = Field(sa_type = BIGINT, primary_key = True, foreign_key="user.id", ondelete="CASCADE")
entity_name: str = Field(primary_key = True)
filter: str | None = None
user_id: int = Field(
sa_type=BIGINT, primary_key=True, foreign_key="user.id", ondelete="CASCADE"
)
entity_name: str = Field(primary_key=True)
filter: str | None = None
@classmethod
@session_dep
async def get_filter(cls, *,
session: AsyncSession | None = None,
user_id: int,
entity_name: str):
async def get_filter(
cls, *, session: AsyncSession | None = None, user_id: int, entity_name: str
):
setting = await session.get(cls, (user_id, entity_name))
return setting.filter if setting else None
@classmethod
@session_dep
async def set_filter(cls, *,
session: AsyncSession | None = None,
user_id: int,
entity_name: str,
filter: str):
setting = await session.get(cls, (user_id, entity_name))
if setting:
setting.filter = filter
else:
setting = cls(user_id = user_id, entity_name = entity_name, filter = filter)
session.add(setting)
await session.commit()
async def set_filter(
cls,
*,
session: AsyncSession | None = None,
user_id: int,
entity_name: str,
filter: str,
):
setting = await session.get(cls, (user_id, entity_name))
if setting:
setting.filter = filter
else:
setting = cls(user_id=user_id, entity_name=entity_name, filter=filter)
session.add(setting)
await session.commit()