Decode cache data to the correct endpoint type

Use the return annotation to decode cached data to the correct type.
This follows the same logic FastAPI uses to JSON request bodies.

For the PickleCoder, this is a no-op as pickle already stores type
information in the serialised data.
This commit is contained in:
Martijn Pieters
2023-05-08 16:42:21 +01:00
parent 550ba76df4
commit f78a599bbc
6 changed files with 137 additions and 5 deletions

View File

@@ -3,13 +3,17 @@ import datetime
import json
import pickle # nosec:B403
from decimal import Decimal
from typing import Any, Callable
from typing import Any, Callable, TypeVar, overload
import pendulum
from fastapi.encoders import jsonable_encoder
from pydantic import BaseConfig, ValidationError, fields
from starlette.responses import JSONResponse
from starlette.templating import _TemplateResponse as TemplateResponse
_T = TypeVar("_T")
CONVERTERS: dict[str, Callable[[str], Any]] = {
"date": lambda x: pendulum.parse(x, exact=True),
"datetime": lambda x: pendulum.parse(x, exact=True),
@@ -49,6 +53,35 @@ class Coder:
def decode(cls, value: str) -> Any:
raise NotImplementedError
@overload
@classmethod
def decode_as_type(cls, value: str, type_: _T) -> _T:
...
@overload
@classmethod
def decode_as_type(cls, value: str, *, type_: None) -> Any:
...
@classmethod
def decode_as_type(cls, value: str, *, type_: _T | None) -> _T | Any:
"""Decode value to the specific given type
The default implementation uses the Pydantic model system to convert the value.
"""
result = cls.decode(value)
if type_ is not None:
field = fields.ModelField(
name="body", type_=type_, class_validators=None, model_config=BaseConfig
)
result, errors = field.validate(result, {}, loc=())
if errors is not None:
if not isinstance(errors, list):
errors = [errors]
raise ValidationError(errors, type_)
return result
class JsonCoder(Coder):
@classmethod
@@ -72,3 +105,10 @@ class PickleCoder(Coder):
@classmethod
def decode(cls, value: str) -> Any:
return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301
@classmethod
def decode_as_type(cls, value: str, *, type_: Any) -> Any:
# Pickle already produces the correct type on decoding, no point
# in paying an extra performance penalty for pydantic to discover
# the same.
return cls.decode(value)

View File

@@ -10,6 +10,7 @@ else:
from typing_extensions import ParamSpec
from fastapi.concurrency import run_in_threadpool
from fastapi.dependencies.utils import get_typed_return_annotation
from starlette.requests import Request
from starlette.responses import Response
@@ -79,6 +80,7 @@ def cache(
(param for param in signature.parameters.values() if param.annotation is Response),
None,
)
return_type = get_typed_return_annotation(func)
@wraps(func)
async def inner(*args: P.args, **kwargs: P.kwargs) -> R:
@@ -139,7 +141,7 @@ def cache(
ttl, ret = 0, None
if not request:
if ret is not None:
return coder.decode(ret)
return coder.decode_as_type(ret, type_=return_type)
ret = await ensure_async_func(*args, **kwargs)
try:
await backend.set(cache_key, coder.encode(ret), expire)
@@ -161,7 +163,7 @@ def cache(
response.status_code = 304
return response
response.headers["ETag"] = etag
return coder.decode(ret)
return coder.decode_as_type(ret, type_=return_type)
ret = await ensure_async_func(*args, **kwargs)
encoded_ret = coder.encode(ret)