Compare commits

...

3 Commits

Author SHA1 Message Date
Alexander Kalinovsky
3208721d9e fix save state when entity not found 2025-02-13 02:37:49 +01:00
Alexander Kalinovsky
b7211368cc merge from remote 2025-02-13 02:13:22 +01:00
Alexander Kalinovsky
ca374cdea0 upd get_callable_str async 2025-02-13 02:00:20 +01:00
12 changed files with 124 additions and 86 deletions

View File

@@ -6,7 +6,7 @@ from ....utils.main import get_callable_str
from ..context import ContextData, CallbackCommand from ..context import ContextData, CallbackCommand
def add_filter_controls( async def add_filter_controls(
keyboard_builder: InlineKeyboardBuilder, keyboard_builder: InlineKeyboardBuilder,
entity_descriptor: EntityDescriptor, entity_descriptor: EntityDescriptor,
filter: str = None, filter: str = None,
@@ -15,7 +15,7 @@ def add_filter_controls(
): ):
caption = ", ".join( caption = ", ".join(
[ [
get_callable_str( await get_callable_str(
entity_descriptor.fields_descriptors[field_name].caption, entity_descriptor.fields_descriptors[field_name].caption,
entity_descriptor, entity_descriptor,
) )

View File

@@ -17,7 +17,6 @@ class CallbackCommand(StrEnum):
DATE_PICKER_MONTH = "dm" DATE_PICKER_MONTH = "dm"
DATE_PICKER_YEAR = "dy" DATE_PICKER_YEAR = "dy"
TIME_PICKER = "tp" TIME_PICKER = "tp"
# STRING_EDITOR_LOCALE = "sl"
ENTITY_PICKER_PAGE = "ep" ENTITY_PICKER_PAGE = "ep"
ENTITY_PICKER_TOGGLE_ITEM = "et" ENTITY_PICKER_TOGGLE_ITEM = "et"
VIEW_FILTER_EDIT = "vf" VIEW_FILTER_EDIT = "vf"

View File

@@ -25,13 +25,23 @@ async def show_editor(message: Message | CallbackQuery, **kwargs):
value_type = field_descriptor.type_base value_type = field_descriptor.type_base
if field_descriptor.edit_prompt: if field_descriptor.edit_prompt:
edit_prompt = get_callable_str( edit_prompt = await get_callable_str(
field_descriptor.edit_prompt, field_descriptor, None, current_value field_descriptor.edit_prompt,
field_descriptor,
callback_data
if callback_data.context == CommandContext.COMMAND_FORM
else None,
current_value,
) )
else: else:
if field_descriptor.caption: if field_descriptor.caption:
caption_str = get_callable_str( caption_str = await get_callable_str(
field_descriptor.caption, field_descriptor, None, current_value field_descriptor.caption,
field_descriptor,
callback_data
if callback_data.context == CommandContext.COMMAND_FORM
else None,
current_value,
) )
else: else:
caption_str = field_descriptor.name caption_str = field_descriptor.name
@@ -42,7 +52,7 @@ async def show_editor(message: Message | CallbackQuery, **kwargs):
) )
).format( ).format(
name=caption_str, name=caption_str,
value=get_value_repr(current_value, field_descriptor, user.lang), value=await get_value_repr(current_value, field_descriptor, user.lang),
) )
else: else:
edit_prompt = ( edit_prompt = (

View File

@@ -56,12 +56,20 @@ async def time_picker(
if not current_value: if not current_value:
current_value = time(0, 0) current_value = time(0, 0)
is_datetime = False
else: else:
is_datetime = isinstance(current_value, datetime)
if not is_datetime:
current_time = datetime.combine(datetime.now(), current_value)
remainder = current_value.minute % 5 remainder = current_value.minute % 5
if remainder >= 3: if remainder >= 3:
current_value += timedelta(minutes=(5 - remainder)) current_time += timedelta(minutes=(5 - remainder))
else: else:
current_value -= timedelta(minutes=remainder) current_time -= timedelta(minutes=remainder)
if is_datetime:
current_value = datetime.combine(current_value.date(), current_time.time())
else:
current_value = current_time.time()
for i in range(12): for i in range(12):
keyboard_builder.row( keyboard_builder.row(

View File

@@ -209,7 +209,7 @@ async def render_entity_picker(
type_.bot_entity_descriptor, item type_.bot_entity_descriptor, item
) )
if type_.bot_entity_descriptor.item_repr if type_.bot_entity_descriptor.item_repr
else get_callable_str( else await get_callable_str(
type_.bot_entity_descriptor.full_name, type_.bot_entity_descriptor.full_name,
type_.bot_entity_descriptor, type_.bot_entity_descriptor,
item, item,
@@ -259,7 +259,7 @@ async def render_entity_picker(
and form_list.filtering and form_list.filtering
and form_list.filtering_fields and form_list.filtering_fields
): ):
add_filter_controls( await add_filter_controls(
keyboard_builder=keyboard_builder, keyboard_builder=keyboard_builder,
entity_descriptor=type_.bot_entity_descriptor, entity_descriptor=type_.bot_entity_descriptor,
filter=entity_filter, filter=entity_filter,

View File

@@ -62,6 +62,10 @@ async def entity_item(
entity_item = await entity_type.get(session=db_session, id=callback_data.entity_id) entity_item = await entity_type.get(session=db_session, id=callback_data.entity_id)
state: FSMContext = kwargs["state"]
state_data = kwargs["state_data"]
await state.set_data(state_data)
if not entity_item: if not entity_item:
return await query.answer( return await query.answer(
text=(await Settings.get(Settings.APP_STRINGS_NOT_FOUND)) text=(await Settings.get(Settings.APP_STRINGS_NOT_FOUND))
@@ -100,13 +104,13 @@ async def entity_item(
] ]
field_value = getattr(entity_item, field_descriptor.field_name) field_value = getattr(entity_item, field_descriptor.field_name)
if btn_caption: if btn_caption:
btn_text = get_callable_str( btn_text = await get_callable_str(
btn_caption, field_descriptor, entity_item, field_value btn_caption, field_descriptor, entity_item, field_value
) )
else: else:
if field_descriptor.type_base is bool: if field_descriptor.type_base is bool:
btn_text = f"{'【✔︎】 ' if field_value else '【 】 '}{ btn_text = f"{'【✔︎】 ' if field_value else '【 】 '}{
get_callable_str( await get_callable_str(
field_descriptor.caption, field_descriptor.caption,
field_descriptor, field_descriptor,
entity_item, entity_item,
@@ -116,18 +120,20 @@ async def entity_item(
else field_name else field_name
}" }"
else: else:
btn_text = ( btn_text = f"{
f"✏️ { field_descriptor.icon
get_callable_str( if field_descriptor.icon
else '✏️'
} {
await get_callable_str(
field_descriptor.caption, field_descriptor.caption,
field_descriptor, field_descriptor,
entity_item, entity_item,
field_value, field_value,
) )
}"
if field_descriptor.caption if field_descriptor.caption
else f"✏️ {field_name}" else field_name
) }"
btn_row.append( btn_row.append(
InlineKeyboardButton( InlineKeyboardButton(
text=btn_text, text=btn_text,
@@ -144,7 +150,7 @@ async def entity_item(
elif isinstance(button, CommandButton): elif isinstance(button, CommandButton):
btn_caption = button.caption btn_caption = button.caption
btn_text = get_callable_str( btn_text = await get_callable_str(
btn_caption, entity_descriptor, entity_item btn_caption, entity_descriptor, entity_item
) )
@@ -215,7 +221,7 @@ async def entity_item(
item_text = form.item_repr(entity_descriptor, entity_item) item_text = form.item_repr(entity_descriptor, entity_item)
else: else:
entity_caption = ( entity_caption = (
get_callable_str( await get_callable_str(
entity_descriptor.full_name, entity_descriptor, entity_item entity_descriptor.full_name, entity_descriptor, entity_item
) )
if entity_descriptor.full_name if entity_descriptor.full_name
@@ -223,7 +229,7 @@ async def entity_item(
) )
entity_item_repr = ( entity_item_repr = (
get_callable_str( await get_callable_str(
entity_descriptor.item_repr, entity_descriptor, entity_item entity_descriptor.item_repr, entity_descriptor, entity_item
) )
if entity_descriptor.item_repr if entity_descriptor.item_repr
@@ -234,18 +240,18 @@ async def entity_item(
for field_descriptor in entity_descriptor.fields_descriptors.values(): for field_descriptor in entity_descriptor.fields_descriptors.values():
if field_descriptor.is_visible: if field_descriptor.is_visible:
field_caption = get_callable_str( field_caption = await get_callable_str(
field_descriptor.caption, field_descriptor, entity_item field_descriptor.caption, field_descriptor, entity_item
) )
if field_descriptor.caption_value: if field_descriptor.caption_value:
value = get_callable_str( value = await get_callable_str(
field_descriptor.caption_value, field_descriptor.caption_value,
field_descriptor, field_descriptor,
entity_item, entity_item,
getattr(entity_item, field_descriptor.field_name), getattr(entity_item, field_descriptor.field_name),
) )
else: else:
value = get_value_repr( value = await get_value_repr(
value=getattr(entity_item, field_descriptor.field_name), value=getattr(entity_item, field_descriptor.field_name),
field_descriptor=field_descriptor, field_descriptor=field_descriptor,
locale=user.lang, locale=user.lang,
@@ -261,9 +267,9 @@ async def entity_item(
) )
) )
state: FSMContext = kwargs["state"] # state: FSMContext = kwargs["state"]
state_data = kwargs["state_data"] # state_data = kwargs["state_data"]
await state.set_data(state_data) # await state.set_data(state_data)
send_message = get_send_message(query) send_message = get_send_message(query)

View File

@@ -61,7 +61,7 @@ async def entity_delete_callback(query: CallbackQuery, **kwargs):
return await query.message.edit_text( return await query.message.edit_text(
text=( text=(
await Settings.get(Settings.APP_STRINGS_CONFIRM_DELETE_P_NAME) await Settings.get(Settings.APP_STRINGS_CONFIRM_DELETE_P_NAME)
).format(name=get_entity_item_repr(entity=entity)), ).format(name=await get_entity_item_repr(entity=entity)),
reply_markup=InlineKeyboardBuilder() reply_markup=InlineKeyboardBuilder()
.row( .row(
InlineKeyboardButton( InlineKeyboardButton(

View File

@@ -197,7 +197,7 @@ async def entity_list(
caption = entity_descriptor.item_repr(entity_descriptor, item) caption = entity_descriptor.item_repr(entity_descriptor, item)
elif entity_descriptor.full_name: elif entity_descriptor.full_name:
caption = f"{ caption = f"{
get_callable_str( await get_callable_str(
callable_str=entity_descriptor.full_name, callable_str=entity_descriptor.full_name,
descriptor=entity_descriptor, descriptor=entity_descriptor,
entity=item, entity=item,
@@ -228,7 +228,7 @@ async def entity_list(
) )
if form_list.filtering and form_list.filtering_fields: if form_list.filtering and form_list.filtering_fields:
add_filter_controls( await add_filter_controls(
keyboard_builder=keyboard_builder, keyboard_builder=keyboard_builder,
entity_descriptor=entity_descriptor, entity_descriptor=entity_descriptor,
filter=entity_filter, filter=entity_filter,
@@ -245,17 +245,17 @@ async def entity_list(
) )
if form_list.caption: if form_list.caption:
entity_text = get_callable_str(form_list.caption, entity_descriptor) entity_text = await get_callable_str(form_list.caption, entity_descriptor)
else: else:
if entity_descriptor.full_name_plural: if entity_descriptor.full_name_plural:
entity_text = get_callable_str( entity_text = await get_callable_str(
entity_descriptor.full_name_plural, entity_descriptor entity_descriptor.full_name_plural, entity_descriptor
) )
else: else:
entity_text = entity_descriptor.name entity_text = entity_descriptor.name
if entity_descriptor.description: if entity_descriptor.description:
entity_text = f"{entity_text} {get_callable_str(entity_descriptor.description, entity_descriptor)}" entity_text = f"{entity_text} {await get_callable_str(entity_descriptor.description, entity_descriptor)}"
state: FSMContext = kwargs["state"] state: FSMContext = kwargs["state"]
state_data = kwargs["state_data"] state_data = kwargs["state_data"]

View File

@@ -54,12 +54,12 @@ async def parameters_menu(
continue continue
if key.caption_value: if key.caption_value:
caption = get_callable_str( caption = await get_callable_str(
callable_str=key.caption_value, descriptor=key, entity=None, value=value callable_str=key.caption_value, descriptor=key, entity=None, value=value
) )
else: else:
if key.caption: if key.caption:
caption = get_callable_str( caption = await get_callable_str(
callable_str=key.caption, descriptor=key, entity=None, value=value callable_str=key.caption, descriptor=key, entity=None, value=value
) )
else: else:
@@ -68,7 +68,7 @@ async def parameters_menu(
if key.type_ is bool: if key.type_ is bool:
caption = f"{'【✔︎】' if value else '【 】'} {caption}" caption = f"{'【✔︎】' if value else '【 】'} {caption}"
else: else:
caption = f"{caption}: {get_value_repr(value=value, field_descriptor=key, locale=user.lang)}" caption = f"{caption}: {await get_value_repr(value=value, field_descriptor=key, locale=user.lang)}"
keyboard_builder.row( keyboard_builder.row(
InlineKeyboardButton( InlineKeyboardButton(

View File

@@ -8,6 +8,7 @@ from sqlmodel.ext.asyncio.session import AsyncSession
from ...main import QBotApp from ...main import QBotApp
from ...model.settings import Settings from ...model.settings import Settings
from ...model.language import LanguageBase from ...model.language import LanguageBase
from ...model.user import UserBase
from ...utils.main import clear_state from ...utils.main import clear_state
@@ -16,9 +17,24 @@ router = Router()
@router.message(CommandStart()) @router.message(CommandStart())
async def start( async def start(message: Message, **kwargs):
message: Message, db_session: AsyncSession, app: QBotApp, state: FSMContext app: QBotApp = kwargs["app"]
):
if app.start_handler:
await app.start_handler(
default_start_handler, message, **kwargs
)
else:
await default_start_handler(message, **kwargs)
async def default_start_handler[UserType: UserBase](
message: Message,
db_session: AsyncSession,
app: QBotApp,
state: FSMContext,
**kwargs,
) -> tuple[UserType, bool]:
state_data = await state.get_data() state_data = await state.get_data()
clear_state(state_data=state_data, clear_nav=True) clear_state(state_data=state_data, clear_nav=True)
@@ -27,6 +43,7 @@ async def start(
user = await User.get(session=db_session, id=message.from_user.id) user = await User.get(session=db_session, id=message.from_user.id)
if not user: if not user:
is_new = True
msg_text = (await Settings.get(Settings.APP_STRINGS_WELCOME_P_NAME)).format( msg_text = (await Settings.get(Settings.APP_STRINGS_WELCOME_P_NAME)).format(
name=message.from_user.full_name name=message.from_user.full_name
) )
@@ -61,6 +78,7 @@ async def start(
return return
else: else:
is_new = False
if user.is_active: if user.is_active:
msg_text = ( msg_text = (
await Settings.get(Settings.APP_STRINGS_GREETING_P_NAME) await Settings.get(Settings.APP_STRINGS_GREETING_P_NAME)
@@ -71,3 +89,5 @@ async def start(
).format(name=user.name) ).format(name=user.name)
await message.answer(msg_text) await message.answer(msg_text)
return user, is_new

29
main.py
View File

@@ -1,6 +1,5 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Annotated, Callable, Any from typing import Callable, Any
from typing_extensions import Doc
from aiogram import Bot, Dispatcher from aiogram import Bot, Dispatcher
from aiogram.client.default import DefaultBotProperties from aiogram.client.default import DefaultBotProperties
from aiogram.types import Message, BotCommand as AiogramBotCommand from aiogram.types import Message, BotCommand as AiogramBotCommand
@@ -52,37 +51,23 @@ async def default_lifespan(app: "QBotApp"):
logger.info("qbot app stopped") logger.info("qbot app stopped")
class QBotApp(FastAPI): class QBotApp[UserType: UserBase](FastAPI):
""" """
Main class for the QBot application Main class for the QBot application
""" """
def __init__[UserType: UserBase]( def __init__(
self, self,
user_class: ( user_class: UserType = None,
Annotated[
type[UserType], Doc("User class that will be used in the application")
]
| None
) = None,
config: Config | None = None, config: Config | None = None,
bot_start: ( bot_start: Callable[
Annotated[
Callable[
[ [
Annotated[ Callable[[Message, Any], tuple[UserType, bool]],
Callable[[Message, Any], None],
Doc("Default handler for the start command"),
],
Message, Message,
Any, Any,
], ],
None, None,
], ] = None,
Doc("Handler for the start command"),
]
| None
) = None,
lifespan: Lifespan[AppType] | None = None, lifespan: Lifespan[AppType] | None = None,
lifespan_bot_init: bool = True, lifespan_bot_init: bool = True,
allowed_updates: list[str] | None = None, allowed_updates: list[str] | None = None,

View File

@@ -1,5 +1,5 @@
from babel.support import LazyProxy from babel.support import LazyProxy
from inspect import signature from inspect import iscoroutinefunction, signature
from aiogram.types import Message, CallbackQuery from aiogram.types import Message, CallbackQuery
from aiogram.utils.i18n import I18n from aiogram.utils.i18n import I18n
from typing import Any, TYPE_CHECKING from typing import Any, TYPE_CHECKING
@@ -105,7 +105,7 @@ def clear_state(state_data: dict, clear_nav: bool = False):
state_data["navigation_context"] = context state_data["navigation_context"] = context
def get_entity_item_repr( async def get_entity_item_repr(
entity: BotEntity, item_repr: EntityItemCaptionCallable | None = None entity: BotEntity, item_repr: EntityItemCaptionCallable | None = None
) -> str: ) -> str:
descr = entity.bot_entity_descriptor descr = entity.bot_entity_descriptor
@@ -115,14 +115,14 @@ def get_entity_item_repr(
descr.item_repr(descr, entity) descr.item_repr(descr, entity)
if descr.item_repr if descr.item_repr
else f"{ else f"{
get_callable_str(descr.full_name, descr, entity) await get_callable_str(descr.full_name, descr, entity)
if descr.full_name if descr.full_name
else descr.name else descr.name
}: {str(entity.id)}" }: {str(entity.id)}"
) )
def get_value_repr( async def get_value_repr(
value: Any, field_descriptor: FieldDescriptor, locale: str | None = None value: Any, field_descriptor: FieldDescriptor, locale: str | None = None
) -> str: ) -> str:
if value is None: if value is None:
@@ -133,7 +133,9 @@ def get_value_repr(
return "【✔︎】" if value else "【 】" return "【✔︎】" if value else "【 】"
elif field_descriptor.is_list: elif field_descriptor.is_list:
if issubclass(type_, BotEntity): if issubclass(type_, BotEntity):
return f"[{', '.join([get_entity_item_repr(item) for item in value])}]" return (
f"[{', '.join([await get_entity_item_repr(item) for item in value])}]"
)
elif issubclass(type_, BotEnum): elif issubclass(type_, BotEnum):
return f"[{', '.join(item.localized(locale) for item in value)}]" return f"[{', '.join(item.localized(locale) for item in value)}]"
elif type_ is str: elif type_ is str:
@@ -141,7 +143,7 @@ def get_value_repr(
else: else:
return f"[{', '.join([str(item) for item in value])}]" return f"[{', '.join([str(item) for item in value])}]"
elif issubclass(type_, BotEntity): elif issubclass(type_, BotEntity):
return get_entity_item_repr(value) return await get_entity_item_repr(value)
elif issubclass(type_, BotEnum): elif issubclass(type_, BotEnum):
return value.localized(locale) return value.localized(locale)
elif isinstance(value, str): elif isinstance(value, str):
@@ -156,7 +158,7 @@ def get_value_repr(
return str(value) return str(value)
def get_callable_str( async def get_callable_str(
callable_str: ( callable_str: (
str str
| LazyProxy | LazyProxy
@@ -174,6 +176,14 @@ def get_callable_str(
return callable_str.value return callable_str.value
elif callable(callable_str): elif callable(callable_str):
args = signature(callable_str).parameters args = signature(callable_str).parameters
if iscoroutinefunction(callable_str):
if len(args) == 1:
return await callable_str(descriptor)
elif len(args) == 2:
return await callable_str(descriptor, entity)
elif len(args) == 3:
return await callable_str(descriptor, entity, value)
else:
if len(args) == 1: if len(args) == 1:
return callable_str(descriptor) return callable_str(descriptor)
elif len(args) == 2: elif len(args) == 2: