mirror of
https://github.com/long2ice/fastapi-cache.git
synced 2026-03-25 04:57:54 +00:00
Decode cache data to the correct endpoint type
Use the return annotation to decode cached data to the correct type. This follows the same logic FastAPI uses to JSON request bodies. For the PickleCoder, this is a no-op as pickle already stores type information in the serialised data.
This commit is contained in:
24
README.md
24
README.md
@@ -101,6 +101,30 @@ key_builder | which key builder to use, default to builtin
|
|||||||
|
|
||||||
You can also use `cache` as decorator like other cache tools to cache common function result.
|
You can also use `cache` as decorator like other cache tools to cache common function result.
|
||||||
|
|
||||||
|
|
||||||
|
### Supported data types
|
||||||
|
|
||||||
|
When using the (default) `JsonCoder`, the cache can store any data type that FastAPI can convert to JSON, including Pydantic models and dataclasses,
|
||||||
|
_provided_ that your endpoint has a correct return type annotation, unless
|
||||||
|
the return type is a standard JSON-supported type such as a dictionary or a list.
|
||||||
|
|
||||||
|
E.g. for an endpoint that returns a Pydantic model named `SomeModel`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .models import SomeModel, create_some_model
|
||||||
|
|
||||||
|
@app.get("/foo")
|
||||||
|
@cache(expire=60)
|
||||||
|
async def foo() -> SomeModel:
|
||||||
|
return create_some_model
|
||||||
|
```
|
||||||
|
|
||||||
|
It is not sufficient to configure a response model in the route decorator; the cache needs to know what the method itself returns.
|
||||||
|
|
||||||
|
If no return type decorator is given, the primitive JSON type is returned instead.
|
||||||
|
|
||||||
|
For broader type support, use the `fastapi_cache.coder.PickleCoder` or implement a custom coder (see below).
|
||||||
|
|
||||||
### Custom coder
|
### Custom coder
|
||||||
|
|
||||||
By default use `JsonCoder`, you can write custom coder to encode and decode cache result, just need
|
By default use `JsonCoder`, you can write custom coder to encode and decode cache result, just need
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from starlette.responses import JSONResponse, Response
|
|||||||
from fastapi_cache import FastAPICache
|
from fastapi_cache import FastAPICache
|
||||||
from fastapi_cache.backends.inmemory import InMemoryBackend
|
from fastapi_cache.backends.inmemory import InMemoryBackend
|
||||||
from fastapi_cache.decorator import cache
|
from fastapi_cache.decorator import cache
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
@@ -80,6 +81,20 @@ instance = SomeClass(17)
|
|||||||
app.get("/method")(cache(namespace="test")(instance.handler_method))
|
app.get("/method")(cache(namespace="test")(instance.handler_method))
|
||||||
|
|
||||||
|
|
||||||
|
# cache a Pydantic model instance; the return type annotation is required in this case
|
||||||
|
class Item(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
price: float
|
||||||
|
tax: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/pydantic_instance")
|
||||||
|
@cache(namespace="test", expire=5)
|
||||||
|
async def pydantic_instance() -> Item:
|
||||||
|
return Item(name="Something", description="An instance of a Pydantic model", price=10.5)
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup():
|
async def startup():
|
||||||
FastAPICache.init(InMemoryBackend())
|
FastAPICache.init(InMemoryBackend())
|
||||||
|
|||||||
@@ -3,13 +3,17 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
import pickle # nosec:B403
|
import pickle # nosec:B403
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable, TypeVar, overload
|
||||||
|
|
||||||
import pendulum
|
import pendulum
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from pydantic import BaseConfig, ValidationError, fields
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from starlette.templating import _TemplateResponse as TemplateResponse
|
from starlette.templating import _TemplateResponse as TemplateResponse
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
CONVERTERS: dict[str, Callable[[str], Any]] = {
|
CONVERTERS: dict[str, Callable[[str], Any]] = {
|
||||||
"date": lambda x: pendulum.parse(x, exact=True),
|
"date": lambda x: pendulum.parse(x, exact=True),
|
||||||
"datetime": lambda x: pendulum.parse(x, exact=True),
|
"datetime": lambda x: pendulum.parse(x, exact=True),
|
||||||
@@ -49,6 +53,35 @@ class Coder:
|
|||||||
def decode(cls, value: str) -> Any:
|
def decode(cls, value: str) -> Any:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
def decode_as_type(cls, value: str, type_: _T) -> _T:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
def decode_as_type(cls, value: str, *, type_: None) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def decode_as_type(cls, value: str, *, type_: _T | None) -> _T | Any:
|
||||||
|
"""Decode value to the specific given type
|
||||||
|
|
||||||
|
The default implementation uses the Pydantic model system to convert the value.
|
||||||
|
|
||||||
|
"""
|
||||||
|
result = cls.decode(value)
|
||||||
|
if type_ is not None:
|
||||||
|
field = fields.ModelField(
|
||||||
|
name="body", type_=type_, class_validators=None, model_config=BaseConfig
|
||||||
|
)
|
||||||
|
result, errors = field.validate(result, {}, loc=())
|
||||||
|
if errors is not None:
|
||||||
|
if not isinstance(errors, list):
|
||||||
|
errors = [errors]
|
||||||
|
raise ValidationError(errors, type_)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class JsonCoder(Coder):
|
class JsonCoder(Coder):
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -72,3 +105,10 @@ class PickleCoder(Coder):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def decode(cls, value: str) -> Any:
|
def decode(cls, value: str) -> Any:
|
||||||
return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301
|
return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def decode_as_type(cls, value: str, *, type_: Any) -> Any:
|
||||||
|
# Pickle already produces the correct type on decoding, no point
|
||||||
|
# in paying an extra performance penalty for pydantic to discover
|
||||||
|
# the same.
|
||||||
|
return cls.decode(value)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ else:
|
|||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from fastapi.concurrency import run_in_threadpool
|
from fastapi.concurrency import run_in_threadpool
|
||||||
|
from fastapi.dependencies.utils import get_typed_return_annotation
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
@@ -79,6 +80,7 @@ def cache(
|
|||||||
(param for param in signature.parameters.values() if param.annotation is Response),
|
(param for param in signature.parameters.values() if param.annotation is Response),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
return_type = get_typed_return_annotation(func)
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
async def inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
@@ -139,7 +141,7 @@ def cache(
|
|||||||
ttl, ret = 0, None
|
ttl, ret = 0, None
|
||||||
if not request:
|
if not request:
|
||||||
if ret is not None:
|
if ret is not None:
|
||||||
return coder.decode(ret)
|
return coder.decode_as_type(ret, type_=return_type)
|
||||||
ret = await ensure_async_func(*args, **kwargs)
|
ret = await ensure_async_func(*args, **kwargs)
|
||||||
try:
|
try:
|
||||||
await backend.set(cache_key, coder.encode(ret), expire)
|
await backend.set(cache_key, coder.encode(ret), expire)
|
||||||
@@ -161,7 +163,7 @@ def cache(
|
|||||||
response.status_code = 304
|
response.status_code = 304
|
||||||
return response
|
return response
|
||||||
response.headers["ETag"] = etag
|
response.headers["ETag"] = etag
|
||||||
return coder.decode(ret)
|
return coder.decode_as_type(ret, type_=return_type)
|
||||||
|
|
||||||
ret = await ensure_async_func(*args, **kwargs)
|
ret = await ensure_async_func(*args, **kwargs)
|
||||||
encoded_ret = coder.encode(ret)
|
encoded_ret = coder.encode(ret)
|
||||||
|
|||||||
@@ -1,8 +1,25 @@
|
|||||||
from typing import Any
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
from fastapi_cache.coder import PickleCoder
|
from fastapi_cache.coder import JsonCoder, PickleCoder
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DCItem:
|
||||||
|
name: str
|
||||||
|
price: float
|
||||||
|
description: Optional[str] = None
|
||||||
|
tax: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PDItem(BaseModel):
|
||||||
|
name: str
|
||||||
|
price: float
|
||||||
|
description: Optional[str] = None
|
||||||
|
tax: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -13,6 +30,8 @@ from fastapi_cache.coder import PickleCoder
|
|||||||
(1, 2),
|
(1, 2),
|
||||||
[1, 2, 3],
|
[1, 2, 3],
|
||||||
{"some_key": 1, "other_key": 2},
|
{"some_key": 1, "other_key": 2},
|
||||||
|
DCItem(name="foo", price=42.0, description="some dataclass item", tax=0.2),
|
||||||
|
PDItem(name="foo", price=42.0, description="some pydantic item", tax=0.2),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_pickle_coder(value: Any) -> None:
|
def test_pickle_coder(value: Any) -> None:
|
||||||
@@ -20,3 +39,28 @@ def test_pickle_coder(value: Any) -> None:
|
|||||||
assert isinstance(encoded_value, str)
|
assert isinstance(encoded_value, str)
|
||||||
decoded_value = PickleCoder.decode(encoded_value)
|
decoded_value = PickleCoder.decode(encoded_value)
|
||||||
assert decoded_value == value
|
assert decoded_value == value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("value", "return_type"),
|
||||||
|
[
|
||||||
|
(1, None),
|
||||||
|
("some_string", None),
|
||||||
|
((1, 2), tuple[int, int]),
|
||||||
|
([1, 2, 3], None),
|
||||||
|
({"some_key": 1, "other_key": 2}, None),
|
||||||
|
(DCItem(name="foo", price=42.0, description="some dataclass item", tax=0.2), DCItem),
|
||||||
|
(PDItem(name="foo", price=42.0, description="some pydantic item", tax=0.2), PDItem),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_json_coder(value: Any, return_type) -> None:
|
||||||
|
encoded_value = JsonCoder.encode(value)
|
||||||
|
assert isinstance(encoded_value, str)
|
||||||
|
decoded_value = JsonCoder.decode_as_type(encoded_value, type_=return_type)
|
||||||
|
assert decoded_value == value
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_coder_validation_error() -> None:
|
||||||
|
invalid = '{"name": "incomplete"}'
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
JsonCoder.decode_as_type(invalid, type_=PDItem)
|
||||||
|
|||||||
@@ -79,3 +79,10 @@ def test_method() -> None:
|
|||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
response = client.get("/method")
|
response = client.get("/method")
|
||||||
assert response.json() == 17
|
assert response.json() == 17
|
||||||
|
|
||||||
|
|
||||||
|
def test_pydantic_model() -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
r1 = client.get("/pydantic_instance").json()
|
||||||
|
r2 = client.get("/pydantic_instance").json()
|
||||||
|
assert r1 == r2
|
||||||
|
|||||||
Reference in New Issue
Block a user