Decode cache data to the correct endpoint type

Use the return annotation to decode cached data to the correct type.
This follows the same logic FastAPI uses to JSON request bodies.

For the PickleCoder, this is a no-op as pickle already stores type
information in the serialised data.
This commit is contained in:
Martijn Pieters
2023-05-08 16:42:21 +01:00
parent 550ba76df4
commit f78a599bbc
6 changed files with 137 additions and 5 deletions

View File

@@ -3,13 +3,17 @@ import datetime
import json
import pickle # nosec:B403
from decimal import Decimal
from typing import Any, Callable
from typing import Any, Callable, TypeVar, overload
import pendulum
from fastapi.encoders import jsonable_encoder
from pydantic import BaseConfig, ValidationError, fields
from starlette.responses import JSONResponse
from starlette.templating import _TemplateResponse as TemplateResponse
_T = TypeVar("_T")
CONVERTERS: dict[str, Callable[[str], Any]] = {
"date": lambda x: pendulum.parse(x, exact=True),
"datetime": lambda x: pendulum.parse(x, exact=True),
@@ -49,6 +53,35 @@ class Coder:
def decode(cls, value: str) -> Any:
raise NotImplementedError
@overload
@classmethod
def decode_as_type(cls, value: str, type_: _T) -> _T:
...
@overload
@classmethod
def decode_as_type(cls, value: str, *, type_: None) -> Any:
...
@classmethod
def decode_as_type(cls, value: str, *, type_: _T | None) -> _T | Any:
"""Decode value to the specific given type
The default implementation uses the Pydantic model system to convert the value.
"""
result = cls.decode(value)
if type_ is not None:
field = fields.ModelField(
name="body", type_=type_, class_validators=None, model_config=BaseConfig
)
result, errors = field.validate(result, {}, loc=())
if errors is not None:
if not isinstance(errors, list):
errors = [errors]
raise ValidationError(errors, type_)
return result
class JsonCoder(Coder):
@classmethod
@@ -72,3 +105,10 @@ class PickleCoder(Coder):
@classmethod
def decode(cls, value: str) -> Any:
return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301
@classmethod
def decode_as_type(cls, value: str, *, type_: Any) -> Any:
# Pickle already produces the correct type on decoding, no point
# in paying an extra performance penalty for pydantic to discover
# the same.
return cls.decode(value)