103 lines
3.4 KiB
Python
103 lines
3.4 KiB
Python
from datetime import datetime
|
|
from decimal import Decimal
|
|
from types import NoneType, UnionType
|
|
from sqlmodel import select, column
|
|
from sqlmodel.ext.asyncio.session import AsyncSession
|
|
from typing import Any, get_origin, get_args, TYPE_CHECKING
|
|
import ujson as json
|
|
|
|
from ..model.bot_entity import BotEntity
|
|
from ..model.bot_enum import BotEnum
|
|
from ..model.descriptors import EntityFieldDescriptor, EntityDescriptor
|
|
from ..model import EntityPermission
|
|
|
|
if TYPE_CHECKING:
|
|
from ..model.user import UserBase
|
|
|
|
|
|
async def deserialize[T](session: AsyncSession, type_: type[T], value: str = None) -> T:
|
|
type_origin = get_origin(type_)
|
|
is_optional = False
|
|
if type_origin == UnionType:
|
|
args = get_args(type_)
|
|
if args[1] == NoneType:
|
|
type_ = args[0]
|
|
if value is None:
|
|
return None
|
|
is_optional = True
|
|
if get_origin(type_) == list:
|
|
arg_type = None
|
|
args = get_args(type_)
|
|
if args:
|
|
arg_type = args[0]
|
|
values = json.loads(value) if value else []
|
|
if arg_type:
|
|
if issubclass(arg_type, BotEntity):
|
|
ret = list[arg_type]()
|
|
items = (await session.exec(select(arg_type).where(column("id").in_(values)))).all()
|
|
for item in items:
|
|
ret.append(item)
|
|
return ret
|
|
elif issubclass(arg_type, BotEnum):
|
|
return [arg_type(value) for value in values]
|
|
else:
|
|
return [arg_type(value) for value in values]
|
|
else:
|
|
return values
|
|
elif issubclass(type_, BotEntity):
|
|
return await session.get(type_, int(value))
|
|
elif issubclass(type_, BotEnum):
|
|
if is_optional and not value:
|
|
return None
|
|
return type_(value)
|
|
elif type_ == datetime:
|
|
if is_optional and not value:
|
|
return None
|
|
return datetime.fromisoformat(value)
|
|
elif type_ == bool:
|
|
return value == "True"
|
|
elif type_ == Decimal:
|
|
if is_optional and not value:
|
|
return None
|
|
return Decimal(value)
|
|
|
|
if is_optional and not value:
|
|
return None
|
|
return type_(value)
|
|
|
|
|
|
def serialize(value: Any, field_descriptor: EntityFieldDescriptor) -> str:
|
|
|
|
if value is None:
|
|
return ""
|
|
type_ = field_descriptor.type_
|
|
type_origin = get_origin(type_)
|
|
if type_origin == UnionType:
|
|
args = get_args(type_)
|
|
if args[1] == NoneType:
|
|
type_ = get_args(type_)[0]
|
|
if type_origin == list:
|
|
arg_type = None
|
|
args = get_args(type_)
|
|
if args:
|
|
arg_type = args[0]
|
|
if arg_type and issubclass(arg_type, BotEntity):
|
|
return json.dumps([item.id for item in value])
|
|
elif arg_type and issubclass(arg_type, BotEnum):
|
|
return json.dumps([item.value for item in value])
|
|
else:
|
|
return json.dumps(value)
|
|
elif issubclass(type_, BotEntity):
|
|
return str(value.id) if value else ""
|
|
return str(value)
|
|
|
|
|
|
def get_user_permissions(user: "UserBase", entity_descriptor: EntityDescriptor) -> list[EntityPermission]:
|
|
|
|
permissions = list[EntityPermission]()
|
|
for permission, roles in entity_descriptor.permissions.items():
|
|
for role in roles:
|
|
if role in user.roles:
|
|
permissions.append(permission)
|
|
break
|
|
return permissions |