refactoring
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
from functools import wraps
|
||||
from typing import ClassVar, cast, get_args, get_origin
|
||||
from types import NoneType, UnionType
|
||||
from typing import ClassVar, ForwardRef, Optional, Union, cast, get_args, get_origin
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import SQLModel, BIGINT, Field, select, func
|
||||
from sqlmodel import SQLModel, BIGINT, Field, select, func, column
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.main import SQLModelMetaclass, RelationshipInfo
|
||||
|
||||
@@ -53,24 +54,36 @@ class BotEntityMetaclass(SQLModelMetaclass):
|
||||
|
||||
type_ = namespace['__annotations__'][annotation]
|
||||
|
||||
type_origin = get_origin(type_)
|
||||
|
||||
field_descriptor = EntityFieldDescriptor(
|
||||
name = descriptor_name,
|
||||
field_name = annotation,
|
||||
type_ = type_,
|
||||
type_base = type_,
|
||||
**descriptor_kwargs)
|
||||
|
||||
type_origin = get_origin(type_)
|
||||
|
||||
is_list = False
|
||||
if type_origin == list:
|
||||
is_list = True
|
||||
type_ = get_args(type_)[0]
|
||||
field_descriptor.is_list = is_list = True
|
||||
field_descriptor.type_base = type_ = get_args(type_)[0]
|
||||
|
||||
if type_origin == Union and isinstance(get_args(type_)[0], ForwardRef):
|
||||
field_descriptor.is_optional = True
|
||||
field_descriptor.type_base = type_ = get_args(type_)[0].__forward_arg__
|
||||
|
||||
if type_origin == UnionType and get_args(type_)[1] == NoneType:
|
||||
field_descriptor.is_optional = True
|
||||
field_descriptor.type_base = type_ = get_args(type_)[0]
|
||||
|
||||
if isinstance(type_, str):
|
||||
type_not_found = True
|
||||
for entity_descriptor in EntityMetadata().entity_descriptors.values():
|
||||
if type_ == entity_descriptor.class_name:
|
||||
field_descriptor.type_ = list[entity_descriptor.type_] if is_list else entity_descriptor.type_
|
||||
field_descriptor.type_ = (list[entity_descriptor.type_] if is_list
|
||||
else Optional[entity_descriptor.type_] if type_origin == Optional
|
||||
else entity_descriptor.type_ | None if (type_origin == UnionType and get_args(type_)[1] == NoneType)
|
||||
else entity_descriptor.type_)
|
||||
type_not_found = False
|
||||
break
|
||||
if type_not_found:
|
||||
@@ -131,8 +144,12 @@ class BotEntityMetaclass(SQLModelMetaclass):
|
||||
|
||||
if name in mcs.__future_references__:
|
||||
for field_descriptor in mcs.__future_references__[name]:
|
||||
field_descriptor.type_ = list[type_] if get_origin(field_descriptor.type_) == list else type_
|
||||
a = field_descriptor
|
||||
type_origin = get_origin(field_descriptor.type_)
|
||||
field_descriptor.type_base = type_
|
||||
field_descriptor.type_ = (list[type_] if get_origin(field_descriptor.type_) == list else
|
||||
Optional[type_] if type_origin == Union and isinstance(get_args(field_descriptor.type_)[0], ForwardRef) else
|
||||
type_ | None if type_origin == UnionType else
|
||||
type_)
|
||||
|
||||
setattr(namespace["bot_entity_descriptor"], "type_", type_)
|
||||
|
||||
@@ -160,15 +177,19 @@ class BotEntity[CreateSchemaType: BaseModel,
|
||||
session: AsyncSession | None = None,
|
||||
id: int):
|
||||
|
||||
return await session.get(cls, id)
|
||||
return await session.get(cls, id, populate_existing = True)
|
||||
|
||||
|
||||
@classmethod
|
||||
@session_dep
|
||||
async def get_count(cls, *,
|
||||
session: AsyncSession | None = None) -> int:
|
||||
session: AsyncSession | None = None,
|
||||
filter: str = None) -> int:
|
||||
|
||||
return await session.scalar(select(func.count()).select_from(cls))
|
||||
select_statement = select(func.count()).select_from(cls)
|
||||
if filter:
|
||||
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
|
||||
return await session.scalar(select_statement)
|
||||
|
||||
|
||||
@classmethod
|
||||
@@ -176,12 +197,15 @@ class BotEntity[CreateSchemaType: BaseModel,
|
||||
async def get_multi(cls, *,
|
||||
session: AsyncSession | None = None,
|
||||
order_by = None,
|
||||
filter:str = None,
|
||||
skip: int = 0,
|
||||
limit: int = None):
|
||||
|
||||
select_statement = select(cls).offset(skip)
|
||||
if limit:
|
||||
select_statement = select_statement.limit(limit)
|
||||
if filter:
|
||||
select_statement = select_statement.where(column("name").ilike(f"%{filter}%"))
|
||||
if order_by:
|
||||
select_statement = select_statement.order_by(order_by)
|
||||
return (await session.exec(select_statement)).all()
|
||||
@@ -238,4 +262,5 @@ class BotEntity[CreateSchemaType: BaseModel,
|
||||
if commit:
|
||||
await session.commit()
|
||||
return obj
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user