mirror of
https://github.com/long2ice/fastapi-cache.git
synced 2026-03-25 04:57:54 +00:00
In `timeit` tests, 10.000 calls to `ModelField()` could take up to half a second on my Macbook Pro M1, depending on the type annotation used. Given that the method is called for every cache hit, this can really add up. The number of different return types for endpoints is very much finite however, so caching is a definite win here.
125 lines
3.9 KiB
Python
125 lines
3.9 KiB
Python
import codecs
|
|
import datetime
|
|
import json
|
|
import pickle # nosec:B403
|
|
from decimal import Decimal
|
|
from typing import Any, Callable, ClassVar, Dict, TypeVar, overload
|
|
|
|
import pendulum
|
|
from fastapi.encoders import jsonable_encoder
|
|
from pydantic import BaseConfig, ValidationError, fields
|
|
from starlette.responses import JSONResponse
|
|
from starlette.templating import _TemplateResponse as TemplateResponse
|
|
|
|
_T = TypeVar("_T")
|
|
|
|
|
|
CONVERTERS: dict[str, Callable[[str], Any]] = {
|
|
"date": lambda x: pendulum.parse(x, exact=True),
|
|
"datetime": lambda x: pendulum.parse(x, exact=True),
|
|
"decimal": Decimal,
|
|
}
|
|
|
|
|
|
class JsonEncoder(json.JSONEncoder):
|
|
def default(self, obj: Any) -> Any:
|
|
if isinstance(obj, datetime.datetime):
|
|
return {"val": str(obj), "_spec_type": "datetime"}
|
|
elif isinstance(obj, datetime.date):
|
|
return {"val": str(obj), "_spec_type": "date"}
|
|
elif isinstance(obj, Decimal):
|
|
return {"val": str(obj), "_spec_type": "decimal"}
|
|
else:
|
|
return jsonable_encoder(obj)
|
|
|
|
|
|
def object_hook(obj: Any) -> Any:
|
|
_spec_type = obj.get("_spec_type")
|
|
if not _spec_type:
|
|
return obj
|
|
|
|
if _spec_type in CONVERTERS:
|
|
return CONVERTERS[_spec_type](obj["val"])
|
|
else:
|
|
raise TypeError("Unknown {}".format(_spec_type))
|
|
|
|
|
|
class Coder:
|
|
@classmethod
|
|
def encode(cls, value: Any) -> str:
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def decode(cls, value: str) -> Any:
|
|
raise NotImplementedError
|
|
|
|
# (Shared) cache for endpoint return types to Pydantic model fields.
|
|
# Note that subclasses share this cache! If a subclass overrides the
|
|
# decode_as_type method and then stores a different kind of field for a
|
|
# given type, do make sure that the subclass provides its own class
|
|
# attribute for this cache.
|
|
_type_field_cache: ClassVar[Dict[Any, fields.ModelField]] = {}
|
|
|
|
@overload
|
|
@classmethod
|
|
def decode_as_type(cls, value: str, type_: _T) -> _T:
|
|
...
|
|
|
|
@overload
|
|
@classmethod
|
|
def decode_as_type(cls, value: str, *, type_: None) -> Any:
|
|
...
|
|
|
|
@classmethod
|
|
def decode_as_type(cls, value: str, *, type_: _T | None) -> _T | Any:
|
|
"""Decode value to the specific given type
|
|
|
|
The default implementation uses the Pydantic model system to convert the value.
|
|
|
|
"""
|
|
result = cls.decode(value)
|
|
if type_ is not None:
|
|
try:
|
|
field = cls._type_field_cache[type_]
|
|
except KeyError:
|
|
field = cls._type_field_cache[type_] = fields.ModelField(
|
|
name="body", type_=type_, class_validators=None, model_config=BaseConfig
|
|
)
|
|
result, errors = field.validate(result, {}, loc=())
|
|
if errors is not None:
|
|
if not isinstance(errors, list):
|
|
errors = [errors]
|
|
raise ValidationError(errors, type_)
|
|
return result
|
|
|
|
|
|
class JsonCoder(Coder):
|
|
@classmethod
|
|
def encode(cls, value: Any) -> str:
|
|
if isinstance(value, JSONResponse):
|
|
return value.body.decode()
|
|
return json.dumps(value, cls=JsonEncoder)
|
|
|
|
@classmethod
|
|
def decode(cls, value: str) -> str:
|
|
return json.loads(value, object_hook=object_hook)
|
|
|
|
|
|
class PickleCoder(Coder):
|
|
@classmethod
|
|
def encode(cls, value: Any) -> str:
|
|
if isinstance(value, TemplateResponse):
|
|
value = value.body
|
|
return codecs.encode(pickle.dumps(value), "base64").decode()
|
|
|
|
@classmethod
|
|
def decode(cls, value: str) -> Any:
|
|
return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301
|
|
|
|
@classmethod
|
|
def decode_as_type(cls, value: str, *, type_: Any) -> Any:
|
|
# Pickle already produces the correct type on decoding, no point
|
|
# in paying an extra performance penalty for pydantic to discover
|
|
# the same.
|
|
return cls.decode(value)
|