diff --git a/README.md b/README.md index ce038b0..5cfdb54 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,30 @@ key_builder | which key builder to use, default to builtin You can also use `cache` as decorator like other cache tools to cache common function result. + +### Supported data types + +When using the (default) `JsonCoder`, the cache can store any data type that FastAPI can convert to JSON, including Pydantic models and dataclasses, +_provided_ that your endpoint has a correct return type annotation, unless +the return type is a standard JSON-supported type such as a dictionary or a list. + +E.g. for an endpoint that returns a Pydantic model named `SomeModel`: + +```python +from .models import SomeModel, create_some_model + +@app.get("/foo") +@cache(expire=60) +async def foo() -> SomeModel: + return create_some_model +``` + +It is not sufficient to configure a response model in the route decorator; the cache needs to know what the method itself returns. + +If no return type decorator is given, the primitive JSON type is returned instead. + +For broader type support, use the `fastapi_cache.coder.PickleCoder` or implement a custom coder (see below). + ### Custom coder By default use `JsonCoder`, you can write custom coder to encode and decode cache result, just need diff --git a/examples/in_memory/main.py b/examples/in_memory/main.py index 2306472..d79007d 100644 --- a/examples/in_memory/main.py +++ b/examples/in_memory/main.py @@ -7,6 +7,7 @@ from starlette.responses import JSONResponse, Response from fastapi_cache import FastAPICache from fastapi_cache.backends.inmemory import InMemoryBackend from fastapi_cache.decorator import cache +from pydantic import BaseModel app = FastAPI() @@ -80,6 +81,20 @@ instance = SomeClass(17) app.get("/method")(cache(namespace="test")(instance.handler_method)) +# cache a Pydantic model instance; the return type annotation is required in this case +class Item(BaseModel): + name: str + description: str | None = None + price: float + tax: float | None = None + + +@app.get("/pydantic_instance") +@cache(namespace="test", expire=5) +async def pydantic_instance() -> Item: + return Item(name="Something", description="An instance of a Pydantic model", price=10.5) + + @app.on_event("startup") async def startup(): FastAPICache.init(InMemoryBackend()) diff --git a/fastapi_cache/coder.py b/fastapi_cache/coder.py index d698db2..a7373da 100644 --- a/fastapi_cache/coder.py +++ b/fastapi_cache/coder.py @@ -3,13 +3,17 @@ import datetime import json import pickle # nosec:B403 from decimal import Decimal -from typing import Any, Callable +from typing import Any, Callable, TypeVar, overload import pendulum from fastapi.encoders import jsonable_encoder +from pydantic import BaseConfig, ValidationError, fields from starlette.responses import JSONResponse from starlette.templating import _TemplateResponse as TemplateResponse +_T = TypeVar("_T") + + CONVERTERS: dict[str, Callable[[str], Any]] = { "date": lambda x: pendulum.parse(x, exact=True), "datetime": lambda x: pendulum.parse(x, exact=True), @@ -49,6 +53,35 @@ class Coder: def decode(cls, value: str) -> Any: raise NotImplementedError + @overload + @classmethod + def decode_as_type(cls, value: str, type_: _T) -> _T: + ... + + @overload + @classmethod + def decode_as_type(cls, value: str, *, type_: None) -> Any: + ... + + @classmethod + def decode_as_type(cls, value: str, *, type_: _T | None) -> _T | Any: + """Decode value to the specific given type + + The default implementation uses the Pydantic model system to convert the value. + + """ + result = cls.decode(value) + if type_ is not None: + field = fields.ModelField( + name="body", type_=type_, class_validators=None, model_config=BaseConfig + ) + result, errors = field.validate(result, {}, loc=()) + if errors is not None: + if not isinstance(errors, list): + errors = [errors] + raise ValidationError(errors, type_) + return result + class JsonCoder(Coder): @classmethod @@ -72,3 +105,10 @@ class PickleCoder(Coder): @classmethod def decode(cls, value: str) -> Any: return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301 + + @classmethod + def decode_as_type(cls, value: str, *, type_: Any) -> Any: + # Pickle already produces the correct type on decoding, no point + # in paying an extra performance penalty for pydantic to discover + # the same. + return cls.decode(value) diff --git a/fastapi_cache/decorator.py b/fastapi_cache/decorator.py index 65fc02b..23ea496 100644 --- a/fastapi_cache/decorator.py +++ b/fastapi_cache/decorator.py @@ -10,6 +10,7 @@ else: from typing_extensions import ParamSpec from fastapi.concurrency import run_in_threadpool +from fastapi.dependencies.utils import get_typed_return_annotation from starlette.requests import Request from starlette.responses import Response @@ -79,6 +80,7 @@ def cache( (param for param in signature.parameters.values() if param.annotation is Response), None, ) + return_type = get_typed_return_annotation(func) @wraps(func) async def inner(*args: P.args, **kwargs: P.kwargs) -> R: @@ -139,7 +141,7 @@ def cache( ttl, ret = 0, None if not request: if ret is not None: - return coder.decode(ret) + return coder.decode_as_type(ret, type_=return_type) ret = await ensure_async_func(*args, **kwargs) try: await backend.set(cache_key, coder.encode(ret), expire) @@ -161,7 +163,7 @@ def cache( response.status_code = 304 return response response.headers["ETag"] = etag - return coder.decode(ret) + return coder.decode_as_type(ret, type_=return_type) ret = await ensure_async_func(*args, **kwargs) encoded_ret = coder.encode(ret) diff --git a/tests/test_codecs.py b/tests/test_codecs.py index 371172d..8120977 100644 --- a/tests/test_codecs.py +++ b/tests/test_codecs.py @@ -1,8 +1,25 @@ -from typing import Any +from dataclasses import dataclass +from typing import Any, Optional import pytest +from pydantic import BaseModel, ValidationError -from fastapi_cache.coder import PickleCoder +from fastapi_cache.coder import JsonCoder, PickleCoder + + +@dataclass +class DCItem: + name: str + price: float + description: Optional[str] = None + tax: Optional[float] = None + + +class PDItem(BaseModel): + name: str + price: float + description: Optional[str] = None + tax: Optional[float] = None @pytest.mark.parametrize( @@ -13,6 +30,8 @@ from fastapi_cache.coder import PickleCoder (1, 2), [1, 2, 3], {"some_key": 1, "other_key": 2}, + DCItem(name="foo", price=42.0, description="some dataclass item", tax=0.2), + PDItem(name="foo", price=42.0, description="some pydantic item", tax=0.2), ], ) def test_pickle_coder(value: Any) -> None: @@ -20,3 +39,28 @@ def test_pickle_coder(value: Any) -> None: assert isinstance(encoded_value, str) decoded_value = PickleCoder.decode(encoded_value) assert decoded_value == value + + +@pytest.mark.parametrize( + ("value", "return_type"), + [ + (1, None), + ("some_string", None), + ((1, 2), tuple[int, int]), + ([1, 2, 3], None), + ({"some_key": 1, "other_key": 2}, None), + (DCItem(name="foo", price=42.0, description="some dataclass item", tax=0.2), DCItem), + (PDItem(name="foo", price=42.0, description="some pydantic item", tax=0.2), PDItem), + ], +) +def test_json_coder(value: Any, return_type) -> None: + encoded_value = JsonCoder.encode(value) + assert isinstance(encoded_value, str) + decoded_value = JsonCoder.decode_as_type(encoded_value, type_=return_type) + assert decoded_value == value + + +def test_json_coder_validation_error() -> None: + invalid = '{"name": "incomplete"}' + with pytest.raises(ValidationError): + JsonCoder.decode_as_type(invalid, type_=PDItem) diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 9fd8aa8..a037bed 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -79,3 +79,10 @@ def test_method() -> None: with TestClient(app) as client: response = client.get("/method") assert response.json() == 17 + + +def test_pydantic_model() -> None: + with TestClient(app) as client: + r1 = client.get("/pydantic_instance").json() + r2 = client.get("/pydantic_instance").json() + assert r1 == r2