Files
fastapi-cache/fastapi_cache/coder.py

127 lines
4.0 KiB
Python
Raw Normal View History

import datetime
2020-08-26 18:04:57 +08:00
import json
import pickle # nosec:B403
from decimal import Decimal
from typing import Any, Callable, ClassVar, Dict, Optional, TypeVar, Union, overload
2020-08-26 18:04:57 +08:00
2021-09-17 10:19:56 +08:00
import pendulum
from fastapi.encoders import jsonable_encoder
from pydantic import BaseConfig, ValidationError, fields
2022-11-04 17:31:37 +08:00
from starlette.responses import JSONResponse
from starlette.templating import _TemplateResponse as TemplateResponse
_T = TypeVar("_T", bound=type)
CONVERTERS: Dict[str, Callable[[str], Any]] = {
2021-10-09 16:51:05 +08:00
"date": lambda x: pendulum.parse(x, exact=True),
2021-09-17 10:19:56 +08:00
"datetime": lambda x: pendulum.parse(x, exact=True),
"decimal": Decimal,
}
class JsonEncoder(json.JSONEncoder):
2022-10-22 20:59:37 +04:00
def default(self, obj: Any) -> Any:
if isinstance(obj, datetime.datetime):
2021-10-09 16:51:05 +08:00
return {"val": str(obj), "_spec_type": "datetime"}
elif isinstance(obj, datetime.date):
2021-10-09 16:51:05 +08:00
return {"val": str(obj), "_spec_type": "date"}
elif isinstance(obj, Decimal):
return {"val": str(obj), "_spec_type": "decimal"}
else:
return jsonable_encoder(obj)
2022-10-22 20:59:37 +04:00
def object_hook(obj: Any) -> Any:
_spec_type = obj.get("_spec_type")
if not _spec_type:
return obj
if _spec_type in CONVERTERS:
return CONVERTERS[_spec_type](obj["val"])
else:
raise TypeError("Unknown {}".format(_spec_type))
2020-08-26 18:04:57 +08:00
class Coder:
@classmethod
def encode(cls, value: Any) -> bytes:
2020-08-26 18:04:57 +08:00
raise NotImplementedError
@classmethod
def decode(cls, value: bytes) -> Any:
2020-08-26 18:04:57 +08:00
raise NotImplementedError
# (Shared) cache for endpoint return types to Pydantic model fields.
# Note that subclasses share this cache! If a subclass overrides the
# decode_as_type method and then stores a different kind of field for a
# given type, do make sure that the subclass provides its own class
# attribute for this cache.
_type_field_cache: ClassVar[Dict[Any, fields.ModelField]] = {}
@overload
@classmethod
def decode_as_type(cls, value: bytes, *, type_: _T) -> _T:
...
@overload
@classmethod
def decode_as_type(cls, value: bytes, *, type_: None) -> Any:
...
@classmethod
def decode_as_type(cls, value: bytes, *, type_: Optional[_T]) -> Union[_T, Any]:
"""Decode value to the specific given type
The default implementation uses the Pydantic model system to convert the value.
"""
result = cls.decode(value)
if type_ is not None:
try:
field = cls._type_field_cache[type_]
except KeyError:
field = cls._type_field_cache[type_] = fields.ModelField(
name="body", type_=type_, class_validators=None, model_config=BaseConfig
)
result, errors = field.validate(result, {}, loc=())
if errors is not None:
if not isinstance(errors, list):
errors = [errors]
raise ValidationError(errors, type_)
return result
2020-08-26 18:04:57 +08:00
class JsonCoder(Coder):
@classmethod
def encode(cls, value: Any) -> bytes:
2022-11-04 17:31:37 +08:00
if isinstance(value, JSONResponse):
return value.body
return json.dumps(value, cls=JsonEncoder).encode()
2020-08-26 18:04:57 +08:00
@classmethod
def decode(cls, value: bytes) -> Any:
# explicitly decode from UTF-8 bytes first, as otherwise
# json.loads() will first have to detect the correct UTF-
# encoding used.
return json.loads(value.decode(), object_hook=object_hook)
2020-08-26 18:04:57 +08:00
class PickleCoder(Coder):
@classmethod
def encode(cls, value: Any) -> bytes:
if isinstance(value, TemplateResponse):
value = value.body
return pickle.dumps(value)
2020-08-26 18:04:57 +08:00
@classmethod
def decode(cls, value: bytes) -> Any:
return pickle.loads(value) # nosec:B403,B301
@classmethod
def decode_as_type(cls, value: bytes, *, type_: Optional[_T]) -> Any:
# Pickle already produces the correct type on decoding, no point
# in paying an extra performance penalty for pydantic to discover
# the same.
return cls.decode(value)