Decode cache data to the correct endpoint type

Use the return annotation to decode cached data to the correct type.
This follows the same logic FastAPI uses to JSON request bodies.

For the PickleCoder, this is a no-op as pickle already stores type
information in the serialised data.
This commit is contained in:
Martijn Pieters
2023-05-08 16:42:21 +01:00
parent 550ba76df4
commit f78a599bbc
6 changed files with 137 additions and 5 deletions

View File

@@ -101,6 +101,30 @@ key_builder | which key builder to use, default to builtin
You can also use `cache` as decorator like other cache tools to cache common function result.
### Supported data types
When using the (default) `JsonCoder`, the cache can store any data type that FastAPI can convert to JSON, including Pydantic models and dataclasses,
_provided_ that your endpoint has a correct return type annotation, unless
the return type is a standard JSON-supported type such as a dictionary or a list.
E.g. for an endpoint that returns a Pydantic model named `SomeModel`:
```python
from .models import SomeModel, create_some_model
@app.get("/foo")
@cache(expire=60)
async def foo() -> SomeModel:
return create_some_model
```
It is not sufficient to configure a response model in the route decorator; the cache needs to know what the method itself returns.
If no return type decorator is given, the primitive JSON type is returned instead.
For broader type support, use the `fastapi_cache.coder.PickleCoder` or implement a custom coder (see below).
### Custom coder
By default use `JsonCoder`, you can write custom coder to encode and decode cache result, just need

View File

@@ -7,6 +7,7 @@ from starlette.responses import JSONResponse, Response
from fastapi_cache import FastAPICache
from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_cache.decorator import cache
from pydantic import BaseModel
app = FastAPI()
@@ -80,6 +81,20 @@ instance = SomeClass(17)
app.get("/method")(cache(namespace="test")(instance.handler_method))
# cache a Pydantic model instance; the return type annotation is required in this case
class Item(BaseModel):
name: str
description: str | None = None
price: float
tax: float | None = None
@app.get("/pydantic_instance")
@cache(namespace="test", expire=5)
async def pydantic_instance() -> Item:
return Item(name="Something", description="An instance of a Pydantic model", price=10.5)
@app.on_event("startup")
async def startup():
FastAPICache.init(InMemoryBackend())

View File

@@ -3,13 +3,17 @@ import datetime
import json
import pickle # nosec:B403
from decimal import Decimal
from typing import Any, Callable
from typing import Any, Callable, TypeVar, overload
import pendulum
from fastapi.encoders import jsonable_encoder
from pydantic import BaseConfig, ValidationError, fields
from starlette.responses import JSONResponse
from starlette.templating import _TemplateResponse as TemplateResponse
_T = TypeVar("_T")
CONVERTERS: dict[str, Callable[[str], Any]] = {
"date": lambda x: pendulum.parse(x, exact=True),
"datetime": lambda x: pendulum.parse(x, exact=True),
@@ -49,6 +53,35 @@ class Coder:
def decode(cls, value: str) -> Any:
raise NotImplementedError
@overload
@classmethod
def decode_as_type(cls, value: str, type_: _T) -> _T:
...
@overload
@classmethod
def decode_as_type(cls, value: str, *, type_: None) -> Any:
...
@classmethod
def decode_as_type(cls, value: str, *, type_: _T | None) -> _T | Any:
"""Decode value to the specific given type
The default implementation uses the Pydantic model system to convert the value.
"""
result = cls.decode(value)
if type_ is not None:
field = fields.ModelField(
name="body", type_=type_, class_validators=None, model_config=BaseConfig
)
result, errors = field.validate(result, {}, loc=())
if errors is not None:
if not isinstance(errors, list):
errors = [errors]
raise ValidationError(errors, type_)
return result
class JsonCoder(Coder):
@classmethod
@@ -72,3 +105,10 @@ class PickleCoder(Coder):
@classmethod
def decode(cls, value: str) -> Any:
return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301
@classmethod
def decode_as_type(cls, value: str, *, type_: Any) -> Any:
# Pickle already produces the correct type on decoding, no point
# in paying an extra performance penalty for pydantic to discover
# the same.
return cls.decode(value)

View File

@@ -10,6 +10,7 @@ else:
from typing_extensions import ParamSpec
from fastapi.concurrency import run_in_threadpool
from fastapi.dependencies.utils import get_typed_return_annotation
from starlette.requests import Request
from starlette.responses import Response
@@ -79,6 +80,7 @@ def cache(
(param for param in signature.parameters.values() if param.annotation is Response),
None,
)
return_type = get_typed_return_annotation(func)
@wraps(func)
async def inner(*args: P.args, **kwargs: P.kwargs) -> R:
@@ -139,7 +141,7 @@ def cache(
ttl, ret = 0, None
if not request:
if ret is not None:
return coder.decode(ret)
return coder.decode_as_type(ret, type_=return_type)
ret = await ensure_async_func(*args, **kwargs)
try:
await backend.set(cache_key, coder.encode(ret), expire)
@@ -161,7 +163,7 @@ def cache(
response.status_code = 304
return response
response.headers["ETag"] = etag
return coder.decode(ret)
return coder.decode_as_type(ret, type_=return_type)
ret = await ensure_async_func(*args, **kwargs)
encoded_ret = coder.encode(ret)

View File

@@ -1,8 +1,25 @@
from typing import Any
from dataclasses import dataclass
from typing import Any, Optional
import pytest
from pydantic import BaseModel, ValidationError
from fastapi_cache.coder import PickleCoder
from fastapi_cache.coder import JsonCoder, PickleCoder
@dataclass
class DCItem:
name: str
price: float
description: Optional[str] = None
tax: Optional[float] = None
class PDItem(BaseModel):
name: str
price: float
description: Optional[str] = None
tax: Optional[float] = None
@pytest.mark.parametrize(
@@ -13,6 +30,8 @@ from fastapi_cache.coder import PickleCoder
(1, 2),
[1, 2, 3],
{"some_key": 1, "other_key": 2},
DCItem(name="foo", price=42.0, description="some dataclass item", tax=0.2),
PDItem(name="foo", price=42.0, description="some pydantic item", tax=0.2),
],
)
def test_pickle_coder(value: Any) -> None:
@@ -20,3 +39,28 @@ def test_pickle_coder(value: Any) -> None:
assert isinstance(encoded_value, str)
decoded_value = PickleCoder.decode(encoded_value)
assert decoded_value == value
@pytest.mark.parametrize(
("value", "return_type"),
[
(1, None),
("some_string", None),
((1, 2), tuple[int, int]),
([1, 2, 3], None),
({"some_key": 1, "other_key": 2}, None),
(DCItem(name="foo", price=42.0, description="some dataclass item", tax=0.2), DCItem),
(PDItem(name="foo", price=42.0, description="some pydantic item", tax=0.2), PDItem),
],
)
def test_json_coder(value: Any, return_type) -> None:
encoded_value = JsonCoder.encode(value)
assert isinstance(encoded_value, str)
decoded_value = JsonCoder.decode_as_type(encoded_value, type_=return_type)
assert decoded_value == value
def test_json_coder_validation_error() -> None:
invalid = '{"name": "incomplete"}'
with pytest.raises(ValidationError):
JsonCoder.decode_as_type(invalid, type_=PDItem)

View File

@@ -79,3 +79,10 @@ def test_method() -> None:
with TestClient(app) as client:
response = client.get("/method")
assert response.json() == 17
def test_pydantic_model() -> None:
with TestClient(app) as client:
r1 = client.get("/pydantic_instance").json()
r2 = client.get("/pydantic_instance").json()
assert r1 == r2