From f78a599bbc2b91f4302360bbb28977bbea6bea26 Mon Sep 17 00:00:00 2001 From: Martijn Pieters Date: Mon, 8 May 2023 16:42:21 +0100 Subject: [PATCH] Decode cache data to the correct endpoint type Use the return annotation to decode cached data to the correct type. This follows the same logic FastAPI uses to JSON request bodies. For the PickleCoder, this is a no-op as pickle already stores type information in the serialised data. --- README.md | 24 +++++++++++++++++++ examples/in_memory/main.py | 15 ++++++++++++ fastapi_cache/coder.py | 42 ++++++++++++++++++++++++++++++++- fastapi_cache/decorator.py | 6 +++-- tests/test_codecs.py | 48 ++++++++++++++++++++++++++++++++++++-- tests/test_decorator.py | 7 ++++++ 6 files changed, 137 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index ce038b0..5cfdb54 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,30 @@ key_builder | which key builder to use, default to builtin You can also use `cache` as decorator like other cache tools to cache common function result. + +### Supported data types + +When using the (default) `JsonCoder`, the cache can store any data type that FastAPI can convert to JSON, including Pydantic models and dataclasses, +_provided_ that your endpoint has a correct return type annotation, unless +the return type is a standard JSON-supported type such as a dictionary or a list. + +E.g. for an endpoint that returns a Pydantic model named `SomeModel`: + +```python +from .models import SomeModel, create_some_model + +@app.get("/foo") +@cache(expire=60) +async def foo() -> SomeModel: + return create_some_model +``` + +It is not sufficient to configure a response model in the route decorator; the cache needs to know what the method itself returns. + +If no return type decorator is given, the primitive JSON type is returned instead. + +For broader type support, use the `fastapi_cache.coder.PickleCoder` or implement a custom coder (see below). + ### Custom coder By default use `JsonCoder`, you can write custom coder to encode and decode cache result, just need diff --git a/examples/in_memory/main.py b/examples/in_memory/main.py index 2306472..d79007d 100644 --- a/examples/in_memory/main.py +++ b/examples/in_memory/main.py @@ -7,6 +7,7 @@ from starlette.responses import JSONResponse, Response from fastapi_cache import FastAPICache from fastapi_cache.backends.inmemory import InMemoryBackend from fastapi_cache.decorator import cache +from pydantic import BaseModel app = FastAPI() @@ -80,6 +81,20 @@ instance = SomeClass(17) app.get("/method")(cache(namespace="test")(instance.handler_method)) +# cache a Pydantic model instance; the return type annotation is required in this case +class Item(BaseModel): + name: str + description: str | None = None + price: float + tax: float | None = None + + +@app.get("/pydantic_instance") +@cache(namespace="test", expire=5) +async def pydantic_instance() -> Item: + return Item(name="Something", description="An instance of a Pydantic model", price=10.5) + + @app.on_event("startup") async def startup(): FastAPICache.init(InMemoryBackend()) diff --git a/fastapi_cache/coder.py b/fastapi_cache/coder.py index d698db2..a7373da 100644 --- a/fastapi_cache/coder.py +++ b/fastapi_cache/coder.py @@ -3,13 +3,17 @@ import datetime import json import pickle # nosec:B403 from decimal import Decimal -from typing import Any, Callable +from typing import Any, Callable, TypeVar, overload import pendulum from fastapi.encoders import jsonable_encoder +from pydantic import BaseConfig, ValidationError, fields from starlette.responses import JSONResponse from starlette.templating import _TemplateResponse as TemplateResponse +_T = TypeVar("_T") + + CONVERTERS: dict[str, Callable[[str], Any]] = { "date": lambda x: pendulum.parse(x, exact=True), "datetime": lambda x: pendulum.parse(x, exact=True), @@ -49,6 +53,35 @@ class Coder: def decode(cls, value: str) -> Any: raise NotImplementedError + @overload + @classmethod + def decode_as_type(cls, value: str, type_: _T) -> _T: + ... + + @overload + @classmethod + def decode_as_type(cls, value: str, *, type_: None) -> Any: + ... + + @classmethod + def decode_as_type(cls, value: str, *, type_: _T | None) -> _T | Any: + """Decode value to the specific given type + + The default implementation uses the Pydantic model system to convert the value. + + """ + result = cls.decode(value) + if type_ is not None: + field = fields.ModelField( + name="body", type_=type_, class_validators=None, model_config=BaseConfig + ) + result, errors = field.validate(result, {}, loc=()) + if errors is not None: + if not isinstance(errors, list): + errors = [errors] + raise ValidationError(errors, type_) + return result + class JsonCoder(Coder): @classmethod @@ -72,3 +105,10 @@ class PickleCoder(Coder): @classmethod def decode(cls, value: str) -> Any: return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301 + + @classmethod + def decode_as_type(cls, value: str, *, type_: Any) -> Any: + # Pickle already produces the correct type on decoding, no point + # in paying an extra performance penalty for pydantic to discover + # the same. + return cls.decode(value) diff --git a/fastapi_cache/decorator.py b/fastapi_cache/decorator.py index 65fc02b..23ea496 100644 --- a/fastapi_cache/decorator.py +++ b/fastapi_cache/decorator.py @@ -10,6 +10,7 @@ else: from typing_extensions import ParamSpec from fastapi.concurrency import run_in_threadpool +from fastapi.dependencies.utils import get_typed_return_annotation from starlette.requests import Request from starlette.responses import Response @@ -79,6 +80,7 @@ def cache( (param for param in signature.parameters.values() if param.annotation is Response), None, ) + return_type = get_typed_return_annotation(func) @wraps(func) async def inner(*args: P.args, **kwargs: P.kwargs) -> R: @@ -139,7 +141,7 @@ def cache( ttl, ret = 0, None if not request: if ret is not None: - return coder.decode(ret) + return coder.decode_as_type(ret, type_=return_type) ret = await ensure_async_func(*args, **kwargs) try: await backend.set(cache_key, coder.encode(ret), expire) @@ -161,7 +163,7 @@ def cache( response.status_code = 304 return response response.headers["ETag"] = etag - return coder.decode(ret) + return coder.decode_as_type(ret, type_=return_type) ret = await ensure_async_func(*args, **kwargs) encoded_ret = coder.encode(ret) diff --git a/tests/test_codecs.py b/tests/test_codecs.py index 371172d..8120977 100644 --- a/tests/test_codecs.py +++ b/tests/test_codecs.py @@ -1,8 +1,25 @@ -from typing import Any +from dataclasses import dataclass +from typing import Any, Optional import pytest +from pydantic import BaseModel, ValidationError -from fastapi_cache.coder import PickleCoder +from fastapi_cache.coder import JsonCoder, PickleCoder + + +@dataclass +class DCItem: + name: str + price: float + description: Optional[str] = None + tax: Optional[float] = None + + +class PDItem(BaseModel): + name: str + price: float + description: Optional[str] = None + tax: Optional[float] = None @pytest.mark.parametrize( @@ -13,6 +30,8 @@ from fastapi_cache.coder import PickleCoder (1, 2), [1, 2, 3], {"some_key": 1, "other_key": 2}, + DCItem(name="foo", price=42.0, description="some dataclass item", tax=0.2), + PDItem(name="foo", price=42.0, description="some pydantic item", tax=0.2), ], ) def test_pickle_coder(value: Any) -> None: @@ -20,3 +39,28 @@ def test_pickle_coder(value: Any) -> None: assert isinstance(encoded_value, str) decoded_value = PickleCoder.decode(encoded_value) assert decoded_value == value + + +@pytest.mark.parametrize( + ("value", "return_type"), + [ + (1, None), + ("some_string", None), + ((1, 2), tuple[int, int]), + ([1, 2, 3], None), + ({"some_key": 1, "other_key": 2}, None), + (DCItem(name="foo", price=42.0, description="some dataclass item", tax=0.2), DCItem), + (PDItem(name="foo", price=42.0, description="some pydantic item", tax=0.2), PDItem), + ], +) +def test_json_coder(value: Any, return_type) -> None: + encoded_value = JsonCoder.encode(value) + assert isinstance(encoded_value, str) + decoded_value = JsonCoder.decode_as_type(encoded_value, type_=return_type) + assert decoded_value == value + + +def test_json_coder_validation_error() -> None: + invalid = '{"name": "incomplete"}' + with pytest.raises(ValidationError): + JsonCoder.decode_as_type(invalid, type_=PDItem) diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 9fd8aa8..a037bed 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -79,3 +79,10 @@ def test_method() -> None: with TestClient(app) as client: response = client.get("/method") assert response.json() == 17 + + +def test_pydantic_model() -> None: + with TestClient(app) as client: + r1 = client.get("/pydantic_instance").json() + r2 = client.get("/pydantic_instance").json() + assert r1 == r2