2020-10-16 16:55:33 +08:00
|
|
|
import datetime
|
2020-08-26 18:04:57 +08:00
|
|
|
import json
|
|
|
|
|
import pickle # nosec:B403
|
2020-10-16 16:55:33 +08:00
|
|
|
from decimal import Decimal
|
2023-05-16 12:09:50 +01:00
|
|
|
from typing import (
|
|
|
|
|
Any,
|
|
|
|
|
Callable,
|
|
|
|
|
ClassVar,
|
|
|
|
|
Dict,
|
|
|
|
|
Optional,
|
|
|
|
|
TypeVar,
|
|
|
|
|
Union,
|
|
|
|
|
overload,
|
|
|
|
|
)
|
2020-08-26 18:04:57 +08:00
|
|
|
|
2021-09-17 10:19:56 +08:00
|
|
|
import pendulum
|
2021-07-26 16:33:22 +08:00
|
|
|
from fastapi.encoders import jsonable_encoder
|
2023-05-08 16:42:21 +01:00
|
|
|
from pydantic import BaseConfig, ValidationError, fields
|
2022-11-04 17:31:37 +08:00
|
|
|
from starlette.responses import JSONResponse
|
2023-05-09 17:33:07 +01:00
|
|
|
from starlette.templating import (
|
|
|
|
|
_TemplateResponse as TemplateResponse, # pyright: ignore[reportPrivateUsage]
|
|
|
|
|
)
|
2020-10-16 16:55:33 +08:00
|
|
|
|
2023-05-09 15:30:46 +01:00
|
|
|
_T = TypeVar("_T", bound=type)
|
2023-05-08 16:42:21 +01:00
|
|
|
|
|
|
|
|
|
2023-05-09 15:30:46 +01:00
|
|
|
CONVERTERS: Dict[str, Callable[[str], Any]] = {
|
2023-05-09 17:08:32 +01:00
|
|
|
# Pendulum 3.0.0 adds parse to __all__, at which point these ignores can be removed
|
|
|
|
|
"date": lambda x: pendulum.parse(x, exact=True), # type: ignore[attr-defined]
|
|
|
|
|
"datetime": lambda x: pendulum.parse(x, exact=True), # type: ignore[attr-defined]
|
2020-10-16 16:55:33 +08:00
|
|
|
"decimal": Decimal,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class JsonEncoder(json.JSONEncoder):
|
2023-05-09 17:33:07 +01:00
|
|
|
def default(self, o: Any) -> Any:
|
|
|
|
|
if isinstance(o, datetime.datetime):
|
|
|
|
|
return {"val": str(o), "_spec_type": "datetime"}
|
|
|
|
|
elif isinstance(o, datetime.date):
|
|
|
|
|
return {"val": str(o), "_spec_type": "date"}
|
|
|
|
|
elif isinstance(o, Decimal):
|
|
|
|
|
return {"val": str(o), "_spec_type": "decimal"}
|
2020-10-16 16:55:33 +08:00
|
|
|
else:
|
2023-05-09 17:33:07 +01:00
|
|
|
return jsonable_encoder(o)
|
2020-10-16 16:55:33 +08:00
|
|
|
|
|
|
|
|
|
2022-10-22 20:59:37 +04:00
|
|
|
def object_hook(obj: Any) -> Any:
|
2020-10-16 16:55:33 +08:00
|
|
|
_spec_type = obj.get("_spec_type")
|
|
|
|
|
if not _spec_type:
|
|
|
|
|
return obj
|
|
|
|
|
|
|
|
|
|
if _spec_type in CONVERTERS:
|
2023-04-27 16:19:02 +01:00
|
|
|
return CONVERTERS[_spec_type](obj["val"])
|
2020-10-16 16:55:33 +08:00
|
|
|
else:
|
2023-05-16 12:46:20 +01:00
|
|
|
raise TypeError(f"Unknown {_spec_type}")
|
2020-10-16 16:55:33 +08:00
|
|
|
|
2020-08-26 18:04:57 +08:00
|
|
|
|
|
|
|
|
class Coder:
|
|
|
|
|
@classmethod
|
2023-05-09 18:17:28 +01:00
|
|
|
def encode(cls, value: Any) -> bytes:
|
2020-08-26 18:04:57 +08:00
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
@classmethod
|
2023-05-09 18:17:28 +01:00
|
|
|
def decode(cls, value: bytes) -> Any:
|
2020-08-26 18:04:57 +08:00
|
|
|
raise NotImplementedError
|
|
|
|
|
|
2023-05-09 12:31:19 +01:00
|
|
|
# (Shared) cache for endpoint return types to Pydantic model fields.
|
|
|
|
|
# Note that subclasses share this cache! If a subclass overrides the
|
|
|
|
|
# decode_as_type method and then stores a different kind of field for a
|
|
|
|
|
# given type, do make sure that the subclass provides its own class
|
|
|
|
|
# attribute for this cache.
|
|
|
|
|
_type_field_cache: ClassVar[Dict[Any, fields.ModelField]] = {}
|
|
|
|
|
|
2023-05-08 16:42:21 +01:00
|
|
|
@overload
|
|
|
|
|
@classmethod
|
2023-05-09 18:17:28 +01:00
|
|
|
def decode_as_type(cls, value: bytes, *, type_: _T) -> _T:
|
2023-05-08 16:42:21 +01:00
|
|
|
...
|
|
|
|
|
|
|
|
|
|
@overload
|
|
|
|
|
@classmethod
|
2023-05-09 18:17:28 +01:00
|
|
|
def decode_as_type(cls, value: bytes, *, type_: None) -> Any:
|
2023-05-08 16:42:21 +01:00
|
|
|
...
|
|
|
|
|
|
|
|
|
|
@classmethod
|
2023-05-09 18:17:28 +01:00
|
|
|
def decode_as_type(cls, value: bytes, *, type_: Optional[_T]) -> Union[_T, Any]:
|
2023-05-08 16:42:21 +01:00
|
|
|
"""Decode value to the specific given type
|
|
|
|
|
|
|
|
|
|
The default implementation uses the Pydantic model system to convert the value.
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
result = cls.decode(value)
|
|
|
|
|
if type_ is not None:
|
2023-05-09 12:31:19 +01:00
|
|
|
try:
|
|
|
|
|
field = cls._type_field_cache[type_]
|
|
|
|
|
except KeyError:
|
|
|
|
|
field = cls._type_field_cache[type_] = fields.ModelField(
|
|
|
|
|
name="body", type_=type_, class_validators=None, model_config=BaseConfig
|
|
|
|
|
)
|
2023-05-08 16:42:21 +01:00
|
|
|
result, errors = field.validate(result, {}, loc=())
|
|
|
|
|
if errors is not None:
|
|
|
|
|
if not isinstance(errors, list):
|
|
|
|
|
errors = [errors]
|
|
|
|
|
raise ValidationError(errors, type_)
|
|
|
|
|
return result
|
|
|
|
|
|
2020-08-26 18:04:57 +08:00
|
|
|
|
|
|
|
|
class JsonCoder(Coder):
|
|
|
|
|
@classmethod
|
2023-05-09 18:17:28 +01:00
|
|
|
def encode(cls, value: Any) -> bytes:
|
2022-11-04 17:31:37 +08:00
|
|
|
if isinstance(value, JSONResponse):
|
2023-05-09 18:17:28 +01:00
|
|
|
return value.body
|
|
|
|
|
return json.dumps(value, cls=JsonEncoder).encode()
|
2020-08-26 18:04:57 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
2023-05-09 18:17:28 +01:00
|
|
|
def decode(cls, value: bytes) -> Any:
|
|
|
|
|
# explicitly decode from UTF-8 bytes first, as otherwise
|
|
|
|
|
# json.loads() will first have to detect the correct UTF-
|
|
|
|
|
# encoding used.
|
|
|
|
|
return json.loads(value.decode(), object_hook=object_hook)
|
2020-08-26 18:04:57 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class PickleCoder(Coder):
|
|
|
|
|
@classmethod
|
2023-05-09 18:17:28 +01:00
|
|
|
def encode(cls, value: Any) -> bytes:
|
2022-09-28 17:37:05 +08:00
|
|
|
if isinstance(value, TemplateResponse):
|
|
|
|
|
value = value.body
|
2023-05-09 18:17:28 +01:00
|
|
|
return pickle.dumps(value)
|
2020-08-26 18:04:57 +08:00
|
|
|
|
|
|
|
|
@classmethod
|
2023-05-09 18:17:28 +01:00
|
|
|
def decode(cls, value: bytes) -> Any:
|
2023-05-16 13:11:10 +01:00
|
|
|
return pickle.loads(value) # noqa: S301
|
2023-05-08 16:42:21 +01:00
|
|
|
|
|
|
|
|
@classmethod
|
2023-05-09 18:17:28 +01:00
|
|
|
def decode_as_type(cls, value: bytes, *, type_: Optional[_T]) -> Any:
|
2023-05-08 16:42:21 +01:00
|
|
|
# Pickle already produces the correct type on decoding, no point
|
|
|
|
|
# in paying an extra performance penalty for pydantic to discover
|
|
|
|
|
# the same.
|
|
|
|
|
return cls.decode(value)
|