diff --git a/src/quickbot/api_route/telegram.py b/src/quickbot/api_route/telegram.py index 172f5e0..0135833 100644 --- a/src/quickbot/api_route/telegram.py +++ b/src/quickbot/api_route/telegram.py @@ -16,7 +16,7 @@ router = APIRouter() @router.post("/webhook") async def telegram_webhook( - db_session: Annotated[AsyncSession, Depends(get_db)], + # db_session: Annotated[AsyncSession, Depends(get_db)], request: Request, background_tasks: BackgroundTasks, ): @@ -47,7 +47,7 @@ async def feed_bot_update( update: Update, app_state: State, ): - async with async_session() as db_session: + async with get_db() as db_session: await app.dp.feed_webhook_update( bot=app.bot, update=update, diff --git a/src/quickbot/db/__init__.py b/src/quickbot/db/__init__.py index 9ee288a..a84f282 100644 --- a/src/quickbot/db/__init__.py +++ b/src/quickbot/db/__init__.py @@ -1,19 +1,41 @@ +from contextlib import asynccontextmanager +from logging import getLogger +from typing import AsyncGenerator from sqlmodel.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.orm import sessionmaker from ..config import config -# import logging -# logger = logging.getLogger('sqlalchemy.engine') -# logger.setLevel(logging.DEBUG) +logger = getLogger(__name__) + +class TracedSession(AsyncSession): + def __del__(self): + logger.warning(f"💥 __del__ called — session was not properly closed! {id(self)}") + # This is a workaround for the issue with SQLAlchemy 2.0 + # where the session is not closed properly and __del__ is called + # when the session is garbage collected. + # This is not a good practice, but it is a workaround for now + # to avoid the session being closed too early. + + # You can also use the following line to print the stack trace + # to see where the session was created and not closed properly. + import traceback + traceback.print_stack() + + async_engine = create_async_engine(config.DATABASE_URI, pool_size=20, max_overflow=60) async_session = sessionmaker[AsyncSession]( async_engine, class_=AsyncSession, expire_on_commit=False ) - -async def get_db() -> AsyncSession: # type: ignore - async with async_session() as session: +@asynccontextmanager +async def get_db() -> AsyncGenerator[AsyncSession, None]: + session = async_session() + logger.warning(f"🟢 Session created: {id(session)}") + try: yield session + finally: + await session.close() + logger.warning(f"❌ Session closed: {id(session)}") diff --git a/src/quickbot/fsm/db_storage.py b/src/quickbot/fsm/db_storage.py index ebc1bd2..ea55cc8 100644 --- a/src/quickbot/fsm/db_storage.py +++ b/src/quickbot/fsm/db_storage.py @@ -10,7 +10,7 @@ from sqlmodel import select from typing import Any, Dict import ujson as json -from ..db import async_session +from ..db import async_session, get_db from ..model.fsm_storage import FSMStorage @@ -22,7 +22,7 @@ class DbStorage(BaseStorage): async def set_state(self, key: StorageKey, state: StateType = None) -> None: db_key = self.key_builder.build(key, "state") - async with async_session() as session: + async with get_db() as session: db_state = ( await session.exec(select(FSMStorage).where(FSMStorage.key == db_key)) ).first() @@ -44,7 +44,7 @@ class DbStorage(BaseStorage): async def get_state(self, key: StorageKey) -> str | None: db_key = self.key_builder.build(key, "state") - async with async_session() as session: + async with get_db() as session: db_state = ( await session.exec(select(FSMStorage).where(FSMStorage.key == db_key)) ).first() @@ -52,7 +52,7 @@ class DbStorage(BaseStorage): async def set_data(self, key: StorageKey, data: Dict[str, Any]) -> None: db_key = self.key_builder.build(key, "data") - async with async_session() as session: + async with get_db() as session: db_data = ( await session.exec(select(FSMStorage).where(FSMStorage.key == db_key)) ).first() @@ -74,7 +74,7 @@ class DbStorage(BaseStorage): async def get_data(self, key: StorageKey) -> Dict[str, Any]: db_key = self.key_builder.build(key, "data") - async with async_session() as session: + async with get_db() as session: db_data = ( await session.exec(select(FSMStorage).where(FSMStorage.key == db_key)) ).first() diff --git a/src/quickbot/model/__init__.py b/src/quickbot/model/__init__.py index 50ee14d..705928b 100644 --- a/src/quickbot/model/__init__.py +++ b/src/quickbot/model/__init__.py @@ -4,7 +4,7 @@ from sqlalchemy.orm.state import InstanceState from typing import cast from .bot_enum import BotEnum, EnumMember -from ..db import async_session +from ..db import async_session, get_db class EntityPermission(BotEnum): @@ -33,7 +33,7 @@ def session_dep(func): _session = state.async_session if not _session: - async with async_session() as session: + async with get_db() as session: kwargs["session"] = session return await func(cls, *args, **kwargs) else: diff --git a/src/quickbot/model/settings.py b/src/quickbot/model/settings.py index cd8d26d..96ecb77 100644 --- a/src/quickbot/model/settings.py +++ b/src/quickbot/model/settings.py @@ -5,7 +5,7 @@ from sqlmodel import SQLModel, Field, select from sqlmodel.ext.asyncio.session import AsyncSession from typing import Any, get_args, get_origin -from ..db import async_session +from ..db import async_session, get_db from .role import RoleBase from .descriptors import FieldDescriptor, Setting from ..utils.serialization import deserialize, serialize @@ -205,7 +205,7 @@ class Settings(metaclass=SettingsMetaclass): if name not in cls._cache.keys(): if session is None: - async with async_session() as session: + async with get_db() as session: cls._cache[name] = await cls.load_param( session=session, param=param ) @@ -272,7 +272,7 @@ class Settings(metaclass=SettingsMetaclass): if isinstance(param, str): param = cls._settings_descriptors[param] ser_value = serialize(value, param) - async with async_session() as session: + async with get_db() as session: db_setting = ( await session.exec( select(DbSettings).where(DbSettings.name == param.field_name)