From 4df67c93d4ad3545d94e9c895e53cb5491fa4edf Mon Sep 17 00:00:00 2001 From: Alexander Kalinovsky Date: Mon, 11 Aug 2025 20:47:39 +0300 Subject: [PATCH] crud service --- src/quickbot/__init__.py | 2 + src/quickbot/api_route/depends.py | 37 ++ src/quickbot/api_route/models.py | 49 ++ src/quickbot/api_route/telegram.py | 17 +- src/quickbot/auth/jwt.py | 20 + src/quickbot/auth/telegram.py | 15 + src/quickbot/bot/handlers/editors/entity.py | 20 +- src/quickbot/bot/handlers/editors/main.py | 10 +- .../bot/handlers/editors/main_callbacks.py | 47 +- .../bot/handlers/forms/entity_form.py | 30 +- .../handlers/forms/entity_form_callbacks.py | 6 +- .../bot/handlers/forms/entity_list.py | 10 +- src/quickbot/bot/handlers/menu/entities.py | 2 +- src/quickbot/main.py | 490 +++++++++++++++++- src/quickbot/middleware/telegram/auth.py | 12 +- src/quickbot/model/__init__.py | 10 +- src/quickbot/model/annotated_schema.py | 46 ++ src/quickbot/model/bot_entity.py | 400 +++++++++----- src/quickbot/model/bot_enum.py | 44 +- src/quickbot/model/bot_metadata.py | 16 + src/quickbot/model/bot_process.py | 77 +++ src/quickbot/model/crud_command.py | 9 + src/quickbot/model/crud_service.py | 374 +++++++++++++ src/quickbot/model/descriptors.py | 155 +++++- src/quickbot/model/entity_metadata.py | 7 - src/quickbot/model/list_schema.py | 6 + src/quickbot/model/permissions.py | 192 +++++++ src/quickbot/model/pydantic_json.py | 53 ++ src/quickbot/model/user.py | 23 +- src/quickbot/model/utils.py | 364 +++++++++++++ src/quickbot/plugin.py | 9 + src/quickbot/utils/main.py | 122 ++--- src/quickbot/utils/serialization.py | 18 +- 33 files changed, 2358 insertions(+), 334 deletions(-) create mode 100644 src/quickbot/api_route/depends.py create mode 100644 src/quickbot/api_route/models.py create mode 100644 src/quickbot/auth/jwt.py create mode 100644 src/quickbot/auth/telegram.py create mode 100644 src/quickbot/model/annotated_schema.py create mode 100644 src/quickbot/model/bot_metadata.py create mode 100644 src/quickbot/model/bot_process.py create mode 100644 src/quickbot/model/crud_command.py create mode 100644 src/quickbot/model/crud_service.py delete mode 100644 src/quickbot/model/entity_metadata.py create mode 100644 src/quickbot/model/list_schema.py create mode 100644 src/quickbot/model/permissions.py create mode 100644 src/quickbot/model/pydantic_json.py create mode 100644 src/quickbot/model/utils.py create mode 100644 src/quickbot/plugin.py diff --git a/src/quickbot/__init__.py b/src/quickbot/__init__.py index ab0ae84..aef2fa3 100644 --- a/src/quickbot/__init__.py +++ b/src/quickbot/__init__.py @@ -1,6 +1,7 @@ from .main import QBotApp as QBotApp, Config as Config from .router import Router as Router from .model.bot_entity import BotEntity as BotEntity +from .model.bot_process import BotProcess as BotProcess from .model.bot_enum import BotEnum as BotEnum, EnumMember as EnumMember from .bot.handlers.context import ( ContextData as ContextData, @@ -20,4 +21,5 @@ from .model.descriptors import ( FieldEditButton as FieldEditButton, InlineButton as InlineButton, FormField as FormField, + Process as Process, ) diff --git a/src/quickbot/api_route/depends.py b/src/quickbot/api_route/depends.py new file mode 100644 index 0000000..e61a949 --- /dev/null +++ b/src/quickbot/api_route/depends.py @@ -0,0 +1,37 @@ +from typing import Annotated, TYPE_CHECKING +from fastapi import Depends, HTTPException, Request +from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from jose import JWTError +from sqlmodel.ext.asyncio.session import AsyncSession + +from quickbot.auth.jwt import decode_access_token +from quickbot.db import get_db +from quickbot.model.user import UserBase + +if TYPE_CHECKING: + from quickbot import QBotApp + +security_scheme = HTTPBearer( + scheme_name="bearerAuth", + bearerFormat="JWT", +) + + +async def get_current_user( + request: Request, + db_session: Annotated[AsyncSession, Depends(get_db)], + credentials: HTTPAuthorizationCredentials = Depends(security_scheme), +) -> UserBase: + try: + payload = decode_access_token(credentials.credentials) + user_id = payload.get("sub") + if user_id is None: + raise HTTPException(status_code=401, detail="Invalid token") + app: QBotApp = request.app + user = await app.user_class.get( + session=db_session, + id=int(user_id), + ) + return user + except JWTError: + raise HTTPException(status_code=401, detail="Invalid token") diff --git a/src/quickbot/api_route/models.py b/src/quickbot/api_route/models.py new file mode 100644 index 0000000..ebe4016 --- /dev/null +++ b/src/quickbot/api_route/models.py @@ -0,0 +1,49 @@ +from fastapi import Depends, Request +from pydantic import BaseModel +from sqlmodel.ext.asyncio.session import AsyncSession +from typing import Annotated, TYPE_CHECKING + +from ..db import get_db +from ..model.descriptors import EntityDescriptor +from .depends import get_current_user + +if TYPE_CHECKING: + from ..main import QBotApp + from ..model.user import UserBase + + +class ListParams(BaseModel): + query: str = "" + order_by: str = "" + limit: int = 100 + offset: int = 0 + + +async def list_entity_items( + db_session: Annotated[AsyncSession, Depends(get_db)], + request: Request, + params: Annotated[ListParams, Depends()], + current_user=Depends(get_current_user), +): + entity_descriptor: EntityDescriptor = request.app.bot_metadata.entity_descriptors[ + request.url.path.split("/")[-1] + ] + entity_list = await entity_descriptor.crud.list_all( + db_session=db_session, + user=current_user, + ) + return entity_list + + +async def get_me( + db_session: Annotated[AsyncSession, Depends(get_db)], + request: Request, + current_user: Annotated["UserBase", Depends(get_current_user)], +): + app: "QBotApp" = request.app + user = await app.user_class.bot_entity_descriptor.crud.get_by_id( + db_session=db_session, + user=current_user, + id=current_user.id, + ) + return user diff --git a/src/quickbot/api_route/telegram.py b/src/quickbot/api_route/telegram.py index 2234971..2efd6a2 100644 --- a/src/quickbot/api_route/telegram.py +++ b/src/quickbot/api_route/telegram.py @@ -1,10 +1,12 @@ from aiogram.types import Update -from fastapi import APIRouter, Request, Response, Depends +from fastapi import APIRouter, Request, Response, Depends, HTTPException, Body from sqlmodel.ext.asyncio.session import AsyncSession from typing import Annotated from ..db import get_db from ..main import QBotApp +from ..auth.telegram import check_telegram_auth +from ..auth.jwt import create_access_token from logging import getLogger @@ -49,6 +51,19 @@ async def telegram_webhook( return Response(status_code=200) +@router.post("/auth") +async def telegram_login(request: Request, data: dict = Body(...)): + if not check_telegram_auth(data, request.app.config.TELEGRAM_BOT_TOKEN): + raise HTTPException(status_code=401, detail="Invalid Telegram auth") + payload = { + "sub": str(data["id"]), + "first_name": data.get("first_name"), + "username": data.get("username"), + } + token = create_access_token(payload) + return {"access_token": token, "token_type": "bearer"} + + # async def feed_bot_update( # app: QBotApp, # update: Update, diff --git a/src/quickbot/auth/jwt.py b/src/quickbot/auth/jwt.py new file mode 100644 index 0000000..8551309 --- /dev/null +++ b/src/quickbot/auth/jwt.py @@ -0,0 +1,20 @@ +from datetime import datetime, timedelta +from jose import jwt + +SECRET_KEY = "your_secret_key" # TODO: вынести в конфиг +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 1 неделя + + +def create_access_token(data: dict, expires_delta: timedelta = None): + to_encode = data.copy() + expire = datetime.utcnow() + ( + expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + ) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return encoded_jwt + + +def decode_access_token(token: str): + return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) diff --git a/src/quickbot/auth/telegram.py b/src/quickbot/auth/telegram.py new file mode 100644 index 0000000..506060d --- /dev/null +++ b/src/quickbot/auth/telegram.py @@ -0,0 +1,15 @@ +import hashlib +import hmac + + +def check_telegram_auth(data: dict, bot_token: str) -> bool: + auth_data = data.copy() + hash_ = auth_data.pop("hash", None) + if not hash_: + return False + data_check_string = "\n".join([f"{k}={v}" for k, v in sorted(auth_data.items())]) + secret_key = hashlib.sha256(bot_token.encode()).digest() + hmac_hash = hmac.new( + secret_key, data_check_string.encode(), hashlib.sha256 + ).hexdigest() + return hmac_hash == hash_ diff --git a/src/quickbot/bot/handlers/editors/entity.py b/src/quickbot/bot/handlers/editors/entity.py index bd09429..329af42 100644 --- a/src/quickbot/bot/handlers/editors/entity.py +++ b/src/quickbot/bot/handlers/editors/entity.py @@ -179,17 +179,31 @@ async def render_entity_picker( entity_filter = None list_all = EntityPermission.LIST_ALL in permissions - if list_all or EntityPermission.LIST in permissions: + if list_all or EntityPermission.LIST_RLS in permissions: if ( field_descriptor.ep_parent_field and field_descriptor.ep_child_field and callback_data.entity_id ): + if callable(field_descriptor.ep_parent_field): + parent_field = field_descriptor.ep_parent_field( + field_descriptor.entity_descriptor.type_ + ).key + else: + parent_field = field_descriptor.ep_parent_field + + if callable(field_descriptor.ep_child_field): + child_field = field_descriptor.ep_child_field( + field_descriptor.entity_descriptor.type_ + ).key + else: + child_field = field_descriptor.ep_child_field + entity = await field_descriptor.entity_descriptor.type_.get( session=db_session, id=callback_data.entity_id ) - value = getattr(entity, field_descriptor.ep_parent_field) - ext_filter = column(field_descriptor.ep_child_field).__eq__(value) + value = getattr(entity, parent_field) + ext_filter = column(child_field).__eq__(value) else: ext_filter = None diff --git a/src/quickbot/bot/handlers/editors/main.py b/src/quickbot/bot/handlers/editors/main.py index e99f0b4..6d72012 100644 --- a/src/quickbot/bot/handlers/editors/main.py +++ b/src/quickbot/bot/handlers/editors/main.py @@ -11,9 +11,9 @@ from quickbot.model.descriptors import BotContext, EntityForm from ....model import EntityPermission from ....model.settings import Settings from ....model.user import UserBase +from ....model.permissions import check_entity_permission from ....utils.main import ( build_field_sequence, - check_entity_permission, get_field_descriptor, clear_state, ) @@ -109,8 +109,8 @@ async def field_editor(message: Message | CallbackQuery, **kwargs): entity = await entity_descriptor.type_.get( session=db_session, id=int(callback_data.entity_id) ) - if check_entity_permission( - entity=entity, user=user, permission=EntityPermission.UPDATE + if await check_entity_permission( + entity=entity, user=user, permission=EntityPermission.UPDATE_RLS ): old_values = {} @@ -188,8 +188,8 @@ async def field_editor(message: Message | CallbackQuery, **kwargs): entity = await entity_descriptor.type_.get( session=kwargs["db_session"], id=int(callback_data.entity_id) ) - if check_entity_permission( - entity=entity, user=user, permission=EntityPermission.READ + if await check_entity_permission( + entity=entity, user=user, permission=EntityPermission.READ_RLS ): if entity: form_name = ( diff --git a/src/quickbot/bot/handlers/editors/main_callbacks.py b/src/quickbot/bot/handlers/editors/main_callbacks.py index 84499c0..e5a6cc6 100644 --- a/src/quickbot/bot/handlers/editors/main_callbacks.py +++ b/src/quickbot/bot/handlers/editors/main_callbacks.py @@ -19,12 +19,14 @@ from ....model.descriptors import ( EntityForm, EntityList, FieldDescriptor, + Filter, + FilterExpression, ) from ....model.language import LanguageBase from ....auth import authorize_command +from ....model.permissions import check_entity_permission from ....utils.main import ( get_user_permissions, - check_entity_permission, clear_state, get_entity_descriptor, get_field_descriptor, @@ -328,12 +330,39 @@ async def process_field_edit_callback(message: Message | CallbackQuery, **kwargs ]: user_permissions = get_user_permissions(user, entity_descriptor) - for role in user.roles: - if ( - role in entity_descriptor.ownership_fields - and EntityPermission.CREATE_ALL not in user_permissions + if entity_descriptor.rls_filters: + filters = [] + if isinstance(entity_descriptor.rls_filters, Filter): + filters = [entity_descriptor.rls_filters] + elif ( + isinstance(entity_descriptor.rls_filters, FilterExpression) + and entity_descriptor.rls_filters.operator == "and" + and all( + isinstance(f, Filter) + for f in entity_descriptor.rls_filters.filters + ) ): - entity_data[entity_descriptor.ownership_fields[role]] = user.id + filters = entity_descriptor.rls_filters.filters + filter_params = [] + if filters and entity_descriptor.rls_filters_params: + if iscoroutinefunction(entity_descriptor.rls_filters_params): + filter_params = await entity_descriptor.rls_filters_params( + user + ) + else: + filter_params = entity_descriptor.rls_filters_params(user) + + for f in filters: + if f.operator == "==": + if isinstance(f.field, str): + field_name = f.field + else: + field_name = f.field(entity_descriptor.type_).key + entity_data[field_name] = ( + f.value + if f.value_type == "const" + else filter_params[f.param_index] + ) deser_entity_data = { key: await deserialize( @@ -356,7 +385,7 @@ async def process_field_edit_callback(message: Message | CallbackQuery, **kwargs entity_type = entity_descriptor.type_ user_permissions = get_user_permissions(user, entity_descriptor) if ( - EntityPermission.CREATE not in user_permissions + EntityPermission.CREATE_RLS not in user_permissions and EntityPermission.CREATE_ALL not in user_permissions ): return await message.answer( @@ -438,8 +467,8 @@ async def process_field_edit_callback(message: Message | CallbackQuery, **kwargs text=(await Settings.get(Settings.APP_STRINGS_NOT_FOUND)) ) - if not check_entity_permission( - entity=entity, user=user, permission=EntityPermission.UPDATE + if not await check_entity_permission( + entity=entity, user=user, permission=EntityPermission.UPDATE_RLS ): return await message.answer( text=(await Settings.get(Settings.APP_STRINGS_FORBIDDEN)) diff --git a/src/quickbot/bot/handlers/forms/entity_form.py b/src/quickbot/bot/handlers/forms/entity_form.py index 90242dc..bd56562 100644 --- a/src/quickbot/bot/handlers/forms/entity_form.py +++ b/src/quickbot/bot/handlers/forms/entity_form.py @@ -18,8 +18,8 @@ from ....model.bot_entity import BotEntity from ....model.settings import Settings from ....model.user import UserBase from ....model import EntityPermission +from ....model.permissions import check_entity_permission from ....utils.main import ( - check_entity_permission, get_send_message, clear_state, get_value_repr, @@ -84,15 +84,15 @@ async def entity_item( # is_owned = issubclass(entity_type, OwnedBotEntity) - if query and not check_entity_permission( - entity=entity_item, user=user, permission=EntityPermission.READ + if query and not await check_entity_permission( + entity=entity_item, user=user, permission=EntityPermission.READ_RLS ): return await query.answer( text=(await Settings.get(Settings.APP_STRINGS_FORBIDDEN)) ) - can_edit = check_entity_permission( - entity=entity_item, user=user, permission=EntityPermission.UPDATE + can_edit = await check_entity_permission( + entity=entity_item, user=user, permission=EntityPermission.UPDATE_RLS ) form: EntityForm = entity_descriptor.forms.get( @@ -250,8 +250,8 @@ async def entity_item( ) if ( - check_entity_permission( - entity=entity_item, user=user, permission=EntityPermission.DELETE + await check_entity_permission( + entity=entity_item, user=user, permission=EntityPermission.DELETE_RLS ) and form.show_delete_button ): @@ -303,7 +303,7 @@ async def entity_item( ) -async def item_repr(entity_item: BotEntity, context: BotContext[UserBase]): +async def item_repr(entity_item: BotEntity, context: BotContext): entity_descriptor = entity_item.bot_entity_descriptor user = context.user entity_caption = ( @@ -349,20 +349,6 @@ async def item_repr(entity_item: BotEntity, context: BotContext[UserBase]): if not field_visible: continue - skip = False - - for own_field in entity_descriptor.ownership_fields.items(): - if ( - own_field[1].rstrip("_id") == field_descriptor.field_name.rstrip("_id") - and own_field[0] in user.roles - and EntityPermission.READ_ALL not in user_permissions - ): - skip = True - break - - if skip: - continue - if field_descriptor.caption_value: item_text += f"\n{ await get_callable_str( diff --git a/src/quickbot/bot/handlers/forms/entity_form_callbacks.py b/src/quickbot/bot/handlers/forms/entity_form_callbacks.py index b227101..f1cffe2 100644 --- a/src/quickbot/bot/handlers/forms/entity_form_callbacks.py +++ b/src/quickbot/bot/handlers/forms/entity_form_callbacks.py @@ -12,8 +12,8 @@ from ..context import ContextData, CallbackCommand from ....model.user import UserBase from ....model.settings import Settings from ....model import EntityPermission +from ....model.permissions import check_entity_permission from ....utils.main import ( - check_entity_permission, get_entity_item_repr, get_entity_descriptor, ) @@ -42,8 +42,8 @@ async def entity_delete_callback(query: CallbackQuery, **kwargs): session=db_session, id=int(callback_data.entity_id) ) - if not check_entity_permission( - entity=entity, user=user, permission=EntityPermission.DELETE + if not await check_entity_permission( + entity=entity, user=user, permission=EntityPermission.DELETE_RLS ): return await query.answer( text=(await Settings.get(Settings.APP_STRINGS_FORBIDDEN)) diff --git a/src/quickbot/bot/handlers/forms/entity_list.py b/src/quickbot/bot/handlers/forms/entity_list.py index 6d2d94f..1a38f15 100644 --- a/src/quickbot/bot/handlers/forms/entity_list.py +++ b/src/quickbot/bot/handlers/forms/entity_list.py @@ -95,7 +95,7 @@ async def entity_list( ) if ( - EntityPermission.CREATE in user_permissions + EntityPermission.CREATE_RLS in user_permissions or EntityPermission.CREATE_ALL in user_permissions ) and form_list.show_add_new_button: if form_item.edit_field_sequence: @@ -136,8 +136,8 @@ async def entity_list( ) if ( list_all - or EntityPermission.LIST in user_permissions - or EntityPermission.READ in user_permissions + or EntityPermission.LIST_RLS in user_permissions + or EntityPermission.READ_RLS in user_permissions ): if form_list.pagination: page_size = await Settings.get(Settings.PAGE_SIZE) @@ -265,10 +265,10 @@ async def entity_list( else: entity_text = entity_descriptor.name - if entity_descriptor.description: + if entity_descriptor.ui_description: entity_text = f"{entity_text} { await get_callable_str( - callable_str=entity_descriptor.description, + callable_str=entity_descriptor.ui_description, context=context, descriptor=entity_descriptor, ) diff --git a/src/quickbot/bot/handlers/menu/entities.py b/src/quickbot/bot/handlers/menu/entities.py index d8bc4ae..74aba6a 100644 --- a/src/quickbot/bot/handlers/menu/entities.py +++ b/src/quickbot/bot/handlers/menu/entities.py @@ -42,7 +42,7 @@ async def entities_menu( ): keyboard_builder = InlineKeyboardBuilder() - entity_metadata = app.entity_metadata + entity_metadata = app.bot_metadata for entity in entity_metadata.entity_descriptors.values(): if entity.show_in_entities_menu: diff --git a/src/quickbot/main.py b/src/quickbot/main.py index 1b10afc..b518921 100644 --- a/src/quickbot/main.py +++ b/src/quickbot/main.py @@ -1,5 +1,21 @@ +""" +main.py - QuickBot RAD Framework Main Application Module + +Defines QBotApp, the main entry point for the QuickBot rapid application development (RAD) framework for Telegram bots. +Integrates FastAPI (HTTP API), Aiogram (Telegram bot logic), SQLModel (async DB), and i18n (internationalization). + +Key Features: +- Dynamic registration of CRUD API endpoints for all entities +- Telegram bot command and webhook management +- Row-level security (RLS) and user management +- Middleware for authentication and localization +- Swagger UI with Telegram login integration +""" + from contextlib import asynccontextmanager -from typing import Callable, Any, Generic, TypeVar +from inspect import iscoroutinefunction +from typing import Union +from typing import Annotated, Callable, Any, Generic, TypeVar from aiogram import Bot, Dispatcher from aiogram.client.session.aiohttp import AiohttpSession from aiogram.client.telegram import TelegramAPIServer @@ -7,15 +23,21 @@ from aiogram.client.default import DefaultBotProperties from aiogram.types import Message, BotCommand as AiogramBotCommand from aiogram.utils.callback_answer import CallbackAnswerMiddleware from aiogram.utils.i18n import I18n -from fastapi import FastAPI, Request +from fastapi import Depends, FastAPI, Request, Body, Path, HTTPException from fastapi.applications import Lifespan, AppType from fastapi.datastructures import State from logging import getLogger +from fastapi.responses import HTMLResponse, JSONResponse +from fastapi.openapi.utils import get_openapi from sqlmodel.ext.asyncio.session import AsyncSession from quickbot.bot.handlers.user_handlers.main import command_handler +from quickbot.db import get_db +from quickbot.model.list_schema import ListSchema +from quickbot.plugin import Registerable from quickbot.utils.main import clear_state from quickbot.utils.navigation import save_navigation_context +from quickbot.model.crud_service import NotFoundError, ForbiddenError from .config import Config from .bot.handlers.forms.entity_form import entity_item @@ -23,10 +45,18 @@ from .fsm.db_storage import DbStorage from .middleware.telegram import AuthMiddleware, I18nMiddleware from .model.bot_entity import BotEntity from .model.user import UserBase -from .model.entity_metadata import EntityMetadata -from .model.descriptors import BotCommand +from .model.bot_metadata import BotMetadata +from .model.descriptors import ( + BotCommand, + EntityDescriptor, + ProcessDescriptor, + BotContext, +) +from .model.crud_command import CrudCommand from .bot.handlers.context import CallbackCommand, ContextData from .router import Router +from .api_route.models import list_entity_items, get_me +from .api_route.depends import get_current_user logger = getLogger(__name__) @@ -60,7 +90,20 @@ async def default_lifespan(app: "QBotApp"): class QBotApp(Generic[UserType, ConfigType], FastAPI): """ - Main class for the QBot application + Main application class for QuickBot RAD framework. + + Integrates FastAPI, Aiogram, SQLModel, and i18n for rapid Telegram bot development. + Handles bot initialization, API registration, command routing, and RLS. + + Args: + config: App configuration (see quickbot/config.py) + user_class: User model class (subclass of UserBase) + bot_start: Optional custom bot start handler + lifespan: Optional FastAPI lifespan context + lifespan_bot_init: Whether to run bot init on startup + lifespan_set_webhook: Whether to set webhook on startup + webhook_handler: Optional custom webhook handler + allowed_updates: List of Telegram update types to allow """ def __init__( @@ -82,6 +125,7 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI): allowed_updates: list[str] | None = None, **kwargs, ): + # --- Initialize default user class if not provided --- if user_class is None: from .model.default_user import DefaultUser @@ -92,14 +136,18 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI): ) self.user_class = user_class - self.entity_metadata: EntityMetadata = user_class.entity_metadata + self.bot_metadata: BotMetadata = user_class.bot_metadata + self.bot_metadata.app = self self.config = config self.lifespan = lifespan + # --- Setup Telegram API server and session --- api_server = TelegramAPIServer.from_base( self.config.TELEGRAM_BOT_SERVER, is_local=self.config.TELEGRAM_BOT_SERVER_IS_LOCAL, ) session = AiohttpSession(api=api_server) + + # --- Initialize Telegram Bot instance --- self.bot = Bot( token=self.config.TELEGRAM_BOT_TOKEN, session=session, @@ -108,26 +156,30 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI): ), ) + # --- Setup Aiogram dispatcher with DB storage for FSM --- dp = Dispatcher(storage=DbStorage()) + # --- Setup i18n and middleware --- self.i18n = I18n(path="locales", default_locale="en", domain="messages") i18n_middleware = I18nMiddleware(user_class=user_class, i18n=self.i18n) i18n_middleware.setup(dp) dp.callback_query.middleware(CallbackAnswerMiddleware()) + # --- Register core routers (start, main menu) --- from .bot.handlers.start import router as start_router dp.include_router(start_router) - from .bot.handlers.menu.main import router as main_menu_router - auth = AuthMiddleware(user_class=user_class) - main_menu_router.message.middleware.register(auth) - main_menu_router.callback_query.middleware.register(auth) + # Register authentication middleware for menu routers + self.auth = AuthMiddleware(user_class=user_class) + main_menu_router.message.middleware.register(self.auth) + main_menu_router.callback_query.middleware.register(self.auth) dp.include_router(main_menu_router) self.dp = dp + # --- Extension points for custom bot start and webhook handlers --- self.start_handler = bot_start self.webhook_handler = webhook_handler self.bot_commands = dict[str, BotCommand]() @@ -135,8 +187,15 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI): self.lifespan_bot_init = lifespan_bot_init self.lifespan_set_webhook = lifespan_set_webhook - super().__init__(lifespan=default_lifespan, **kwargs) + # --- Initialize FastAPI with custom lifespan and no default docs --- + super().__init__(lifespan=default_lifespan, docs_url=None, **kwargs) + self.bot_metadata.app_state = self.state + + # --- Initialize plugins --- + self.plugins = dict[str, Any]() + + # --- Register Telegram API router for /telegram endpoints (for webhook and auth) --- from .api_route.telegram import router as telegram_router self.include_router(telegram_router, prefix="/telegram", tags=["telegram"]) @@ -144,29 +203,422 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI): self.root_router._commands = self.bot_commands self.command = self.root_router.command + # --- Register all entity CRUD endpoints dynamically (for models API) --- + self.register_models_api() + # --- Register custom Swagger UI with Telegram login (for docs) --- + self.register_swagger_ui_html() + + def register_plugin(self, plugin: Any): + self.plugins[type(plugin).__name__] = plugin + if isinstance(plugin, Registerable): + plugin.register(self) + + def register_swagger_ui_html(self): + """ + Register a custom /docs endpoint with Telegram login widget and JWT support for Swagger UI. + """ + + def swagger_ui_html(): + return HTMLResponse(f""" + + + + + + + QuickBot API + + + +
+ + + + + """) + + self.router.add_api_route( + path="/docs", + include_in_schema=False, + endpoint=swagger_ui_html, + methods=["GET"], + tags=["docs"], + ) + + def openapi_json(): + schema = get_openapi( + title="FastAPI + Telegram OAuth", + version="1.0.0", + description="Swagger с Telegram Login", + routes=self.routes, + ) + schema["components"]["securitySchemes"] = { + "bearerAuth": { + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT", + } + } + for path in schema["paths"].values(): + for op in path.values(): + op.setdefault("security", [{"bearerAuth": []}]) + return JSONResponse(schema) + + self.router.add_api_route( + path="/openapi.json", + endpoint=openapi_json, + methods=["GET"], + tags=["docs"], + include_in_schema=False, + ) + + def register_models_api(self): + """ + Dynamically register CRUD API endpoints for all entities in the app's metadata. + Endpoints: list, create, get by id, update, delete. + Uses FastAPI dependency injection for database session and user authentication. + """ + + def make_create_api_endpoint(entity_descriptor: EntityDescriptor): + async def create_entity( + db_session: Annotated[AsyncSession, Depends(get_db)], + request: Request, + obj_in: entity_descriptor.crud.create_schema = Body(...), + current_user=Depends(get_current_user), + ): + try: + ret_obj = await entity_descriptor.crud.create( + db_session=db_session, + user=current_user, + model=obj_in, + ) + except NotFoundError as e: + raise HTTPException(status_code=404, detail=e.args[0]) + except ForbiddenError as e: + raise HTTPException(status_code=403, detail=e.args[0]) + except Exception as e: + logger.error(f"Error creating entity: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + return ret_obj + + return create_entity + + def make_update_api_endpoint(entity_descriptor: EntityDescriptor): + async def update_entity( + db_session: Annotated[AsyncSession, Depends(get_db)], + request: Request, + id: int = Path(..., description="ID of the entity to update"), + obj_in: entity_descriptor.crud.update_schema = Body(...), + current_user=Depends(get_current_user), + ): + try: + entity = await entity_descriptor.crud.update( + db_session=db_session, + id=id, + model=obj_in, + user=current_user, + ) + except NotFoundError as e: + raise HTTPException(status_code=404, detail=e.args[0]) + except ForbiddenError as e: + raise HTTPException(status_code=403, detail=e.args[0]) + except Exception as e: + logger.error(f"Error updating entity: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + return entity + + return update_entity + + def make_delete_api_endpoint(entity_descriptor: EntityDescriptor): + async def delete_entity( + db_session: Annotated[AsyncSession, Depends(get_db)], + request: Request, + id: int = Path(..., description="ID of the entity to delete"), + current_user=Depends(get_current_user), + ): + try: + entity = await entity_descriptor.crud.delete( + db_session=db_session, + id=id, + user=current_user, + ) + except NotFoundError as e: + raise HTTPException(status_code=404, detail=e.args[0]) + except ForbiddenError as e: + raise HTTPException(status_code=403, detail=e.args[0]) + except Exception as e: + logger.error(f"Error deleting entity: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + return entity + + return delete_entity + + def make_get_by_id_api_endpoint(entity_descriptor: EntityDescriptor): + async def get_entity_by_id( + db_session: Annotated[AsyncSession, Depends(get_db)], + request: Request, + id: int = Path(..., description="ID of the entity to get"), + current_user=Depends(get_current_user), + ): + try: + entity = await entity_descriptor.type_.get( + session=db_session, + id=id, + ) + except NotFoundError as e: + raise HTTPException(status_code=404, detail=e.args[0]) + except ForbiddenError as e: + raise HTTPException(status_code=403, detail=e.args[0]) + except Exception as e: + logger.error(f"Error getting entity: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + return entity + + return get_entity_by_id + + def make_process_api_endpoint(process_descriptor: ProcessDescriptor): + async def run_process( + db_session: Annotated[AsyncSession, Depends(get_db)], + current_user: Annotated[UserBase, Depends(get_current_user)], + request: Request, + obj_in: process_descriptor.input_schema = Body(...), + ): + for role in current_user.roles: + if role in process_descriptor.roles: + break + else: + raise HTTPException(status_code=403, detail="Forbidden") + + run_func = process_descriptor.process_class.run + bot_context = BotContext( + db_session=db_session, + app=current_user.bot_metadata.app, + app_state=current_user.bot_metadata.app_state, + user=current_user, + ) + + try: + if iscoroutinefunction(run_func): + result = await run_func( + context=bot_context, + parameters=obj_in, + ) + else: + result = run_func( + context=bot_context, + parameters=obj_in, + ) + return result + except Exception as e: + logger.error(f"Error running process: {e}") + raise HTTPException(status_code=500, detail="Internal server error") + + return run_process + + for entity_descriptor in self.bot_metadata.entity_descriptors.values(): + if issubclass(entity_descriptor.type_, UserBase): + self.router.add_api_route( + path=f"/models/{entity_descriptor.name}/me", + methods=["GET"], + endpoint=get_me, + response_model=entity_descriptor.crud.schema, + summary="Get current user", + description="Get current user", + tags=["models"], + ) + + if CrudCommand.LIST in entity_descriptor.crud.commands: + self.router.add_api_route( + path=f"/models/{entity_descriptor.name}", + endpoint=list_entity_items, + methods=["GET"], + response_model=list[ + Union[entity_descriptor.crud.schema, ListSchema] + ], + summary=f"List {entity_descriptor.name}", + description=f"List {entity_descriptor.name}", + tags=["models"], + ) + + if CrudCommand.CREATE in entity_descriptor.crud.commands: + self.router.add_api_route( + path=f"/models/{entity_descriptor.name}", + methods=["POST"], + endpoint=make_create_api_endpoint(entity_descriptor), + response_model=entity_descriptor.crud.schema, + summary=f"Create {entity_descriptor.name}", + description=f"Create {entity_descriptor.name}", + tags=["models"], + ) + + if CrudCommand.GET_BY_ID in entity_descriptor.crud.commands: + self.router.add_api_route( + path=f"/models/{entity_descriptor.name}/{{id}}", + methods=["GET"], + endpoint=make_get_by_id_api_endpoint(entity_descriptor), + response_model=entity_descriptor.crud.schema, + summary=f"Get {entity_descriptor.name} by id", + description=f"Get {entity_descriptor.name} by id", + tags=["models"], + ) + + if CrudCommand.UPDATE in entity_descriptor.crud.commands: + self.router.add_api_route( + path=f"/models/{entity_descriptor.name}/{{id}}", + methods=["PATCH"], + endpoint=make_update_api_endpoint(entity_descriptor), + response_model=Union[entity_descriptor.crud.schema, ListSchema], + summary=f"Update {entity_descriptor.name}", + description=f"Update {entity_descriptor.name}", + tags=["models"], + ) + + if CrudCommand.DELETE in entity_descriptor.crud.commands: + self.router.add_api_route( + path=f"/models/{entity_descriptor.name}/{{id}}", + methods=["DELETE"], + endpoint=make_delete_api_endpoint(entity_descriptor), + response_model=Union[entity_descriptor.crud.schema, ListSchema], + summary=f"Delete {entity_descriptor.name}", + description=f"Delete {entity_descriptor.name}", + tags=["models"], + ) + + for process_descriptor in self.bot_metadata.process_descriptors.values(): + self.router.add_api_route( + path=f"/processes/{process_descriptor.name}", + methods=["POST"], + endpoint=make_process_api_endpoint(process_descriptor), + response_model=process_descriptor.output_schema, + summary=f"Run {process_descriptor.name}", + description=process_descriptor.description, + tags=["processes"], + ) + def register_routers(self, *routers: Router): + # Register additional routers and their commands with the application. + # This allows modular extension of bot command sets and menu trees. for router in routers: for command_name, command in router._commands.items(): self.bot_commands[command_name] = command async def bot_init(self): + # --- Set up bot commands for all locales --- + # This method collects all commands (with captions) that should be shown in the Telegram UI. + # It supports localization by grouping commands by locale. commands_captions = dict[str, list[tuple[str, str]]]() for command_name, command in self.bot_commands.items(): if command.show_in_bot_commands: if isinstance(command.caption, str) or command.caption is None: + # Default locale (or no caption provided) if "default" not in commands_captions: commands_captions["default"] = [] commands_captions["default"].append( (command_name, command.caption or command_name) ) else: + # Localized captions per locale for locale, description in command.caption.items(): locale = "default" if locale == "en" else locale if locale not in commands_captions: commands_captions[locale] = [] commands_captions[locale].append((command_name, description)) + # Register commands with Telegram for each locale for locale, commands in commands_captions.items(): await self.bot.set_my_commands( [ @@ -177,6 +629,8 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI): ) async def set_webhook(self): + # --- Set Telegram webhook for receiving updates --- + # This is called on startup if lifespan_set_webhook is True. await self.bot.set_webhook( url=f"{self.config.TELEGRAM_WEBHOOK_URL}/telegram/webhook", drop_pending_updates=True, @@ -194,6 +648,8 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI): form_name: str = None, form_params: list[Any] = None, ): + # --- Show a form for a specific entity instance to a user --- + # Used for interactive entity editing or viewing in the Telegram bot UI. f_params = [] if form_name: @@ -202,9 +658,11 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI): if form_params: f_params.extend([str(p) for p in form_params]) + # Allow passing entity as class or string name if isinstance(entity, type): entity = entity.bot_entity_descriptor.name + # Prepare callback data for navigation stack callback_data = ContextData( command=CallbackCommand.ENTITY_ITEM, entity_name=entity, @@ -212,6 +670,7 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI): form_params="&".join(f_params), ) + # Get FSM state context for the user state = self.dp.fsm.get_context(bot=self.bot, chat_id=user_id, user_id=user_id) state_data = await state.get_data() clear_state(state_data=state_data) @@ -220,11 +679,13 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI): ) await state.set_data(state_data) + # Fetch user object for locale and permissions user = await self.user_class.get( session=db_session, id=user_id, ) + # Use i18n context for the user's language with self.i18n.context(), self.i18n.use_locale(user.lang.value): await entity_item( query=None, @@ -246,6 +707,8 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI): user_id: int, db_session: AsyncSession, ): + # --- Execute a user command in the Telegram bot context --- + # This is used for programmatically triggering bot commands (e.g., from API or callback). state = self.dp.fsm.get_context(bot=self.bot, chat_id=user_id, user_id=user_id) state_data = await state.get_data() callback_data = ContextData( @@ -255,22 +718,27 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI): command_name = command.split("&")[0] cmd = self.bot_commands.get(command_name) + # Fetch user object for permissions and locale user = await self.user_class.get( session=db_session, id=user_id, ) if cmd is None: + # Command not found (could be a custom or unregistered command) return + # Optionally clear navigation stack if command requires it if cmd.clear_navigation: state_data.pop("navigation_stack", None) state_data.pop("navigation_context", None) + # Optionally register navigation context for this command if cmd.register_navigation: clear_state(state_data=state_data) save_navigation_context(callback_data=callback_data, state_data=state_data) + # Use i18n context for the user's language with self.i18n.context(), self.i18n.use_locale(user.lang.value): await command_handler( message=None, diff --git a/src/quickbot/middleware/telegram/auth.py b/src/quickbot/middleware/telegram/auth.py index 89a23fa..36b80ca 100644 --- a/src/quickbot/middleware/telegram/auth.py +++ b/src/quickbot/middleware/telegram/auth.py @@ -23,9 +23,15 @@ class AuthMiddleware(BaseMiddleware): event: TelegramObject, data: Dict[str, Any], ) -> Any: - user = await self.user_class.get( - id=event.from_user.id, session=data["db_session"] - ) + if event.business_connection_id: + business_connection = await event.bot.get_business_connection( + event.business_connection_id + ) + user_id = business_connection.user.id + else: + user_id = event.from_user.id + + user = await self.user_class.get(id=user_id, session=data["db_session"]) if user and user.is_active: data["user"] = user diff --git a/src/quickbot/model/__init__.py b/src/quickbot/model/__init__.py index 50ee14d..44968fe 100644 --- a/src/quickbot/model/__init__.py +++ b/src/quickbot/model/__init__.py @@ -8,11 +8,11 @@ from ..db import async_session class EntityPermission(BotEnum): - LIST = EnumMember("list") - READ = EnumMember("read") - CREATE = EnumMember("create") - UPDATE = EnumMember("update") - DELETE = EnumMember("delete") + LIST_RLS = EnumMember("list_rls") + READ_RLS = EnumMember("read_rls") + CREATE_RLS = EnumMember("create_rls") + UPDATE_RLS = EnumMember("update_rls") + DELETE_RLS = EnumMember("delete_rls") LIST_ALL = EnumMember("list_all") READ_ALL = EnumMember("read_all") CREATE_ALL = EnumMember("create_all") diff --git a/src/quickbot/model/annotated_schema.py b/src/quickbot/model/annotated_schema.py new file mode 100644 index 0000000..f7f0c4a --- /dev/null +++ b/src/quickbot/model/annotated_schema.py @@ -0,0 +1,46 @@ +""" +BotEntity module provides a metaclass and base class for creating database entities +with enhanced functionality for bot operations, including field descriptors, +filtering, and ownership management. +""" + +from typing import ( + TYPE_CHECKING, + dataclass_transform, +) +from pydantic import BaseModel +from pydantic._internal._model_construction import ModelMetaclass +from sqlmodel import Field +from sqlmodel.main import FieldInfo + + +from .descriptors import EntityField, FieldDescriptor + +if TYPE_CHECKING: + pass + + +@dataclass_transform( + kw_only_default=True, + field_specifiers=(Field, FieldInfo, EntityField, FieldDescriptor), +) +class AnnotatedSchemaMetaclass(ModelMetaclass): + """ + Metaclass for annotated schemas. + """ + + def __new__(mcs, name, bases, namespace, **kwargs): + """ + Create a new annotated schema. + """ + + # --- Create the class using parent metaclass --- + type_ = super().__new__(mcs, name, bases, namespace, **kwargs) + + return type_ + + +class AnnotatedSchema(BaseModel, metaclass=AnnotatedSchemaMetaclass): + """ + Base class for annotated schemas. + """ diff --git a/src/quickbot/model/bot_entity.py b/src/quickbot/model/bot_entity.py index 0307e3f..88d0347 100644 --- a/src/quickbot/model/bot_entity.py +++ b/src/quickbot/model/bot_entity.py @@ -1,28 +1,47 @@ +""" +BotEntity module provides a metaclass and base class for creating database entities +with enhanced functionality for bot operations, including field descriptors, +filtering, and ownership management. +""" + from types import NoneType, UnionType from typing import ( Any, ClassVar, ForwardRef, Optional, - Self, Union, get_args, get_origin, TYPE_CHECKING, dataclass_transform, + Self, ) from pydantic import BaseModel +from pydantic.fields import _Unset from pydantic_core import PydanticUndefined -from sqlmodel import SQLModel, BigInteger, Field, select, func, column, col +from sqlmodel import SQLModel, BigInteger, Field, select, func from sqlmodel.main import FieldInfo from sqlmodel.ext.asyncio.session import AsyncSession -from sqlmodel.sql.expression import SelectOfScalar from sqlmodel.main import SQLModelMetaclass, RelationshipInfo -from .descriptors import EntityDescriptor, EntityField, FieldDescriptor, Filter -from .entity_metadata import EntityMetadata +from .descriptors import ( + EntityDescriptor, + EntityField, + FieldDescriptor, + Filter, + FilterExpression, +) +from .bot_metadata import BotMetadata +from .crud_service import CrudService from . import session_dep +from .utils import ( + _static_filter_condition, + _build_filter_condition, + _filter_condition, + _apply_rls_filters, +) if TYPE_CHECKING: from .user import UserBase @@ -33,11 +52,34 @@ if TYPE_CHECKING: field_specifiers=(Field, FieldInfo, EntityField, FieldDescriptor), ) class BotEntityMetaclass(SQLModelMetaclass): + """ + Metaclass for BotEntity that handles field processing, descriptor creation, + and type resolution for bot-specific database entities. + + This metaclass extends SQLModelMetaclass to provide additional functionality + for bot operations including field descriptors, type annotations, and + entity metadata management. + """ + + # Store future references for forward-declared types _future_references = {} def __new__(mcs, name, bases, namespace, **kwargs): + """ + Create a new class with processed field descriptors and metadata. + + Args: + name: Name of the class being created + bases: Base classes + namespace: Class namespace containing attributes and annotations + **kwargs: Additional keyword arguments passed to the metaclass + + Returns: + The created class with processed field descriptors and metadata + """ bot_fields_descriptors = {} + # --- Inherit field descriptors from parent classes (if any) --- if bases: bot_entity_descriptor = bases[0].__dict__.get("bot_entity_descriptor") bot_fields_descriptors = ( @@ -49,52 +91,67 @@ class BotEntityMetaclass(SQLModelMetaclass): else {} ) + # --- Process field annotations to create field descriptors --- if "__annotations__" in namespace: for annotation in namespace["__annotations__"]: - if annotation in ["bot_entity_descriptor", "entity_metadata"]: + # Skip special attributes + if annotation in ["bot_entity_descriptor", "bot_metadata"]: continue - attribute_value = namespace.get(annotation) + attribute_value = namespace.get(annotation, PydanticUndefined) + # Skip relationship fields (handled by SQLModel) if isinstance(attribute_value, RelationshipInfo): continue descriptor_kwargs = {} descriptor_name = annotation - if attribute_value: + # --- Process EntityField attributes to extract SQLModel field descriptors --- + if attribute_value is not PydanticUndefined: if isinstance(attribute_value, EntityField): descriptor_kwargs = attribute_value.__dict__.copy() + # Extract SQLModel field descriptor if present sm_descriptor = descriptor_kwargs.pop("sm_descriptor", None) # type: FieldInfo if sm_descriptor: + # Transfer default values from EntityField to SQLModel descriptor if ( - attribute_value.default is not None + attribute_value.default is not PydanticUndefined and sm_descriptor.default is PydanticUndefined ): sm_descriptor.default = attribute_value.default if ( attribute_value.default_factory is not None - and sm_descriptor.default_factory is PydanticUndefined + and sm_descriptor.default_factory is None ): sm_descriptor.default_factory = ( attribute_value.default_factory ) + if attribute_value.description is not PydanticUndefined: + sm_descriptor.description = attribute_value.description else: + # Create new SQLModel field descriptor if none exists if ( attribute_value.default is not None or attribute_value.default_factory is not None ): sm_descriptor = Field() - if attribute_value.default is not None: + if attribute_value.default is not PydanticUndefined: sm_descriptor.default = attribute_value.default if attribute_value.default_factory is not None: sm_descriptor.default_factory = ( attribute_value.default_factory ) + if attribute_value.description is not PydanticUndefined: + sm_descriptor.description = ( + attribute_value.description + ) + # Clean up internal attributes descriptor_kwargs.pop("__orig_class__", None) + # Replace EntityField with SQLModel field descriptor in namespace if sm_descriptor: namespace[annotation] = sm_descriptor else: @@ -102,10 +159,28 @@ class BotEntityMetaclass(SQLModelMetaclass): descriptor_name = descriptor_kwargs.pop("name") or annotation + elif isinstance(attribute_value, FieldInfo): + if attribute_value.default is not PydanticUndefined: + descriptor_kwargs["default"] = attribute_value.default + if attribute_value.default_factory is not None: + descriptor_kwargs["default_factory"] = ( + attribute_value.default_factory + ) + if attribute_value.description is not _Unset: + descriptor_kwargs["description"] = ( + attribute_value.description + ) + + elif isinstance(attribute_value, RelationshipInfo): + pass + + else: + descriptor_kwargs["default"] = attribute_value + + # --- Get the type annotation for the field --- type_ = namespace["__annotations__"][annotation] - type_origin = get_origin(type_) - + # --- Create field descriptor with basic information --- field_descriptor = FieldDescriptor( name=descriptor_name, field_name=annotation, @@ -114,12 +189,18 @@ class BotEntityMetaclass(SQLModelMetaclass): **descriptor_kwargs, ) + # --- Process type annotations to determine if field is list or optional --- + type_origin = get_origin(type_) + is_list = False is_optional = False + + # Handle list types (e.g., List[str]) if type_origin is list: field_descriptor.is_list = is_list = True field_descriptor.type_base = type_ = get_args(type_)[0] + # Handle Union types for optional fields (e.g., Optional[str]) if type_origin is Union: args = get_args(type_) if isinstance(args[0], ForwardRef): @@ -129,16 +210,17 @@ class BotEntityMetaclass(SQLModelMetaclass): field_descriptor.is_optional = is_optional = True field_descriptor.type_base = type_ = args[0] + # Handle Python 3.10+ UnionType (e.g., str | None) if type_origin is UnionType and get_args(type_)[1] is NoneType: field_descriptor.is_optional = is_optional = True field_descriptor.type_base = type_ = get_args(type_)[0] + # --- Handle string type references (forward references to other entities) --- if isinstance(type_, str): type_not_found = True - for ( - entity_descriptor - ) in EntityMetadata().entity_descriptors.values(): + for entity_descriptor in BotMetadata().entity_descriptors.values(): if type_ == entity_descriptor.class_name: + # Resolve the type to the actual entity class field_descriptor.type_base = entity_descriptor.type_ field_descriptor.type_ = ( list[entity_descriptor.type_] @@ -155,6 +237,8 @@ class BotEntityMetaclass(SQLModelMetaclass): ) type_not_found = False break + + # If type not found, store for future resolution if type_not_found: if type_ in mcs._future_references: mcs._future_references[type_].append(field_descriptor) @@ -163,15 +247,17 @@ class BotEntityMetaclass(SQLModelMetaclass): bot_fields_descriptors[descriptor_name] = field_descriptor + # --- Process entity descriptor configuration --- descriptor_name = name if "bot_entity_descriptor" in namespace: + # Extract and process custom entity descriptor entity_descriptor = namespace.pop("bot_entity_descriptor") descriptor_kwargs: dict = entity_descriptor.__dict__.copy() descriptor_name = descriptor_kwargs.pop("name", None) descriptor_kwargs.pop("__orig_class__", None) descriptor_name = descriptor_name or name.lower() - namespace["bot_entity_descriptor"] = EntityDescriptor( + entity_descriptor = namespace["bot_entity_descriptor"] = EntityDescriptor( name=descriptor_name, class_name=name, type_=name, @@ -179,39 +265,41 @@ class BotEntityMetaclass(SQLModelMetaclass): **descriptor_kwargs, ) else: + # Create default entity descriptor descriptor_name = name.lower() - namespace["bot_entity_descriptor"] = EntityDescriptor( + entity_descriptor = namespace["bot_entity_descriptor"] = EntityDescriptor( name=descriptor_name, class_name=name, type_=name, fields_descriptors=bot_fields_descriptors, ) + # --- Link field descriptors to their entity descriptor --- for field_descriptor in bot_fields_descriptors.values(): - field_descriptor.entity_descriptor = namespace["bot_entity_descriptor"] + field_descriptor.entity_descriptor = entity_descriptor + # --- Configure table settings (set to True by default) --- if "table" not in kwargs: kwargs["table"] = True + # --- If table is set to True, register entity in global metadata --- if kwargs["table"]: - entity_metadata = EntityMetadata() - entity_metadata.entity_descriptors[descriptor_name] = namespace[ - "bot_entity_descriptor" - ] + # Register entity in global metadata + entity_metadata = BotMetadata() + entity_metadata.entity_descriptors[descriptor_name] = entity_descriptor + # Add entity_metadata to class annotations if "__annotations__" in namespace: - namespace["__annotations__"]["entity_metadata"] = ClassVar[ - EntityMetadata - ] + namespace["__annotations__"]["bot_metadata"] = ClassVar[BotMetadata] else: - namespace["__annotations__"] = { - "entity_metadata": ClassVar[EntityMetadata] - } + namespace["__annotations__"] = {"bot_metadata": ClassVar[BotMetadata]} - namespace["entity_metadata"] = entity_metadata + namespace["bot_metadata"] = entity_metadata + # --- Create the class using parent metaclass --- type_ = super().__new__(mcs, name, bases, namespace, **kwargs) + # --- Resolve future references now that the class exists --- if name in mcs._future_references: for field_descriptor in mcs._future_references[name]: type_origin = get_origin(field_descriptor.type_) @@ -230,78 +318,63 @@ class BotEntityMetaclass(SQLModelMetaclass): ) ) - setattr(namespace["bot_entity_descriptor"], "type_", type_) + # --- Set the resolved type in the entity descriptor --- + entity_descriptor.type_ = type_ + # setattr(entity_descriptor, "type_", type_) + + if kwargs["table"] and entity_descriptor.crud is None: + entity_descriptor.crud = CrudService(entity_descriptor) return type_ -class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel]( - SQLModel, metaclass=BotEntityMetaclass, table=False -): - bot_entity_descriptor: ClassVar[EntityDescriptor] - entity_metadata: ClassVar[EntityMetadata] +class BotEntity(SQLModel, metaclass=BotEntityMetaclass, table=False): + """ + Base class for bot entities that provides CRUD operations, filtering, + and Row Level Security (RLS) capabilities. + This class extends SQLModel and uses BotEntityMetaclass to provide + enhanced functionality for bot operations including: + - Field descriptors for UI generation + - Advanced filtering and search capabilities + - Row Level Security (RLS) access control + - Standardized CRUD operations + """ + + # Class variables set by the metaclass + bot_entity_descriptor: ClassVar[EntityDescriptor] + bot_metadata: ClassVar[BotMetadata] + + # Standard ID field for all entities id: int = EntityField( - sm_descriptor=Field(primary_key=True, sa_type=BigInteger), is_visible=False + sm_descriptor=Field(primary_key=True, sa_type=BigInteger), + is_visible=False, + default=None, ) @classmethod @session_dep - async def get(cls, *, session: AsyncSession | None = None, id: int): - return await session.get(cls, id) - - @classmethod - def _static_filter_condition( - cls, select_statement: SelectOfScalar[Self], static_filter: list[Filter] - ): - for sfilt in static_filter: - column = getattr(cls, sfilt.field_name) - if sfilt.operator == "==": - condition = column.__eq__(sfilt.value) - elif sfilt.operator == "!=": - condition = column.__ne__(sfilt.value) - elif sfilt.operator == "<": - condition = column.__lt__(sfilt.value) - elif sfilt.operator == "<=": - condition = column.__le__(sfilt.value) - elif sfilt.operator == ">": - condition = column.__gt__(sfilt.value) - elif sfilt.operator == ">=": - condition = column.__ge__(sfilt.value) - elif sfilt.operator == "ilike": - condition = col(column).ilike(f"%{sfilt.value}%") - elif sfilt.operator == "like": - condition = col(column).like(f"%{sfilt.value}%") - elif sfilt.operator == "in": - condition = col(column).in_(sfilt.value) - elif sfilt.operator == "not in": - condition = col(column).notin_(sfilt.value) - elif sfilt.operator == "is none": - condition = col(column).is_(None) - elif sfilt.operator == "is not none": - condition = col(column).isnot(None) - elif sfilt.operator == "contains": - condition = sfilt.value == col(column).any_() - else: - condition = None - if condition is not None: - select_statement = select_statement.where(condition) - return select_statement - - @classmethod - def _filter_condition( + async def get( cls, - select_statement: SelectOfScalar[Self], - filter: str, - filter_fields: list[str], - ): - condition = None - for field in filter_fields: - if condition is not None: - condition = condition | (column(field).ilike(f"%{filter}%")) - else: - condition = column(field).ilike(f"%{filter}%") - return select_statement.where(condition) + *, + session: AsyncSession | None = None, + id: int, + user: "UserBase | None" = None, + ) -> Self: + """ + Retrieve a single entity by ID. + + Args: + session: Database session (injected by session_dep) + id: Entity ID to retrieve + + Returns: + The entity instance or None if not found + """ + select_statement = select(cls).where(cls.id == id) + if user: + select_statement = await _apply_rls_filters(cls, select_statement, user) + return await session.scalar(select_statement) @classmethod @session_dep @@ -309,28 +382,49 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel]( cls, *, session: AsyncSession | None = None, - static_filter: list[Filter] | Any = None, + user: "UserBase", + static_filter: Filter | FilterExpression | Any = None, filter: str = None, filter_fields: list[str] = None, ext_filter: Any = None, - user: "UserBase" = None, ) -> int: + """ + Get the count of entities matching the specified criteria. + + Args: + session: Database session (injected by session_dep) + static_filter: List of static filter conditions + filter: Text search filter + filter_fields: Fields to search in for text filter + ext_filter: Additional custom filter conditions + user: User for RLS-based filtering + + Returns: + Count of matching entities + """ + # --- Build select statement for counting entities --- select_statement = select(func.count()).select_from(cls) + + # --- Apply various filter conditions --- if static_filter: if isinstance(static_filter, list): - select_statement = cls._static_filter_condition( + select_statement = _static_filter_condition( select_statement, static_filter ) else: - select_statement = select_statement.where(static_filter) + # Handle single Filter or FilterExpression object + condition = _build_filter_condition(cls, static_filter) + if condition is not None: + select_statement = select_statement.where(condition) if filter and filter_fields: - select_statement = cls._filter_condition( + select_statement = _filter_condition( select_statement, filter, filter_fields ) if ext_filter: select_statement = select_statement.where(ext_filter) - if user: - select_statement = cls._ownership_condition(select_statement, user) + + select_statement = await _apply_rls_filters(cls, select_statement, user) + return await session.scalar(select_statement) @classmethod @@ -339,56 +433,60 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel]( cls, *, session: AsyncSession | None = None, + user: "UserBase | None" = None, order_by=None, - static_filter: list[Filter] | Any = None, + static_filter: Filter | FilterExpression | Any = None, filter: str = None, filter_fields: list[str] = None, ext_filter: Any = None, - user: "UserBase" = None, skip: int = 0, limit: int = None, - ): + ) -> list[Self]: + """ + Retrieve multiple entities with filtering, pagination, and ordering. + + Args: + session: Database session (injected by session_dep) + order_by: Ordering criteria + static_filter: List of static filter conditions + filter: Text search filter + filter_fields: Fields to search in for text filter + ext_filter: Additional custom filter conditions + user: User for RLS-based filtering + skip: Number of records to skip (for pagination) + limit: Maximum number of records to return + + Returns: + List of matching entity instances + """ + # --- Build select statement for entity retrieval --- select_statement = select(cls).offset(skip) if limit: select_statement = select_statement.limit(limit) + + # --- Apply various filter conditions --- if static_filter is not None: if isinstance(static_filter, list): - select_statement = cls._static_filter_condition( - select_statement, static_filter + select_statement = _static_filter_condition( + cls, select_statement, static_filter ) else: - select_statement = select_statement.where(static_filter) + # Handle single Filter or FilterExpression object + condition = _build_filter_condition(cls, static_filter) + if condition is not None: + select_statement = select_statement.where(condition) if filter and filter_fields: - select_statement = cls._filter_condition( - select_statement, filter, filter_fields + select_statement = _filter_condition( + cls, select_statement, filter, filter_fields ) if ext_filter is not None: select_statement = select_statement.where(ext_filter) if user: - select_statement = cls._ownership_condition(select_statement, user) - if order_by is not None: + select_statement = await _apply_rls_filters(cls, select_statement, user) + if order_by: select_statement = select_statement.order_by(order_by) - return (await session.exec(select_statement)).all() - @classmethod - def _ownership_condition( - cls, select_statement: SelectOfScalar[Self], user: "UserBase" - ): - if cls.bot_entity_descriptor.ownership_fields: - condition = None - for role in user.roles: - if role in cls.bot_entity_descriptor.ownership_fields: - owner_col = column(cls.bot_entity_descriptor.ownership_fields[role]) - if condition is not None: - condition = condition | (owner_col == user.id) - else: - condition = owner_col == user.id - else: - condition = None - break - if condition is not None: - return select_statement.where(condition) - return select_statement + return (await session.exec(select_statement)).all() @classmethod @session_dep @@ -396,9 +494,21 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel]( cls, *, session: AsyncSession | None = None, - obj_in: CreateSchemaType, + obj_in: BaseModel, commit: bool = False, - ): + ) -> Self: + """ + Create a new entity instance. + + Args: + session: Database session (injected by session_dep) + obj_in: Data for creating the entity (can be Pydantic model or entity instance) + commit: Whether to commit the transaction immediately + + Returns: + The created entity instance + """ + # --- Accept both entity instances and Pydantic models --- if isinstance(obj_in, cls): obj = obj_in else: @@ -415,13 +525,26 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel]( *, session: AsyncSession | None = None, id: int, - obj_in: UpdateSchemaType, + obj_in: BaseModel, commit: bool = False, - ): + ) -> Self: + """ + Update an existing entity instance. + + Args: + session: Database session (injected by session_dep) + id: ID of the entity to update + obj_in: Data for updating the entity + commit: Whether to commit the transaction immediately + + Returns: + The updated entity instance or None if not found + """ obj = await session.get(cls, id) if obj: obj_data = obj.model_dump() update_data = obj_in.model_dump(exclude_unset=True) + # Only update fields present in the update data for field in obj_data: if field in update_data: setattr(obj, field, update_data[field]) @@ -435,7 +558,18 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel]( @session_dep async def remove( cls, *, session: AsyncSession | None = None, id: int, commit: bool = False - ): + ) -> Self: + """ + Delete an entity instance. + + Args: + session: Database session (injected by session_dep) + id: ID of the entity to delete + commit: Whether to commit the transaction immediately + + Returns: + The deleted entity instance or None if not found + """ obj = await session.get(cls, id) if obj: await session.delete(obj) diff --git a/src/quickbot/model/bot_enum.py b/src/quickbot/model/bot_enum.py index 80e5d4f..cc2fc04 100644 --- a/src/quickbot/model/bot_enum.py +++ b/src/quickbot/model/bot_enum.py @@ -1,5 +1,8 @@ from aiogram.utils.i18n import I18n -from pydantic_core.core_schema import str_schema +from pydantic import GetCoreSchemaHandler + +# from pydantic_core.core_schema import str_schema +from pydantic_core import core_schema from sqlalchemy.types import TypeDecorator from sqlmodel import AutoString from typing import Any, Self, overload @@ -59,13 +62,13 @@ class BotEnumMetaclass(type): class EnumMember(object): @overload - def __init__(self, value: str) -> "EnumMember": ... + def __init__(self, value: str) -> Self: ... @overload - def __init__(self, value: "EnumMember") -> "EnumMember": ... + def __init__(self, value: Self) -> Self: ... @overload - def __init__(self, value: str, loc_obj: dict[str, str]) -> "EnumMember": ... + def __init__(self, value: str, loc_obj: dict[str, str]) -> Self: ... def __init__( self, @@ -74,7 +77,7 @@ class EnumMember(object): parent: type = None, name: str = None, casting: bool = True, - ) -> "EnumMember": + ) -> Self: if not casting: self._parent = parent self._name = name @@ -82,9 +85,9 @@ class EnumMember(object): self.loc_obj = loc_obj @overload - def __new__(cls: Self, *args, **kwargs) -> "EnumMember": ... + def __new__(cls: Self, *args, **kwargs) -> Self: ... - def __new__(cls, *args, casting: bool = True, **kwargs) -> "EnumMember": + def __new__(cls, *args, casting: bool = True, **kwargs) -> Self: if (cls.__name__ == "EnumMember") or not casting: obj = super().__new__(cls) kwargs["casting"] = False @@ -104,8 +107,8 @@ class EnumMember(object): else: return args[0] - def __get_pydantic_core_schema__(cls, *args, **kwargs): - return str_schema() + # def __get_pydantic_core_schema__(cls, *args, **kwargs): + # return str_schema() def __get__(self, instance, owner) -> Self: return { @@ -159,6 +162,29 @@ class EnumMember(object): return self.value + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: type, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + return core_schema.no_info_after_validator_function( + function=cls._validate_from_string, + schema=core_schema.str_schema(), + serialization=core_schema.plain_serializer_function_ser_schema( + cls._serialize_to_string, return_schema=core_schema.str_schema() + ), + ) + + @classmethod + def _validate_from_string(cls, value: str) -> Self: + member = cls(value) + if member is None: + raise ValueError(f"Invalid value for {cls.__name__}: {value}") + return member + + @classmethod + def _serialize_to_string(cls, value: Self) -> str: + return value.value + class BotEnum(EnumMember, metaclass=BotEnumMetaclass): all_members: dict[str, EnumMember] diff --git a/src/quickbot/model/bot_metadata.py b/src/quickbot/model/bot_metadata.py new file mode 100644 index 0000000..481c4d7 --- /dev/null +++ b/src/quickbot/model/bot_metadata.py @@ -0,0 +1,16 @@ +from fastapi.datastructures import State +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from quickbot.main import QBotApp + +from .descriptors import EntityDescriptor, ProcessDescriptor +from ._singleton import Singleton + + +class BotMetadata(metaclass=Singleton): + def __init__(self): + self.entity_descriptors: dict[str, EntityDescriptor] = {} + self.process_descriptors: dict[str, ProcessDescriptor] = {} + self.app: "QBotApp" = None + self.app_state: State = None diff --git a/src/quickbot/model/bot_process.py b/src/quickbot/model/bot_process.py new file mode 100644 index 0000000..5d4ba32 --- /dev/null +++ b/src/quickbot/model/bot_process.py @@ -0,0 +1,77 @@ +import inspect +from typing_extensions import get_type_hints +from pydantic import BaseModel +from typing import ClassVar + +from .descriptors import ProcessDescriptor, BotContext, Process +from .bot_metadata import BotMetadata + + +class BotProcessMetaclass(type): + def __new__(mcs, name, bases, namespace, **kwargs): + if name == "BotProcess": + namespace.pop("run") + return super().__new__(mcs, name, bases, namespace, **kwargs) + + run_attr = namespace.get("run", None) + if run_attr is None or not isinstance(run_attr, staticmethod): + raise TypeError(f"{name}.run must be defined as a @staticmethod") + + func = run_attr.__func__ + sig = inspect.signature(func) + if list(sig.parameters.keys()) not in [["context", "parameters"], ["context"]]: + raise TypeError( + f"{name}.run must have exactly two arguments: context, parameters or one argument: context" + ) + + hints = get_type_hints(func, globalns=globals(), localns=namespace) + + # Check arguments + ctx_type = hints.get("context") + if ctx_type is None or not issubclass(ctx_type, BotContext): + raise TypeError(f"{name}.run: 'context' must be BotContext") + + param_type = hints.get("parameters") + if param_type is not None and not issubclass(param_type, BaseModel): + raise TypeError(f"{name}.run: 'parameters' must be subclass of BaseModel") + + return_type = hints.get("return") + + # Auto-generation of schemas + + descriptor_kwargs = {"process_class": None} + process_descriptor = namespace.pop("bot_process_descriptor", None) + if process_descriptor and isinstance(process_descriptor, Process): + descriptor_kwargs.update(process_descriptor.__dict__) + + descriptor_kwargs.update( + name=name, + input_schema=param_type, + output_schema=return_type, + ) + + descriptor = ProcessDescriptor(**descriptor_kwargs) + namespace["bot_process_descriptor"] = descriptor + + cls = super().__new__(mcs, name, bases, namespace, **kwargs) + + descriptor.process_class = cls + + bot_metadata = BotMetadata() + bot_metadata.process_descriptors[descriptor.name] = descriptor + + namespace["bot_metadata"] = bot_metadata + + return super().__new__(mcs, name, bases, namespace, **kwargs) + + +class BotProcess(metaclass=BotProcessMetaclass): + """ + Base class for business logic processes. + """ + + bot_process_descriptor: ClassVar[ProcessDescriptor] + bot_metadata: ClassVar[BotMetadata] + + @staticmethod + def run(context: BotContext, parameters: BaseModel = None): ... diff --git a/src/quickbot/model/crud_command.py b/src/quickbot/model/crud_command.py new file mode 100644 index 0000000..d266e8b --- /dev/null +++ b/src/quickbot/model/crud_command.py @@ -0,0 +1,9 @@ +from enum import StrEnum + + +class CrudCommand(StrEnum): + LIST = "list" + GET_BY_ID = "get_by_id" + CREATE = "create" + UPDATE = "update" + DELETE = "delete" diff --git a/src/quickbot/model/crud_service.py b/src/quickbot/model/crud_service.py new file mode 100644 index 0000000..cf3d40b --- /dev/null +++ b/src/quickbot/model/crud_service.py @@ -0,0 +1,374 @@ +from sqlmodel import select, col +from sqlmodel.ext.asyncio.session import AsyncSession +from pydantic import BaseModel +from quickbot.model.permissions import get_user_permissions +from quickbot.model.descriptors import EntityDescriptor, BotContext +from quickbot.model.crud_command import CrudCommand +from quickbot.model.utils import ( + entity_to_schema, + entity_to_list_schema, + pydantic_model, +) +from quickbot.model.descriptors import EntityPermission +from sqlalchemy.exc import IntegrityError +from asyncpg import ForeignKeyViolationError +from logging import getLogger +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from quickbot.model.user import UserBase + +logger = getLogger(__name__) + + +class NotFoundError(Exception): + pass + + +class ForbiddenError(Exception): + pass + + +class CrudService: + def __init__( + self, + entity_descriptor: EntityDescriptor, + commands: list[CrudCommand] = [ + CrudCommand.LIST, + CrudCommand.GET_BY_ID, + CrudCommand.CREATE, + CrudCommand.UPDATE, + CrudCommand.DELETE, + ], + schema: type[BaseModel] = None, + create_schema: type[BaseModel] = None, + update_schema: type[BaseModel] = None, + ): + self.entity_descriptor = entity_descriptor + self.commands = commands + if CrudCommand.CREATE in commands: + self.create_schema = create_schema or pydantic_model( + entity_descriptor, entity_descriptor.type_.__module__, "create" + ) + else: + self.create_schema = None + if CrudCommand.UPDATE in commands: + self.update_schema = update_schema or pydantic_model( + entity_descriptor, entity_descriptor.type_.__module__, "update" + ) + else: + self.update_schema = None + self.schema = schema or pydantic_model( + entity_descriptor, entity_descriptor.type_.__module__ + ) + + async def list_all( + self, db_session: AsyncSession, user: "UserBase" + ) -> list[BaseModel]: + if CrudCommand.LIST not in self.commands: + raise ForbiddenError( + f"List command not allowed for entity {self.entity_descriptor.name}" + ) + user_permissions = get_user_permissions(user, self.entity_descriptor) + if ( + EntityPermission.READ_ALL in user_permissions + or EntityPermission.READ_RLS in user_permissions + ): + ret_list = await self.entity_descriptor.type_.get_multi( + session=db_session, + user=user + if EntityPermission.READ_ALL not in user_permissions + else None, + ) + return [entity_to_schema(item) for item in ret_list] + elif ( + EntityPermission.LIST_ALL in user_permissions + or EntityPermission.LIST_RLS in user_permissions + ): + ret_list = await self.entity_descriptor.type_.get_multi( + session=db_session, + user=user + if EntityPermission.LIST_ALL not in user_permissions + else None, + ) + context = BotContext( + db_session=db_session, + app=user.bot_metadata.app, + app_state=user.bot_metadata.app_state, + user=user, + ) + return [await entity_to_list_schema(item, context) for item in ret_list] + else: + raise ForbiddenError( + f"User {user.id} does not have permission to read or list entities" + ) + + async def get_by_id( + self, db_session: AsyncSession, user: "UserBase", id: int + ) -> BaseModel: + if CrudCommand.GET_BY_ID not in self.commands: + raise ForbiddenError( + f"Get by id command not allowed for entity {self.entity_descriptor.name}" + ) + ret_obj = await self.entity_descriptor.type_.get(session=db_session, id=id) + if ret_obj is None: + raise NotFoundError(f"Entity with id {id} not found") + return entity_to_schema(ret_obj) + + async def create( + self, db_session: AsyncSession, user: "UserBase", model: BaseModel + ) -> BaseModel: + if CrudCommand.CREATE not in self.commands: + raise ForbiddenError( + f"Create command not allowed for entity {self.entity_descriptor.name}" + ) + user_permissions = get_user_permissions(user, self.entity_descriptor) + if EntityPermission.CREATE_ALL in user_permissions: + # TODO: check if entity values are valid + pass + elif EntityPermission.CREATE_RLS in user_permissions: + # TODO: check if RLS fields are valid + # TODO: check if entity values are valid + pass + else: + raise ForbiddenError( + f"User {user.id} does not have permission to create entities" + ) + + obj_dict = {} + ret_obj_dict = {} + for field_descriptor in self.entity_descriptor.fields_descriptors.values(): + # Only process fields present in the input model + field_name = field_descriptor.field_name + if field_name in model.__class__.model_fields: + # Handle list fields that are relations to other BotEntities + if ( + field_descriptor.is_list + and isinstance(field_descriptor.type_base, type) + and hasattr(field_descriptor.type_base, "bot_entity_descriptor") + ): + items = ( + await db_session.exec( + select(field_descriptor.type_base).where( + col(field_descriptor.type_base.id).in_( + getattr(model, field_name) + ) + ) + ) + ).all() + obj_dict[field_name] = items + ret_obj_dict[field_name] = [item.id for item in items] + elif isinstance(field_descriptor.type_base, type) and hasattr( + field_descriptor.type_base, "all_members" + ): + if field_descriptor.is_list: + obj_dict[field_name] = [ + field_descriptor.type_base(item) + for item in getattr(model, field_name) + ] + else: + obj_dict[field_name] = field_descriptor.type_base( + getattr(model, field_name) + ) + ret_obj_dict[field_name] = getattr(model, field_name) + else: + obj_dict[field_name] = getattr(model, field_name) + ret_obj_dict[field_name] = getattr(model, field_name) + obj = self.entity_descriptor.type_(**obj_dict) + db_session.add(obj) + try: + await db_session.commit() + except IntegrityError as e: + if isinstance(e.orig.__cause__, ForeignKeyViolationError): + raise ValueError(e.orig.__cause__.detail) + raise ValueError("DB Integrity error") + except Exception as e: + logger.error(f"Error creating entity: {e}") + raise e + if "id" not in ret_obj_dict: + ret_obj_dict["id"] = obj.id + return self.schema(**ret_obj_dict) + + async def update( + self, db_session: AsyncSession, user: "UserBase", id: int, model: BaseModel + ) -> BaseModel: + if CrudCommand.UPDATE not in self.commands: + raise ForbiddenError( + f"Update command not allowed for entity {self.entity_descriptor.name}" + ) + user_permissions = get_user_permissions(user, self.entity_descriptor) + if EntityPermission.UPDATE_ALL in user_permissions: + # TODO: check if entity values are valid + pass + elif EntityPermission.UPDATE_RLS in user_permissions: + # TODO: check if RLS fields are valid + # TODO: check if entity values are valid + pass + else: + raise ForbiddenError( + f"User {user.id} does not have permission to update entities" + ) + + entity = await self.entity_descriptor.type_.get(session=db_session, id=id) + if entity is None: + raise NotFoundError(f"Entity with id {id} not found") + for field_descriptor in self.entity_descriptor.fields_descriptors.values(): + field_name = field_descriptor.field_name + if field_name in model.model_fields_set: + model_field_value = getattr(model, field_name) + if ( + field_descriptor.is_list + and isinstance(field_descriptor.type_base, type) + and hasattr(field_descriptor.type_base, "bot_entity_descriptor") + ): + items = ( + await db_session.exec( + select(field_descriptor.type_base).where( + col(field_descriptor.type_base.id).in_( + model_field_value + ) + ) + ) + ).all() + setattr(entity, field_name, items) + elif isinstance(field_descriptor.type_base, type) and hasattr( + field_descriptor.type_base, "all_members" + ): + if field_descriptor.is_list: + setattr( + entity, + field_name, + [ + field_descriptor.type_base(item) + for item in model_field_value + ], + ) + else: + setattr( + entity, + field_name, + field_descriptor.type_base(model_field_value), + ) + else: + setattr(entity, field_name, model_field_value) + try: + await db_session.commit() + except IntegrityError as e: + if isinstance(e.orig.__cause__, ForeignKeyViolationError): + raise ValueError(e.orig.__cause__.detail) + raise ValueError("DB Integrity error") + except Exception as e: + logger.error(f"Error updating entity: {e}") + raise e + return entity_to_schema(entity) + + async def delete( + self, db_session: AsyncSession, user: "UserBase", id: int + ) -> BaseModel: + if CrudCommand.DELETE not in self.commands: + raise ForbiddenError( + f"Delete command not allowed for entity {self.entity_descriptor.name}" + ) + user_permissions = get_user_permissions(user, self.entity_descriptor) + if EntityPermission.DELETE_ALL in user_permissions: + pass + elif EntityPermission.DELETE_RLS in user_permissions: + # TODO: check if RLS fields are valid + pass + else: + raise ForbiddenError( + f"User {user.id} does not have permission to delete entities" + ) + + try: + entity = await self.entity_descriptor.type_.remove( + session=db_session, id=id, commit=True + ) + except IntegrityError as e: + if isinstance(e.orig.__cause__, ForeignKeyViolationError): + raise ValueError(e.orig.__cause__.detail) + raise ValueError("DB Integrity error") + except Exception as e: + logger.error(f"Error deleting entity: {e}") + raise e + if entity is None: + raise NotFoundError(f"Entity with id {id} not found") + return entity_to_schema(entity) + + # async def _create_from_schema( + # cls: type[BotEntity], + # *, + # session: AsyncSession | None = None, + # obj_in: BaseModel, + # ): + # """ + # Create a new entity instance from a Pydantic model. + + # Args: + # session: Database session (injected by session_dep) + # obj_in: Pydantic model to create the entity from + + # Returns: + # The created entity instance + # """ + # obj_dict = {} + # ret_obj_dict = {} + # for field_descriptor in cls.bot_entity_descriptor.fields_descriptors.values(): + # # Only process fields present in the input model + # if field_descriptor.field_name in obj_in.__class__.model_fields: + # # Handle list fields that are relations to other BotEntities + # if ( + # field_descriptor.is_list + # and isinstance(field_descriptor.type_base, type) + # and issubclass(field_descriptor.type_base, BotEntity) + # ): + # items = ( + # await session.exec( + # select(field_descriptor.type_base).where( + # col(field_descriptor.type_base.id).in_( + # getattr(obj_in, field_descriptor.field_name) + # ) + # ) + # ) + # ).all() + # obj_dict[field_descriptor.field_name] = items + # ret_obj_dict[field_descriptor.field_name] = [ + # item.id for item in items + # ] + # else: + # obj_dict[field_descriptor.field_name] = getattr( + # obj_in, field_descriptor.field_name + # ) + # ret_obj_dict[field_descriptor.field_name] = getattr( + # obj_in, field_descriptor.field_name + # ) + # obj = cls(**obj_dict) + # session.add(obj) + # await session.commit() + # if "id" not in ret_obj_dict: + # ret_obj_dict["id"] = obj.id + # return cls.bot_entity_descriptor.schema_class(**ret_obj_dict) + + +# @classmethod +# def apply_filters( +# cls, +# select_statement: SelectOfScalar[Self], +# filters: Filter | FilterExpression | None = None, +# ) -> SelectOfScalar[Self]: +# """ +# Apply filters to a select statement. + +# Args: +# select_statement: SQLAlchemy select statement to modify +# filters: Filter or FilterExpression to apply + +# Returns: +# Modified select statement with filter conditions +# """ +# if filters is None: +# return select_statement +# condition = cls._build_filter_condition(filters) +# if condition is not None: +# return select_statement.where(condition) +# return select_statement diff --git a/src/quickbot/model/descriptors.py b/src/quickbot/model/descriptors.py index 9f7441c..4475a01 100644 --- a/src/quickbot/model/descriptors.py +++ b/src/quickbot/model/descriptors.py @@ -6,6 +6,8 @@ from typing import Any, Callable, TYPE_CHECKING, Literal, Union from babel.support import LazyProxy from dataclasses import dataclass, field from fastapi.datastructures import State +from pydantic import BaseModel +from pydantic_core import PydanticUndefined from sqlmodel.ext.asyncio.session import AsyncSession from sqlalchemy.orm import InstrumentedAttribute @@ -17,6 +19,8 @@ if TYPE_CHECKING: from .bot_entity import BotEntity from ..main import QBotApp from .user import UserBase + from .crud_service import CrudService + from .bot_process import BotProcess # EntityCaptionCallable = Callable[["EntityDescriptor"], str] # EntityItemCaptionCallable = Callable[["EntityDescriptor", Any], str] @@ -46,8 +50,8 @@ class InlineButton[T: "BotEntity"]: @dataclass -class Filter: - field_name: str +class Filter[T: "BotEntity"]: + field: str | Callable[[type[T]], InstrumentedAttribute] operator: Literal[ "==", "!=", @@ -67,6 +71,89 @@ class Filter: value: Any | None = None param_index: int | None = None + def __or__(self, other: "Filter[T] | FilterExpression[T]") -> "FilterExpression[T]": + """Create OR expression with another filter or expression""" + if isinstance(other, Filter): + return FilterExpression("or", [self, other]) + elif isinstance(other, FilterExpression): + if other.operator == "or": + # Simplify: filter | (a | b) = (filter | a | b) + return FilterExpression("or", [self] + other.filters) + else: + return FilterExpression("or", [self, other]) + else: + raise TypeError(f"Cannot combine Filter with {type(other)}") + + def __and__( + self, other: "Filter[T] | FilterExpression[T]" + ) -> "FilterExpression[T]": + """Create AND expression with another filter or expression""" + if isinstance(other, Filter): + return FilterExpression("and", [self, other]) + elif isinstance(other, FilterExpression): + if other.operator == "and": + # Simplify: filter & (a & b) = (filter & a & b) + return FilterExpression("and", [self] + other.filters) + else: + return FilterExpression("and", [self, other]) + else: + raise TypeError(f"Cannot combine Filter with {type(other)}") + + +class FilterExpression[T: "BotEntity"]: + """ + Represents a logical expression combining multiple filters with AND/OR operations. + Supports expression simplification for optimal query building. + """ + + def __init__( + self, + operator: Literal["or", "and"], + filters: list["Filter[T] | FilterExpression[T]"], + ): + self.operator = operator + self.filters = self._simplify_filters(filters) + + def _simplify_filters( + self, filters: list["Filter[T] | FilterExpression[T]"] + ) -> list["Filter[T] | FilterExpression[T]"]: + """Simplify filters by flattening nested expressions with the same operator""" + simplified = [] + for filter_obj in filters: + if ( + isinstance(filter_obj, FilterExpression) + and filter_obj.operator == self.operator + ): + # Flatten nested expressions with the same operator + simplified.extend(filter_obj.filters) + else: + simplified.append(filter_obj) + return simplified + + def __or__(self, other: "Filter[T] | FilterExpression[T]") -> "FilterExpression[T]": + """Combine with another filter or expression using OR""" + if isinstance(other, (Filter, FilterExpression)): + if isinstance(other, FilterExpression) and other.operator == "or": + # Simplify: (a | b) | (c | d) = (a | b | c | d) + return FilterExpression("or", self.filters + other.filters) + else: + return FilterExpression("or", [self, other]) + else: + raise TypeError(f"Cannot combine FilterExpression with {type(other)}") + + def __and__( + self, other: "Filter[T] | FilterExpression[T]" + ) -> "FilterExpression[T]": + """Combine with another filter or expression using AND""" + if isinstance(other, (Filter, FilterExpression)): + if isinstance(other, FilterExpression) and other.operator == "and": + # Simplify: (a & b) & (c & d) = (a & b & c & d) + return FilterExpression("and", self.filters + other.filters) + else: + return FilterExpression("and", [self, other]) + else: + raise TypeError(f"Cannot combine FilterExpression with {type(other)}") + @dataclass class EntityList[T: "BotEntity"]: @@ -77,7 +164,7 @@ class EntityList[T: "BotEntity"]: show_add_new_button: bool = True item_form: str | None = None pagination: bool = True - static_filters: list[Filter] = None + static_filters: Filter[T] | FilterExpression[T] | None = None filtering: bool = False filtering_fields: list[str] = None order_by: str | Any | None = None @@ -99,9 +186,7 @@ class _BaseFieldDescriptor[T: "BotEntity"]: caption: ( str | LazyProxy | Callable[["FieldDescriptor", "BotContext"], str] | None ) = None - description: ( - str | LazyProxy | Callable[["FieldDescriptor", "BotContext"], str] | None - ) = None + description: str | LazyProxy | None = PydanticUndefined edit_prompt: ( str | LazyProxy @@ -122,8 +207,8 @@ class _BaseFieldDescriptor[T: "BotEntity"]: bool_false_value: str | LazyProxy = "no" bool_true_value: str | LazyProxy = "yes" ep_form: str | Callable[["BotContext"], str] | None = None - ep_parent_field: str | None = None - ep_child_field: str | None = None + ep_parent_field: str | Callable[[type[T]], InstrumentedAttribute] | None = None + ep_child_field: str | Callable[[type[T]], InstrumentedAttribute] | None = None dt_type: Literal["date", "datetime"] = "date" options: ( list[list[Union[Any, tuple[Any, str]]]] @@ -133,7 +218,7 @@ class _BaseFieldDescriptor[T: "BotEntity"]: options_custom_value: bool = True show_current_value_button: bool = True show_skip_in_editor: Literal[False, "Auto"] = "Auto" - default: Any = None + default: Any = PydanticUndefined default_factory: Callable[[], Any] | None = None @@ -178,7 +263,8 @@ class _BaseEntityDescriptor[T: "BotEntity"]: full_name_plural: ( str | LazyProxy | Callable[["EntityDescriptor", "BotContext"], str] | None ) = None - description: ( + description: str | None = None + ui_description: ( str | LazyProxy | Callable[["EntityDescriptor", "BotContext"], str] | None ) = None item_repr: Callable[[T, "BotContext"], str] | None = None @@ -190,11 +276,11 @@ class _BaseEntityDescriptor[T: "BotEntity"]: ownership_fields: dict[RoleBase, str] = field(default_factory=dict[RoleBase, str]) permissions: dict[EntityPermission, list[RoleBase]] = field( default_factory=lambda: { - EntityPermission.LIST: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER], - EntityPermission.READ: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER], - EntityPermission.CREATE: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER], - EntityPermission.UPDATE: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER], - EntityPermission.DELETE: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER], + EntityPermission.LIST_RLS: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER], + EntityPermission.READ_RLS: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER], + EntityPermission.CREATE_RLS: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER], + EntityPermission.UPDATE_RLS: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER], + EntityPermission.DELETE_RLS: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER], EntityPermission.LIST_ALL: [RoleBase.SUPER_USER], EntityPermission.READ_ALL: [RoleBase.SUPER_USER], EntityPermission.CREATE_ALL: [RoleBase.SUPER_USER], @@ -202,6 +288,8 @@ class _BaseEntityDescriptor[T: "BotEntity"]: EntityPermission.DELETE_ALL: [RoleBase.SUPER_USER], } ) + rls_filters: Filter[T] | FilterExpression[T] | None = None + rls_filters_params: Callable[["UserBase"], list[Any]] = lambda user: [user.id] before_create: Callable[["BotContext"], Union[bool, str]] | None = None before_create_save: Callable[[T, "BotContext"], Union[bool, str]] | None = None before_update_save: ( @@ -212,6 +300,7 @@ class _BaseEntityDescriptor[T: "BotEntity"]: on_created: Callable[[T, "BotContext"], None] | None = None on_deleted: Callable[[T, "BotContext"], None] | None = None on_updated: Callable[[dict[str, Any], T, "BotContext"], None] | None = None + crud: Union["CrudService", None] = None @dataclass(kw_only=True) @@ -228,7 +317,7 @@ class EntityDescriptor(_BaseEntityDescriptor): @dataclass(kw_only=True) -class CommandCallbackContext[UT: UserBase]: +class CommandCallbackContext: keyboard_builder: InlineKeyboardBuilder = field( default_factory=InlineKeyboardBuilder ) @@ -238,7 +327,7 @@ class CommandCallbackContext[UT: UserBase]: message: Message | CallbackQuery callback_data: ContextData db_session: AsyncSession - user: UT + user: "UserBase" app: "QBotApp" app_state: State state_data: dict[str, Any] @@ -249,11 +338,11 @@ class CommandCallbackContext[UT: UserBase]: @dataclass(kw_only=True) -class BotContext[UT: UserBase]: +class BotContext: db_session: AsyncSession app: "QBotApp" app_state: State - user: UT + user: "UserBase" message: Message | CallbackQuery | None = None default_handler: Callable[["BotEntity", "BotContext"], None] | None = None @@ -271,3 +360,31 @@ class BotCommand: show_cancel_in_param_form: bool = True show_back_in_param_form: bool = True handler: Callable[[CommandCallbackContext], None] + + +@dataclass(kw_only=True) +class _BaseProcessDescriptor: + description: str | LazyProxy | None = None + roles: list[RoleBase] = field( + default_factory=lambda: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER] + ) + icon: str | None = None + caption: str | LazyProxy | None = None + pre_check: Callable[[BotContext], bool | str] | None = None + show_in_bot_menu: bool = False + answer_message: Callable[[BotContext, BaseModel], str] | None = None + answer_inline_buttons: ( + Callable[[BotContext, BaseModel], list[InlineKeyboardButton]] | None + ) = None + + +@dataclass(kw_only=True) +class ProcessDescriptor(_BaseProcessDescriptor): + name: str + process_class: type["BotProcess"] + input_schema: type[BaseModel] | None = None + output_schema: type[BaseModel] | None = None + + +@dataclass(kw_only=True) +class Process(_BaseProcessDescriptor): ... diff --git a/src/quickbot/model/entity_metadata.py b/src/quickbot/model/entity_metadata.py deleted file mode 100644 index 47c9886..0000000 --- a/src/quickbot/model/entity_metadata.py +++ /dev/null @@ -1,7 +0,0 @@ -from .descriptors import EntityDescriptor -from ._singleton import Singleton - - -class EntityMetadata(metaclass=Singleton): - def __init__(self): - self.entity_descriptors: dict[str, EntityDescriptor] = {} diff --git a/src/quickbot/model/list_schema.py b/src/quickbot/model/list_schema.py new file mode 100644 index 0000000..0b61df6 --- /dev/null +++ b/src/quickbot/model/list_schema.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class ListSchema(BaseModel): + id: int + name: str diff --git a/src/quickbot/model/permissions.py b/src/quickbot/model/permissions.py new file mode 100644 index 0000000..f172203 --- /dev/null +++ b/src/quickbot/model/permissions.py @@ -0,0 +1,192 @@ +from inspect import iscoroutinefunction +from quickbot.model.descriptors import EntityDescriptor, EntityPermission +from quickbot.model.descriptors import Filter, FilterExpression +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from quickbot.model.user import UserBase + from quickbot.model.bot_entity import BotEntity + + +def get_user_permissions( + user: "UserBase", entity_descriptor: EntityDescriptor +) -> list[EntityPermission]: + permissions = list[EntityPermission]() + for permission, roles in entity_descriptor.permissions.items(): + for role in roles: + if role in user.roles: + permissions.append(permission) + break + return permissions + + +async def check_entity_permission( + entity: "BotEntity", user: "UserBase", permission: EntityPermission +) -> bool: + perm_mapping = { + EntityPermission.LIST_RLS: EntityPermission.LIST_ALL, + EntityPermission.READ_RLS: EntityPermission.READ_ALL, + EntityPermission.UPDATE_RLS: EntityPermission.UPDATE_ALL, + EntityPermission.CREATE_RLS: EntityPermission.CREATE_ALL, + EntityPermission.DELETE_RLS: EntityPermission.DELETE_ALL, + } + + if permission not in perm_mapping: + raise ValueError(f"Invalid permission: {permission}") + + entity_descriptor = entity.__class__.bot_entity_descriptor + permissions = get_user_permissions(user, entity_descriptor) + + # Check if user has the corresponding ALL permission + if perm_mapping[permission] in permissions: + return True + + # Check RLS filters if they exist + if entity_descriptor.rls_filters: + # Get parameters for RLS + params = [] + if entity_descriptor.rls_filters_params: + if iscoroutinefunction(entity_descriptor.rls_filters_params): + params = await entity_descriptor.rls_filters_params(user) + else: + params = entity_descriptor.rls_filters_params(user) + + # Create a copy of the RLS filters with parameter values substituted + rls_filters = entity.__class__._substitute_rls_parameters( + entity_descriptor.rls_filters, params + ) + + # Check if the entity matches the RLS filters by evaluating the condition + # against the entity's attributes + if _entity_matches_rls_filters(entity, rls_filters): + return True + + # If no RLS filters are defined, check if user has the RLS permission + if permission in permissions: + return True + + return False + + +def _entity_matches_rls_filters( + entity: "BotEntity", rls_filters: "Filter | FilterExpression" +) -> bool: + """ + Check if an entity matches the given RLS filters by evaluating the filters + against the entity's attributes. + + Args: + entity: The entity to check + rls_filters: RLS filters to evaluate + + Returns: + True if the entity matches the filters, False otherwise + """ + + if isinstance(rls_filters, Filter): + return _evaluate_single_filter(entity, rls_filters) + elif isinstance(rls_filters, FilterExpression): + return _evaluate_filter_expression(entity, rls_filters) + else: + return False + + +def _evaluate_single_filter(entity: "BotEntity", filter_obj: "Filter") -> bool: + """Evaluate a single filter against an entity""" + # Get the field value from the entity + if isinstance(filter_obj.field, str): + field_value = getattr(entity, filter_obj.field, None) + else: + # Handle callable field (should return the field name) + field_name = filter_obj.field(entity.__class__).key + field_value = getattr(entity, field_name, None) + + # Apply the operator + if filter_obj.operator == "==": + return field_value == filter_obj.value + elif filter_obj.operator == "!=": + return field_value != filter_obj.value + elif filter_obj.operator == ">": + return field_value > filter_obj.value + elif filter_obj.operator == "<": + return field_value < filter_obj.value + elif filter_obj.operator == ">=": + return field_value >= filter_obj.value + elif filter_obj.operator == "<=": + return field_value <= filter_obj.value + elif filter_obj.operator == "in": + return field_value in filter_obj.value + elif filter_obj.operator == "not in": + return field_value not in filter_obj.value + elif filter_obj.operator == "like": + return str(field_value).find(str(filter_obj.value)) != -1 + elif filter_obj.operator == "ilike": + return str(field_value).lower().find(str(filter_obj.value).lower()) != -1 + elif filter_obj.operator == "is none": + return field_value is None + elif filter_obj.operator == "is not none": + return field_value is not None + elif filter_obj.operator == "contains": + return filter_obj.value in field_value + else: + return False + + +def _evaluate_filter_expression( + entity: "BotEntity", filter_expr: "FilterExpression" +) -> bool: + """Evaluate a filter expression against an entity""" + results = [] + for sub_filter in filter_expr.filters: + if isinstance(sub_filter, Filter): + result = _evaluate_single_filter(entity, sub_filter) + elif isinstance(sub_filter, FilterExpression): + result = _evaluate_filter_expression(entity, sub_filter) + else: + result = False + results.append(result) + + if not results: + return False + + # Apply the logical operator + if filter_expr.operator == "and": + return all(results) + elif filter_expr.operator == "or": + return any(results) + else: + return False + + +def _extract_rls_filter_fields(entity_descriptor: EntityDescriptor) -> set[str]: + return _extract_filter_fields( + entity_descriptor.rls_filters, entity_descriptor.type_ + ) + + +def _extract_filter_fields( + filter: Filter | FilterExpression | None, entity_type: type +) -> set[str]: + fields = set() + + if filter: + if isinstance(filter, Filter): + if filter.operator == "==": + if isinstance(filter.field, str): + fields.add(filter.field) + else: + fields.add(filter.field(entity_type).key) + + elif isinstance(filter, FilterExpression): + if ( + filter.operator == "and" + and all(isinstance(sub_filter, Filter) for sub_filter in filter.filters) + and all(sub_filter.operator == "==" for sub_filter in filter.filters) + ): + for sub_filter in filter.filters: + if isinstance(sub_filter.field, str): + fields.add(sub_filter.field) + else: + fields.add(sub_filter.field(entity_type).key) + + return fields diff --git a/src/quickbot/model/pydantic_json.py b/src/quickbot/model/pydantic_json.py new file mode 100644 index 0000000..f15a96c --- /dev/null +++ b/src/quickbot/model/pydantic_json.py @@ -0,0 +1,53 @@ +from typing import Any, Type +from pydantic import BaseModel +from sqlmodel import TypeDecorator, JSON + + +class PydanticJSON(TypeDecorator): + """ + SQLAlchemy-compatible JSON type for storing Pydantic models + (including nested ones). Automatically serializes on insert + and deserializes on read. + """ + + impl = JSON + cache_ok = True + + def __init__(self, model_class: Type[BaseModel], *args, **kwargs): + if not issubclass(model_class, BaseModel): + raise TypeError("PydanticJSON expects a Pydantic BaseModel subclass") + self.model_class = model_class + super().__init__(*args, **kwargs) + + def process_bind_param(self, value: Any, dialect) -> Any: + """ + Serialize Python object to JSON-compatible form before saving to DB. + """ + if value is None: + return None + + if isinstance(value, list): + return [ + item.model_dump(mode="json") if isinstance(item, BaseModel) else item + for item in value + ] + + if isinstance(value, BaseModel): + return value.model_dump(mode="json") + + return value # assume already JSON-serializable + + def process_result_value(self, value: Any, dialect) -> Any: + """ + Deserialize JSON data from DB back into Python object. + """ + if value is None: + return None + + if isinstance(value, list): + return [self.model_class(**item) for item in value] + + if isinstance(value, dict): + return self.model_class(**value) + + raise TypeError(f"Unsupported value type for deserialization: {type(value)}") diff --git a/src/quickbot/model/user.py b/src/quickbot/model/user.py index eb2dd80..2c4c351 100644 --- a/src/quickbot/model/user.py +++ b/src/quickbot/model/user.py @@ -1,3 +1,4 @@ +from sqlalchemy import BigInteger from sqlmodel import Field, ARRAY from .bot_entity import BotEntity @@ -5,6 +6,7 @@ from .bot_enum import EnumType from .language import LanguageBase from .role import RoleBase +from .descriptors import EntityField from .settings import DbSettings as DbSettings from .fsm_storage import FSMStorage as FSMStorage from .view_setting import ViewSetting as ViewSetting @@ -13,11 +15,24 @@ from .view_setting import ViewSetting as ViewSetting class UserBase(BotEntity, table=False): __tablename__ = "user" - lang: LanguageBase = Field(sa_type=EnumType(LanguageBase), default=LanguageBase.EN) - is_active: bool = True + id: int = EntityField( + description="User Telegram ID", + sm_descriptor=Field(primary_key=True, sa_type=BigInteger), + is_visible=False, + ) - name: str + lang: LanguageBase = Field( + description="User language", + sa_type=EnumType(LanguageBase), + default_factory=lambda: LanguageBase.EN, + ) + + is_active: bool = EntityField(description="User is active", default=True) + + name: str = EntityField(description="User name") roles: list[RoleBase] = Field( - sa_type=ARRAY(EnumType(RoleBase)), default=[RoleBase.DEFAULT_USER] + description="User roles", + sa_type=ARRAY(EnumType(RoleBase)), + default_factory=lambda: [RoleBase.DEFAULT_USER], ) diff --git a/src/quickbot/model/utils.py b/src/quickbot/model/utils.py new file mode 100644 index 0000000..da9e4b0 --- /dev/null +++ b/src/quickbot/model/utils.py @@ -0,0 +1,364 @@ +from inspect import iscoroutinefunction +from typing import Any, Optional, get_origin +from pydantic import BaseModel, Field +from pydantic.fields import _Unset +from pydantic_core import PydanticUndefined +from typing_extensions import Literal +from sqlmodel import SQLModel, col, column +from sqlmodel.sql.expression import SelectOfScalar +from typing import TYPE_CHECKING + +from quickbot.model.list_schema import ListSchema +from quickbot.model.descriptors import ( + EntityDescriptor, + Filter, + FilterExpression, + BotContext, +) +from quickbot.utils.main import get_entity_item_repr + +if TYPE_CHECKING: + from quickbot import BotEntity, BotEnum + from quickbot.model.user import UserBase + + +def entity_to_schema(entity: "BotEntity") -> BaseModel: + entity_data = {} + for field_descriptor in entity.bot_entity_descriptor.fields_descriptors.values(): + if field_descriptor.is_list and is_bot_entity(field_descriptor.type_base): + entity_data[field_descriptor.field_name] = [ + item.id for item in getattr(entity, field_descriptor.field_name) + ] + elif field_descriptor.is_list and is_bot_enum(field_descriptor.type_base): + entity_data[field_descriptor.field_name] = [ + item.value for item in getattr(entity, field_descriptor.field_name) + ] + elif not field_descriptor.is_list and is_bot_entity(field_descriptor.type_base): + continue + elif not field_descriptor.is_list and is_bot_enum(field_descriptor.type_base): + val: "BotEnum" | None = getattr(entity, field_descriptor.field_name) + entity_data[field_descriptor.field_name] = val.value if val else None + else: + entity_data[field_descriptor.field_name] = getattr( + entity, field_descriptor.field_name + ) + return ( + entity.bot_entity_descriptor.crud.schema(**entity_data) + if entity.bot_entity_descriptor.crud + else entity.bot_entity_descriptor.schema(**entity_data) + ) + + +async def entity_to_list_schema( + entity: "BotEntity", context: "BotContext" +) -> ListSchema: + entity_repr = await get_entity_item_repr( + entity, context, entity.bot_entity_descriptor.item_repr + ) + return ListSchema(id=entity.id, name=entity_repr) + + +def _pydantic_model_fields( + namespace: dict[str, Any], + entity_descriptor: EntityDescriptor, + schema_type: Literal["schema", "create", "update"] = "schema", +) -> dict[str, Any]: + namespace["__annotations__"] = {} + for field_descriptor in entity_descriptor.fields_descriptors.values(): + type_origin = get_origin(field_descriptor.type_base) + if ( + type_origin is not list + and not field_descriptor.is_list + and ( + issubclass(field_descriptor.type_base, SQLModel) + or isinstance(field_descriptor.type_base, str) + ) + ) or ( + schema_type in ["create", "update"] + and field_descriptor.field_name == "id" + and field_descriptor.default is None + ): + continue + + if ( + type_origin is not list + and field_descriptor.is_list + and ( + issubclass(field_descriptor.type_base, SQLModel) + or isinstance(field_descriptor.type_base, str) + ) + ): + namespace["__annotations__"][field_descriptor.field_name] = list[int] + elif type_origin is not list and is_bot_enum(field_descriptor.type_base): + enum_values = [ + member.value + for member in field_descriptor.type_base.all_members.values() + ] + enum_annotation = ( + list[Literal[*enum_values]] + if field_descriptor.is_list + else Literal[*enum_values] + ) + if field_descriptor.is_optional: + enum_annotation = Optional[enum_annotation] + namespace["__annotations__"][field_descriptor.field_name] = enum_annotation + else: + namespace["__annotations__"][field_descriptor.field_name] = ( + Optional[field_descriptor.type_base] + if field_descriptor.is_optional + else field_descriptor.type_ + ) + + description = ( + field_descriptor.description if field_descriptor.description else _Unset + ) + + if schema_type == "schema" and field_descriptor.is_optional: + namespace[field_descriptor.field_name] = Field(description=description) + elif schema_type == "create": + if field_descriptor.default is not PydanticUndefined: + namespace[field_descriptor.field_name] = Field( + default=field_descriptor.default, description=description + ) + elif field_descriptor.default_factory is not None: + namespace[field_descriptor.field_name] = Field( + default_factory=field_descriptor.default_factory, + description=description, + ) + elif field_descriptor.is_optional: + namespace[field_descriptor.field_name] = Field( + default=None, description=description + ) + else: + namespace[field_descriptor.field_name] = Field(description=description) + elif schema_type == "update": + namespace[field_descriptor.field_name] = Field( + default=None, description=description + ) + else: + namespace[field_descriptor.field_name] = Field(description=description) + + +def pydantic_model( + entity_descriptor: EntityDescriptor, + module_name: str, + schema_type: Literal["schema", "create", "update"] = "schema", +) -> type[BaseModel]: + namespace = { + "__module__": module_name, + } + _pydantic_model_fields(namespace, entity_descriptor, schema_type) + + return type( + f"{entity_descriptor.class_name}{schema_type.capitalize() if schema_type != 'schema' else ''}Schema", + (BaseModel,), + namespace, + ) + + +def _build_filter_condition( + cls: type["BotEntity"], filter_obj: Filter | FilterExpression +) -> Any: + """ + Build SQLAlchemy condition from a Filter or FilterExpression object. + + Args: + filter_obj: Filter or FilterExpression object to convert + + Returns: + SQLAlchemy condition + """ + # --- Handle single Filter object --- + if isinstance(filter_obj, Filter): + # Support both string field names and callables for custom columns + if isinstance(filter_obj.field, str): + column = getattr(cls, filter_obj.field) + else: + column = filter_obj.field(cls) + # Map filter operator to SQLAlchemy expression + if filter_obj.operator == "==": + return column.__eq__(filter_obj.value) + elif filter_obj.operator == "!=": + return column.__ne__(filter_obj.value) + elif filter_obj.operator == "<": + return column.__lt__(filter_obj.value) + elif filter_obj.operator == "<=": + return column.__le__(filter_obj.value) + elif filter_obj.operator == ">": + return column.__gt__(filter_obj.value) + elif filter_obj.operator == ">=": + return column.__ge__(filter_obj.value) + elif filter_obj.operator == "ilike": + return col(column).ilike(f"%{filter_obj.value}%") + elif filter_obj.operator == "like": + return col(column).like(f"%{filter_obj.value}%") + elif filter_obj.operator == "in": + return col(column).in_(filter_obj.value) + elif filter_obj.operator == "not in": + return col(column).notin_(filter_obj.value) + elif filter_obj.operator == "is none": + return col(column).is_(None) + elif filter_obj.operator == "is not none": + return col(column).isnot(None) + elif filter_obj.operator == "contains": + return filter_obj.value == col(column).any_() + else: + # Unknown operator, return None (no condition) + return None + # --- Handle FilterExpression object (logical AND/OR of filters) --- + elif isinstance(filter_obj, FilterExpression): + operator = filter_obj.operator + filters = filter_obj.filters + # Recursively build conditions for all sub-filters + conditions = [] + for sub_filter in filters: + condition = _build_filter_condition(cls, sub_filter) + if condition is not None: + conditions.append(condition) + if not conditions: + return None + # Combine conditions using logical AND/OR + if operator == "and": + res_condition = conditions[0] + if len(conditions) > 1: + for condition in conditions[1:]: + res_condition = res_condition & condition + return res_condition + elif operator == "or": + res_condition = conditions[0] + if len(conditions) > 1: + for condition in conditions[1:]: + res_condition = res_condition | condition + return res_condition + + +def _static_filter_condition( + cls, + select_statement: SelectOfScalar, + static_filter: Filter | FilterExpression, +): + """ + Apply static filters to a select statement. + + Static filters are predefined conditions that don't depend on user input. + Supports both Filter and FilterExpression objects with logical operations. + + Args: + select_statement: SQLAlchemy select statement to modify + static_filter: filter condition to apply + + Returns: + Modified select statement with filter conditions + """ + condition = _build_filter_condition(cls, static_filter) + if condition is not None: + select_statement = select_statement.where(condition) + return select_statement + + +def _filter_condition( + select_statement: SelectOfScalar, + filter: str, + filter_fields: list[str], +): + """ + Apply text-based search filters to a select statement. + + Creates a case-insensitive LIKE search across multiple fields. + + Args: + select_statement: SQLAlchemy select statement to modify + filter: Search text to look for + filter_fields: List of field names to search in + + Returns: + Modified select statement with search conditions + """ + condition = None + for field in filter_fields: + if condition is not None: + condition = condition | (column(field).ilike(f"%{filter}%")) + else: + condition = column(field).ilike(f"%{filter}%") + return select_statement.where(condition) + + +async def _apply_rls_filters( + cls: type["BotEntity"], select_statement: SelectOfScalar, user: "UserBase" +): + """ + Apply Row Level Security (RLS) filters to restrict access based on user roles. + + This method uses the entity's rls_filters and rls_filters_params to apply + dynamic filtering conditions based on the user's roles and permissions. + + Args: + select_statement: SQLAlchemy select statement to modify + user: User whose access should be restricted + + Returns: + Modified select statement with RLS conditions + """ + # --- Check if RLS filters are defined for this entity --- + if cls.bot_entity_descriptor.rls_filters: + # Get parameters for RLS filters (may be sync or async) + params = [] + if cls.bot_entity_descriptor.rls_filters_params: + if iscoroutinefunction(cls.bot_entity_descriptor.rls_filters_params): + params = await cls.bot_entity_descriptor.rls_filters_params(user) + else: + params = cls.bot_entity_descriptor.rls_filters_params(user) + + # Create a copy of the RLS filters with parameter values substituted + rls_filters = _substitute_rls_parameters( + cls.bot_entity_descriptor.rls_filters, params + ) + + # Apply RLS filters with parameters + condition = _build_filter_condition(cls, rls_filters) + if condition is not None: + return select_statement.where(condition) + return select_statement + + +def _substitute_rls_parameters( + rls_filters: Filter | FilterExpression, params: list[Any] +) -> Filter | FilterExpression: + """ + Substitute parameter placeholders in RLS filters with actual values. + + Args: + rls_filters: RLS filters that may contain parameter placeholders + params: List of parameter values to substitute + + Returns: + RLS filters with parameters substituted + """ + # --- Substitute parameter in single filter --- + if isinstance(rls_filters, Filter): + if rls_filters.value_type == "param" and rls_filters.param_index is not None: + if 0 <= rls_filters.param_index < len(params): + # Create a new filter with the parameter value substituted + return Filter( + field=rls_filters.field, + operator=rls_filters.operator, + value_type="const", + value=params[rls_filters.param_index], + param_index=None, + ) + return rls_filters + # --- Recursively substitute parameters in all sub-filters --- + elif isinstance(rls_filters, FilterExpression): + substituted_filters = [] + for sub_filter in rls_filters.filters: + substituted_filter = _substitute_rls_parameters(sub_filter, params) + substituted_filters.append(substituted_filter) + return FilterExpression(rls_filters.operator, substituted_filters) + + +def is_bot_entity(type_: type) -> bool: + return hasattr(type_, "bot_entity_descriptor") + + +def is_bot_enum(type_: type) -> bool: + return hasattr(type_, "all_members") diff --git a/src/quickbot/plugin.py b/src/quickbot/plugin.py new file mode 100644 index 0000000..7e1f466 --- /dev/null +++ b/src/quickbot/plugin.py @@ -0,0 +1,9 @@ +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from quickbot import QBotApp + + +@runtime_checkable +class Registerable(Protocol): + def register(self, app: "QBotApp") -> None: ... diff --git a/src/quickbot/utils/main.py b/src/quickbot/utils/main.py index 6c1a4bb..be3bd78 100644 --- a/src/quickbot/utils/main.py +++ b/src/quickbot/utils/main.py @@ -7,11 +7,14 @@ from typing import Any, TYPE_CHECKING, Callable import ujson as json from quickbot.utils.serialization import deserialize - -from ..model.bot_entity import BotEntity -from ..model.bot_enum import BotEnum +from quickbot.model.permissions import ( + _extract_rls_filter_fields, + get_user_permissions, + _extract_filter_fields, +) from ..model.settings import Settings + from ..model.descriptors import ( BotContext, EntityList, @@ -26,22 +29,11 @@ from ..model.descriptors import ( from ..bot.handlers.context import CallbackCommand, ContextData, CommandContext if TYPE_CHECKING: + from ..model.bot_entity import BotEntity from ..model.user import UserBase from ..main import QBotApp -def get_user_permissions( - user: "UserBase", entity_descriptor: EntityDescriptor -) -> list[EntityPermission]: - permissions = list[EntityPermission]() - for permission, roles in entity_descriptor.permissions.items(): - for role in roles: - if role in user.roles: - permissions.append(permission) - break - return permissions - - def get_local_text(text: str, locale: str = None) -> str: if not locale: i18n = I18n.get_current(no_error=True) @@ -57,39 +49,6 @@ def get_local_text(text: str, locale: str = None) -> str: return obj.get(locale, obj[list(obj.keys())[0]]) -def check_entity_permission( - entity: BotEntity, user: "UserBase", permission: EntityPermission -) -> bool: - perm_mapping = { - EntityPermission.LIST: EntityPermission.LIST_ALL, - EntityPermission.READ: EntityPermission.READ_ALL, - EntityPermission.UPDATE: EntityPermission.UPDATE_ALL, - EntityPermission.CREATE: EntityPermission.CREATE_ALL, - EntityPermission.DELETE: EntityPermission.DELETE_ALL, - } - - if permission not in perm_mapping: - raise ValueError(f"Invalid permission: {permission}") - - entity_descriptor = entity.__class__.bot_entity_descriptor - permissions = get_user_permissions(user, entity_descriptor) - - if perm_mapping[permission] in permissions: - return True - - ownership_fields = entity_descriptor.ownership_fields - - for role in user.roles: - if role in ownership_fields: - if getattr(entity, ownership_fields[role]) == user.id: - return True - else: - if permission in permissions: - return True - - return False - - def get_send_message(message: Message | CallbackQuery): if isinstance(message, Message): return message.answer @@ -111,9 +70,9 @@ def clear_state(state_data: dict, clear_nav: bool = False): async def get_entity_item_repr( - entity: BotEntity, + entity: "BotEntity", context: BotContext, - item_repr: Callable[[BotEntity, BotContext], str] | None = None, + item_repr: Callable[["BotEntity", BotContext], str] | None = None, ) -> str: descr = entity.bot_entity_descriptor @@ -151,7 +110,7 @@ async def get_value_repr( if isinstance(value, bool): return "[✓]" if value else "[ ]" elif field_descriptor.is_list: - if issubclass(type_, BotEntity): + if hasattr(type_, "bot_entity_descriptor"): return f"[{ ', '.join( [ @@ -160,15 +119,15 @@ async def get_value_repr( ] ) }]" - elif issubclass(type_, BotEnum): + elif hasattr(type_, "all_members"): return f"[{', '.join(item.localized(locale) for item in value)}]" elif type_ is str: return f"[{', '.join([f'"{item}"' for item in value])}]" else: return f"[{', '.join([str(item) for item in value])}]" - elif issubclass(type_, BotEntity): + elif hasattr(type_, "bot_entity_descriptor"): return await get_entity_item_repr(entity=value, context=context) - elif issubclass(type_, BotEnum): + elif hasattr(type_, "all_members"): return value.localized(locale) elif isinstance(value, str): if field_descriptor and field_descriptor.localizable: @@ -187,12 +146,12 @@ async def get_callable_str( str | LazyProxy | Callable[[EntityDescriptor, BotContext], str] - | Callable[[BotEntity, BotContext], str] - | Callable[[FieldDescriptor, BotEntity, BotContext], str] + | Callable[["BotEntity", BotContext], str] + | Callable[[FieldDescriptor, "BotEntity", BotContext], str] ), context: BotContext, descriptor: FieldDescriptor | EntityDescriptor | None = None, - entity: BotEntity | Any = None, + entity: "BotEntity | Any" = None, ) -> str: if isinstance(callable_str, str): return callable_str @@ -217,17 +176,13 @@ async def get_callable_str( return callable_str(descriptor, entity, context) else: return callable_str(entity or descriptor, context) - else: - raise ValueError( - f"Invalid callable type: {type(callable_str)}. Expected str, LazyProxy or callable." - ) def get_entity_descriptor( app: "QBotApp", callback_data: ContextData ) -> EntityDescriptor: if callback_data.entity_name: - return app.entity_metadata.entity_descriptors[callback_data.entity_name] + return app.bot_metadata.entity_descriptors[callback_data.entity_name] return None @@ -283,7 +238,7 @@ async def build_field_sequence( entity_data = state_data.get("entity_data", {}) field_sequence = list[str]() - # exclude ownership fields from edit if user has no CREATE_ALL/UPDATE_ALL permission + # exclude RLS fields from edit if user has no CREATE_ALL/UPDATE_ALL permission user_permissions = get_user_permissions(user, entity_descriptor) for fd in entity_descriptor.fields_descriptors.values(): if isinstance(fd.is_visible_in_edit_form, bool): @@ -306,31 +261,30 @@ async def build_field_sequence( or fd.default_factory is not None ): skip = True - for own_field in entity_descriptor.ownership_fields.items(): - if ( - own_field[1].rstrip("_id") == fd.field_name.rstrip("_id") - and own_field[0] in user.roles - and ( - ( - EntityPermission.CREATE_ALL not in user_permissions - and callback_data.context == CommandContext.ENTITY_CREATE - ) - or ( - EntityPermission.UPDATE_ALL not in user_permissions - and callback_data.context == CommandContext.ENTITY_EDIT - ) + # Check RLS filters for field visibility + if entity_descriptor.rls_filters: + # Get RLS filter fields that should be auto-filled and hidden from user + rls_filter_fields = _extract_rls_filter_fields(entity_descriptor) + if fd.field_name in rls_filter_fields and ( + ( + EntityPermission.CREATE_ALL not in user_permissions + and callback_data.context == CommandContext.ENTITY_CREATE + ) + or ( + EntityPermission.UPDATE_ALL not in user_permissions + and callback_data.context == CommandContext.ENTITY_EDIT ) ): skip = True - break - if ( - prev_form_list - and prev_form_list.static_filters - and fd.field_name.rstrip("_id") - in [f.field_name.rstrip("_id") for f in prev_form_list.static_filters] - ): - skip = True + if prev_form_list and prev_form_list.static_filters: + static_filter_fields = _extract_filter_fields( + prev_form_list.static_filters, entity_descriptor.type_ + ) + if fd.field_name.rstrip("_id") in [ + f.rstrip("_id") for f in static_filter_fields + ]: + skip = True if not skip: field_sequence.append(fd.field_name) diff --git a/src/quickbot/utils/serialization.py b/src/quickbot/utils/serialization.py index 2c15c54..2b5e128 100644 --- a/src/quickbot/utils/serialization.py +++ b/src/quickbot/utils/serialization.py @@ -6,9 +6,7 @@ from typing import Any, Union, get_origin, get_args from types import UnionType, NoneType import ujson as json -from ..model.bot_entity import BotEntity -from ..model.bot_enum import BotEnum -from ..model.descriptors import FieldDescriptor +from quickbot.model.descriptors import FieldDescriptor async def deserialize[T](session: AsyncSession, type_: type[T], value: str = None) -> T: @@ -28,7 +26,7 @@ async def deserialize[T](session: AsyncSession, type_: type[T], value: str = Non arg_type = args[0] values = json.loads(value) if value else [] if arg_type: - if issubclass(arg_type, BotEntity): + if hasattr(arg_type, "bot_entity_descriptor"): ret = list[arg_type]() items = ( await session.exec(select(arg_type).where(column("id").in_(values))) @@ -36,17 +34,17 @@ async def deserialize[T](session: AsyncSession, type_: type[T], value: str = Non for item in items: ret.append(item) return ret - elif issubclass(arg_type, BotEnum): + elif hasattr(arg_type, "all_members"): return [arg_type(value) for value in values] else: return [arg_type(value) for value in values] else: return values - elif issubclass(type_, BotEntity): + elif hasattr(type_, "bot_entity_descriptor"): if is_optional and not value: return None return await session.get(type_, int(value)) - elif issubclass(type_, BotEnum): + elif hasattr(type_, "all_members"): if is_optional and not value: return None return type_(value) @@ -79,12 +77,12 @@ def serialize(value: Any, field_descriptor: FieldDescriptor) -> str: type_ = field_descriptor.type_base if field_descriptor.is_list: - if issubclass(type_, BotEntity): + if hasattr(type_, "bot_entity_descriptor"): return json.dumps([item.id for item in value], ensure_ascii=False) - elif issubclass(type_, BotEnum): + elif hasattr(type_, "all_members"): return json.dumps([item.value for item in value], ensure_ascii=False) else: return json.dumps(value, ensure_ascii=False) - elif issubclass(type_, BotEntity): + elif hasattr(type_, "bot_entity_descriptor"): return str(value.id) if value else "" return str(value)