Files
quickbot/bot/handlers/editors/__init__.py
Alexander Kalinovsky 6dbe0536ca init
2025-01-04 12:00:12 +01:00

371 lines
16 KiB
Python

from datetime import datetime
from decimal import Decimal
from types import NoneType, UnionType
from typing import get_args, get_origin
from aiogram import Router, F
from aiogram.fsm.context import FSMContext
from aiogram.types import Message, CallbackQuery
from logging import getLogger
from sqlmodel.ext.asyncio.session import AsyncSession
import ujson as json
from ....main import QBotApp
from ....model import EntityPermission
from ....model.bot_entity import BotEntity
from ....model.owned_bot_entity import OwnedBotEntity
from ....model.bot_enum import BotEnum
from ....model.language import LanguageBase
from ....model.settings import Settings
from ....model.user import UserBase
from ....model.descriptors import EntityFieldDescriptor
from ....utils import deserialize, get_user_permissions, serialize
from ...command_context_filter import CallbackCommandFilter
from ..context import ContextData, CallbackCommand, CommandContext
from ..common import (get_value_repr, authorize_command, get_callable_str,
get_entity_descriptor, get_field_descriptor)
from ..menu.parameters import parameters_menu
from .string import string_editor, router as string_editor_router
from .date import date_picker, router as date_picker_router
from .boolean import bool_editor, router as bool_editor_router
from .entity import entity_picker, router as entity_picker_router
logger = getLogger(__name__)
router = Router()
router.include_routers(
string_editor_router,
date_picker_router,
bool_editor_router,
entity_picker_router,
)
@router.callback_query(ContextData.filter(F.command == CallbackCommand.FIELD_EDITOR))
async def settings_field_editor(message: Message | CallbackQuery, **kwargs):
callback_data: ContextData = kwargs.get("callback_data", None)
db_session: AsyncSession = kwargs["db_session"]
user: UserBase = kwargs["user"]
app: QBotApp = kwargs["app"]
state: FSMContext = kwargs["state"]
state_data = await state.get_data()
entity_data = state_data.get("entity_data")
for key in ["current_value", "value", "locale_index"]:
if key in state_data:
state_data.pop(key)
await state.clear()
await state.update_data(state_data)
entity_descriptor = None
if callback_data.context == CommandContext.SETTING_EDIT:
field_descriptor = get_field_descriptor(app, callback_data)
if field_descriptor.type_ == bool:
if await authorize_command(user = user, callback_data = callback_data):
await Settings.set_param(field_descriptor, not await Settings.get(field_descriptor))
else:
return await message.answer(text = (await Settings.get(Settings.APP_STRINGS_FORBIDDEN)))
stack, context = await get_navigation_context(state = state)
return await parameters_menu(message = message,
navigation_stack = stack,
**kwargs)
current_value = await Settings.get(field_descriptor)
else:
entity_descriptor = get_entity_descriptor(app, callback_data)
field_descriptor = get_field_descriptor(app, callback_data)
current_value = None
if not entity_data and callback_data.context == CommandContext.ENTITY_EDIT:
if (EntityPermission.READ_ALL in get_user_permissions(user, entity_descriptor) or
(EntityPermission.READ in get_user_permissions(user, entity_descriptor) and
not issubclass(entity_descriptor.type_, OwnedBotEntity)) or
(EntityPermission.READ in get_user_permissions(user, entity_descriptor) and
issubclass(entity_descriptor.type_, OwnedBotEntity) and
entity_data.user_id == user.id)):
entity = await entity_descriptor.type_.get(session = kwargs["db_session"], id = int(callback_data.entity_id))
if entity:
entity_data = {key: serialize(getattr(entity, key), entity_descriptor.fields_descriptors[key]) for key in entity_descriptor.field_sequence}
await state.update_data({"entity_data": entity_data})
if entity_data:
current_value = await deserialize(session = db_session,
type_= field_descriptor.type_,
value = entity_data.get(callback_data.field_name))
await show_editor(message = message,
field_descriptor = field_descriptor,
entity_descriptor = entity_descriptor,
current_value = current_value,
**kwargs)
async def show_editor(message: Message | CallbackQuery,
**kwargs):
field_descriptor: EntityFieldDescriptor = kwargs["field_descriptor"]
current_value = kwargs["current_value"]
user: UserBase = kwargs["user"]
callback_data: ContextData = kwargs.get("callback_data", None)
state: FSMContext = kwargs["state"]
value_type = field_descriptor.type_
if field_descriptor.edit_prompt:
edit_prompt = get_callable_str(field_descriptor.edit_prompt, field_descriptor, None, current_value)
else:
if field_descriptor.caption_str:
caption_str = get_callable_str(field_descriptor.caption_str, field_descriptor, None, current_value)
else:
caption_str = field_descriptor.name
if callback_data.context == CommandContext.ENTITY_EDIT:
edit_prompt = (await Settings.get(Settings.APP_STRINGS_FIELD_EDIT_PROMPT_TEMPLATE_P_NAME_VALUE)).format(
name = caption_str, value = get_value_repr(current_value, field_descriptor, user.lang))
else:
edit_prompt = (await Settings.get(Settings.APP_STRINGS_FIELD_CREATE_PROMPT_TEMPLATE_P_NAME)).format(
name = caption_str)
kwargs["edit_prompt"] = edit_prompt
type_origin = get_origin(value_type)
if type_origin == UnionType:
args = get_args(value_type)
if args[1] == NoneType:
value_type = args[0]
if value_type not in [int, float, Decimal, str]:
await state.update_data({"context_data": callback_data.pack()})
if value_type == str:
await string_editor(message = message, **kwargs)
elif value_type == bool:
await bool_editor(message = message, **kwargs)
elif value_type in [int, float, Decimal, str]:
await string_editor(message = message, **kwargs)
elif value_type == datetime:
await date_picker(message = message, **kwargs)
elif type_origin == list:
type_args = get_args(value_type)
if type_args and issubclass(type_args[0], BotEntity) or issubclass(type_args[0], BotEnum):
await entity_picker(message = message, **kwargs)
else:
await string_editor(message = message, **kwargs)
elif issubclass(value_type, BotEntity) or issubclass(value_type, BotEnum):
await entity_picker(message = message, **kwargs)
else:
raise ValueError(f"Unsupported field type: {value_type}")
@router.message(CallbackCommandFilter(CallbackCommand.FIELD_EDITOR_CALLBACK))
@router.callback_query(ContextData.filter(F.command == CallbackCommand.FIELD_EDITOR_CALLBACK))
async def field_editor_callback(message: Message | CallbackQuery, **kwargs):
callback_data: ContextData = kwargs.get("callback_data", None)
app: QBotApp = kwargs["app"]
state: FSMContext = kwargs["state"]
state_data = await state.get_data()
if isinstance(message, Message):
context_data = state_data.get("context_data")
if context_data:
context_data = ContextData.unpack(context_data)
callback_data = context_data
value = message.text
field_descriptor = get_field_descriptor(app, callback_data)
base_type = field_descriptor.type_
if get_origin(base_type) == UnionType:
args = get_args(base_type)
if args[1] == NoneType:
base_type = args[0]
if base_type == str and field_descriptor.localizable:
locale_index = int(state_data.get("locale_index"))
if locale_index < len(LanguageBase.all_members.values()) - 1:
#entity_data = state_data.get("entity_data", {})
#current_value = entity_data.get(field_descriptor.field_name)
current_value = state_data.get("current_value")
value = state_data.get("value")
if value:
value = json.loads(value)
else:
value = {}
value[list(LanguageBase.all_members.values())[locale_index]] = message.text
value = json.dumps(value)
await state.update_data({"value": value})
entity_descriptor = get_entity_descriptor(app, callback_data)
kwargs.update({"callback_data": callback_data})
return await show_editor(message = message,
locale_index = locale_index + 1,
field_descriptor = field_descriptor,
entity_descriptor = entity_descriptor,
current_value = current_value,
value = value,
**kwargs)
else:
value = state_data.get("value")
if value:
value = json.loads(value)
else:
value = {}
value[list(LanguageBase.all_members.values())[locale_index]] = message.text
value = json.dumps(value)
elif (base_type in [int, float, Decimal]):
try:
_ = base_type(value) #@IgnoreException
except:
return await message.answer(text = (await Settings.get(Settings.APP_STRINGS_INVALID_INPUT)))
else:
if callback_data.data:
value = callback_data.data
else:
value = state_data.get("value")
field_descriptor = get_field_descriptor(app, callback_data)
kwargs.update({"callback_data": callback_data,})
await process_field_edit_callback(message = message,
value = value,
field_descriptor = field_descriptor,
**kwargs)
async def process_field_edit_callback(message: Message | CallbackQuery, **kwargs):
user: UserBase = kwargs["user"]
db_session: AsyncSession = kwargs["db_session"]
callback_data: ContextData = kwargs.get("callback_data", None)
state: FSMContext = kwargs["state"]
value = kwargs["value"]
field_descriptor: EntityFieldDescriptor = kwargs["field_descriptor"]
if callback_data.context == CommandContext.SETTING_EDIT:
await clear_state(state = state)
if callback_data.data != "cancel":
if await authorize_command(user = user, callback_data = callback_data):
value = await deserialize(session = db_session, type_ = field_descriptor.type_, value = value)
await Settings.set_param(field_descriptor, value)
else:
return await message.answer(text = (await Settings.get(Settings.APP_STRINGS_FORBIDDEN)))
stack, context = await get_navigation_context(state = state)
return await parameters_menu(message = message,
navigation_stack = stack,
**kwargs)
elif callback_data.context in [CommandContext.ENTITY_CREATE, CommandContext.ENTITY_EDIT]:
app: QBotApp = kwargs["app"]
entity_descriptor = get_entity_descriptor(app, callback_data)
field_sequence = entity_descriptor.field_sequence
current_index = field_sequence.index(callback_data.field_name)
state_data = await state.get_data()
entity_data = state_data.get("entity_data", {})
if current_index < len(field_sequence) - 1:
entity_data[field_descriptor.field_name] = value
await state.update_data({"entity_data": entity_data})
next_field_name = field_sequence[current_index + 1]
next_field_descriptor = entity_descriptor.fields_descriptors[next_field_name]
kwargs.update({"field_descriptor": next_field_descriptor})
callback_data.field_name = next_field_name
state_entity_val = entity_data.get(next_field_descriptor.field_name)
current_value = await deserialize(session = db_session, type_ = next_field_descriptor.type_,
value = state_entity_val) if state_entity_val else None
await show_editor(message = message,
entity_descriptor = entity_descriptor,
current_value = current_value,
**kwargs)
else:
entity_type: BotEntity = entity_descriptor.type_
user_permissions = get_user_permissions(user, entity_descriptor)
if ((callback_data.context == CommandContext.ENTITY_CREATE and
EntityPermission.CREATE not in user_permissions and
EntityPermission.CREATE_ALL not in user_permissions) or
(callback_data.context == CommandContext.ENTITY_EDIT and
EntityPermission.UPDATE not in user_permissions and
EntityPermission.UPDATE_ALL not in user_permissions)):
return await message.answer(text = (await Settings.get(Settings.APP_STRINGS_FORBIDDEN)))
is_owned = issubclass(entity_type, OwnedBotEntity)
entity_data[field_descriptor.field_name] = value
if is_owned and EntityPermission.CREATE_ALL not in user_permissions:
entity_data["user_id"] = user.id
deser_entity_data = {key: await deserialize(
session = db_session,
type_ = entity_descriptor.fields_descriptors[key].type_,
value = value) for key, value in entity_data.items()}
if callback_data.context == CommandContext.ENTITY_CREATE:
new_entity = await entity_type.create(session = db_session,
obj_in = entity_type(**deser_entity_data),
commit = True)
await save_navigation_context(state = state, callback_data = ContextData(
command = CallbackCommand.ENTITY_ITEM,
entity_name = entity_descriptor.name,
entity_id = str(new_entity.id)
))
elif callback_data.context == CommandContext.ENTITY_EDIT:
entity_id = int(callback_data.entity_id)
entity = await entity_type.get(session = db_session, id = entity_id)
if not entity:
return await message.answer(text = (await Settings.get(Settings.APP_STRINGS_NOT_FOUND)))
if (is_owned and entity.user_id != user.id and
EntityPermission.UPDATE_ALL not in user_permissions):
return await message.answer(text = (await Settings.get(Settings.APP_STRINGS_FORBIDDEN)))
for key, value in deser_entity_data.items():
setattr(entity, key, value)
await db_session.commit()
await clear_state(state = state)
await route_callback(message = message, back = False, **kwargs)
from ..navigation import get_navigation_context, route_callback, clear_state, save_navigation_context