Merge pull request #131 from mjpieters/json_decoder_pydantic

Decode cache data to the correct endpoint type
This commit is contained in:
long2ice
2023-05-09 10:16:17 +08:00
committed by GitHub
6 changed files with 137 additions and 5 deletions

View File

@@ -101,6 +101,30 @@ key_builder | which key builder to use, default to builtin
You can also use `cache` as decorator like other cache tools to cache common function result.
### Supported data types
When using the (default) `JsonCoder`, the cache can store any data type that FastAPI can convert to JSON, including Pydantic models and dataclasses,
_provided_ that your endpoint has a correct return type annotation, unless
the return type is a standard JSON-supported type such as a dictionary or a list.
E.g. for an endpoint that returns a Pydantic model named `SomeModel`:
```python
from .models import SomeModel, create_some_model
@app.get("/foo")
@cache(expire=60)
async def foo() -> SomeModel:
return create_some_model
```
It is not sufficient to configure a response model in the route decorator; the cache needs to know what the method itself returns.
If no return type decorator is given, the primitive JSON type is returned instead.
For broader type support, use the `fastapi_cache.coder.PickleCoder` or implement a custom coder (see below).
### Custom coder
By default use `JsonCoder`, you can write custom coder to encode and decode cache result, just need

View File

@@ -7,6 +7,7 @@ from starlette.responses import JSONResponse, Response
from fastapi_cache import FastAPICache
from fastapi_cache.backends.inmemory import InMemoryBackend
from fastapi_cache.decorator import cache
from pydantic import BaseModel
app = FastAPI()
@@ -80,6 +81,20 @@ instance = SomeClass(17)
app.get("/method")(cache(namespace="test")(instance.handler_method))
# cache a Pydantic model instance; the return type annotation is required in this case
class Item(BaseModel):
name: str
description: str | None = None
price: float
tax: float | None = None
@app.get("/pydantic_instance")
@cache(namespace="test", expire=5)
async def pydantic_instance() -> Item:
return Item(name="Something", description="An instance of a Pydantic model", price=10.5)
@app.on_event("startup")
async def startup():
FastAPICache.init(InMemoryBackend())

View File

@@ -3,13 +3,17 @@ import datetime
import json
import pickle # nosec:B403
from decimal import Decimal
from typing import Any, Callable
from typing import Any, Callable, TypeVar, overload
import pendulum
from fastapi.encoders import jsonable_encoder
from pydantic import BaseConfig, ValidationError, fields
from starlette.responses import JSONResponse
from starlette.templating import _TemplateResponse as TemplateResponse
_T = TypeVar("_T")
CONVERTERS: dict[str, Callable[[str], Any]] = {
"date": lambda x: pendulum.parse(x, exact=True),
"datetime": lambda x: pendulum.parse(x, exact=True),
@@ -49,6 +53,35 @@ class Coder:
def decode(cls, value: str) -> Any:
raise NotImplementedError
@overload
@classmethod
def decode_as_type(cls, value: str, type_: _T) -> _T:
...
@overload
@classmethod
def decode_as_type(cls, value: str, *, type_: None) -> Any:
...
@classmethod
def decode_as_type(cls, value: str, *, type_: _T | None) -> _T | Any:
"""Decode value to the specific given type
The default implementation uses the Pydantic model system to convert the value.
"""
result = cls.decode(value)
if type_ is not None:
field = fields.ModelField(
name="body", type_=type_, class_validators=None, model_config=BaseConfig
)
result, errors = field.validate(result, {}, loc=())
if errors is not None:
if not isinstance(errors, list):
errors = [errors]
raise ValidationError(errors, type_)
return result
class JsonCoder(Coder):
@classmethod
@@ -72,3 +105,10 @@ class PickleCoder(Coder):
@classmethod
def decode(cls, value: str) -> Any:
return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301
@classmethod
def decode_as_type(cls, value: str, *, type_: Any) -> Any:
# Pickle already produces the correct type on decoding, no point
# in paying an extra performance penalty for pydantic to discover
# the same.
return cls.decode(value)

View File

@@ -10,6 +10,7 @@ else:
from typing_extensions import ParamSpec
from fastapi.concurrency import run_in_threadpool
from fastapi.dependencies.utils import get_typed_return_annotation
from starlette.requests import Request
from starlette.responses import Response
@@ -79,6 +80,7 @@ def cache(
(param for param in signature.parameters.values() if param.annotation is Response),
None,
)
return_type = get_typed_return_annotation(func)
@wraps(func)
async def inner(*args: P.args, **kwargs: P.kwargs) -> R:
@@ -139,7 +141,7 @@ def cache(
ttl, ret = 0, None
if not request:
if ret is not None:
return coder.decode(ret)
return coder.decode_as_type(ret, type_=return_type)
ret = await ensure_async_func(*args, **kwargs)
try:
await backend.set(cache_key, coder.encode(ret), expire)
@@ -161,7 +163,7 @@ def cache(
response.status_code = 304
return response
response.headers["ETag"] = etag
return coder.decode(ret)
return coder.decode_as_type(ret, type_=return_type)
ret = await ensure_async_func(*args, **kwargs)
encoded_ret = coder.encode(ret)

View File

@@ -1,8 +1,25 @@
from typing import Any
from dataclasses import dataclass
from typing import Any, Optional
import pytest
from pydantic import BaseModel, ValidationError
from fastapi_cache.coder import PickleCoder
from fastapi_cache.coder import JsonCoder, PickleCoder
@dataclass
class DCItem:
name: str
price: float
description: Optional[str] = None
tax: Optional[float] = None
class PDItem(BaseModel):
name: str
price: float
description: Optional[str] = None
tax: Optional[float] = None
@pytest.mark.parametrize(
@@ -13,6 +30,8 @@ from fastapi_cache.coder import PickleCoder
(1, 2),
[1, 2, 3],
{"some_key": 1, "other_key": 2},
DCItem(name="foo", price=42.0, description="some dataclass item", tax=0.2),
PDItem(name="foo", price=42.0, description="some pydantic item", tax=0.2),
],
)
def test_pickle_coder(value: Any) -> None:
@@ -20,3 +39,28 @@ def test_pickle_coder(value: Any) -> None:
assert isinstance(encoded_value, str)
decoded_value = PickleCoder.decode(encoded_value)
assert decoded_value == value
@pytest.mark.parametrize(
("value", "return_type"),
[
(1, None),
("some_string", None),
((1, 2), tuple[int, int]),
([1, 2, 3], None),
({"some_key": 1, "other_key": 2}, None),
(DCItem(name="foo", price=42.0, description="some dataclass item", tax=0.2), DCItem),
(PDItem(name="foo", price=42.0, description="some pydantic item", tax=0.2), PDItem),
],
)
def test_json_coder(value: Any, return_type) -> None:
encoded_value = JsonCoder.encode(value)
assert isinstance(encoded_value, str)
decoded_value = JsonCoder.decode_as_type(encoded_value, type_=return_type)
assert decoded_value == value
def test_json_coder_validation_error() -> None:
invalid = '{"name": "incomplete"}'
with pytest.raises(ValidationError):
JsonCoder.decode_as_type(invalid, type_=PDItem)

View File

@@ -79,3 +79,10 @@ def test_method() -> None:
with TestClient(app) as client:
response = client.get("/method")
assert response.json() == 17
def test_pydantic_model() -> None:
with TestClient(app) as client:
r1 = client.get("/pydantic_instance").json()
r2 = client.get("/pydantic_instance").json()
assert r1 == r2