add command params

This commit is contained in:
Alexander Kalinovsky
2025-01-29 23:40:43 +01:00
parent b40e588379
commit f666bcfba3
33 changed files with 547 additions and 340 deletions

View File

@@ -11,13 +11,13 @@ from typing import (
TYPE_CHECKING,
)
from pydantic import BaseModel
from sqlmodel import SQLModel, BIGINT, Field, select, func, column
from sqlmodel import SQLModel, BigInteger, Field, select, func, column
from sqlmodel.ext.asyncio.session import AsyncSession
from sqlmodel.sql.expression import SelectOfScalar
from sqlmodel.main import SQLModelMetaclass, RelationshipInfo
from .descriptors import EntityDescriptor, EntityField, EntityFieldDescriptor, Filter
from .descriptors import EntityDescriptor, EntityField, FieldDescriptor, Filter
from .entity_metadata import EntityMetadata
from . import session_dep
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
class BotEntityMetaclass(SQLModelMetaclass):
__future_references__ = {}
_future_references = {}
def __new__(mcs, name, bases, namespace, **kwargs):
bot_fields_descriptors = {}
@@ -35,7 +35,7 @@ class BotEntityMetaclass(SQLModelMetaclass):
bot_entity_descriptor = bases[0].__dict__.get("bot_entity_descriptor")
bot_fields_descriptors = (
{
key: EntityFieldDescriptor(**value.__dict__.copy())
key: FieldDescriptor(**value.__dict__.copy())
for key, value in bot_entity_descriptor.fields_descriptors.items()
}
if bot_entity_descriptor
@@ -71,7 +71,7 @@ class BotEntityMetaclass(SQLModelMetaclass):
type_origin = get_origin(type_)
field_descriptor = EntityFieldDescriptor(
field_descriptor = FieldDescriptor(
name=descriptor_name,
field_name=annotation,
type_=type_,
@@ -80,18 +80,19 @@ class BotEntityMetaclass(SQLModelMetaclass):
)
is_list = False
is_optional = False
if type_origin is list:
field_descriptor.is_list = is_list = True
field_descriptor.type_base = type_ = get_args(type_)[0]
if type_origin == Union and isinstance(get_args(type_)[0], ForwardRef):
field_descriptor.is_optional = True
field_descriptor.is_optional = is_optional = True
field_descriptor.type_base = type_ = get_args(type_)[
0
].__forward_arg__
if type_origin == UnionType and get_args(type_)[1] == NoneType:
field_descriptor.is_optional = True
field_descriptor.is_optional = is_optional = True
field_descriptor.type_base = type_ = get_args(type_)[0]
if isinstance(type_, str):
@@ -100,18 +101,16 @@ class BotEntityMetaclass(SQLModelMetaclass):
entity_descriptor
) in EntityMetadata().entity_descriptors.values():
if type_ == entity_descriptor.class_name:
field_descriptor.type_base = entity_descriptor.type_
field_descriptor.type_ = (
list[entity_descriptor.type_]
if is_list
else (
Optional[entity_descriptor.type_]
if type_origin == Optional
if type_origin == Union and is_optional
else (
entity_descriptor.type_ | None
if (
type_origin == UnionType
and get_args(type_)[1] == NoneType
)
if (type_origin == UnionType and is_optional)
else entity_descriptor.type_
)
)
@@ -119,10 +118,10 @@ class BotEntityMetaclass(SQLModelMetaclass):
type_not_found = False
break
if type_not_found:
if type_ in mcs.__future_references__:
mcs.__future_references__[type_].append(field_descriptor)
if type_ in mcs._future_references:
mcs._future_references[type_].append(field_descriptor)
else:
mcs.__future_references__[type_] = [field_descriptor]
mcs._future_references[type_] = [field_descriptor]
bot_fields_descriptors[descriptor_name] = field_descriptor
@@ -191,14 +190,14 @@ class BotEntityMetaclass(SQLModelMetaclass):
type_ = super().__new__(mcs, name, bases, namespace, **kwargs)
if name in mcs.__future_references__:
for field_descriptor in mcs.__future_references__[name]:
if name in mcs._future_references:
for field_descriptor in mcs._future_references[name]:
type_origin = get_origin(field_descriptor.type_)
field_descriptor.type_base = type_
field_descriptor.type_ = (
list[type_]
if get_origin(field_descriptor.type_) is list
if type_origin is list
else (
Optional[type_]
if type_origin == Union
@@ -220,7 +219,9 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
bot_entity_descriptor: ClassVar[EntityDescriptor]
entity_metadata: ClassVar[EntityMetadata]
id: int = Field(primary_key=True, sa_type=BIGINT)
id: int = EntityField(
sm_descriptor=Field(primary_key=True, sa_type=BigInteger), is_visible=False
)
@classmethod
@session_dep
@@ -228,7 +229,7 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
return await session.get(cls, id, populate_existing=True)
@classmethod
def _static_fiter_condition(
def _static_filter_condition(
cls, select_statement: SelectOfScalar[Self], static_filter: list[Filter]
):
for sfilt in static_filter:
@@ -292,7 +293,7 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
select_statement = select(func.count()).select_from(cls)
if static_filter:
if isinstance(static_filter, list):
select_statement = cls._static_fiter_condition(
select_statement = cls._static_filter_condition(
select_statement, static_filter
)
else:
@@ -327,7 +328,7 @@ class BotEntity[CreateSchemaType: BaseModel, UpdateSchemaType: BaseModel](
select_statement = select_statement.limit(limit)
if static_filter:
if isinstance(static_filter, list):
select_statement = cls._static_fiter_condition(
select_statement = cls._static_filter_condition(
select_statement, static_filter
)
else: