crud service
All checks were successful
Build Docs / changes (push) Successful in 30s
Build Docs / build-docs (push) Has been skipped
Build Docs / deploy-docs (push) Has been skipped

This commit is contained in:
Alexander Kalinovsky
2025-08-11 20:47:39 +03:00
parent a078cdfd86
commit 4df67c93d4
33 changed files with 2358 additions and 334 deletions

View File

@@ -1,6 +1,7 @@
from .main import QBotApp as QBotApp, Config as Config from .main import QBotApp as QBotApp, Config as Config
from .router import Router as Router from .router import Router as Router
from .model.bot_entity import BotEntity as BotEntity from .model.bot_entity import BotEntity as BotEntity
from .model.bot_process import BotProcess as BotProcess
from .model.bot_enum import BotEnum as BotEnum, EnumMember as EnumMember from .model.bot_enum import BotEnum as BotEnum, EnumMember as EnumMember
from .bot.handlers.context import ( from .bot.handlers.context import (
ContextData as ContextData, ContextData as ContextData,
@@ -20,4 +21,5 @@ from .model.descriptors import (
FieldEditButton as FieldEditButton, FieldEditButton as FieldEditButton,
InlineButton as InlineButton, InlineButton as InlineButton,
FormField as FormField, FormField as FormField,
Process as Process,
) )

View File

@@ -0,0 +1,37 @@
from typing import Annotated, TYPE_CHECKING
from fastapi import Depends, HTTPException, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError
from sqlmodel.ext.asyncio.session import AsyncSession
from quickbot.auth.jwt import decode_access_token
from quickbot.db import get_db
from quickbot.model.user import UserBase
if TYPE_CHECKING:
from quickbot import QBotApp
security_scheme = HTTPBearer(
scheme_name="bearerAuth",
bearerFormat="JWT",
)
async def get_current_user(
request: Request,
db_session: Annotated[AsyncSession, Depends(get_db)],
credentials: HTTPAuthorizationCredentials = Depends(security_scheme),
) -> UserBase:
try:
payload = decode_access_token(credentials.credentials)
user_id = payload.get("sub")
if user_id is None:
raise HTTPException(status_code=401, detail="Invalid token")
app: QBotApp = request.app
user = await app.user_class.get(
session=db_session,
id=int(user_id),
)
return user
except JWTError:
raise HTTPException(status_code=401, detail="Invalid token")

View File

@@ -0,0 +1,49 @@
from fastapi import Depends, Request
from pydantic import BaseModel
from sqlmodel.ext.asyncio.session import AsyncSession
from typing import Annotated, TYPE_CHECKING
from ..db import get_db
from ..model.descriptors import EntityDescriptor
from .depends import get_current_user
if TYPE_CHECKING:
from ..main import QBotApp
from ..model.user import UserBase
class ListParams(BaseModel):
query: str = ""
order_by: str = ""
limit: int = 100
offset: int = 0
async def list_entity_items(
db_session: Annotated[AsyncSession, Depends(get_db)],
request: Request,
params: Annotated[ListParams, Depends()],
current_user=Depends(get_current_user),
):
entity_descriptor: EntityDescriptor = request.app.bot_metadata.entity_descriptors[
request.url.path.split("/")[-1]
]
entity_list = await entity_descriptor.crud.list_all(
db_session=db_session,
user=current_user,
)
return entity_list
async def get_me(
db_session: Annotated[AsyncSession, Depends(get_db)],
request: Request,
current_user: Annotated["UserBase", Depends(get_current_user)],
):
app: "QBotApp" = request.app
user = await app.user_class.bot_entity_descriptor.crud.get_by_id(
db_session=db_session,
user=current_user,
id=current_user.id,
)
return user

View File

@@ -1,10 +1,12 @@
from aiogram.types import Update from aiogram.types import Update
from fastapi import APIRouter, Request, Response, Depends from fastapi import APIRouter, Request, Response, Depends, HTTPException, Body
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from typing import Annotated from typing import Annotated
from ..db import get_db from ..db import get_db
from ..main import QBotApp from ..main import QBotApp
from ..auth.telegram import check_telegram_auth
from ..auth.jwt import create_access_token
from logging import getLogger from logging import getLogger
@@ -49,6 +51,19 @@ async def telegram_webhook(
return Response(status_code=200) return Response(status_code=200)
@router.post("/auth")
async def telegram_login(request: Request, data: dict = Body(...)):
if not check_telegram_auth(data, request.app.config.TELEGRAM_BOT_TOKEN):
raise HTTPException(status_code=401, detail="Invalid Telegram auth")
payload = {
"sub": str(data["id"]),
"first_name": data.get("first_name"),
"username": data.get("username"),
}
token = create_access_token(payload)
return {"access_token": token, "token_type": "bearer"}
# async def feed_bot_update( # async def feed_bot_update(
# app: QBotApp, # app: QBotApp,
# update: Update, # update: Update,

20
src/quickbot/auth/jwt.py Normal file
View File

@@ -0,0 +1,20 @@
from datetime import datetime, timedelta
from jose import jwt
SECRET_KEY = "your_secret_key" # TODO: вынести в конфиг
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 1 неделя
def create_access_token(data: dict, expires_delta: timedelta = None):
to_encode = data.copy()
expire = datetime.utcnow() + (
expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def decode_access_token(token: str):
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])

View File

@@ -0,0 +1,15 @@
import hashlib
import hmac
def check_telegram_auth(data: dict, bot_token: str) -> bool:
auth_data = data.copy()
hash_ = auth_data.pop("hash", None)
if not hash_:
return False
data_check_string = "\n".join([f"{k}={v}" for k, v in sorted(auth_data.items())])
secret_key = hashlib.sha256(bot_token.encode()).digest()
hmac_hash = hmac.new(
secret_key, data_check_string.encode(), hashlib.sha256
).hexdigest()
return hmac_hash == hash_

View File

@@ -179,17 +179,31 @@ async def render_entity_picker(
entity_filter = None entity_filter = None
list_all = EntityPermission.LIST_ALL in permissions list_all = EntityPermission.LIST_ALL in permissions
if list_all or EntityPermission.LIST in permissions: if list_all or EntityPermission.LIST_RLS in permissions:
if ( if (
field_descriptor.ep_parent_field field_descriptor.ep_parent_field
and field_descriptor.ep_child_field and field_descriptor.ep_child_field
and callback_data.entity_id and callback_data.entity_id
): ):
if callable(field_descriptor.ep_parent_field):
parent_field = field_descriptor.ep_parent_field(
field_descriptor.entity_descriptor.type_
).key
else:
parent_field = field_descriptor.ep_parent_field
if callable(field_descriptor.ep_child_field):
child_field = field_descriptor.ep_child_field(
field_descriptor.entity_descriptor.type_
).key
else:
child_field = field_descriptor.ep_child_field
entity = await field_descriptor.entity_descriptor.type_.get( entity = await field_descriptor.entity_descriptor.type_.get(
session=db_session, id=callback_data.entity_id session=db_session, id=callback_data.entity_id
) )
value = getattr(entity, field_descriptor.ep_parent_field) value = getattr(entity, parent_field)
ext_filter = column(field_descriptor.ep_child_field).__eq__(value) ext_filter = column(child_field).__eq__(value)
else: else:
ext_filter = None ext_filter = None

View File

@@ -11,9 +11,9 @@ from quickbot.model.descriptors import BotContext, EntityForm
from ....model import EntityPermission from ....model import EntityPermission
from ....model.settings import Settings from ....model.settings import Settings
from ....model.user import UserBase from ....model.user import UserBase
from ....model.permissions import check_entity_permission
from ....utils.main import ( from ....utils.main import (
build_field_sequence, build_field_sequence,
check_entity_permission,
get_field_descriptor, get_field_descriptor,
clear_state, clear_state,
) )
@@ -109,8 +109,8 @@ async def field_editor(message: Message | CallbackQuery, **kwargs):
entity = await entity_descriptor.type_.get( entity = await entity_descriptor.type_.get(
session=db_session, id=int(callback_data.entity_id) session=db_session, id=int(callback_data.entity_id)
) )
if check_entity_permission( if await check_entity_permission(
entity=entity, user=user, permission=EntityPermission.UPDATE entity=entity, user=user, permission=EntityPermission.UPDATE_RLS
): ):
old_values = {} old_values = {}
@@ -188,8 +188,8 @@ async def field_editor(message: Message | CallbackQuery, **kwargs):
entity = await entity_descriptor.type_.get( entity = await entity_descriptor.type_.get(
session=kwargs["db_session"], id=int(callback_data.entity_id) session=kwargs["db_session"], id=int(callback_data.entity_id)
) )
if check_entity_permission( if await check_entity_permission(
entity=entity, user=user, permission=EntityPermission.READ entity=entity, user=user, permission=EntityPermission.READ_RLS
): ):
if entity: if entity:
form_name = ( form_name = (

View File

@@ -19,12 +19,14 @@ from ....model.descriptors import (
EntityForm, EntityForm,
EntityList, EntityList,
FieldDescriptor, FieldDescriptor,
Filter,
FilterExpression,
) )
from ....model.language import LanguageBase from ....model.language import LanguageBase
from ....auth import authorize_command from ....auth import authorize_command
from ....model.permissions import check_entity_permission
from ....utils.main import ( from ....utils.main import (
get_user_permissions, get_user_permissions,
check_entity_permission,
clear_state, clear_state,
get_entity_descriptor, get_entity_descriptor,
get_field_descriptor, get_field_descriptor,
@@ -328,12 +330,39 @@ async def process_field_edit_callback(message: Message | CallbackQuery, **kwargs
]: ]:
user_permissions = get_user_permissions(user, entity_descriptor) user_permissions = get_user_permissions(user, entity_descriptor)
for role in user.roles: if entity_descriptor.rls_filters:
if ( filters = []
role in entity_descriptor.ownership_fields if isinstance(entity_descriptor.rls_filters, Filter):
and EntityPermission.CREATE_ALL not in user_permissions filters = [entity_descriptor.rls_filters]
elif (
isinstance(entity_descriptor.rls_filters, FilterExpression)
and entity_descriptor.rls_filters.operator == "and"
and all(
isinstance(f, Filter)
for f in entity_descriptor.rls_filters.filters
)
): ):
entity_data[entity_descriptor.ownership_fields[role]] = user.id filters = entity_descriptor.rls_filters.filters
filter_params = []
if filters and entity_descriptor.rls_filters_params:
if iscoroutinefunction(entity_descriptor.rls_filters_params):
filter_params = await entity_descriptor.rls_filters_params(
user
)
else:
filter_params = entity_descriptor.rls_filters_params(user)
for f in filters:
if f.operator == "==":
if isinstance(f.field, str):
field_name = f.field
else:
field_name = f.field(entity_descriptor.type_).key
entity_data[field_name] = (
f.value
if f.value_type == "const"
else filter_params[f.param_index]
)
deser_entity_data = { deser_entity_data = {
key: await deserialize( key: await deserialize(
@@ -356,7 +385,7 @@ async def process_field_edit_callback(message: Message | CallbackQuery, **kwargs
entity_type = entity_descriptor.type_ entity_type = entity_descriptor.type_
user_permissions = get_user_permissions(user, entity_descriptor) user_permissions = get_user_permissions(user, entity_descriptor)
if ( if (
EntityPermission.CREATE not in user_permissions EntityPermission.CREATE_RLS not in user_permissions
and EntityPermission.CREATE_ALL not in user_permissions and EntityPermission.CREATE_ALL not in user_permissions
): ):
return await message.answer( return await message.answer(
@@ -438,8 +467,8 @@ async def process_field_edit_callback(message: Message | CallbackQuery, **kwargs
text=(await Settings.get(Settings.APP_STRINGS_NOT_FOUND)) text=(await Settings.get(Settings.APP_STRINGS_NOT_FOUND))
) )
if not check_entity_permission( if not await check_entity_permission(
entity=entity, user=user, permission=EntityPermission.UPDATE entity=entity, user=user, permission=EntityPermission.UPDATE_RLS
): ):
return await message.answer( return await message.answer(
text=(await Settings.get(Settings.APP_STRINGS_FORBIDDEN)) text=(await Settings.get(Settings.APP_STRINGS_FORBIDDEN))

View File

@@ -18,8 +18,8 @@ from ....model.bot_entity import BotEntity
from ....model.settings import Settings from ....model.settings import Settings
from ....model.user import UserBase from ....model.user import UserBase
from ....model import EntityPermission from ....model import EntityPermission
from ....model.permissions import check_entity_permission
from ....utils.main import ( from ....utils.main import (
check_entity_permission,
get_send_message, get_send_message,
clear_state, clear_state,
get_value_repr, get_value_repr,
@@ -84,15 +84,15 @@ async def entity_item(
# is_owned = issubclass(entity_type, OwnedBotEntity) # is_owned = issubclass(entity_type, OwnedBotEntity)
if query and not check_entity_permission( if query and not await check_entity_permission(
entity=entity_item, user=user, permission=EntityPermission.READ entity=entity_item, user=user, permission=EntityPermission.READ_RLS
): ):
return await query.answer( return await query.answer(
text=(await Settings.get(Settings.APP_STRINGS_FORBIDDEN)) text=(await Settings.get(Settings.APP_STRINGS_FORBIDDEN))
) )
can_edit = check_entity_permission( can_edit = await check_entity_permission(
entity=entity_item, user=user, permission=EntityPermission.UPDATE entity=entity_item, user=user, permission=EntityPermission.UPDATE_RLS
) )
form: EntityForm = entity_descriptor.forms.get( form: EntityForm = entity_descriptor.forms.get(
@@ -250,8 +250,8 @@ async def entity_item(
) )
if ( if (
check_entity_permission( await check_entity_permission(
entity=entity_item, user=user, permission=EntityPermission.DELETE entity=entity_item, user=user, permission=EntityPermission.DELETE_RLS
) )
and form.show_delete_button and form.show_delete_button
): ):
@@ -303,7 +303,7 @@ async def entity_item(
) )
async def item_repr(entity_item: BotEntity, context: BotContext[UserBase]): async def item_repr(entity_item: BotEntity, context: BotContext):
entity_descriptor = entity_item.bot_entity_descriptor entity_descriptor = entity_item.bot_entity_descriptor
user = context.user user = context.user
entity_caption = ( entity_caption = (
@@ -349,20 +349,6 @@ async def item_repr(entity_item: BotEntity, context: BotContext[UserBase]):
if not field_visible: if not field_visible:
continue continue
skip = False
for own_field in entity_descriptor.ownership_fields.items():
if (
own_field[1].rstrip("_id") == field_descriptor.field_name.rstrip("_id")
and own_field[0] in user.roles
and EntityPermission.READ_ALL not in user_permissions
):
skip = True
break
if skip:
continue
if field_descriptor.caption_value: if field_descriptor.caption_value:
item_text += f"\n{ item_text += f"\n{
await get_callable_str( await get_callable_str(

View File

@@ -12,8 +12,8 @@ from ..context import ContextData, CallbackCommand
from ....model.user import UserBase from ....model.user import UserBase
from ....model.settings import Settings from ....model.settings import Settings
from ....model import EntityPermission from ....model import EntityPermission
from ....model.permissions import check_entity_permission
from ....utils.main import ( from ....utils.main import (
check_entity_permission,
get_entity_item_repr, get_entity_item_repr,
get_entity_descriptor, get_entity_descriptor,
) )
@@ -42,8 +42,8 @@ async def entity_delete_callback(query: CallbackQuery, **kwargs):
session=db_session, id=int(callback_data.entity_id) session=db_session, id=int(callback_data.entity_id)
) )
if not check_entity_permission( if not await check_entity_permission(
entity=entity, user=user, permission=EntityPermission.DELETE entity=entity, user=user, permission=EntityPermission.DELETE_RLS
): ):
return await query.answer( return await query.answer(
text=(await Settings.get(Settings.APP_STRINGS_FORBIDDEN)) text=(await Settings.get(Settings.APP_STRINGS_FORBIDDEN))

View File

@@ -95,7 +95,7 @@ async def entity_list(
) )
if ( if (
EntityPermission.CREATE in user_permissions EntityPermission.CREATE_RLS in user_permissions
or EntityPermission.CREATE_ALL in user_permissions or EntityPermission.CREATE_ALL in user_permissions
) and form_list.show_add_new_button: ) and form_list.show_add_new_button:
if form_item.edit_field_sequence: if form_item.edit_field_sequence:
@@ -136,8 +136,8 @@ async def entity_list(
) )
if ( if (
list_all list_all
or EntityPermission.LIST in user_permissions or EntityPermission.LIST_RLS in user_permissions
or EntityPermission.READ in user_permissions or EntityPermission.READ_RLS in user_permissions
): ):
if form_list.pagination: if form_list.pagination:
page_size = await Settings.get(Settings.PAGE_SIZE) page_size = await Settings.get(Settings.PAGE_SIZE)
@@ -265,10 +265,10 @@ async def entity_list(
else: else:
entity_text = entity_descriptor.name entity_text = entity_descriptor.name
if entity_descriptor.description: if entity_descriptor.ui_description:
entity_text = f"{entity_text} { entity_text = f"{entity_text} {
await get_callable_str( await get_callable_str(
callable_str=entity_descriptor.description, callable_str=entity_descriptor.ui_description,
context=context, context=context,
descriptor=entity_descriptor, descriptor=entity_descriptor,
) )

View File

@@ -42,7 +42,7 @@ async def entities_menu(
): ):
keyboard_builder = InlineKeyboardBuilder() keyboard_builder = InlineKeyboardBuilder()
entity_metadata = app.entity_metadata entity_metadata = app.bot_metadata
for entity in entity_metadata.entity_descriptors.values(): for entity in entity_metadata.entity_descriptors.values():
if entity.show_in_entities_menu: if entity.show_in_entities_menu:

View File

@@ -1,5 +1,21 @@
"""
main.py - QuickBot RAD Framework Main Application Module
Defines QBotApp, the main entry point for the QuickBot rapid application development (RAD) framework for Telegram bots.
Integrates FastAPI (HTTP API), Aiogram (Telegram bot logic), SQLModel (async DB), and i18n (internationalization).
Key Features:
- Dynamic registration of CRUD API endpoints for all entities
- Telegram bot command and webhook management
- Row-level security (RLS) and user management
- Middleware for authentication and localization
- Swagger UI with Telegram login integration
"""
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Callable, Any, Generic, TypeVar from inspect import iscoroutinefunction
from typing import Union
from typing import Annotated, Callable, Any, Generic, TypeVar
from aiogram import Bot, Dispatcher from aiogram import Bot, Dispatcher
from aiogram.client.session.aiohttp import AiohttpSession from aiogram.client.session.aiohttp import AiohttpSession
from aiogram.client.telegram import TelegramAPIServer from aiogram.client.telegram import TelegramAPIServer
@@ -7,15 +23,21 @@ from aiogram.client.default import DefaultBotProperties
from aiogram.types import Message, BotCommand as AiogramBotCommand from aiogram.types import Message, BotCommand as AiogramBotCommand
from aiogram.utils.callback_answer import CallbackAnswerMiddleware from aiogram.utils.callback_answer import CallbackAnswerMiddleware
from aiogram.utils.i18n import I18n from aiogram.utils.i18n import I18n
from fastapi import FastAPI, Request from fastapi import Depends, FastAPI, Request, Body, Path, HTTPException
from fastapi.applications import Lifespan, AppType from fastapi.applications import Lifespan, AppType
from fastapi.datastructures import State from fastapi.datastructures import State
from logging import getLogger from logging import getLogger
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.openapi.utils import get_openapi
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from quickbot.bot.handlers.user_handlers.main import command_handler from quickbot.bot.handlers.user_handlers.main import command_handler
from quickbot.db import get_db
from quickbot.model.list_schema import ListSchema
from quickbot.plugin import Registerable
from quickbot.utils.main import clear_state from quickbot.utils.main import clear_state
from quickbot.utils.navigation import save_navigation_context from quickbot.utils.navigation import save_navigation_context
from quickbot.model.crud_service import NotFoundError, ForbiddenError
from .config import Config from .config import Config
from .bot.handlers.forms.entity_form import entity_item from .bot.handlers.forms.entity_form import entity_item
@@ -23,10 +45,18 @@ from .fsm.db_storage import DbStorage
from .middleware.telegram import AuthMiddleware, I18nMiddleware from .middleware.telegram import AuthMiddleware, I18nMiddleware
from .model.bot_entity import BotEntity from .model.bot_entity import BotEntity
from .model.user import UserBase from .model.user import UserBase
from .model.entity_metadata import EntityMetadata from .model.bot_metadata import BotMetadata
from .model.descriptors import BotCommand from .model.descriptors import (
BotCommand,
EntityDescriptor,
ProcessDescriptor,
BotContext,
)
from .model.crud_command import CrudCommand
from .bot.handlers.context import CallbackCommand, ContextData from .bot.handlers.context import CallbackCommand, ContextData
from .router import Router from .router import Router
from .api_route.models import list_entity_items, get_me
from .api_route.depends import get_current_user
logger = getLogger(__name__) logger = getLogger(__name__)
@@ -60,7 +90,20 @@ async def default_lifespan(app: "QBotApp"):
class QBotApp(Generic[UserType, ConfigType], FastAPI): class QBotApp(Generic[UserType, ConfigType], FastAPI):
""" """
Main class for the QBot application Main application class for QuickBot RAD framework.
Integrates FastAPI, Aiogram, SQLModel, and i18n for rapid Telegram bot development.
Handles bot initialization, API registration, command routing, and RLS.
Args:
config: App configuration (see quickbot/config.py)
user_class: User model class (subclass of UserBase)
bot_start: Optional custom bot start handler
lifespan: Optional FastAPI lifespan context
lifespan_bot_init: Whether to run bot init on startup
lifespan_set_webhook: Whether to set webhook on startup
webhook_handler: Optional custom webhook handler
allowed_updates: List of Telegram update types to allow
""" """
def __init__( def __init__(
@@ -82,6 +125,7 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI):
allowed_updates: list[str] | None = None, allowed_updates: list[str] | None = None,
**kwargs, **kwargs,
): ):
# --- Initialize default user class if not provided ---
if user_class is None: if user_class is None:
from .model.default_user import DefaultUser from .model.default_user import DefaultUser
@@ -92,14 +136,18 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI):
) )
self.user_class = user_class self.user_class = user_class
self.entity_metadata: EntityMetadata = user_class.entity_metadata self.bot_metadata: BotMetadata = user_class.bot_metadata
self.bot_metadata.app = self
self.config = config self.config = config
self.lifespan = lifespan self.lifespan = lifespan
# --- Setup Telegram API server and session ---
api_server = TelegramAPIServer.from_base( api_server = TelegramAPIServer.from_base(
self.config.TELEGRAM_BOT_SERVER, self.config.TELEGRAM_BOT_SERVER,
is_local=self.config.TELEGRAM_BOT_SERVER_IS_LOCAL, is_local=self.config.TELEGRAM_BOT_SERVER_IS_LOCAL,
) )
session = AiohttpSession(api=api_server) session = AiohttpSession(api=api_server)
# --- Initialize Telegram Bot instance ---
self.bot = Bot( self.bot = Bot(
token=self.config.TELEGRAM_BOT_TOKEN, token=self.config.TELEGRAM_BOT_TOKEN,
session=session, session=session,
@@ -108,26 +156,30 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI):
), ),
) )
# --- Setup Aiogram dispatcher with DB storage for FSM ---
dp = Dispatcher(storage=DbStorage()) dp = Dispatcher(storage=DbStorage())
# --- Setup i18n and middleware ---
self.i18n = I18n(path="locales", default_locale="en", domain="messages") self.i18n = I18n(path="locales", default_locale="en", domain="messages")
i18n_middleware = I18nMiddleware(user_class=user_class, i18n=self.i18n) i18n_middleware = I18nMiddleware(user_class=user_class, i18n=self.i18n)
i18n_middleware.setup(dp) i18n_middleware.setup(dp)
dp.callback_query.middleware(CallbackAnswerMiddleware()) dp.callback_query.middleware(CallbackAnswerMiddleware())
# --- Register core routers (start, main menu) ---
from .bot.handlers.start import router as start_router from .bot.handlers.start import router as start_router
dp.include_router(start_router) dp.include_router(start_router)
from .bot.handlers.menu.main import router as main_menu_router from .bot.handlers.menu.main import router as main_menu_router
auth = AuthMiddleware(user_class=user_class) # Register authentication middleware for menu routers
main_menu_router.message.middleware.register(auth) self.auth = AuthMiddleware(user_class=user_class)
main_menu_router.callback_query.middleware.register(auth) main_menu_router.message.middleware.register(self.auth)
main_menu_router.callback_query.middleware.register(self.auth)
dp.include_router(main_menu_router) dp.include_router(main_menu_router)
self.dp = dp self.dp = dp
# --- Extension points for custom bot start and webhook handlers ---
self.start_handler = bot_start self.start_handler = bot_start
self.webhook_handler = webhook_handler self.webhook_handler = webhook_handler
self.bot_commands = dict[str, BotCommand]() self.bot_commands = dict[str, BotCommand]()
@@ -135,8 +187,15 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI):
self.lifespan_bot_init = lifespan_bot_init self.lifespan_bot_init = lifespan_bot_init
self.lifespan_set_webhook = lifespan_set_webhook self.lifespan_set_webhook = lifespan_set_webhook
super().__init__(lifespan=default_lifespan, **kwargs) # --- Initialize FastAPI with custom lifespan and no default docs ---
super().__init__(lifespan=default_lifespan, docs_url=None, **kwargs)
self.bot_metadata.app_state = self.state
# --- Initialize plugins ---
self.plugins = dict[str, Any]()
# --- Register Telegram API router for /telegram endpoints (for webhook and auth) ---
from .api_route.telegram import router as telegram_router from .api_route.telegram import router as telegram_router
self.include_router(telegram_router, prefix="/telegram", tags=["telegram"]) self.include_router(telegram_router, prefix="/telegram", tags=["telegram"])
@@ -144,29 +203,422 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI):
self.root_router._commands = self.bot_commands self.root_router._commands = self.bot_commands
self.command = self.root_router.command self.command = self.root_router.command
# --- Register all entity CRUD endpoints dynamically (for models API) ---
self.register_models_api()
# --- Register custom Swagger UI with Telegram login (for docs) ---
self.register_swagger_ui_html()
def register_plugin(self, plugin: Any):
self.plugins[type(plugin).__name__] = plugin
if isinstance(plugin, Registerable):
plugin.register(self)
def register_swagger_ui_html(self):
"""
Register a custom /docs endpoint with Telegram login widget and JWT support for Swagger UI.
"""
def swagger_ui_html():
return HTMLResponse(f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="description" content="SwaggerUI" />
<title>QuickBot API</title>
<link rel="stylesheet" href="https://unpkg.com/swagger-ui-dist@5.26.2/swagger-ui.css" />
</head>
<body>
<div id="swagger-ui"></div>
<script src="https://unpkg.com/swagger-ui-dist@5.26.2/swagger-ui-bundle.js" crossorigin></script>
<script>
function logout() {{
localStorage.removeItem("jwt");
window.ui.preauthorizeApiKey("bearerAuth", "invalid.token");
//location.reload();
}}
function injectTelegramWidget(jwt) {{
const auth_wrapper = document.querySelector(".auth-wrapper");
if (!auth_wrapper) return;
const oldAuthBtn = auth_wrapper.querySelector(".authorize");
if (oldAuthBtn) oldAuthBtn.remove();
const authContainer = document.createElement("div");
authContainer.className = "auth-info";
/*if (jwt) {{
try {{
console.log(jwt);
const payload = JSON.parse(atob(jwt.split('.')[1]));
const username = payload.username || payload.id;
authContainer.innerHTML = `
<span>👤 ${{username}}</span>
<button class="logout-btn" onclick="logout()">Logout</button>
`;
}} catch (e) {{
authContainer.textContent = "JWT error";
}}
}} else {{*/
const script = document.createElement("script");
script.async = true;
script.src = "https://telegram.org/js/telegram-widget.js";
script.setAttribute("data-telegram-login", "{self.config.TELEGRAM_BOT_USERNAME}");
script.setAttribute("data-size", "large");
script.setAttribute("data-onauth", "handleTelegramAuth(user)");
script.setAttribute("data-request-access", "write");
authContainer.appendChild(script);
//}}
auth_wrapper.appendChild(authContainer);
}}
function waitForElement(selector, callback) {{
const el = document.querySelector(selector);
if (el) {{
callback(el);
return;
}}
const observer = new MutationObserver(() => {{
const el = document.querySelector(selector);
if (el) {{
observer.disconnect();
callback(el);
}}
}});
observer.observe(document.body, {{ childList: true, subtree: true }});
}}
function handleTelegramAuth(user) {{
fetch('/telegram/auth', {{
method: 'POST',
headers: {{ 'Content-Type': 'application/json' }},
body: JSON.stringify(user)
}})
.then(res => res.json())
.then(data => {{
localStorage.setItem("jwt", data.access_token);
window.ui.preauthorizeApiKey("bearerAuth", data.access_token);
//location.reload();
}})
.catch(() => alert("Authorization error"));
}}
window.handleTelegramAuth = handleTelegramAuth;
window.onload = function () {{
const jwt = localStorage.getItem("jwt", null);
window.ui = SwaggerUIBundle({{
url: '/openapi.json',
dom_id: '#swagger-ui',
onComplete: function () {{
if (jwt) {{
window.ui.preauthorizeApiKey("bearerAuth", jwt);
}}
waitForElement(".auth-wrapper", (el) => {{
injectTelegramWidget(jwt);
}});
}}
}});
}};
</script>
</body>
</html>
""")
self.router.add_api_route(
path="/docs",
include_in_schema=False,
endpoint=swagger_ui_html,
methods=["GET"],
tags=["docs"],
)
def openapi_json():
schema = get_openapi(
title="FastAPI + Telegram OAuth",
version="1.0.0",
description="Swagger с Telegram Login",
routes=self.routes,
)
schema["components"]["securitySchemes"] = {
"bearerAuth": {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
}
}
for path in schema["paths"].values():
for op in path.values():
op.setdefault("security", [{"bearerAuth": []}])
return JSONResponse(schema)
self.router.add_api_route(
path="/openapi.json",
endpoint=openapi_json,
methods=["GET"],
tags=["docs"],
include_in_schema=False,
)
def register_models_api(self):
"""
Dynamically register CRUD API endpoints for all entities in the app's metadata.
Endpoints: list, create, get by id, update, delete.
Uses FastAPI dependency injection for database session and user authentication.
"""
def make_create_api_endpoint(entity_descriptor: EntityDescriptor):
async def create_entity(
db_session: Annotated[AsyncSession, Depends(get_db)],
request: Request,
obj_in: entity_descriptor.crud.create_schema = Body(...),
current_user=Depends(get_current_user),
):
try:
ret_obj = await entity_descriptor.crud.create(
db_session=db_session,
user=current_user,
model=obj_in,
)
except NotFoundError as e:
raise HTTPException(status_code=404, detail=e.args[0])
except ForbiddenError as e:
raise HTTPException(status_code=403, detail=e.args[0])
except Exception as e:
logger.error(f"Error creating entity: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
return ret_obj
return create_entity
def make_update_api_endpoint(entity_descriptor: EntityDescriptor):
async def update_entity(
db_session: Annotated[AsyncSession, Depends(get_db)],
request: Request,
id: int = Path(..., description="ID of the entity to update"),
obj_in: entity_descriptor.crud.update_schema = Body(...),
current_user=Depends(get_current_user),
):
try:
entity = await entity_descriptor.crud.update(
db_session=db_session,
id=id,
model=obj_in,
user=current_user,
)
except NotFoundError as e:
raise HTTPException(status_code=404, detail=e.args[0])
except ForbiddenError as e:
raise HTTPException(status_code=403, detail=e.args[0])
except Exception as e:
logger.error(f"Error updating entity: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
return entity
return update_entity
def make_delete_api_endpoint(entity_descriptor: EntityDescriptor):
async def delete_entity(
db_session: Annotated[AsyncSession, Depends(get_db)],
request: Request,
id: int = Path(..., description="ID of the entity to delete"),
current_user=Depends(get_current_user),
):
try:
entity = await entity_descriptor.crud.delete(
db_session=db_session,
id=id,
user=current_user,
)
except NotFoundError as e:
raise HTTPException(status_code=404, detail=e.args[0])
except ForbiddenError as e:
raise HTTPException(status_code=403, detail=e.args[0])
except Exception as e:
logger.error(f"Error deleting entity: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
return entity
return delete_entity
def make_get_by_id_api_endpoint(entity_descriptor: EntityDescriptor):
async def get_entity_by_id(
db_session: Annotated[AsyncSession, Depends(get_db)],
request: Request,
id: int = Path(..., description="ID of the entity to get"),
current_user=Depends(get_current_user),
):
try:
entity = await entity_descriptor.type_.get(
session=db_session,
id=id,
)
except NotFoundError as e:
raise HTTPException(status_code=404, detail=e.args[0])
except ForbiddenError as e:
raise HTTPException(status_code=403, detail=e.args[0])
except Exception as e:
logger.error(f"Error getting entity: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
return entity
return get_entity_by_id
def make_process_api_endpoint(process_descriptor: ProcessDescriptor):
async def run_process(
db_session: Annotated[AsyncSession, Depends(get_db)],
current_user: Annotated[UserBase, Depends(get_current_user)],
request: Request,
obj_in: process_descriptor.input_schema = Body(...),
):
for role in current_user.roles:
if role in process_descriptor.roles:
break
else:
raise HTTPException(status_code=403, detail="Forbidden")
run_func = process_descriptor.process_class.run
bot_context = BotContext(
db_session=db_session,
app=current_user.bot_metadata.app,
app_state=current_user.bot_metadata.app_state,
user=current_user,
)
try:
if iscoroutinefunction(run_func):
result = await run_func(
context=bot_context,
parameters=obj_in,
)
else:
result = run_func(
context=bot_context,
parameters=obj_in,
)
return result
except Exception as e:
logger.error(f"Error running process: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
return run_process
for entity_descriptor in self.bot_metadata.entity_descriptors.values():
if issubclass(entity_descriptor.type_, UserBase):
self.router.add_api_route(
path=f"/models/{entity_descriptor.name}/me",
methods=["GET"],
endpoint=get_me,
response_model=entity_descriptor.crud.schema,
summary="Get current user",
description="Get current user",
tags=["models"],
)
if CrudCommand.LIST in entity_descriptor.crud.commands:
self.router.add_api_route(
path=f"/models/{entity_descriptor.name}",
endpoint=list_entity_items,
methods=["GET"],
response_model=list[
Union[entity_descriptor.crud.schema, ListSchema]
],
summary=f"List {entity_descriptor.name}",
description=f"List {entity_descriptor.name}",
tags=["models"],
)
if CrudCommand.CREATE in entity_descriptor.crud.commands:
self.router.add_api_route(
path=f"/models/{entity_descriptor.name}",
methods=["POST"],
endpoint=make_create_api_endpoint(entity_descriptor),
response_model=entity_descriptor.crud.schema,
summary=f"Create {entity_descriptor.name}",
description=f"Create {entity_descriptor.name}",
tags=["models"],
)
if CrudCommand.GET_BY_ID in entity_descriptor.crud.commands:
self.router.add_api_route(
path=f"/models/{entity_descriptor.name}/{{id}}",
methods=["GET"],
endpoint=make_get_by_id_api_endpoint(entity_descriptor),
response_model=entity_descriptor.crud.schema,
summary=f"Get {entity_descriptor.name} by id",
description=f"Get {entity_descriptor.name} by id",
tags=["models"],
)
if CrudCommand.UPDATE in entity_descriptor.crud.commands:
self.router.add_api_route(
path=f"/models/{entity_descriptor.name}/{{id}}",
methods=["PATCH"],
endpoint=make_update_api_endpoint(entity_descriptor),
response_model=Union[entity_descriptor.crud.schema, ListSchema],
summary=f"Update {entity_descriptor.name}",
description=f"Update {entity_descriptor.name}",
tags=["models"],
)
if CrudCommand.DELETE in entity_descriptor.crud.commands:
self.router.add_api_route(
path=f"/models/{entity_descriptor.name}/{{id}}",
methods=["DELETE"],
endpoint=make_delete_api_endpoint(entity_descriptor),
response_model=Union[entity_descriptor.crud.schema, ListSchema],
summary=f"Delete {entity_descriptor.name}",
description=f"Delete {entity_descriptor.name}",
tags=["models"],
)
for process_descriptor in self.bot_metadata.process_descriptors.values():
self.router.add_api_route(
path=f"/processes/{process_descriptor.name}",
methods=["POST"],
endpoint=make_process_api_endpoint(process_descriptor),
response_model=process_descriptor.output_schema,
summary=f"Run {process_descriptor.name}",
description=process_descriptor.description,
tags=["processes"],
)
def register_routers(self, *routers: Router): def register_routers(self, *routers: Router):
# Register additional routers and their commands with the application.
# This allows modular extension of bot command sets and menu trees.
for router in routers: for router in routers:
for command_name, command in router._commands.items(): for command_name, command in router._commands.items():
self.bot_commands[command_name] = command self.bot_commands[command_name] = command
async def bot_init(self): async def bot_init(self):
# --- Set up bot commands for all locales ---
# This method collects all commands (with captions) that should be shown in the Telegram UI.
# It supports localization by grouping commands by locale.
commands_captions = dict[str, list[tuple[str, str]]]() commands_captions = dict[str, list[tuple[str, str]]]()
for command_name, command in self.bot_commands.items(): for command_name, command in self.bot_commands.items():
if command.show_in_bot_commands: if command.show_in_bot_commands:
if isinstance(command.caption, str) or command.caption is None: if isinstance(command.caption, str) or command.caption is None:
# Default locale (or no caption provided)
if "default" not in commands_captions: if "default" not in commands_captions:
commands_captions["default"] = [] commands_captions["default"] = []
commands_captions["default"].append( commands_captions["default"].append(
(command_name, command.caption or command_name) (command_name, command.caption or command_name)
) )
else: else:
# Localized captions per locale
for locale, description in command.caption.items(): for locale, description in command.caption.items():
locale = "default" if locale == "en" else locale locale = "default" if locale == "en" else locale
if locale not in commands_captions: if locale not in commands_captions:
commands_captions[locale] = [] commands_captions[locale] = []
commands_captions[locale].append((command_name, description)) commands_captions[locale].append((command_name, description))
# Register commands with Telegram for each locale
for locale, commands in commands_captions.items(): for locale, commands in commands_captions.items():
await self.bot.set_my_commands( await self.bot.set_my_commands(
[ [
@@ -177,6 +629,8 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI):
) )
async def set_webhook(self): async def set_webhook(self):
# --- Set Telegram webhook for receiving updates ---
# This is called on startup if lifespan_set_webhook is True.
await self.bot.set_webhook( await self.bot.set_webhook(
url=f"{self.config.TELEGRAM_WEBHOOK_URL}/telegram/webhook", url=f"{self.config.TELEGRAM_WEBHOOK_URL}/telegram/webhook",
drop_pending_updates=True, drop_pending_updates=True,
@@ -194,6 +648,8 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI):
form_name: str = None, form_name: str = None,
form_params: list[Any] = None, form_params: list[Any] = None,
): ):
# --- Show a form for a specific entity instance to a user ---
# Used for interactive entity editing or viewing in the Telegram bot UI.
f_params = [] f_params = []
if form_name: if form_name:
@@ -202,9 +658,11 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI):
if form_params: if form_params:
f_params.extend([str(p) for p in form_params]) f_params.extend([str(p) for p in form_params])
# Allow passing entity as class or string name
if isinstance(entity, type): if isinstance(entity, type):
entity = entity.bot_entity_descriptor.name entity = entity.bot_entity_descriptor.name
# Prepare callback data for navigation stack
callback_data = ContextData( callback_data = ContextData(
command=CallbackCommand.ENTITY_ITEM, command=CallbackCommand.ENTITY_ITEM,
entity_name=entity, entity_name=entity,
@@ -212,6 +670,7 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI):
form_params="&".join(f_params), form_params="&".join(f_params),
) )
# Get FSM state context for the user
state = self.dp.fsm.get_context(bot=self.bot, chat_id=user_id, user_id=user_id) state = self.dp.fsm.get_context(bot=self.bot, chat_id=user_id, user_id=user_id)
state_data = await state.get_data() state_data = await state.get_data()
clear_state(state_data=state_data) clear_state(state_data=state_data)
@@ -220,11 +679,13 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI):
) )
await state.set_data(state_data) await state.set_data(state_data)
# Fetch user object for locale and permissions
user = await self.user_class.get( user = await self.user_class.get(
session=db_session, session=db_session,
id=user_id, id=user_id,
) )
# Use i18n context for the user's language
with self.i18n.context(), self.i18n.use_locale(user.lang.value): with self.i18n.context(), self.i18n.use_locale(user.lang.value):
await entity_item( await entity_item(
query=None, query=None,
@@ -246,6 +707,8 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI):
user_id: int, user_id: int,
db_session: AsyncSession, db_session: AsyncSession,
): ):
# --- Execute a user command in the Telegram bot context ---
# This is used for programmatically triggering bot commands (e.g., from API or callback).
state = self.dp.fsm.get_context(bot=self.bot, chat_id=user_id, user_id=user_id) state = self.dp.fsm.get_context(bot=self.bot, chat_id=user_id, user_id=user_id)
state_data = await state.get_data() state_data = await state.get_data()
callback_data = ContextData( callback_data = ContextData(
@@ -255,22 +718,27 @@ class QBotApp(Generic[UserType, ConfigType], FastAPI):
command_name = command.split("&")[0] command_name = command.split("&")[0]
cmd = self.bot_commands.get(command_name) cmd = self.bot_commands.get(command_name)
# Fetch user object for permissions and locale
user = await self.user_class.get( user = await self.user_class.get(
session=db_session, session=db_session,
id=user_id, id=user_id,
) )
if cmd is None: if cmd is None:
# Command not found (could be a custom or unregistered command)
return return
# Optionally clear navigation stack if command requires it
if cmd.clear_navigation: if cmd.clear_navigation:
state_data.pop("navigation_stack", None) state_data.pop("navigation_stack", None)
state_data.pop("navigation_context", None) state_data.pop("navigation_context", None)
# Optionally register navigation context for this command
if cmd.register_navigation: if cmd.register_navigation:
clear_state(state_data=state_data) clear_state(state_data=state_data)
save_navigation_context(callback_data=callback_data, state_data=state_data) save_navigation_context(callback_data=callback_data, state_data=state_data)
# Use i18n context for the user's language
with self.i18n.context(), self.i18n.use_locale(user.lang.value): with self.i18n.context(), self.i18n.use_locale(user.lang.value):
await command_handler( await command_handler(
message=None, message=None,

View File

@@ -23,9 +23,15 @@ class AuthMiddleware(BaseMiddleware):
event: TelegramObject, event: TelegramObject,
data: Dict[str, Any], data: Dict[str, Any],
) -> Any: ) -> Any:
user = await self.user_class.get( if event.business_connection_id:
id=event.from_user.id, session=data["db_session"] business_connection = await event.bot.get_business_connection(
) event.business_connection_id
)
user_id = business_connection.user.id
else:
user_id = event.from_user.id
user = await self.user_class.get(id=user_id, session=data["db_session"])
if user and user.is_active: if user and user.is_active:
data["user"] = user data["user"] = user

View File

@@ -8,11 +8,11 @@ from ..db import async_session
class EntityPermission(BotEnum): class EntityPermission(BotEnum):
LIST = EnumMember("list") LIST_RLS = EnumMember("list_rls")
READ = EnumMember("read") READ_RLS = EnumMember("read_rls")
CREATE = EnumMember("create") CREATE_RLS = EnumMember("create_rls")
UPDATE = EnumMember("update") UPDATE_RLS = EnumMember("update_rls")
DELETE = EnumMember("delete") DELETE_RLS = EnumMember("delete_rls")
LIST_ALL = EnumMember("list_all") LIST_ALL = EnumMember("list_all")
READ_ALL = EnumMember("read_all") READ_ALL = EnumMember("read_all")
CREATE_ALL = EnumMember("create_all") CREATE_ALL = EnumMember("create_all")

View File

@@ -0,0 +1,46 @@
"""
BotEntity module provides a metaclass and base class for creating database entities
with enhanced functionality for bot operations, including field descriptors,
filtering, and ownership management.
"""
from typing import (
TYPE_CHECKING,
dataclass_transform,
)
from pydantic import BaseModel
from pydantic._internal._model_construction import ModelMetaclass
from sqlmodel import Field
from sqlmodel.main import FieldInfo
from .descriptors import EntityField, FieldDescriptor
if TYPE_CHECKING:
pass
@dataclass_transform(
kw_only_default=True,
field_specifiers=(Field, FieldInfo, EntityField, FieldDescriptor),
)
class AnnotatedSchemaMetaclass(ModelMetaclass):
"""
Metaclass for annotated schemas.
"""
def __new__(mcs, name, bases, namespace, **kwargs):
"""
Create a new annotated schema.
"""
# --- Create the class using parent metaclass ---
type_ = super().__new__(mcs, name, bases, namespace, **kwargs)
return type_
class AnnotatedSchema(BaseModel, metaclass=AnnotatedSchemaMetaclass):
"""
Base class for annotated schemas.
"""

View File

@@ -1,28 +1,47 @@
"""
BotEntity module provides a metaclass and base class for creating database entities
with enhanced functionality for bot operations, including field descriptors,
filtering, and ownership management.
"""
from types import NoneType, UnionType from types import NoneType, UnionType
from typing import ( from typing import (
Any, Any,
ClassVar, ClassVar,
ForwardRef, ForwardRef,
Optional, Optional,
Self,
Union, Union,
get_args, get_args,
get_origin, get_origin,
TYPE_CHECKING, TYPE_CHECKING,
dataclass_transform, dataclass_transform,
Self,
) )
from pydantic import BaseModel from pydantic import BaseModel
from pydantic.fields import _Unset
from pydantic_core import PydanticUndefined from pydantic_core import PydanticUndefined
from sqlmodel import SQLModel, BigInteger, Field, select, func, column, col from sqlmodel import SQLModel, BigInteger, Field, select, func
from sqlmodel.main import FieldInfo from sqlmodel.main import FieldInfo
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
from sqlmodel.main import SQLModelMetaclass, RelationshipInfo from sqlmodel.main import SQLModelMetaclass, RelationshipInfo
from .descriptors import EntityDescriptor, EntityField, FieldDescriptor, Filter from .descriptors import (
from .entity_metadata import EntityMetadata EntityDescriptor,
EntityField,
FieldDescriptor,
Filter,
FilterExpression,
)
from .bot_metadata import BotMetadata
from .crud_service import CrudService
from . import session_dep from . import session_dep
from .utils import (
_static_filter_condition,
_build_filter_condition,
_filter_condition,
_apply_rls_filters,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from .user import UserBase from .user import UserBase
@@ -33,11 +52,34 @@ if TYPE_CHECKING:
field_specifiers=(Field, FieldInfo, EntityField, FieldDescriptor), field_specifiers=(Field, FieldInfo, EntityField, FieldDescriptor),
) )
class BotEntityMetaclass(SQLModelMetaclass): class BotEntityMetaclass(SQLModelMetaclass):
"""
Metaclass for BotEntity that handles field processing, descriptor creation,
and type resolution for bot-specific database entities.
This metaclass extends SQLModelMetaclass to provide additional functionality
for bot operations including field descriptors, type annotations, and
entity metadata management.
"""
# Store future references for forward-declared types
_future_references = {} _future_references = {}
def __new__(mcs, name, bases, namespace, **kwargs): def __new__(mcs, name, bases, namespace, **kwargs):
"""
Create a new class with processed field descriptors and metadata.
Args:
name: Name of the class being created
bases: Base classes
namespace: Class namespace containing attributes and annotations
**kwargs: Additional keyword arguments passed to the metaclass
Returns:
The created class with processed field descriptors and metadata
"""
bot_fields_descriptors = {} bot_fields_descriptors = {}
# --- Inherit field descriptors from parent classes (if any) ---
if bases: if bases:
bot_entity_descriptor = bases[0].__dict__.get("bot_entity_descriptor") bot_entity_descriptor = bases[0].__dict__.get("bot_entity_descriptor")
bot_fields_descriptors = ( bot_fields_descriptors = (
@@ -49,52 +91,67 @@ class BotEntityMetaclass(SQLModelMetaclass):
else {} else {}
) )
# --- Process field annotations to create field descriptors ---
if "__annotations__" in namespace: if "__annotations__" in namespace:
for annotation in namespace["__annotations__"]: for annotation in namespace["__annotations__"]:
if annotation in ["bot_entity_descriptor", "entity_metadata"]: # Skip special attributes
if annotation in ["bot_entity_descriptor", "bot_metadata"]:
continue continue
attribute_value = namespace.get(annotation) attribute_value = namespace.get(annotation, PydanticUndefined)
# Skip relationship fields (handled by SQLModel)
if isinstance(attribute_value, RelationshipInfo): if isinstance(attribute_value, RelationshipInfo):
continue continue
descriptor_kwargs = {} descriptor_kwargs = {}
descriptor_name = annotation descriptor_name = annotation
if attribute_value: # --- Process EntityField attributes to extract SQLModel field descriptors ---
if attribute_value is not PydanticUndefined:
if isinstance(attribute_value, EntityField): if isinstance(attribute_value, EntityField):
descriptor_kwargs = attribute_value.__dict__.copy() descriptor_kwargs = attribute_value.__dict__.copy()
# Extract SQLModel field descriptor if present
sm_descriptor = descriptor_kwargs.pop("sm_descriptor", None) # type: FieldInfo sm_descriptor = descriptor_kwargs.pop("sm_descriptor", None) # type: FieldInfo
if sm_descriptor: if sm_descriptor:
# Transfer default values from EntityField to SQLModel descriptor
if ( if (
attribute_value.default is not None attribute_value.default is not PydanticUndefined
and sm_descriptor.default is PydanticUndefined and sm_descriptor.default is PydanticUndefined
): ):
sm_descriptor.default = attribute_value.default sm_descriptor.default = attribute_value.default
if ( if (
attribute_value.default_factory is not None attribute_value.default_factory is not None
and sm_descriptor.default_factory is PydanticUndefined and sm_descriptor.default_factory is None
): ):
sm_descriptor.default_factory = ( sm_descriptor.default_factory = (
attribute_value.default_factory attribute_value.default_factory
) )
if attribute_value.description is not PydanticUndefined:
sm_descriptor.description = attribute_value.description
else: else:
# Create new SQLModel field descriptor if none exists
if ( if (
attribute_value.default is not None attribute_value.default is not None
or attribute_value.default_factory is not None or attribute_value.default_factory is not None
): ):
sm_descriptor = Field() sm_descriptor = Field()
if attribute_value.default is not None: if attribute_value.default is not PydanticUndefined:
sm_descriptor.default = attribute_value.default sm_descriptor.default = attribute_value.default
if attribute_value.default_factory is not None: if attribute_value.default_factory is not None:
sm_descriptor.default_factory = ( sm_descriptor.default_factory = (
attribute_value.default_factory attribute_value.default_factory
) )
if attribute_value.description is not PydanticUndefined:
sm_descriptor.description = (
attribute_value.description
)
# Clean up internal attributes
descriptor_kwargs.pop("__orig_class__", None) descriptor_kwargs.pop("__orig_class__", None)
# Replace EntityField with SQLModel field descriptor in namespace
if sm_descriptor: if sm_descriptor:
namespace[annotation] = sm_descriptor namespace[annotation] = sm_descriptor
else: else:
@@ -102,10 +159,28 @@ class BotEntityMetaclass(SQLModelMetaclass):
descriptor_name = descriptor_kwargs.pop("name") or annotation descriptor_name = descriptor_kwargs.pop("name") or annotation
elif isinstance(attribute_value, FieldInfo):
if attribute_value.default is not PydanticUndefined:
descriptor_kwargs["default"] = attribute_value.default
if attribute_value.default_factory is not None:
descriptor_kwargs["default_factory"] = (
attribute_value.default_factory
)
if attribute_value.description is not _Unset:
descriptor_kwargs["description"] = (
attribute_value.description
)
elif isinstance(attribute_value, RelationshipInfo):
pass
else:
descriptor_kwargs["default"] = attribute_value
# --- Get the type annotation for the field ---
type_ = namespace["__annotations__"][annotation] type_ = namespace["__annotations__"][annotation]
type_origin = get_origin(type_) # --- Create field descriptor with basic information ---
field_descriptor = FieldDescriptor( field_descriptor = FieldDescriptor(
name=descriptor_name, name=descriptor_name,
field_name=annotation, field_name=annotation,
@@ -114,12 +189,18 @@ class BotEntityMetaclass(SQLModelMetaclass):
**descriptor_kwargs, **descriptor_kwargs,
) )
# --- Process type annotations to determine if field is list or optional ---
type_origin = get_origin(type_)
is_list = False is_list = False
is_optional = False is_optional = False
# Handle list types (e.g., List[str])
if type_origin is list: if type_origin is list:
field_descriptor.is_list = is_list = True field_descriptor.is_list = is_list = True
field_descriptor.type_base = type_ = get_args(type_)[0] field_descriptor.type_base = type_ = get_args(type_)[0]
# Handle Union types for optional fields (e.g., Optional[str])
if type_origin is Union: if type_origin is Union:
args = get_args(type_) args = get_args(type_)
if isinstance(args[0], ForwardRef): if isinstance(args[0], ForwardRef):
@@ -129,16 +210,17 @@ class BotEntityMetaclass(SQLModelMetaclass):
field_descriptor.is_optional = is_optional = True field_descriptor.is_optional = is_optional = True
field_descriptor.type_base = type_ = args[0] field_descriptor.type_base = type_ = args[0]
# Handle Python 3.10+ UnionType (e.g., str | None)
if type_origin is UnionType and get_args(type_)[1] is NoneType: if type_origin is UnionType and get_args(type_)[1] is NoneType:
field_descriptor.is_optional = is_optional = True field_descriptor.is_optional = is_optional = True
field_descriptor.type_base = type_ = get_args(type_)[0] field_descriptor.type_base = type_ = get_args(type_)[0]
# --- Handle string type references (forward references to other entities) ---
if isinstance(type_, str): if isinstance(type_, str):
type_not_found = True type_not_found = True
for ( for entity_descriptor in BotMetadata().entity_descriptors.values():
entity_descriptor
) in EntityMetadata().entity_descriptors.values():
if type_ == entity_descriptor.class_name: if type_ == entity_descriptor.class_name:
# Resolve the type to the actual entity class
field_descriptor.type_base = entity_descriptor.type_ field_descriptor.type_base = entity_descriptor.type_
field_descriptor.type_ = ( field_descriptor.type_ = (
list[entity_descriptor.type_] list[entity_descriptor.type_]
@@ -155,6 +237,8 @@ class BotEntityMetaclass(SQLModelMetaclass):
) )
type_not_found = False type_not_found = False
break break
# If type not found, store for future resolution
if type_not_found: if type_not_found:
if type_ in mcs._future_references: if type_ in mcs._future_references:
mcs._future_references[type_].append(field_descriptor) mcs._future_references[type_].append(field_descriptor)
@@ -163,15 +247,17 @@ class BotEntityMetaclass(SQLModelMetaclass):
bot_fields_descriptors[descriptor_name] = field_descriptor bot_fields_descriptors[descriptor_name] = field_descriptor
# --- Process entity descriptor configuration ---
descriptor_name = name descriptor_name = name
if "bot_entity_descriptor" in namespace: if "bot_entity_descriptor" in namespace:
# Extract and process custom entity descriptor
entity_descriptor = namespace.pop("bot_entity_descriptor") entity_descriptor = namespace.pop("bot_entity_descriptor")
descriptor_kwargs: dict = entity_descriptor.__dict__.copy() descriptor_kwargs: dict = entity_descriptor.__dict__.copy()
descriptor_name = descriptor_kwargs.pop("name", None) descriptor_name = descriptor_kwargs.pop("name", None)
descriptor_kwargs.pop("__orig_class__", None) descriptor_kwargs.pop("__orig_class__", None)
descriptor_name = descriptor_name or name.lower() descriptor_name = descriptor_name or name.lower()
namespace["bot_entity_descriptor"] = EntityDescriptor( entity_descriptor = namespace["bot_entity_descriptor"] = EntityDescriptor(
name=descriptor_name, name=descriptor_name,
class_name=name, class_name=name,
type_=name, type_=name,
@@ -179,39 +265,41 @@ class BotEntityMetaclass(SQLModelMetaclass):
**descriptor_kwargs, **descriptor_kwargs,
) )
else: else:
# Create default entity descriptor
descriptor_name = name.lower() descriptor_name = name.lower()
namespace["bot_entity_descriptor"] = EntityDescriptor( entity_descriptor = namespace["bot_entity_descriptor"] = EntityDescriptor(
name=descriptor_name, name=descriptor_name,
class_name=name, class_name=name,
type_=name, type_=name,
fields_descriptors=bot_fields_descriptors, fields_descriptors=bot_fields_descriptors,
) )
# --- Link field descriptors to their entity descriptor ---
for field_descriptor in bot_fields_descriptors.values(): for field_descriptor in bot_fields_descriptors.values():
field_descriptor.entity_descriptor = namespace["bot_entity_descriptor"] field_descriptor.entity_descriptor = entity_descriptor
# --- Configure table settings (set to True by default) ---
if "table" not in kwargs: if "table" not in kwargs:
kwargs["table"] = True kwargs["table"] = True
# --- If table is set to True, register entity in global metadata ---
if kwargs["table"]: if kwargs["table"]:
entity_metadata = EntityMetadata() # Register entity in global metadata
entity_metadata.entity_descriptors[descriptor_name] = namespace[ entity_metadata = BotMetadata()
"bot_entity_descriptor" entity_metadata.entity_descriptors[descriptor_name] = entity_descriptor
]
# Add entity_metadata to class annotations
if "__annotations__" in namespace: if "__annotations__" in namespace:
namespace["__annotations__"]["entity_metadata"] = ClassVar[ namespace["__annotations__"]["bot_metadata"] = ClassVar[BotMetadata]
EntityMetadata
]
else: else:
namespace["__annotations__"] = { namespace["__annotations__"] = {"bot_metadata": ClassVar[BotMetadata]}
"entity_metadata": ClassVar[EntityMetadata]
}
namespace["entity_metadata"] = entity_metadata namespace["bot_metadata"] = entity_metadata
# --- Create the class using parent metaclass ---
type_ = super().__new__(mcs, name, bases, namespace, **kwargs) type_ = super().__new__(mcs, name, bases, namespace, **kwargs)
# --- Resolve future references now that the class exists ---
if name in mcs._future_references: if name in mcs._future_references:
for field_descriptor in mcs._future_references[name]: for field_descriptor in mcs._future_references[name]:
type_origin = get_origin(field_descriptor.type_) type_origin = get_origin(field_descriptor.type_)
@@ -230,78 +318,63 @@ class BotEntityMetaclass(SQLModelMetaclass):
) )
) )
setattr(namespace["bot_entity_descriptor"], "type_", type_) # --- Set the resolved type in the entity descriptor ---
entity_descriptor.type_ = type_
# setattr(entity_descriptor, "type_", type_)
if kwargs["table"] and entity_descriptor.crud is None:
entity_descriptor.crud = CrudService(entity_descriptor)
return type_ return type_
class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel]( class BotEntity(SQLModel, metaclass=BotEntityMetaclass, table=False):
SQLModel, metaclass=BotEntityMetaclass, table=False """
): Base class for bot entities that provides CRUD operations, filtering,
bot_entity_descriptor: ClassVar[EntityDescriptor] and Row Level Security (RLS) capabilities.
entity_metadata: ClassVar[EntityMetadata]
This class extends SQLModel and uses BotEntityMetaclass to provide
enhanced functionality for bot operations including:
- Field descriptors for UI generation
- Advanced filtering and search capabilities
- Row Level Security (RLS) access control
- Standardized CRUD operations
"""
# Class variables set by the metaclass
bot_entity_descriptor: ClassVar[EntityDescriptor]
bot_metadata: ClassVar[BotMetadata]
# Standard ID field for all entities
id: int = EntityField( id: int = EntityField(
sm_descriptor=Field(primary_key=True, sa_type=BigInteger), is_visible=False sm_descriptor=Field(primary_key=True, sa_type=BigInteger),
is_visible=False,
default=None,
) )
@classmethod @classmethod
@session_dep @session_dep
async def get(cls, *, session: AsyncSession | None = None, id: int): async def get(
return await session.get(cls, id)
@classmethod
def _static_filter_condition(
cls, select_statement: SelectOfScalar[Self], static_filter: list[Filter]
):
for sfilt in static_filter:
column = getattr(cls, sfilt.field_name)
if sfilt.operator == "==":
condition = column.__eq__(sfilt.value)
elif sfilt.operator == "!=":
condition = column.__ne__(sfilt.value)
elif sfilt.operator == "<":
condition = column.__lt__(sfilt.value)
elif sfilt.operator == "<=":
condition = column.__le__(sfilt.value)
elif sfilt.operator == ">":
condition = column.__gt__(sfilt.value)
elif sfilt.operator == ">=":
condition = column.__ge__(sfilt.value)
elif sfilt.operator == "ilike":
condition = col(column).ilike(f"%{sfilt.value}%")
elif sfilt.operator == "like":
condition = col(column).like(f"%{sfilt.value}%")
elif sfilt.operator == "in":
condition = col(column).in_(sfilt.value)
elif sfilt.operator == "not in":
condition = col(column).notin_(sfilt.value)
elif sfilt.operator == "is none":
condition = col(column).is_(None)
elif sfilt.operator == "is not none":
condition = col(column).isnot(None)
elif sfilt.operator == "contains":
condition = sfilt.value == col(column).any_()
else:
condition = None
if condition is not None:
select_statement = select_statement.where(condition)
return select_statement
@classmethod
def _filter_condition(
cls, cls,
select_statement: SelectOfScalar[Self], *,
filter: str, session: AsyncSession | None = None,
filter_fields: list[str], id: int,
): user: "UserBase | None" = None,
condition = None ) -> Self:
for field in filter_fields: """
if condition is not None: Retrieve a single entity by ID.
condition = condition | (column(field).ilike(f"%{filter}%"))
else: Args:
condition = column(field).ilike(f"%{filter}%") session: Database session (injected by session_dep)
return select_statement.where(condition) id: Entity ID to retrieve
Returns:
The entity instance or None if not found
"""
select_statement = select(cls).where(cls.id == id)
if user:
select_statement = await _apply_rls_filters(cls, select_statement, user)
return await session.scalar(select_statement)
@classmethod @classmethod
@session_dep @session_dep
@@ -309,28 +382,49 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
cls, cls,
*, *,
session: AsyncSession | None = None, session: AsyncSession | None = None,
static_filter: list[Filter] | Any = None, user: "UserBase",
static_filter: Filter | FilterExpression | Any = None,
filter: str = None, filter: str = None,
filter_fields: list[str] = None, filter_fields: list[str] = None,
ext_filter: Any = None, ext_filter: Any = None,
user: "UserBase" = None,
) -> int: ) -> int:
"""
Get the count of entities matching the specified criteria.
Args:
session: Database session (injected by session_dep)
static_filter: List of static filter conditions
filter: Text search filter
filter_fields: Fields to search in for text filter
ext_filter: Additional custom filter conditions
user: User for RLS-based filtering
Returns:
Count of matching entities
"""
# --- Build select statement for counting entities ---
select_statement = select(func.count()).select_from(cls) select_statement = select(func.count()).select_from(cls)
# --- Apply various filter conditions ---
if static_filter: if static_filter:
if isinstance(static_filter, list): if isinstance(static_filter, list):
select_statement = cls._static_filter_condition( select_statement = _static_filter_condition(
select_statement, static_filter select_statement, static_filter
) )
else: else:
select_statement = select_statement.where(static_filter) # Handle single Filter or FilterExpression object
condition = _build_filter_condition(cls, static_filter)
if condition is not None:
select_statement = select_statement.where(condition)
if filter and filter_fields: if filter and filter_fields:
select_statement = cls._filter_condition( select_statement = _filter_condition(
select_statement, filter, filter_fields select_statement, filter, filter_fields
) )
if ext_filter: if ext_filter:
select_statement = select_statement.where(ext_filter) select_statement = select_statement.where(ext_filter)
if user:
select_statement = cls._ownership_condition(select_statement, user) select_statement = await _apply_rls_filters(cls, select_statement, user)
return await session.scalar(select_statement) return await session.scalar(select_statement)
@classmethod @classmethod
@@ -339,56 +433,60 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
cls, cls,
*, *,
session: AsyncSession | None = None, session: AsyncSession | None = None,
user: "UserBase | None" = None,
order_by=None, order_by=None,
static_filter: list[Filter] | Any = None, static_filter: Filter | FilterExpression | Any = None,
filter: str = None, filter: str = None,
filter_fields: list[str] = None, filter_fields: list[str] = None,
ext_filter: Any = None, ext_filter: Any = None,
user: "UserBase" = None,
skip: int = 0, skip: int = 0,
limit: int = None, limit: int = None,
): ) -> list[Self]:
"""
Retrieve multiple entities with filtering, pagination, and ordering.
Args:
session: Database session (injected by session_dep)
order_by: Ordering criteria
static_filter: List of static filter conditions
filter: Text search filter
filter_fields: Fields to search in for text filter
ext_filter: Additional custom filter conditions
user: User for RLS-based filtering
skip: Number of records to skip (for pagination)
limit: Maximum number of records to return
Returns:
List of matching entity instances
"""
# --- Build select statement for entity retrieval ---
select_statement = select(cls).offset(skip) select_statement = select(cls).offset(skip)
if limit: if limit:
select_statement = select_statement.limit(limit) select_statement = select_statement.limit(limit)
# --- Apply various filter conditions ---
if static_filter is not None: if static_filter is not None:
if isinstance(static_filter, list): if isinstance(static_filter, list):
select_statement = cls._static_filter_condition( select_statement = _static_filter_condition(
select_statement, static_filter cls, select_statement, static_filter
) )
else: else:
select_statement = select_statement.where(static_filter) # Handle single Filter or FilterExpression object
condition = _build_filter_condition(cls, static_filter)
if condition is not None:
select_statement = select_statement.where(condition)
if filter and filter_fields: if filter and filter_fields:
select_statement = cls._filter_condition( select_statement = _filter_condition(
select_statement, filter, filter_fields cls, select_statement, filter, filter_fields
) )
if ext_filter is not None: if ext_filter is not None:
select_statement = select_statement.where(ext_filter) select_statement = select_statement.where(ext_filter)
if user: if user:
select_statement = cls._ownership_condition(select_statement, user) select_statement = await _apply_rls_filters(cls, select_statement, user)
if order_by is not None: if order_by:
select_statement = select_statement.order_by(order_by) select_statement = select_statement.order_by(order_by)
return (await session.exec(select_statement)).all()
@classmethod return (await session.exec(select_statement)).all()
def _ownership_condition(
cls, select_statement: SelectOfScalar[Self], user: "UserBase"
):
if cls.bot_entity_descriptor.ownership_fields:
condition = None
for role in user.roles:
if role in cls.bot_entity_descriptor.ownership_fields:
owner_col = column(cls.bot_entity_descriptor.ownership_fields[role])
if condition is not None:
condition = condition | (owner_col == user.id)
else:
condition = owner_col == user.id
else:
condition = None
break
if condition is not None:
return select_statement.where(condition)
return select_statement
@classmethod @classmethod
@session_dep @session_dep
@@ -396,9 +494,21 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
cls, cls,
*, *,
session: AsyncSession | None = None, session: AsyncSession | None = None,
obj_in: CreateSchemaType, obj_in: BaseModel,
commit: bool = False, commit: bool = False,
): ) -> Self:
"""
Create a new entity instance.
Args:
session: Database session (injected by session_dep)
obj_in: Data for creating the entity (can be Pydantic model or entity instance)
commit: Whether to commit the transaction immediately
Returns:
The created entity instance
"""
# --- Accept both entity instances and Pydantic models ---
if isinstance(obj_in, cls): if isinstance(obj_in, cls):
obj = obj_in obj = obj_in
else: else:
@@ -415,13 +525,26 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
*, *,
session: AsyncSession | None = None, session: AsyncSession | None = None,
id: int, id: int,
obj_in: UpdateSchemaType, obj_in: BaseModel,
commit: bool = False, commit: bool = False,
): ) -> Self:
"""
Update an existing entity instance.
Args:
session: Database session (injected by session_dep)
id: ID of the entity to update
obj_in: Data for updating the entity
commit: Whether to commit the transaction immediately
Returns:
The updated entity instance or None if not found
"""
obj = await session.get(cls, id) obj = await session.get(cls, id)
if obj: if obj:
obj_data = obj.model_dump() obj_data = obj.model_dump()
update_data = obj_in.model_dump(exclude_unset=True) update_data = obj_in.model_dump(exclude_unset=True)
# Only update fields present in the update data
for field in obj_data: for field in obj_data:
if field in update_data: if field in update_data:
setattr(obj, field, update_data[field]) setattr(obj, field, update_data[field])
@@ -435,7 +558,18 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
@session_dep @session_dep
async def remove( async def remove(
cls, *, session: AsyncSession | None = None, id: int, commit: bool = False cls, *, session: AsyncSession | None = None, id: int, commit: bool = False
): ) -> Self:
"""
Delete an entity instance.
Args:
session: Database session (injected by session_dep)
id: ID of the entity to delete
commit: Whether to commit the transaction immediately
Returns:
The deleted entity instance or None if not found
"""
obj = await session.get(cls, id) obj = await session.get(cls, id)
if obj: if obj:
await session.delete(obj) await session.delete(obj)

View File

@@ -1,5 +1,8 @@
from aiogram.utils.i18n import I18n from aiogram.utils.i18n import I18n
from pydantic_core.core_schema import str_schema from pydantic import GetCoreSchemaHandler
# from pydantic_core.core_schema import str_schema
from pydantic_core import core_schema
from sqlalchemy.types import TypeDecorator from sqlalchemy.types import TypeDecorator
from sqlmodel import AutoString from sqlmodel import AutoString
from typing import Any, Self, overload from typing import Any, Self, overload
@@ -59,13 +62,13 @@ class BotEnumMetaclass(type):
class EnumMember(object): class EnumMember(object):
@overload @overload
def __init__(self, value: str) -> "EnumMember": ... def __init__(self, value: str) -> Self: ...
@overload @overload
def __init__(self, value: "EnumMember") -> "EnumMember": ... def __init__(self, value: Self) -> Self: ...
@overload @overload
def __init__(self, value: str, loc_obj: dict[str, str]) -> "EnumMember": ... def __init__(self, value: str, loc_obj: dict[str, str]) -> Self: ...
def __init__( def __init__(
self, self,
@@ -74,7 +77,7 @@ class EnumMember(object):
parent: type = None, parent: type = None,
name: str = None, name: str = None,
casting: bool = True, casting: bool = True,
) -> "EnumMember": ) -> Self:
if not casting: if not casting:
self._parent = parent self._parent = parent
self._name = name self._name = name
@@ -82,9 +85,9 @@ class EnumMember(object):
self.loc_obj = loc_obj self.loc_obj = loc_obj
@overload @overload
def __new__(cls: Self, *args, **kwargs) -> "EnumMember": ... def __new__(cls: Self, *args, **kwargs) -> Self: ...
def __new__(cls, *args, casting: bool = True, **kwargs) -> "EnumMember": def __new__(cls, *args, casting: bool = True, **kwargs) -> Self:
if (cls.__name__ == "EnumMember") or not casting: if (cls.__name__ == "EnumMember") or not casting:
obj = super().__new__(cls) obj = super().__new__(cls)
kwargs["casting"] = False kwargs["casting"] = False
@@ -104,8 +107,8 @@ class EnumMember(object):
else: else:
return args[0] return args[0]
def __get_pydantic_core_schema__(cls, *args, **kwargs): # def __get_pydantic_core_schema__(cls, *args, **kwargs):
return str_schema() # return str_schema()
def __get__(self, instance, owner) -> Self: def __get__(self, instance, owner) -> Self:
return { return {
@@ -159,6 +162,29 @@ class EnumMember(object):
return self.value return self.value
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: type, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_after_validator_function(
function=cls._validate_from_string,
schema=core_schema.str_schema(),
serialization=core_schema.plain_serializer_function_ser_schema(
cls._serialize_to_string, return_schema=core_schema.str_schema()
),
)
@classmethod
def _validate_from_string(cls, value: str) -> Self:
member = cls(value)
if member is None:
raise ValueError(f"Invalid value for {cls.__name__}: {value}")
return member
@classmethod
def _serialize_to_string(cls, value: Self) -> str:
return value.value
class BotEnum(EnumMember, metaclass=BotEnumMetaclass): class BotEnum(EnumMember, metaclass=BotEnumMetaclass):
all_members: dict[str, EnumMember] all_members: dict[str, EnumMember]

View File

@@ -0,0 +1,16 @@
from fastapi.datastructures import State
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from quickbot.main import QBotApp
from .descriptors import EntityDescriptor, ProcessDescriptor
from ._singleton import Singleton
class BotMetadata(metaclass=Singleton):
def __init__(self):
self.entity_descriptors: dict[str, EntityDescriptor] = {}
self.process_descriptors: dict[str, ProcessDescriptor] = {}
self.app: "QBotApp" = None
self.app_state: State = None

View File

@@ -0,0 +1,77 @@
import inspect
from typing_extensions import get_type_hints
from pydantic import BaseModel
from typing import ClassVar
from .descriptors import ProcessDescriptor, BotContext, Process
from .bot_metadata import BotMetadata
class BotProcessMetaclass(type):
def __new__(mcs, name, bases, namespace, **kwargs):
if name == "BotProcess":
namespace.pop("run")
return super().__new__(mcs, name, bases, namespace, **kwargs)
run_attr = namespace.get("run", None)
if run_attr is None or not isinstance(run_attr, staticmethod):
raise TypeError(f"{name}.run must be defined as a @staticmethod")
func = run_attr.__func__
sig = inspect.signature(func)
if list(sig.parameters.keys()) not in [["context", "parameters"], ["context"]]:
raise TypeError(
f"{name}.run must have exactly two arguments: context, parameters or one argument: context"
)
hints = get_type_hints(func, globalns=globals(), localns=namespace)
# Check arguments
ctx_type = hints.get("context")
if ctx_type is None or not issubclass(ctx_type, BotContext):
raise TypeError(f"{name}.run: 'context' must be BotContext")
param_type = hints.get("parameters")
if param_type is not None and not issubclass(param_type, BaseModel):
raise TypeError(f"{name}.run: 'parameters' must be subclass of BaseModel")
return_type = hints.get("return")
# Auto-generation of schemas
descriptor_kwargs = {"process_class": None}
process_descriptor = namespace.pop("bot_process_descriptor", None)
if process_descriptor and isinstance(process_descriptor, Process):
descriptor_kwargs.update(process_descriptor.__dict__)
descriptor_kwargs.update(
name=name,
input_schema=param_type,
output_schema=return_type,
)
descriptor = ProcessDescriptor(**descriptor_kwargs)
namespace["bot_process_descriptor"] = descriptor
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
descriptor.process_class = cls
bot_metadata = BotMetadata()
bot_metadata.process_descriptors[descriptor.name] = descriptor
namespace["bot_metadata"] = bot_metadata
return super().__new__(mcs, name, bases, namespace, **kwargs)
class BotProcess(metaclass=BotProcessMetaclass):
"""
Base class for business logic processes.
"""
bot_process_descriptor: ClassVar[ProcessDescriptor]
bot_metadata: ClassVar[BotMetadata]
@staticmethod
def run(context: BotContext, parameters: BaseModel = None): ...

View File

@@ -0,0 +1,9 @@
from enum import StrEnum
class CrudCommand(StrEnum):
LIST = "list"
GET_BY_ID = "get_by_id"
CREATE = "create"
UPDATE = "update"
DELETE = "delete"

View File

@@ -0,0 +1,374 @@
from sqlmodel import select, col
from sqlmodel.ext.asyncio.session import AsyncSession
from pydantic import BaseModel
from quickbot.model.permissions import get_user_permissions
from quickbot.model.descriptors import EntityDescriptor, BotContext
from quickbot.model.crud_command import CrudCommand
from quickbot.model.utils import (
entity_to_schema,
entity_to_list_schema,
pydantic_model,
)
from quickbot.model.descriptors import EntityPermission
from sqlalchemy.exc import IntegrityError
from asyncpg import ForeignKeyViolationError
from logging import getLogger
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from quickbot.model.user import UserBase
logger = getLogger(__name__)
class NotFoundError(Exception):
pass
class ForbiddenError(Exception):
pass
class CrudService:
def __init__(
self,
entity_descriptor: EntityDescriptor,
commands: list[CrudCommand] = [
CrudCommand.LIST,
CrudCommand.GET_BY_ID,
CrudCommand.CREATE,
CrudCommand.UPDATE,
CrudCommand.DELETE,
],
schema: type[BaseModel] = None,
create_schema: type[BaseModel] = None,
update_schema: type[BaseModel] = None,
):
self.entity_descriptor = entity_descriptor
self.commands = commands
if CrudCommand.CREATE in commands:
self.create_schema = create_schema or pydantic_model(
entity_descriptor, entity_descriptor.type_.__module__, "create"
)
else:
self.create_schema = None
if CrudCommand.UPDATE in commands:
self.update_schema = update_schema or pydantic_model(
entity_descriptor, entity_descriptor.type_.__module__, "update"
)
else:
self.update_schema = None
self.schema = schema or pydantic_model(
entity_descriptor, entity_descriptor.type_.__module__
)
async def list_all(
self, db_session: AsyncSession, user: "UserBase"
) -> list[BaseModel]:
if CrudCommand.LIST not in self.commands:
raise ForbiddenError(
f"List command not allowed for entity {self.entity_descriptor.name}"
)
user_permissions = get_user_permissions(user, self.entity_descriptor)
if (
EntityPermission.READ_ALL in user_permissions
or EntityPermission.READ_RLS in user_permissions
):
ret_list = await self.entity_descriptor.type_.get_multi(
session=db_session,
user=user
if EntityPermission.READ_ALL not in user_permissions
else None,
)
return [entity_to_schema(item) for item in ret_list]
elif (
EntityPermission.LIST_ALL in user_permissions
or EntityPermission.LIST_RLS in user_permissions
):
ret_list = await self.entity_descriptor.type_.get_multi(
session=db_session,
user=user
if EntityPermission.LIST_ALL not in user_permissions
else None,
)
context = BotContext(
db_session=db_session,
app=user.bot_metadata.app,
app_state=user.bot_metadata.app_state,
user=user,
)
return [await entity_to_list_schema(item, context) for item in ret_list]
else:
raise ForbiddenError(
f"User {user.id} does not have permission to read or list entities"
)
async def get_by_id(
self, db_session: AsyncSession, user: "UserBase", id: int
) -> BaseModel:
if CrudCommand.GET_BY_ID not in self.commands:
raise ForbiddenError(
f"Get by id command not allowed for entity {self.entity_descriptor.name}"
)
ret_obj = await self.entity_descriptor.type_.get(session=db_session, id=id)
if ret_obj is None:
raise NotFoundError(f"Entity with id {id} not found")
return entity_to_schema(ret_obj)
async def create(
self, db_session: AsyncSession, user: "UserBase", model: BaseModel
) -> BaseModel:
if CrudCommand.CREATE not in self.commands:
raise ForbiddenError(
f"Create command not allowed for entity {self.entity_descriptor.name}"
)
user_permissions = get_user_permissions(user, self.entity_descriptor)
if EntityPermission.CREATE_ALL in user_permissions:
# TODO: check if entity values are valid
pass
elif EntityPermission.CREATE_RLS in user_permissions:
# TODO: check if RLS fields are valid
# TODO: check if entity values are valid
pass
else:
raise ForbiddenError(
f"User {user.id} does not have permission to create entities"
)
obj_dict = {}
ret_obj_dict = {}
for field_descriptor in self.entity_descriptor.fields_descriptors.values():
# Only process fields present in the input model
field_name = field_descriptor.field_name
if field_name in model.__class__.model_fields:
# Handle list fields that are relations to other BotEntities
if (
field_descriptor.is_list
and isinstance(field_descriptor.type_base, type)
and hasattr(field_descriptor.type_base, "bot_entity_descriptor")
):
items = (
await db_session.exec(
select(field_descriptor.type_base).where(
col(field_descriptor.type_base.id).in_(
getattr(model, field_name)
)
)
)
).all()
obj_dict[field_name] = items
ret_obj_dict[field_name] = [item.id for item in items]
elif isinstance(field_descriptor.type_base, type) and hasattr(
field_descriptor.type_base, "all_members"
):
if field_descriptor.is_list:
obj_dict[field_name] = [
field_descriptor.type_base(item)
for item in getattr(model, field_name)
]
else:
obj_dict[field_name] = field_descriptor.type_base(
getattr(model, field_name)
)
ret_obj_dict[field_name] = getattr(model, field_name)
else:
obj_dict[field_name] = getattr(model, field_name)
ret_obj_dict[field_name] = getattr(model, field_name)
obj = self.entity_descriptor.type_(**obj_dict)
db_session.add(obj)
try:
await db_session.commit()
except IntegrityError as e:
if isinstance(e.orig.__cause__, ForeignKeyViolationError):
raise ValueError(e.orig.__cause__.detail)
raise ValueError("DB Integrity error")
except Exception as e:
logger.error(f"Error creating entity: {e}")
raise e
if "id" not in ret_obj_dict:
ret_obj_dict["id"] = obj.id
return self.schema(**ret_obj_dict)
async def update(
self, db_session: AsyncSession, user: "UserBase", id: int, model: BaseModel
) -> BaseModel:
if CrudCommand.UPDATE not in self.commands:
raise ForbiddenError(
f"Update command not allowed for entity {self.entity_descriptor.name}"
)
user_permissions = get_user_permissions(user, self.entity_descriptor)
if EntityPermission.UPDATE_ALL in user_permissions:
# TODO: check if entity values are valid
pass
elif EntityPermission.UPDATE_RLS in user_permissions:
# TODO: check if RLS fields are valid
# TODO: check if entity values are valid
pass
else:
raise ForbiddenError(
f"User {user.id} does not have permission to update entities"
)
entity = await self.entity_descriptor.type_.get(session=db_session, id=id)
if entity is None:
raise NotFoundError(f"Entity with id {id} not found")
for field_descriptor in self.entity_descriptor.fields_descriptors.values():
field_name = field_descriptor.field_name
if field_name in model.model_fields_set:
model_field_value = getattr(model, field_name)
if (
field_descriptor.is_list
and isinstance(field_descriptor.type_base, type)
and hasattr(field_descriptor.type_base, "bot_entity_descriptor")
):
items = (
await db_session.exec(
select(field_descriptor.type_base).where(
col(field_descriptor.type_base.id).in_(
model_field_value
)
)
)
).all()
setattr(entity, field_name, items)
elif isinstance(field_descriptor.type_base, type) and hasattr(
field_descriptor.type_base, "all_members"
):
if field_descriptor.is_list:
setattr(
entity,
field_name,
[
field_descriptor.type_base(item)
for item in model_field_value
],
)
else:
setattr(
entity,
field_name,
field_descriptor.type_base(model_field_value),
)
else:
setattr(entity, field_name, model_field_value)
try:
await db_session.commit()
except IntegrityError as e:
if isinstance(e.orig.__cause__, ForeignKeyViolationError):
raise ValueError(e.orig.__cause__.detail)
raise ValueError("DB Integrity error")
except Exception as e:
logger.error(f"Error updating entity: {e}")
raise e
return entity_to_schema(entity)
async def delete(
self, db_session: AsyncSession, user: "UserBase", id: int
) -> BaseModel:
if CrudCommand.DELETE not in self.commands:
raise ForbiddenError(
f"Delete command not allowed for entity {self.entity_descriptor.name}"
)
user_permissions = get_user_permissions(user, self.entity_descriptor)
if EntityPermission.DELETE_ALL in user_permissions:
pass
elif EntityPermission.DELETE_RLS in user_permissions:
# TODO: check if RLS fields are valid
pass
else:
raise ForbiddenError(
f"User {user.id} does not have permission to delete entities"
)
try:
entity = await self.entity_descriptor.type_.remove(
session=db_session, id=id, commit=True
)
except IntegrityError as e:
if isinstance(e.orig.__cause__, ForeignKeyViolationError):
raise ValueError(e.orig.__cause__.detail)
raise ValueError("DB Integrity error")
except Exception as e:
logger.error(f"Error deleting entity: {e}")
raise e
if entity is None:
raise NotFoundError(f"Entity with id {id} not found")
return entity_to_schema(entity)
# async def _create_from_schema(
# cls: type[BotEntity],
# *,
# session: AsyncSession | None = None,
# obj_in: BaseModel,
# ):
# """
# Create a new entity instance from a Pydantic model.
# Args:
# session: Database session (injected by session_dep)
# obj_in: Pydantic model to create the entity from
# Returns:
# The created entity instance
# """
# obj_dict = {}
# ret_obj_dict = {}
# for field_descriptor in cls.bot_entity_descriptor.fields_descriptors.values():
# # Only process fields present in the input model
# if field_descriptor.field_name in obj_in.__class__.model_fields:
# # Handle list fields that are relations to other BotEntities
# if (
# field_descriptor.is_list
# and isinstance(field_descriptor.type_base, type)
# and issubclass(field_descriptor.type_base, BotEntity)
# ):
# items = (
# await session.exec(
# select(field_descriptor.type_base).where(
# col(field_descriptor.type_base.id).in_(
# getattr(obj_in, field_descriptor.field_name)
# )
# )
# )
# ).all()
# obj_dict[field_descriptor.field_name] = items
# ret_obj_dict[field_descriptor.field_name] = [
# item.id for item in items
# ]
# else:
# obj_dict[field_descriptor.field_name] = getattr(
# obj_in, field_descriptor.field_name
# )
# ret_obj_dict[field_descriptor.field_name] = getattr(
# obj_in, field_descriptor.field_name
# )
# obj = cls(**obj_dict)
# session.add(obj)
# await session.commit()
# if "id" not in ret_obj_dict:
# ret_obj_dict["id"] = obj.id
# return cls.bot_entity_descriptor.schema_class(**ret_obj_dict)
# @classmethod
# def apply_filters(
# cls,
# select_statement: SelectOfScalar[Self],
# filters: Filter | FilterExpression | None = None,
# ) -> SelectOfScalar[Self]:
# """
# Apply filters to a select statement.
# Args:
# select_statement: SQLAlchemy select statement to modify
# filters: Filter or FilterExpression to apply
# Returns:
# Modified select statement with filter conditions
# """
# if filters is None:
# return select_statement
# condition = cls._build_filter_condition(filters)
# if condition is not None:
# return select_statement.where(condition)
# return select_statement

View File

@@ -6,6 +6,8 @@ from typing import Any, Callable, TYPE_CHECKING, Literal, Union
from babel.support import LazyProxy from babel.support import LazyProxy
from dataclasses import dataclass, field from dataclasses import dataclass, field
from fastapi.datastructures import State from fastapi.datastructures import State
from pydantic import BaseModel
from pydantic_core import PydanticUndefined
from sqlmodel.ext.asyncio.session import AsyncSession from sqlmodel.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import InstrumentedAttribute from sqlalchemy.orm import InstrumentedAttribute
@@ -17,6 +19,8 @@ if TYPE_CHECKING:
from .bot_entity import BotEntity from .bot_entity import BotEntity
from ..main import QBotApp from ..main import QBotApp
from .user import UserBase from .user import UserBase
from .crud_service import CrudService
from .bot_process import BotProcess
# EntityCaptionCallable = Callable[["EntityDescriptor"], str] # EntityCaptionCallable = Callable[["EntityDescriptor"], str]
# EntityItemCaptionCallable = Callable[["EntityDescriptor", Any], str] # EntityItemCaptionCallable = Callable[["EntityDescriptor", Any], str]
@@ -46,8 +50,8 @@ class InlineButton[T: "BotEntity"]:
@dataclass @dataclass
class Filter: class Filter[T: "BotEntity"]:
field_name: str field: str | Callable[[type[T]], InstrumentedAttribute]
operator: Literal[ operator: Literal[
"==", "==",
"!=", "!=",
@@ -67,6 +71,89 @@ class Filter:
value: Any | None = None value: Any | None = None
param_index: int | None = None param_index: int | None = None
def __or__(self, other: "Filter[T] | FilterExpression[T]") -> "FilterExpression[T]":
"""Create OR expression with another filter or expression"""
if isinstance(other, Filter):
return FilterExpression("or", [self, other])
elif isinstance(other, FilterExpression):
if other.operator == "or":
# Simplify: filter | (a | b) = (filter | a | b)
return FilterExpression("or", [self] + other.filters)
else:
return FilterExpression("or", [self, other])
else:
raise TypeError(f"Cannot combine Filter with {type(other)}")
def __and__(
self, other: "Filter[T] | FilterExpression[T]"
) -> "FilterExpression[T]":
"""Create AND expression with another filter or expression"""
if isinstance(other, Filter):
return FilterExpression("and", [self, other])
elif isinstance(other, FilterExpression):
if other.operator == "and":
# Simplify: filter & (a & b) = (filter & a & b)
return FilterExpression("and", [self] + other.filters)
else:
return FilterExpression("and", [self, other])
else:
raise TypeError(f"Cannot combine Filter with {type(other)}")
class FilterExpression[T: "BotEntity"]:
"""
Represents a logical expression combining multiple filters with AND/OR operations.
Supports expression simplification for optimal query building.
"""
def __init__(
self,
operator: Literal["or", "and"],
filters: list["Filter[T] | FilterExpression[T]"],
):
self.operator = operator
self.filters = self._simplify_filters(filters)
def _simplify_filters(
self, filters: list["Filter[T] | FilterExpression[T]"]
) -> list["Filter[T] | FilterExpression[T]"]:
"""Simplify filters by flattening nested expressions with the same operator"""
simplified = []
for filter_obj in filters:
if (
isinstance(filter_obj, FilterExpression)
and filter_obj.operator == self.operator
):
# Flatten nested expressions with the same operator
simplified.extend(filter_obj.filters)
else:
simplified.append(filter_obj)
return simplified
def __or__(self, other: "Filter[T] | FilterExpression[T]") -> "FilterExpression[T]":
"""Combine with another filter or expression using OR"""
if isinstance(other, (Filter, FilterExpression)):
if isinstance(other, FilterExpression) and other.operator == "or":
# Simplify: (a | b) | (c | d) = (a | b | c | d)
return FilterExpression("or", self.filters + other.filters)
else:
return FilterExpression("or", [self, other])
else:
raise TypeError(f"Cannot combine FilterExpression with {type(other)}")
def __and__(
self, other: "Filter[T] | FilterExpression[T]"
) -> "FilterExpression[T]":
"""Combine with another filter or expression using AND"""
if isinstance(other, (Filter, FilterExpression)):
if isinstance(other, FilterExpression) and other.operator == "and":
# Simplify: (a & b) & (c & d) = (a & b & c & d)
return FilterExpression("and", self.filters + other.filters)
else:
return FilterExpression("and", [self, other])
else:
raise TypeError(f"Cannot combine FilterExpression with {type(other)}")
@dataclass @dataclass
class EntityList[T: "BotEntity"]: class EntityList[T: "BotEntity"]:
@@ -77,7 +164,7 @@ class EntityList[T: "BotEntity"]:
show_add_new_button: bool = True show_add_new_button: bool = True
item_form: str | None = None item_form: str | None = None
pagination: bool = True pagination: bool = True
static_filters: list[Filter] = None static_filters: Filter[T] | FilterExpression[T] | None = None
filtering: bool = False filtering: bool = False
filtering_fields: list[str] = None filtering_fields: list[str] = None
order_by: str | Any | None = None order_by: str | Any | None = None
@@ -99,9 +186,7 @@ class _BaseFieldDescriptor[T: "BotEntity"]:
caption: ( caption: (
str | LazyProxy | Callable[["FieldDescriptor", "BotContext"], str] | None str | LazyProxy | Callable[["FieldDescriptor", "BotContext"], str] | None
) = None ) = None
description: ( description: str | LazyProxy | None = PydanticUndefined
str | LazyProxy | Callable[["FieldDescriptor", "BotContext"], str] | None
) = None
edit_prompt: ( edit_prompt: (
str str
| LazyProxy | LazyProxy
@@ -122,8 +207,8 @@ class _BaseFieldDescriptor[T: "BotEntity"]:
bool_false_value: str | LazyProxy = "no" bool_false_value: str | LazyProxy = "no"
bool_true_value: str | LazyProxy = "yes" bool_true_value: str | LazyProxy = "yes"
ep_form: str | Callable[["BotContext"], str] | None = None ep_form: str | Callable[["BotContext"], str] | None = None
ep_parent_field: str | None = None ep_parent_field: str | Callable[[type[T]], InstrumentedAttribute] | None = None
ep_child_field: str | None = None ep_child_field: str | Callable[[type[T]], InstrumentedAttribute] | None = None
dt_type: Literal["date", "datetime"] = "date" dt_type: Literal["date", "datetime"] = "date"
options: ( options: (
list[list[Union[Any, tuple[Any, str]]]] list[list[Union[Any, tuple[Any, str]]]]
@@ -133,7 +218,7 @@ class _BaseFieldDescriptor[T: "BotEntity"]:
options_custom_value: bool = True options_custom_value: bool = True
show_current_value_button: bool = True show_current_value_button: bool = True
show_skip_in_editor: Literal[False, "Auto"] = "Auto" show_skip_in_editor: Literal[False, "Auto"] = "Auto"
default: Any = None default: Any = PydanticUndefined
default_factory: Callable[[], Any] | None = None default_factory: Callable[[], Any] | None = None
@@ -178,7 +263,8 @@ class _BaseEntityDescriptor[T: "BotEntity"]:
full_name_plural: ( full_name_plural: (
str | LazyProxy | Callable[["EntityDescriptor", "BotContext"], str] | None str | LazyProxy | Callable[["EntityDescriptor", "BotContext"], str] | None
) = None ) = None
description: ( description: str | None = None
ui_description: (
str | LazyProxy | Callable[["EntityDescriptor", "BotContext"], str] | None str | LazyProxy | Callable[["EntityDescriptor", "BotContext"], str] | None
) = None ) = None
item_repr: Callable[[T, "BotContext"], str] | None = None item_repr: Callable[[T, "BotContext"], str] | None = None
@@ -190,11 +276,11 @@ class _BaseEntityDescriptor[T: "BotEntity"]:
ownership_fields: dict[RoleBase, str] = field(default_factory=dict[RoleBase, str]) ownership_fields: dict[RoleBase, str] = field(default_factory=dict[RoleBase, str])
permissions: dict[EntityPermission, list[RoleBase]] = field( permissions: dict[EntityPermission, list[RoleBase]] = field(
default_factory=lambda: { default_factory=lambda: {
EntityPermission.LIST: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER], EntityPermission.LIST_RLS: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.READ: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER], EntityPermission.READ_RLS: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.CREATE: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER], EntityPermission.CREATE_RLS: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.UPDATE: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER], EntityPermission.UPDATE_RLS: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.DELETE: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER], EntityPermission.DELETE_RLS: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER],
EntityPermission.LIST_ALL: [RoleBase.SUPER_USER], EntityPermission.LIST_ALL: [RoleBase.SUPER_USER],
EntityPermission.READ_ALL: [RoleBase.SUPER_USER], EntityPermission.READ_ALL: [RoleBase.SUPER_USER],
EntityPermission.CREATE_ALL: [RoleBase.SUPER_USER], EntityPermission.CREATE_ALL: [RoleBase.SUPER_USER],
@@ -202,6 +288,8 @@ class _BaseEntityDescriptor[T: "BotEntity"]:
EntityPermission.DELETE_ALL: [RoleBase.SUPER_USER], EntityPermission.DELETE_ALL: [RoleBase.SUPER_USER],
} }
) )
rls_filters: Filter[T] | FilterExpression[T] | None = None
rls_filters_params: Callable[["UserBase"], list[Any]] = lambda user: [user.id]
before_create: Callable[["BotContext"], Union[bool, str]] | None = None before_create: Callable[["BotContext"], Union[bool, str]] | None = None
before_create_save: Callable[[T, "BotContext"], Union[bool, str]] | None = None before_create_save: Callable[[T, "BotContext"], Union[bool, str]] | None = None
before_update_save: ( before_update_save: (
@@ -212,6 +300,7 @@ class _BaseEntityDescriptor[T: "BotEntity"]:
on_created: Callable[[T, "BotContext"], None] | None = None on_created: Callable[[T, "BotContext"], None] | None = None
on_deleted: Callable[[T, "BotContext"], None] | None = None on_deleted: Callable[[T, "BotContext"], None] | None = None
on_updated: Callable[[dict[str, Any], T, "BotContext"], None] | None = None on_updated: Callable[[dict[str, Any], T, "BotContext"], None] | None = None
crud: Union["CrudService", None] = None
@dataclass(kw_only=True) @dataclass(kw_only=True)
@@ -228,7 +317,7 @@ class EntityDescriptor(_BaseEntityDescriptor):
@dataclass(kw_only=True) @dataclass(kw_only=True)
class CommandCallbackContext[UT: UserBase]: class CommandCallbackContext:
keyboard_builder: InlineKeyboardBuilder = field( keyboard_builder: InlineKeyboardBuilder = field(
default_factory=InlineKeyboardBuilder default_factory=InlineKeyboardBuilder
) )
@@ -238,7 +327,7 @@ class CommandCallbackContext[UT: UserBase]:
message: Message | CallbackQuery message: Message | CallbackQuery
callback_data: ContextData callback_data: ContextData
db_session: AsyncSession db_session: AsyncSession
user: UT user: "UserBase"
app: "QBotApp" app: "QBotApp"
app_state: State app_state: State
state_data: dict[str, Any] state_data: dict[str, Any]
@@ -249,11 +338,11 @@ class CommandCallbackContext[UT: UserBase]:
@dataclass(kw_only=True) @dataclass(kw_only=True)
class BotContext[UT: UserBase]: class BotContext:
db_session: AsyncSession db_session: AsyncSession
app: "QBotApp" app: "QBotApp"
app_state: State app_state: State
user: UT user: "UserBase"
message: Message | CallbackQuery | None = None message: Message | CallbackQuery | None = None
default_handler: Callable[["BotEntity", "BotContext"], None] | None = None default_handler: Callable[["BotEntity", "BotContext"], None] | None = None
@@ -271,3 +360,31 @@ class BotCommand:
show_cancel_in_param_form: bool = True show_cancel_in_param_form: bool = True
show_back_in_param_form: bool = True show_back_in_param_form: bool = True
handler: Callable[[CommandCallbackContext], None] handler: Callable[[CommandCallbackContext], None]
@dataclass(kw_only=True)
class _BaseProcessDescriptor:
description: str | LazyProxy | None = None
roles: list[RoleBase] = field(
default_factory=lambda: [RoleBase.DEFAULT_USER, RoleBase.SUPER_USER]
)
icon: str | None = None
caption: str | LazyProxy | None = None
pre_check: Callable[[BotContext], bool | str] | None = None
show_in_bot_menu: bool = False
answer_message: Callable[[BotContext, BaseModel], str] | None = None
answer_inline_buttons: (
Callable[[BotContext, BaseModel], list[InlineKeyboardButton]] | None
) = None
@dataclass(kw_only=True)
class ProcessDescriptor(_BaseProcessDescriptor):
name: str
process_class: type["BotProcess"]
input_schema: type[BaseModel] | None = None
output_schema: type[BaseModel] | None = None
@dataclass(kw_only=True)
class Process(_BaseProcessDescriptor): ...

View File

@@ -1,7 +0,0 @@
from .descriptors import EntityDescriptor
from ._singleton import Singleton
class EntityMetadata(metaclass=Singleton):
def __init__(self):
self.entity_descriptors: dict[str, EntityDescriptor] = {}

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class ListSchema(BaseModel):
id: int
name: str

View File

@@ -0,0 +1,192 @@
from inspect import iscoroutinefunction
from quickbot.model.descriptors import EntityDescriptor, EntityPermission
from quickbot.model.descriptors import Filter, FilterExpression
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from quickbot.model.user import UserBase
from quickbot.model.bot_entity import BotEntity
def get_user_permissions(
user: "UserBase", entity_descriptor: EntityDescriptor
) -> list[EntityPermission]:
permissions = list[EntityPermission]()
for permission, roles in entity_descriptor.permissions.items():
for role in roles:
if role in user.roles:
permissions.append(permission)
break
return permissions
async def check_entity_permission(
entity: "BotEntity", user: "UserBase", permission: EntityPermission
) -> bool:
perm_mapping = {
EntityPermission.LIST_RLS: EntityPermission.LIST_ALL,
EntityPermission.READ_RLS: EntityPermission.READ_ALL,
EntityPermission.UPDATE_RLS: EntityPermission.UPDATE_ALL,
EntityPermission.CREATE_RLS: EntityPermission.CREATE_ALL,
EntityPermission.DELETE_RLS: EntityPermission.DELETE_ALL,
}
if permission not in perm_mapping:
raise ValueError(f"Invalid permission: {permission}")
entity_descriptor = entity.__class__.bot_entity_descriptor
permissions = get_user_permissions(user, entity_descriptor)
# Check if user has the corresponding ALL permission
if perm_mapping[permission] in permissions:
return True
# Check RLS filters if they exist
if entity_descriptor.rls_filters:
# Get parameters for RLS
params = []
if entity_descriptor.rls_filters_params:
if iscoroutinefunction(entity_descriptor.rls_filters_params):
params = await entity_descriptor.rls_filters_params(user)
else:
params = entity_descriptor.rls_filters_params(user)
# Create a copy of the RLS filters with parameter values substituted
rls_filters = entity.__class__._substitute_rls_parameters(
entity_descriptor.rls_filters, params
)
# Check if the entity matches the RLS filters by evaluating the condition
# against the entity's attributes
if _entity_matches_rls_filters(entity, rls_filters):
return True
# If no RLS filters are defined, check if user has the RLS permission
if permission in permissions:
return True
return False
def _entity_matches_rls_filters(
entity: "BotEntity", rls_filters: "Filter | FilterExpression"
) -> bool:
"""
Check if an entity matches the given RLS filters by evaluating the filters
against the entity's attributes.
Args:
entity: The entity to check
rls_filters: RLS filters to evaluate
Returns:
True if the entity matches the filters, False otherwise
"""
if isinstance(rls_filters, Filter):
return _evaluate_single_filter(entity, rls_filters)
elif isinstance(rls_filters, FilterExpression):
return _evaluate_filter_expression(entity, rls_filters)
else:
return False
def _evaluate_single_filter(entity: "BotEntity", filter_obj: "Filter") -> bool:
"""Evaluate a single filter against an entity"""
# Get the field value from the entity
if isinstance(filter_obj.field, str):
field_value = getattr(entity, filter_obj.field, None)
else:
# Handle callable field (should return the field name)
field_name = filter_obj.field(entity.__class__).key
field_value = getattr(entity, field_name, None)
# Apply the operator
if filter_obj.operator == "==":
return field_value == filter_obj.value
elif filter_obj.operator == "!=":
return field_value != filter_obj.value
elif filter_obj.operator == ">":
return field_value > filter_obj.value
elif filter_obj.operator == "<":
return field_value < filter_obj.value
elif filter_obj.operator == ">=":
return field_value >= filter_obj.value
elif filter_obj.operator == "<=":
return field_value <= filter_obj.value
elif filter_obj.operator == "in":
return field_value in filter_obj.value
elif filter_obj.operator == "not in":
return field_value not in filter_obj.value
elif filter_obj.operator == "like":
return str(field_value).find(str(filter_obj.value)) != -1
elif filter_obj.operator == "ilike":
return str(field_value).lower().find(str(filter_obj.value).lower()) != -1
elif filter_obj.operator == "is none":
return field_value is None
elif filter_obj.operator == "is not none":
return field_value is not None
elif filter_obj.operator == "contains":
return filter_obj.value in field_value
else:
return False
def _evaluate_filter_expression(
entity: "BotEntity", filter_expr: "FilterExpression"
) -> bool:
"""Evaluate a filter expression against an entity"""
results = []
for sub_filter in filter_expr.filters:
if isinstance(sub_filter, Filter):
result = _evaluate_single_filter(entity, sub_filter)
elif isinstance(sub_filter, FilterExpression):
result = _evaluate_filter_expression(entity, sub_filter)
else:
result = False
results.append(result)
if not results:
return False
# Apply the logical operator
if filter_expr.operator == "and":
return all(results)
elif filter_expr.operator == "or":
return any(results)
else:
return False
def _extract_rls_filter_fields(entity_descriptor: EntityDescriptor) -> set[str]:
return _extract_filter_fields(
entity_descriptor.rls_filters, entity_descriptor.type_
)
def _extract_filter_fields(
filter: Filter | FilterExpression | None, entity_type: type
) -> set[str]:
fields = set()
if filter:
if isinstance(filter, Filter):
if filter.operator == "==":
if isinstance(filter.field, str):
fields.add(filter.field)
else:
fields.add(filter.field(entity_type).key)
elif isinstance(filter, FilterExpression):
if (
filter.operator == "and"
and all(isinstance(sub_filter, Filter) for sub_filter in filter.filters)
and all(sub_filter.operator == "==" for sub_filter in filter.filters)
):
for sub_filter in filter.filters:
if isinstance(sub_filter.field, str):
fields.add(sub_filter.field)
else:
fields.add(sub_filter.field(entity_type).key)
return fields

View File

@@ -0,0 +1,53 @@
from typing import Any, Type
from pydantic import BaseModel
from sqlmodel import TypeDecorator, JSON
class PydanticJSON(TypeDecorator):
"""
SQLAlchemy-compatible JSON type for storing Pydantic models
(including nested ones). Automatically serializes on insert
and deserializes on read.
"""
impl = JSON
cache_ok = True
def __init__(self, model_class: Type[BaseModel], *args, **kwargs):
if not issubclass(model_class, BaseModel):
raise TypeError("PydanticJSON expects a Pydantic BaseModel subclass")
self.model_class = model_class
super().__init__(*args, **kwargs)
def process_bind_param(self, value: Any, dialect) -> Any:
"""
Serialize Python object to JSON-compatible form before saving to DB.
"""
if value is None:
return None
if isinstance(value, list):
return [
item.model_dump(mode="json") if isinstance(item, BaseModel) else item
for item in value
]
if isinstance(value, BaseModel):
return value.model_dump(mode="json")
return value # assume already JSON-serializable
def process_result_value(self, value: Any, dialect) -> Any:
"""
Deserialize JSON data from DB back into Python object.
"""
if value is None:
return None
if isinstance(value, list):
return [self.model_class(**item) for item in value]
if isinstance(value, dict):
return self.model_class(**value)
raise TypeError(f"Unsupported value type for deserialization: {type(value)}")

View File

@@ -1,3 +1,4 @@
from sqlalchemy import BigInteger
from sqlmodel import Field, ARRAY from sqlmodel import Field, ARRAY
from .bot_entity import BotEntity from .bot_entity import BotEntity
@@ -5,6 +6,7 @@ from .bot_enum import EnumType
from .language import LanguageBase from .language import LanguageBase
from .role import RoleBase from .role import RoleBase
from .descriptors import EntityField
from .settings import DbSettings as DbSettings from .settings import DbSettings as DbSettings
from .fsm_storage import FSMStorage as FSMStorage from .fsm_storage import FSMStorage as FSMStorage
from .view_setting import ViewSetting as ViewSetting from .view_setting import ViewSetting as ViewSetting
@@ -13,11 +15,24 @@ from .view_setting import ViewSetting as ViewSetting
class UserBase(BotEntity, table=False): class UserBase(BotEntity, table=False):
__tablename__ = "user" __tablename__ = "user"
lang: LanguageBase = Field(sa_type=EnumType(LanguageBase), default=LanguageBase.EN) id: int = EntityField(
is_active: bool = True description="User Telegram ID",
sm_descriptor=Field(primary_key=True, sa_type=BigInteger),
is_visible=False,
)
name: str lang: LanguageBase = Field(
description="User language",
sa_type=EnumType(LanguageBase),
default_factory=lambda: LanguageBase.EN,
)
is_active: bool = EntityField(description="User is active", default=True)
name: str = EntityField(description="User name")
roles: list[RoleBase] = Field( roles: list[RoleBase] = Field(
sa_type=ARRAY(EnumType(RoleBase)), default=[RoleBase.DEFAULT_USER] description="User roles",
sa_type=ARRAY(EnumType(RoleBase)),
default_factory=lambda: [RoleBase.DEFAULT_USER],
) )

364
src/quickbot/model/utils.py Normal file
View File

@@ -0,0 +1,364 @@
from inspect import iscoroutinefunction
from typing import Any, Optional, get_origin
from pydantic import BaseModel, Field
from pydantic.fields import _Unset
from pydantic_core import PydanticUndefined
from typing_extensions import Literal
from sqlmodel import SQLModel, col, column
from sqlmodel.sql.expression import SelectOfScalar
from typing import TYPE_CHECKING
from quickbot.model.list_schema import ListSchema
from quickbot.model.descriptors import (
EntityDescriptor,
Filter,
FilterExpression,
BotContext,
)
from quickbot.utils.main import get_entity_item_repr
if TYPE_CHECKING:
from quickbot import BotEntity, BotEnum
from quickbot.model.user import UserBase
def entity_to_schema(entity: "BotEntity") -> BaseModel:
entity_data = {}
for field_descriptor in entity.bot_entity_descriptor.fields_descriptors.values():
if field_descriptor.is_list and is_bot_entity(field_descriptor.type_base):
entity_data[field_descriptor.field_name] = [
item.id for item in getattr(entity, field_descriptor.field_name)
]
elif field_descriptor.is_list and is_bot_enum(field_descriptor.type_base):
entity_data[field_descriptor.field_name] = [
item.value for item in getattr(entity, field_descriptor.field_name)
]
elif not field_descriptor.is_list and is_bot_entity(field_descriptor.type_base):
continue
elif not field_descriptor.is_list and is_bot_enum(field_descriptor.type_base):
val: "BotEnum" | None = getattr(entity, field_descriptor.field_name)
entity_data[field_descriptor.field_name] = val.value if val else None
else:
entity_data[field_descriptor.field_name] = getattr(
entity, field_descriptor.field_name
)
return (
entity.bot_entity_descriptor.crud.schema(**entity_data)
if entity.bot_entity_descriptor.crud
else entity.bot_entity_descriptor.schema(**entity_data)
)
async def entity_to_list_schema(
entity: "BotEntity", context: "BotContext"
) -> ListSchema:
entity_repr = await get_entity_item_repr(
entity, context, entity.bot_entity_descriptor.item_repr
)
return ListSchema(id=entity.id, name=entity_repr)
def _pydantic_model_fields(
namespace: dict[str, Any],
entity_descriptor: EntityDescriptor,
schema_type: Literal["schema", "create", "update"] = "schema",
) -> dict[str, Any]:
namespace["__annotations__"] = {}
for field_descriptor in entity_descriptor.fields_descriptors.values():
type_origin = get_origin(field_descriptor.type_base)
if (
type_origin is not list
and not field_descriptor.is_list
and (
issubclass(field_descriptor.type_base, SQLModel)
or isinstance(field_descriptor.type_base, str)
)
) or (
schema_type in ["create", "update"]
and field_descriptor.field_name == "id"
and field_descriptor.default is None
):
continue
if (
type_origin is not list
and field_descriptor.is_list
and (
issubclass(field_descriptor.type_base, SQLModel)
or isinstance(field_descriptor.type_base, str)
)
):
namespace["__annotations__"][field_descriptor.field_name] = list[int]
elif type_origin is not list and is_bot_enum(field_descriptor.type_base):
enum_values = [
member.value
for member in field_descriptor.type_base.all_members.values()
]
enum_annotation = (
list[Literal[*enum_values]]
if field_descriptor.is_list
else Literal[*enum_values]
)
if field_descriptor.is_optional:
enum_annotation = Optional[enum_annotation]
namespace["__annotations__"][field_descriptor.field_name] = enum_annotation
else:
namespace["__annotations__"][field_descriptor.field_name] = (
Optional[field_descriptor.type_base]
if field_descriptor.is_optional
else field_descriptor.type_
)
description = (
field_descriptor.description if field_descriptor.description else _Unset
)
if schema_type == "schema" and field_descriptor.is_optional:
namespace[field_descriptor.field_name] = Field(description=description)
elif schema_type == "create":
if field_descriptor.default is not PydanticUndefined:
namespace[field_descriptor.field_name] = Field(
default=field_descriptor.default, description=description
)
elif field_descriptor.default_factory is not None:
namespace[field_descriptor.field_name] = Field(
default_factory=field_descriptor.default_factory,
description=description,
)
elif field_descriptor.is_optional:
namespace[field_descriptor.field_name] = Field(
default=None, description=description
)
else:
namespace[field_descriptor.field_name] = Field(description=description)
elif schema_type == "update":
namespace[field_descriptor.field_name] = Field(
default=None, description=description
)
else:
namespace[field_descriptor.field_name] = Field(description=description)
def pydantic_model(
entity_descriptor: EntityDescriptor,
module_name: str,
schema_type: Literal["schema", "create", "update"] = "schema",
) -> type[BaseModel]:
namespace = {
"__module__": module_name,
}
_pydantic_model_fields(namespace, entity_descriptor, schema_type)
return type(
f"{entity_descriptor.class_name}{schema_type.capitalize() if schema_type != 'schema' else ''}Schema",
(BaseModel,),
namespace,
)
def _build_filter_condition(
cls: type["BotEntity"], filter_obj: Filter | FilterExpression
) -> Any:
"""
Build SQLAlchemy condition from a Filter or FilterExpression object.
Args:
filter_obj: Filter or FilterExpression object to convert
Returns:
SQLAlchemy condition
"""
# --- Handle single Filter object ---
if isinstance(filter_obj, Filter):
# Support both string field names and callables for custom columns
if isinstance(filter_obj.field, str):
column = getattr(cls, filter_obj.field)
else:
column = filter_obj.field(cls)
# Map filter operator to SQLAlchemy expression
if filter_obj.operator == "==":
return column.__eq__(filter_obj.value)
elif filter_obj.operator == "!=":
return column.__ne__(filter_obj.value)
elif filter_obj.operator == "<":
return column.__lt__(filter_obj.value)
elif filter_obj.operator == "<=":
return column.__le__(filter_obj.value)
elif filter_obj.operator == ">":
return column.__gt__(filter_obj.value)
elif filter_obj.operator == ">=":
return column.__ge__(filter_obj.value)
elif filter_obj.operator == "ilike":
return col(column).ilike(f"%{filter_obj.value}%")
elif filter_obj.operator == "like":
return col(column).like(f"%{filter_obj.value}%")
elif filter_obj.operator == "in":
return col(column).in_(filter_obj.value)
elif filter_obj.operator == "not in":
return col(column).notin_(filter_obj.value)
elif filter_obj.operator == "is none":
return col(column).is_(None)
elif filter_obj.operator == "is not none":
return col(column).isnot(None)
elif filter_obj.operator == "contains":
return filter_obj.value == col(column).any_()
else:
# Unknown operator, return None (no condition)
return None
# --- Handle FilterExpression object (logical AND/OR of filters) ---
elif isinstance(filter_obj, FilterExpression):
operator = filter_obj.operator
filters = filter_obj.filters
# Recursively build conditions for all sub-filters
conditions = []
for sub_filter in filters:
condition = _build_filter_condition(cls, sub_filter)
if condition is not None:
conditions.append(condition)
if not conditions:
return None
# Combine conditions using logical AND/OR
if operator == "and":
res_condition = conditions[0]
if len(conditions) > 1:
for condition in conditions[1:]:
res_condition = res_condition & condition
return res_condition
elif operator == "or":
res_condition = conditions[0]
if len(conditions) > 1:
for condition in conditions[1:]:
res_condition = res_condition | condition
return res_condition
def _static_filter_condition(
cls,
select_statement: SelectOfScalar,
static_filter: Filter | FilterExpression,
):
"""
Apply static filters to a select statement.
Static filters are predefined conditions that don't depend on user input.
Supports both Filter and FilterExpression objects with logical operations.
Args:
select_statement: SQLAlchemy select statement to modify
static_filter: filter condition to apply
Returns:
Modified select statement with filter conditions
"""
condition = _build_filter_condition(cls, static_filter)
if condition is not None:
select_statement = select_statement.where(condition)
return select_statement
def _filter_condition(
select_statement: SelectOfScalar,
filter: str,
filter_fields: list[str],
):
"""
Apply text-based search filters to a select statement.
Creates a case-insensitive LIKE search across multiple fields.
Args:
select_statement: SQLAlchemy select statement to modify
filter: Search text to look for
filter_fields: List of field names to search in
Returns:
Modified select statement with search conditions
"""
condition = None
for field in filter_fields:
if condition is not None:
condition = condition | (column(field).ilike(f"%{filter}%"))
else:
condition = column(field).ilike(f"%{filter}%")
return select_statement.where(condition)
async def _apply_rls_filters(
cls: type["BotEntity"], select_statement: SelectOfScalar, user: "UserBase"
):
"""
Apply Row Level Security (RLS) filters to restrict access based on user roles.
This method uses the entity's rls_filters and rls_filters_params to apply
dynamic filtering conditions based on the user's roles and permissions.
Args:
select_statement: SQLAlchemy select statement to modify
user: User whose access should be restricted
Returns:
Modified select statement with RLS conditions
"""
# --- Check if RLS filters are defined for this entity ---
if cls.bot_entity_descriptor.rls_filters:
# Get parameters for RLS filters (may be sync or async)
params = []
if cls.bot_entity_descriptor.rls_filters_params:
if iscoroutinefunction(cls.bot_entity_descriptor.rls_filters_params):
params = await cls.bot_entity_descriptor.rls_filters_params(user)
else:
params = cls.bot_entity_descriptor.rls_filters_params(user)
# Create a copy of the RLS filters with parameter values substituted
rls_filters = _substitute_rls_parameters(
cls.bot_entity_descriptor.rls_filters, params
)
# Apply RLS filters with parameters
condition = _build_filter_condition(cls, rls_filters)
if condition is not None:
return select_statement.where(condition)
return select_statement
def _substitute_rls_parameters(
rls_filters: Filter | FilterExpression, params: list[Any]
) -> Filter | FilterExpression:
"""
Substitute parameter placeholders in RLS filters with actual values.
Args:
rls_filters: RLS filters that may contain parameter placeholders
params: List of parameter values to substitute
Returns:
RLS filters with parameters substituted
"""
# --- Substitute parameter in single filter ---
if isinstance(rls_filters, Filter):
if rls_filters.value_type == "param" and rls_filters.param_index is not None:
if 0 <= rls_filters.param_index < len(params):
# Create a new filter with the parameter value substituted
return Filter(
field=rls_filters.field,
operator=rls_filters.operator,
value_type="const",
value=params[rls_filters.param_index],
param_index=None,
)
return rls_filters
# --- Recursively substitute parameters in all sub-filters ---
elif isinstance(rls_filters, FilterExpression):
substituted_filters = []
for sub_filter in rls_filters.filters:
substituted_filter = _substitute_rls_parameters(sub_filter, params)
substituted_filters.append(substituted_filter)
return FilterExpression(rls_filters.operator, substituted_filters)
def is_bot_entity(type_: type) -> bool:
return hasattr(type_, "bot_entity_descriptor")
def is_bot_enum(type_: type) -> bool:
return hasattr(type_, "all_members")

9
src/quickbot/plugin.py Normal file
View File

@@ -0,0 +1,9 @@
from typing import TYPE_CHECKING, Protocol, runtime_checkable
if TYPE_CHECKING:
from quickbot import QBotApp
@runtime_checkable
class Registerable(Protocol):
def register(self, app: "QBotApp") -> None: ...

View File

@@ -7,11 +7,14 @@ from typing import Any, TYPE_CHECKING, Callable
import ujson as json import ujson as json
from quickbot.utils.serialization import deserialize from quickbot.utils.serialization import deserialize
from quickbot.model.permissions import (
from ..model.bot_entity import BotEntity _extract_rls_filter_fields,
from ..model.bot_enum import BotEnum get_user_permissions,
_extract_filter_fields,
)
from ..model.settings import Settings from ..model.settings import Settings
from ..model.descriptors import ( from ..model.descriptors import (
BotContext, BotContext,
EntityList, EntityList,
@@ -26,22 +29,11 @@ from ..model.descriptors import (
from ..bot.handlers.context import CallbackCommand, ContextData, CommandContext from ..bot.handlers.context import CallbackCommand, ContextData, CommandContext
if TYPE_CHECKING: if TYPE_CHECKING:
from ..model.bot_entity import BotEntity
from ..model.user import UserBase from ..model.user import UserBase
from ..main import QBotApp from ..main import QBotApp
def get_user_permissions(
user: "UserBase", entity_descriptor: EntityDescriptor
) -> list[EntityPermission]:
permissions = list[EntityPermission]()
for permission, roles in entity_descriptor.permissions.items():
for role in roles:
if role in user.roles:
permissions.append(permission)
break
return permissions
def get_local_text(text: str, locale: str = None) -> str: def get_local_text(text: str, locale: str = None) -> str:
if not locale: if not locale:
i18n = I18n.get_current(no_error=True) i18n = I18n.get_current(no_error=True)
@@ -57,39 +49,6 @@ def get_local_text(text: str, locale: str = None) -> str:
return obj.get(locale, obj[list(obj.keys())[0]]) return obj.get(locale, obj[list(obj.keys())[0]])
def check_entity_permission(
entity: BotEntity, user: "UserBase", permission: EntityPermission
) -> bool:
perm_mapping = {
EntityPermission.LIST: EntityPermission.LIST_ALL,
EntityPermission.READ: EntityPermission.READ_ALL,
EntityPermission.UPDATE: EntityPermission.UPDATE_ALL,
EntityPermission.CREATE: EntityPermission.CREATE_ALL,
EntityPermission.DELETE: EntityPermission.DELETE_ALL,
}
if permission not in perm_mapping:
raise ValueError(f"Invalid permission: {permission}")
entity_descriptor = entity.__class__.bot_entity_descriptor
permissions = get_user_permissions(user, entity_descriptor)
if perm_mapping[permission] in permissions:
return True
ownership_fields = entity_descriptor.ownership_fields
for role in user.roles:
if role in ownership_fields:
if getattr(entity, ownership_fields[role]) == user.id:
return True
else:
if permission in permissions:
return True
return False
def get_send_message(message: Message | CallbackQuery): def get_send_message(message: Message | CallbackQuery):
if isinstance(message, Message): if isinstance(message, Message):
return message.answer return message.answer
@@ -111,9 +70,9 @@ def clear_state(state_data: dict, clear_nav: bool = False):
async def get_entity_item_repr( async def get_entity_item_repr(
entity: BotEntity, entity: "BotEntity",
context: BotContext, context: BotContext,
item_repr: Callable[[BotEntity, BotContext], str] | None = None, item_repr: Callable[["BotEntity", BotContext], str] | None = None,
) -> str: ) -> str:
descr = entity.bot_entity_descriptor descr = entity.bot_entity_descriptor
@@ -151,7 +110,7 @@ async def get_value_repr(
if isinstance(value, bool): if isinstance(value, bool):
return "[✓]" if value else "[ ]" return "[✓]" if value else "[ ]"
elif field_descriptor.is_list: elif field_descriptor.is_list:
if issubclass(type_, BotEntity): if hasattr(type_, "bot_entity_descriptor"):
return f"[{ return f"[{
', '.join( ', '.join(
[ [
@@ -160,15 +119,15 @@ async def get_value_repr(
] ]
) )
}]" }]"
elif issubclass(type_, BotEnum): elif hasattr(type_, "all_members"):
return f"[{', '.join(item.localized(locale) for item in value)}]" return f"[{', '.join(item.localized(locale) for item in value)}]"
elif type_ is str: elif type_ is str:
return f"[{', '.join([f'"{item}"' for item in value])}]" return f"[{', '.join([f'"{item}"' for item in value])}]"
else: else:
return f"[{', '.join([str(item) for item in value])}]" return f"[{', '.join([str(item) for item in value])}]"
elif issubclass(type_, BotEntity): elif hasattr(type_, "bot_entity_descriptor"):
return await get_entity_item_repr(entity=value, context=context) return await get_entity_item_repr(entity=value, context=context)
elif issubclass(type_, BotEnum): elif hasattr(type_, "all_members"):
return value.localized(locale) return value.localized(locale)
elif isinstance(value, str): elif isinstance(value, str):
if field_descriptor and field_descriptor.localizable: if field_descriptor and field_descriptor.localizable:
@@ -187,12 +146,12 @@ async def get_callable_str(
str str
| LazyProxy | LazyProxy
| Callable[[EntityDescriptor, BotContext], str] | Callable[[EntityDescriptor, BotContext], str]
| Callable[[BotEntity, BotContext], str] | Callable[["BotEntity", BotContext], str]
| Callable[[FieldDescriptor, BotEntity, BotContext], str] | Callable[[FieldDescriptor, "BotEntity", BotContext], str]
), ),
context: BotContext, context: BotContext,
descriptor: FieldDescriptor | EntityDescriptor | None = None, descriptor: FieldDescriptor | EntityDescriptor | None = None,
entity: BotEntity | Any = None, entity: "BotEntity | Any" = None,
) -> str: ) -> str:
if isinstance(callable_str, str): if isinstance(callable_str, str):
return callable_str return callable_str
@@ -217,17 +176,13 @@ async def get_callable_str(
return callable_str(descriptor, entity, context) return callable_str(descriptor, entity, context)
else: else:
return callable_str(entity or descriptor, context) return callable_str(entity or descriptor, context)
else:
raise ValueError(
f"Invalid callable type: {type(callable_str)}. Expected str, LazyProxy or callable."
)
def get_entity_descriptor( def get_entity_descriptor(
app: "QBotApp", callback_data: ContextData app: "QBotApp", callback_data: ContextData
) -> EntityDescriptor: ) -> EntityDescriptor:
if callback_data.entity_name: if callback_data.entity_name:
return app.entity_metadata.entity_descriptors[callback_data.entity_name] return app.bot_metadata.entity_descriptors[callback_data.entity_name]
return None return None
@@ -283,7 +238,7 @@ async def build_field_sequence(
entity_data = state_data.get("entity_data", {}) entity_data = state_data.get("entity_data", {})
field_sequence = list[str]() field_sequence = list[str]()
# exclude ownership fields from edit if user has no CREATE_ALL/UPDATE_ALL permission # exclude RLS fields from edit if user has no CREATE_ALL/UPDATE_ALL permission
user_permissions = get_user_permissions(user, entity_descriptor) user_permissions = get_user_permissions(user, entity_descriptor)
for fd in entity_descriptor.fields_descriptors.values(): for fd in entity_descriptor.fields_descriptors.values():
if isinstance(fd.is_visible_in_edit_form, bool): if isinstance(fd.is_visible_in_edit_form, bool):
@@ -306,31 +261,30 @@ async def build_field_sequence(
or fd.default_factory is not None or fd.default_factory is not None
): ):
skip = True skip = True
for own_field in entity_descriptor.ownership_fields.items(): # Check RLS filters for field visibility
if ( if entity_descriptor.rls_filters:
own_field[1].rstrip("_id") == fd.field_name.rstrip("_id") # Get RLS filter fields that should be auto-filled and hidden from user
and own_field[0] in user.roles rls_filter_fields = _extract_rls_filter_fields(entity_descriptor)
and ( if fd.field_name in rls_filter_fields and (
( (
EntityPermission.CREATE_ALL not in user_permissions EntityPermission.CREATE_ALL not in user_permissions
and callback_data.context == CommandContext.ENTITY_CREATE and callback_data.context == CommandContext.ENTITY_CREATE
) )
or ( or (
EntityPermission.UPDATE_ALL not in user_permissions EntityPermission.UPDATE_ALL not in user_permissions
and callback_data.context == CommandContext.ENTITY_EDIT and callback_data.context == CommandContext.ENTITY_EDIT
)
) )
): ):
skip = True skip = True
break
if ( if prev_form_list and prev_form_list.static_filters:
prev_form_list static_filter_fields = _extract_filter_fields(
and prev_form_list.static_filters prev_form_list.static_filters, entity_descriptor.type_
and fd.field_name.rstrip("_id") )
in [f.field_name.rstrip("_id") for f in prev_form_list.static_filters] if fd.field_name.rstrip("_id") in [
): f.rstrip("_id") for f in static_filter_fields
skip = True ]:
skip = True
if not skip: if not skip:
field_sequence.append(fd.field_name) field_sequence.append(fd.field_name)

View File

@@ -6,9 +6,7 @@ from typing import Any, Union, get_origin, get_args
from types import UnionType, NoneType from types import UnionType, NoneType
import ujson as json import ujson as json
from ..model.bot_entity import BotEntity from quickbot.model.descriptors import FieldDescriptor
from ..model.bot_enum import BotEnum
from ..model.descriptors import FieldDescriptor
async def deserialize[T](session: AsyncSession, type_: type[T], value: str = None) -> T: async def deserialize[T](session: AsyncSession, type_: type[T], value: str = None) -> T:
@@ -28,7 +26,7 @@ async def deserialize[T](session: AsyncSession, type_: type[T], value: str = Non
arg_type = args[0] arg_type = args[0]
values = json.loads(value) if value else [] values = json.loads(value) if value else []
if arg_type: if arg_type:
if issubclass(arg_type, BotEntity): if hasattr(arg_type, "bot_entity_descriptor"):
ret = list[arg_type]() ret = list[arg_type]()
items = ( items = (
await session.exec(select(arg_type).where(column("id").in_(values))) await session.exec(select(arg_type).where(column("id").in_(values)))
@@ -36,17 +34,17 @@ async def deserialize[T](session: AsyncSession, type_: type[T], value: str = Non
for item in items: for item in items:
ret.append(item) ret.append(item)
return ret return ret
elif issubclass(arg_type, BotEnum): elif hasattr(arg_type, "all_members"):
return [arg_type(value) for value in values] return [arg_type(value) for value in values]
else: else:
return [arg_type(value) for value in values] return [arg_type(value) for value in values]
else: else:
return values return values
elif issubclass(type_, BotEntity): elif hasattr(type_, "bot_entity_descriptor"):
if is_optional and not value: if is_optional and not value:
return None return None
return await session.get(type_, int(value)) return await session.get(type_, int(value))
elif issubclass(type_, BotEnum): elif hasattr(type_, "all_members"):
if is_optional and not value: if is_optional and not value:
return None return None
return type_(value) return type_(value)
@@ -79,12 +77,12 @@ def serialize(value: Any, field_descriptor: FieldDescriptor) -> str:
type_ = field_descriptor.type_base type_ = field_descriptor.type_base
if field_descriptor.is_list: if field_descriptor.is_list:
if issubclass(type_, BotEntity): if hasattr(type_, "bot_entity_descriptor"):
return json.dumps([item.id for item in value], ensure_ascii=False) return json.dumps([item.id for item in value], ensure_ascii=False)
elif issubclass(type_, BotEnum): elif hasattr(type_, "all_members"):
return json.dumps([item.value for item in value], ensure_ascii=False) return json.dumps([item.value for item in value], ensure_ascii=False)
else: else:
return json.dumps(value, ensure_ascii=False) return json.dumps(value, ensure_ascii=False)
elif issubclass(type_, BotEntity): elif hasattr(type_, "bot_entity_descriptor"):
return str(value.id) if value else "" return str(value.id) if value else ""
return str(value) return str(value)