Files
fastapi-cache/fastapi_cache/coder.py
Martijn Pieters 1d9e126037 Switch to ruff to handle linting and formatting
Ruff handles black, flake8 and isort in one package, and is way faster.
The isort rules had not been enforced, so this commit includes a lot
of import resorting changes.

I switched to flake8-bugbear and the standard black-compatible line
length of 80 + 10% (so max 88 characters), so some line reflowing is
included too.

Finally, because bugbear rightly points out that `setattr()` is less
performant, I've switched the `__signature__` assigment back to using
a direct assignment with type ignore comment.
2023-05-16 12:18:24 +01:00

139 lines
4.2 KiB
Python

import datetime
import json
import pickle # nosec:B403
from decimal import Decimal
from typing import (
Any,
Callable,
ClassVar,
Dict,
Optional,
TypeVar,
Union,
overload,
)
import pendulum
from fastapi.encoders import jsonable_encoder
from pydantic import BaseConfig, ValidationError, fields
from starlette.responses import JSONResponse
from starlette.templating import (
_TemplateResponse as TemplateResponse, # pyright: ignore[reportPrivateUsage]
)
_T = TypeVar("_T", bound=type)
CONVERTERS: Dict[str, Callable[[str], Any]] = {
# Pendulum 3.0.0 adds parse to __all__, at which point these ignores can be removed
"date": lambda x: pendulum.parse(x, exact=True), # type: ignore[attr-defined]
"datetime": lambda x: pendulum.parse(x, exact=True), # type: ignore[attr-defined]
"decimal": Decimal,
}
class JsonEncoder(json.JSONEncoder):
def default(self, o: Any) -> Any:
if isinstance(o, datetime.datetime):
return {"val": str(o), "_spec_type": "datetime"}
elif isinstance(o, datetime.date):
return {"val": str(o), "_spec_type": "date"}
elif isinstance(o, Decimal):
return {"val": str(o), "_spec_type": "decimal"}
else:
return jsonable_encoder(o)
def object_hook(obj: Any) -> Any:
_spec_type = obj.get("_spec_type")
if not _spec_type:
return obj
if _spec_type in CONVERTERS:
return CONVERTERS[_spec_type](obj["val"])
else:
raise TypeError("Unknown {}".format(_spec_type))
class Coder:
@classmethod
def encode(cls, value: Any) -> bytes:
raise NotImplementedError
@classmethod
def decode(cls, value: bytes) -> Any:
raise NotImplementedError
# (Shared) cache for endpoint return types to Pydantic model fields.
# Note that subclasses share this cache! If a subclass overrides the
# decode_as_type method and then stores a different kind of field for a
# given type, do make sure that the subclass provides its own class
# attribute for this cache.
_type_field_cache: ClassVar[Dict[Any, fields.ModelField]] = {}
@overload
@classmethod
def decode_as_type(cls, value: bytes, *, type_: _T) -> _T:
...
@overload
@classmethod
def decode_as_type(cls, value: bytes, *, type_: None) -> Any:
...
@classmethod
def decode_as_type(cls, value: bytes, *, type_: Optional[_T]) -> Union[_T, Any]:
"""Decode value to the specific given type
The default implementation uses the Pydantic model system to convert the value.
"""
result = cls.decode(value)
if type_ is not None:
try:
field = cls._type_field_cache[type_]
except KeyError:
field = cls._type_field_cache[type_] = fields.ModelField(
name="body", type_=type_, class_validators=None, model_config=BaseConfig
)
result, errors = field.validate(result, {}, loc=())
if errors is not None:
if not isinstance(errors, list):
errors = [errors]
raise ValidationError(errors, type_)
return result
class JsonCoder(Coder):
@classmethod
def encode(cls, value: Any) -> bytes:
if isinstance(value, JSONResponse):
return value.body
return json.dumps(value, cls=JsonEncoder).encode()
@classmethod
def decode(cls, value: bytes) -> Any:
# explicitly decode from UTF-8 bytes first, as otherwise
# json.loads() will first have to detect the correct UTF-
# encoding used.
return json.loads(value.decode(), object_hook=object_hook)
class PickleCoder(Coder):
@classmethod
def encode(cls, value: Any) -> bytes:
if isinstance(value, TemplateResponse):
value = value.body
return pickle.dumps(value)
@classmethod
def decode(cls, value: bytes) -> Any:
return pickle.loads(value) # nosec:B403,B301
@classmethod
def decode_as_type(cls, value: bytes, *, type_: Optional[_T]) -> Any:
# Pickle already produces the correct type on decoding, no point
# in paying an extra performance penalty for pydantic to discover
# the same.
return cls.decode(value)