add command params
This commit is contained in:
@@ -11,13 +11,13 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from sqlmodel import SQLModel, BIGINT, Field, select, func, column
|
||||
from sqlmodel import SQLModel, BigInteger, Field, select, func, column
|
||||
|
||||
from sqlmodel.ext.asyncio.session import AsyncSession
|
||||
from sqlmodel.sql.expression import SelectOfScalar
|
||||
from sqlmodel.main import SQLModelMetaclass, RelationshipInfo
|
||||
|
||||
from .descriptors import EntityDescriptor, EntityField, EntityFieldDescriptor, Filter
|
||||
from .descriptors import EntityDescriptor, EntityField, FieldDescriptor, Filter
|
||||
from .entity_metadata import EntityMetadata
|
||||
from . import session_dep
|
||||
|
||||
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class BotEntityMetaclass(SQLModelMetaclass):
|
||||
__future_references__ = {}
|
||||
_future_references = {}
|
||||
|
||||
def __new__(mcs, name, bases, namespace, **kwargs):
|
||||
bot_fields_descriptors = {}
|
||||
@@ -35,7 +35,7 @@ class BotEntityMetaclass(SQLModelMetaclass):
|
||||
bot_entity_descriptor = bases[0].__dict__.get("bot_entity_descriptor")
|
||||
bot_fields_descriptors = (
|
||||
{
|
||||
key: EntityFieldDescriptor(**value.__dict__.copy())
|
||||
key: FieldDescriptor(**value.__dict__.copy())
|
||||
for key, value in bot_entity_descriptor.fields_descriptors.items()
|
||||
}
|
||||
if bot_entity_descriptor
|
||||
@@ -71,7 +71,7 @@ class BotEntityMetaclass(SQLModelMetaclass):
|
||||
|
||||
type_origin = get_origin(type_)
|
||||
|
||||
field_descriptor = EntityFieldDescriptor(
|
||||
field_descriptor = FieldDescriptor(
|
||||
name=descriptor_name,
|
||||
field_name=annotation,
|
||||
type_=type_,
|
||||
@@ -80,18 +80,19 @@ class BotEntityMetaclass(SQLModelMetaclass):
|
||||
)
|
||||
|
||||
is_list = False
|
||||
is_optional = False
|
||||
if type_origin is list:
|
||||
field_descriptor.is_list = is_list = True
|
||||
field_descriptor.type_base = type_ = get_args(type_)[0]
|
||||
|
||||
if type_origin == Union and isinstance(get_args(type_)[0], ForwardRef):
|
||||
field_descriptor.is_optional = True
|
||||
field_descriptor.is_optional = is_optional = True
|
||||
field_descriptor.type_base = type_ = get_args(type_)[
|
||||
0
|
||||
].__forward_arg__
|
||||
|
||||
if type_origin == UnionType and get_args(type_)[1] == NoneType:
|
||||
field_descriptor.is_optional = True
|
||||
field_descriptor.is_optional = is_optional = True
|
||||
field_descriptor.type_base = type_ = get_args(type_)[0]
|
||||
|
||||
if isinstance(type_, str):
|
||||
@@ -100,18 +101,16 @@ class BotEntityMetaclass(SQLModelMetaclass):
|
||||
entity_descriptor
|
||||
) in EntityMetadata().entity_descriptors.values():
|
||||
if type_ == entity_descriptor.class_name:
|
||||
field_descriptor.type_base = entity_descriptor.type_
|
||||
field_descriptor.type_ = (
|
||||
list[entity_descriptor.type_]
|
||||
if is_list
|
||||
else (
|
||||
Optional[entity_descriptor.type_]
|
||||
if type_origin == Optional
|
||||
if type_origin == Union and is_optional
|
||||
else (
|
||||
entity_descriptor.type_ | None
|
||||
if (
|
||||
type_origin == UnionType
|
||||
and get_args(type_)[1] == NoneType
|
||||
)
|
||||
if (type_origin == UnionType and is_optional)
|
||||
else entity_descriptor.type_
|
||||
)
|
||||
)
|
||||
@@ -119,10 +118,10 @@ class BotEntityMetaclass(SQLModelMetaclass):
|
||||
type_not_found = False
|
||||
break
|
||||
if type_not_found:
|
||||
if type_ in mcs.__future_references__:
|
||||
mcs.__future_references__[type_].append(field_descriptor)
|
||||
if type_ in mcs._future_references:
|
||||
mcs._future_references[type_].append(field_descriptor)
|
||||
else:
|
||||
mcs.__future_references__[type_] = [field_descriptor]
|
||||
mcs._future_references[type_] = [field_descriptor]
|
||||
|
||||
bot_fields_descriptors[descriptor_name] = field_descriptor
|
||||
|
||||
@@ -191,14 +190,14 @@ class BotEntityMetaclass(SQLModelMetaclass):
|
||||
|
||||
type_ = super().__new__(mcs, name, bases, namespace, **kwargs)
|
||||
|
||||
if name in mcs.__future_references__:
|
||||
for field_descriptor in mcs.__future_references__[name]:
|
||||
if name in mcs._future_references:
|
||||
for field_descriptor in mcs._future_references[name]:
|
||||
type_origin = get_origin(field_descriptor.type_)
|
||||
field_descriptor.type_base = type_
|
||||
|
||||
field_descriptor.type_ = (
|
||||
list[type_]
|
||||
if get_origin(field_descriptor.type_) is list
|
||||
if type_origin is list
|
||||
else (
|
||||
Optional[type_]
|
||||
if type_origin == Union
|
||||
@@ -220,7 +219,9 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
|
||||
bot_entity_descriptor: ClassVar[EntityDescriptor]
|
||||
entity_metadata: ClassVar[EntityMetadata]
|
||||
|
||||
id: int = Field(primary_key=True, sa_type=BIGINT)
|
||||
id: int = EntityField(
|
||||
sm_descriptor=Field(primary_key=True, sa_type=BigInteger), is_visible=False
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@session_dep
|
||||
@@ -228,7 +229,7 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
|
||||
return await session.get(cls, id, populate_existing=True)
|
||||
|
||||
@classmethod
|
||||
def _static_fiter_condition(
|
||||
def _static_filter_condition(
|
||||
cls, select_statement: SelectOfScalar[Self], static_filter: list[Filter]
|
||||
):
|
||||
for sfilt in static_filter:
|
||||
@@ -292,7 +293,7 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
|
||||
select_statement = select(func.count()).select_from(cls)
|
||||
if static_filter:
|
||||
if isinstance(static_filter, list):
|
||||
select_statement = cls._static_fiter_condition(
|
||||
select_statement = cls._static_filter_condition(
|
||||
select_statement, static_filter
|
||||
)
|
||||
else:
|
||||
@@ -327,7 +328,7 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
|
||||
select_statement = select_statement.limit(limit)
|
||||
if static_filter:
|
||||
if isinstance(static_filter, list):
|
||||
select_statement = cls._static_fiter_condition(
|
||||
select_statement = cls._static_filter_condition(
|
||||
select_statement, static_filter
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user