refactoring

This commit is contained in:
Alexander Kalinovsky
2025-01-09 13:11:10 +01:00
parent 7793a0cb77
commit 3898a333fa
29 changed files with 1065 additions and 381 deletions

View File

@@ -1,7 +1,8 @@
from functools import wraps
from typing import ClassVar, cast, get_args, get_origin
from types import NoneType, UnionType
from typing import ClassVar, ForwardRef, Optional, Union, cast, get_args, get_origin
from pydantic import BaseModel
from sqlmodel import SQLModel, BIGINT, Field, select, func
from sqlmodel import SQLModel, BIGINT, Field, select, func, column
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.main import SQLModelMetaclass, RelationshipInfo
@@ -53,24 +54,36 @@ class BotEntityMetaclass(SQLModelMetaclass):
type_ = namespace['__annotations__'][annotation]
type_origin = get_origin(type_)
field_descriptor = EntityFieldDescriptor(
name = descriptor_name,
field_name = annotation,
type_ = type_,
type_base = type_,
**descriptor_kwargs)
type_origin = get_origin(type_)
is_list = False
if type_origin == list:
is_list = True
type_ = get_args(type_)[0]
field_descriptor.is_list = is_list = True
field_descriptor.type_base = type_ = get_args(type_)[0]
if type_origin == Union and isinstance(get_args(type_)[0], ForwardRef):
field_descriptor.is_optional = True
field_descriptor.type_base = type_ = get_args(type_)[0].__forward_arg__
if type_origin == UnionType and get_args(type_)[1] == NoneType:
field_descriptor.is_optional = True
field_descriptor.type_base = type_ = get_args(type_)[0]
if isinstance(type_, str):
type_not_found = True
for entity_descriptor in EntityMetadata().entity_descriptors.values():
if type_ == entity_descriptor.class_name:
field_descriptor.type_ = list[entity_descriptor.type_] if is_list else entity_descriptor.type_
field_descriptor.type_ = (list[entity_descriptor.type_] if is_list
else Optional[entity_descriptor.type_] if type_origin == Optional
else entity_descriptor.type_ | None if (type_origin == UnionType and get_args(type_)[1] == NoneType)
else entity_descriptor.type_)
type_not_found = False
break
if type_not_found:
@@ -131,8 +144,12 @@ class BotEntityMetaclass(SQLModelMetaclass):
if name in mcs.__future_references__:
for field_descriptor in mcs.__future_references__[name]:
field_descriptor.type_ = list[type_] if get_origin(field_descriptor.type_) == list else type_
a = field_descriptor
type_origin = get_origin(field_descriptor.type_)
field_descriptor.type_base = type_
field_descriptor.type_ = (list[type_] if get_origin(field_descriptor.type_) == list else
Optional[type_] if type_origin == Union and isinstance(get_args(field_descriptor.type_)[0], ForwardRef) else
type_ | None if type_origin == UnionType else
type_)
setattr(namespace["bot_entity_descriptor"], "type_", type_)
@@ -160,15 +177,19 @@ class BotEntity[CreateSchemaType: BaseModel,
session: AsyncSession | None = None,
id: int):
return await session.get(cls, id)
return await session.get(cls, id, populate_existing = True)
@classmethod
@session_dep
async def get_count(cls, *,
session: AsyncSession | None = None) -> int:
session: AsyncSession | None = None,
filter: str = None) -> int:
return await session.scalar(select(func.count()).select_from(cls))
select_statement = select(func.count()).select_from(cls)
if filter:
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
return await session.scalar(select_statement)
@classmethod
@@ -176,12 +197,15 @@ class BotEntity[CreateSchemaType: BaseModel,
async def get_multi(cls, *,
session: AsyncSession | None = None,
order_by = None,
filter:str = None,
skip: int = 0,
limit: int = None):
select_statement = select(cls).offset(skip)
if limit:
select_statement = select_statement.limit(limit)
if filter:
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
if order_by:
select_statement = select_statement.order_by(order_by)
return (await session.exec(select_statement)).all()
@@ -238,4 +262,5 @@ class BotEntity[CreateSchemaType: BaseModel,
if commit:
await session.commit()
return obj
return None
return None