BotEntity params hints, extended BotEnum db support

This commit is contained in:
Alexander Kalinovsky
2025-02-09 21:37:39 +01:00
parent 50a52d6aa7
commit 7a4936d2ef
3 changed files with 21 additions and 13 deletions

View File

@@ -9,10 +9,11 @@ from typing import (
get_args, get_args,
get_origin, get_origin,
TYPE_CHECKING, TYPE_CHECKING,
dataclass_transform,
) )
from pydantic import BaseModel from pydantic import BaseModel
from sqlmodel import SQLModel, BigInteger, Field, select, func, column, col from sqlmodel import SQLModel, BigInteger, Field, select, func, column, col
from sqlmodel.main import FieldInfo
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar from sqlmodel.sql.expression import SelectOfScalar
from sqlmodel.main import SQLModelMetaclass, RelationshipInfo from sqlmodel.main import SQLModelMetaclass, RelationshipInfo
@@ -26,6 +27,10 @@ if TYPE_CHECKING:
from .user import UserBase from .user import UserBase
@dataclass_transform(
kw_only_default=True,
field_specifiers=(Field, FieldInfo, EntityField, FieldDescriptor),
)
class BotEntityMetaclass(SQLModelMetaclass): class BotEntityMetaclass(SQLModelMetaclass):
_future_references = {} _future_references = {}

View File

@@ -1,6 +1,7 @@
from aiogram.utils.i18n import I18n from aiogram.utils.i18n import I18n
from pydantic_core.core_schema import str_schema from pydantic_core.core_schema import str_schema
from sqlalchemy.types import TypeDecorator, String from sqlalchemy.types import TypeDecorator
from sqlmodel import AutoString
from typing import Any, Self, overload from typing import Any, Self, overload
@@ -119,12 +120,14 @@ class EnumMember(object):
def __str__(self): def __str__(self):
return self.value return self.value
def __eq__(self, other: Self | str) -> bool: def __eq__(self, other: Self | str | Any | None) -> bool:
if other is None: if other is None:
return False return False
if isinstance(other, str): if isinstance(other, str):
return self.value == other return self.value == other
return self.value == other.value if isinstance(other, EnumMember):
return self.value == other.value and self._parent is other._parent
return other.__eq__(self.value)
def __hash__(self): def __hash__(self):
return hash(self.value) return hash(self.value)
@@ -151,24 +154,25 @@ class BotEnum(EnumMember, metaclass=BotEnumMetaclass):
class EnumType(TypeDecorator): class EnumType(TypeDecorator):
impl = String(64) impl = AutoString
cache_ok = True cache_ok = True
# class comparator_factory(TypeDecorator.Comparator):
# def __eq__(self, other):
# expr = type_coerce(self.expr, String)
# return expr != other.value
def __init__(self, enum_type: BotEnum): def __init__(self, enum_type: BotEnum):
self._enum_type = enum_type self._enum_type = enum_type
super().__init__() super().__init__()
def process_bind_param(self, value, dialect): def _process_param(self, value):
if value and isinstance(value, EnumMember): if value and isinstance(value, EnumMember):
return value.value return value.value
return str(value) return str(value)
def process_bind_param(self, value, dialect):
return self._process_param(value)
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):
if value: if value:
return self._enum_type(value) return self._enum_type(value)
return None return None
def process_literal_param(self, value, dialect):
return self._process_param(value)

View File

@@ -1,7 +1,6 @@
from aiogram.types import CallbackQuery, Message
from functools import wraps from functools import wraps
from types import UnionType from types import UnionType
from typing import Callable, Union, get_args, get_origin, Any from typing import Callable, Union, get_args, get_origin
from .model.descriptors import ( from .model.descriptors import (
BotCommand, BotCommand,