mirror of
https://github.com/long2ice/fastapi-cache.git
synced 2026-03-25 04:57:54 +00:00
Merge pull request #131 from mjpieters/json_decoder_pydantic
Decode cache data to the correct endpoint type
This commit is contained in:
24
README.md
24
README.md
@@ -101,6 +101,30 @@ key_builder | which key builder to use, default to builtin
|
|||||||
|
|
||||||
You can also use `cache` as decorator like other cache tools to cache common function result.
|
You can also use `cache` as decorator like other cache tools to cache common function result.
|
||||||
|
|
||||||
|
|
||||||
|
### Supported data types
|
||||||
|
|
||||||
|
When using the (default) `JsonCoder`, the cache can store any data type that FastAPI can convert to JSON, including Pydantic models and dataclasses,
|
||||||
|
_provided_ that your endpoint has a correct return type annotation, unless
|
||||||
|
the return type is a standard JSON-supported type such as a dictionary or a list.
|
||||||
|
|
||||||
|
E.g. for an endpoint that returns a Pydantic model named `SomeModel`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .models import SomeModel, create_some_model
|
||||||
|
|
||||||
|
@app.get("/foo")
|
||||||
|
@cache(expire=60)
|
||||||
|
async def foo() -> SomeModel:
|
||||||
|
return create_some_model
|
||||||
|
```
|
||||||
|
|
||||||
|
It is not sufficient to configure a response model in the route decorator; the cache needs to know what the method itself returns.
|
||||||
|
|
||||||
|
If no return type decorator is given, the primitive JSON type is returned instead.
|
||||||
|
|
||||||
|
For broader type support, use the `fastapi_cache.coder.PickleCoder` or implement a custom coder (see below).
|
||||||
|
|
||||||
### Custom coder
|
### Custom coder
|
||||||
|
|
||||||
By default use `JsonCoder`, you can write custom coder to encode and decode cache result, just need
|
By default use `JsonCoder`, you can write custom coder to encode and decode cache result, just need
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from starlette.responses import JSONResponse, Response
|
|||||||
from fastapi_cache import FastAPICache
|
from fastapi_cache import FastAPICache
|
||||||
from fastapi_cache.backends.inmemory import InMemoryBackend
|
from fastapi_cache.backends.inmemory import InMemoryBackend
|
||||||
from fastapi_cache.decorator import cache
|
from fastapi_cache.decorator import cache
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
@@ -80,6 +81,20 @@ instance = SomeClass(17)
|
|||||||
app.get("/method")(cache(namespace="test")(instance.handler_method))
|
app.get("/method")(cache(namespace="test")(instance.handler_method))
|
||||||
|
|
||||||
|
|
||||||
|
# cache a Pydantic model instance; the return type annotation is required in this case
|
||||||
|
class Item(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
price: float
|
||||||
|
tax: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/pydantic_instance")
|
||||||
|
@cache(namespace="test", expire=5)
|
||||||
|
async def pydantic_instance() -> Item:
|
||||||
|
return Item(name="Something", description="An instance of a Pydantic model", price=10.5)
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup():
|
async def startup():
|
||||||
FastAPICache.init(InMemoryBackend())
|
FastAPICache.init(InMemoryBackend())
|
||||||
|
|||||||
@@ -3,13 +3,17 @@ import datetime
|
|||||||
import json
|
import json
|
||||||
import pickle # nosec:B403
|
import pickle # nosec:B403
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable, TypeVar, overload
|
||||||
|
|
||||||
import pendulum
|
import pendulum
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
from pydantic import BaseConfig, ValidationError, fields
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from starlette.templating import _TemplateResponse as TemplateResponse
|
from starlette.templating import _TemplateResponse as TemplateResponse
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
|
|
||||||
CONVERTERS: dict[str, Callable[[str], Any]] = {
|
CONVERTERS: dict[str, Callable[[str], Any]] = {
|
||||||
"date": lambda x: pendulum.parse(x, exact=True),
|
"date": lambda x: pendulum.parse(x, exact=True),
|
||||||
"datetime": lambda x: pendulum.parse(x, exact=True),
|
"datetime": lambda x: pendulum.parse(x, exact=True),
|
||||||
@@ -49,6 +53,35 @@ class Coder:
|
|||||||
def decode(cls, value: str) -> Any:
|
def decode(cls, value: str) -> Any:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
def decode_as_type(cls, value: str, type_: _T) -> _T:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
@classmethod
|
||||||
|
def decode_as_type(cls, value: str, *, type_: None) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def decode_as_type(cls, value: str, *, type_: _T | None) -> _T | Any:
|
||||||
|
"""Decode value to the specific given type
|
||||||
|
|
||||||
|
The default implementation uses the Pydantic model system to convert the value.
|
||||||
|
|
||||||
|
"""
|
||||||
|
result = cls.decode(value)
|
||||||
|
if type_ is not None:
|
||||||
|
field = fields.ModelField(
|
||||||
|
name="body", type_=type_, class_validators=None, model_config=BaseConfig
|
||||||
|
)
|
||||||
|
result, errors = field.validate(result, {}, loc=())
|
||||||
|
if errors is not None:
|
||||||
|
if not isinstance(errors, list):
|
||||||
|
errors = [errors]
|
||||||
|
raise ValidationError(errors, type_)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class JsonCoder(Coder):
|
class JsonCoder(Coder):
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -72,3 +105,10 @@ class PickleCoder(Coder):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def decode(cls, value: str) -> Any:
|
def decode(cls, value: str) -> Any:
|
||||||
return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301
|
return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def decode_as_type(cls, value: str, *, type_: Any) -> Any:
|
||||||
|
# Pickle already produces the correct type on decoding, no point
|
||||||
|
# in paying an extra performance penalty for pydantic to discover
|
||||||
|
# the same.
|
||||||
|
return cls.decode(value)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ else:
|
|||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from fastapi.concurrency import run_in_threadpool
|
from fastapi.concurrency import run_in_threadpool
|
||||||
|
from fastapi.dependencies.utils import get_typed_return_annotation
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
from starlette.responses import Response
|
from starlette.responses import Response
|
||||||
|
|
||||||
@@ -79,6 +80,7 @@ def cache(
|
|||||||
(param for param in signature.parameters.values() if param.annotation is Response),
|
(param for param in signature.parameters.values() if param.annotation is Response),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
return_type = get_typed_return_annotation(func)
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
async def inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
async def inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
@@ -139,7 +141,7 @@ def cache(
|
|||||||
ttl, ret = 0, None
|
ttl, ret = 0, None
|
||||||
if not request:
|
if not request:
|
||||||
if ret is not None:
|
if ret is not None:
|
||||||
return coder.decode(ret)
|
return coder.decode_as_type(ret, type_=return_type)
|
||||||
ret = await ensure_async_func(*args, **kwargs)
|
ret = await ensure_async_func(*args, **kwargs)
|
||||||
try:
|
try:
|
||||||
await backend.set(cache_key, coder.encode(ret), expire)
|
await backend.set(cache_key, coder.encode(ret), expire)
|
||||||
@@ -161,7 +163,7 @@ def cache(
|
|||||||
response.status_code = 304
|
response.status_code = 304
|
||||||
return response
|
return response
|
||||||
response.headers["ETag"] = etag
|
response.headers["ETag"] = etag
|
||||||
return coder.decode(ret)
|
return coder.decode_as_type(ret, type_=return_type)
|
||||||
|
|
||||||
ret = await ensure_async_func(*args, **kwargs)
|
ret = await ensure_async_func(*args, **kwargs)
|
||||||
encoded_ret = coder.encode(ret)
|
encoded_ret = coder.encode(ret)
|
||||||
|
|||||||
@@ -1,8 +1,25 @@
|
|||||||
from typing import Any
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
from fastapi_cache.coder import PickleCoder
|
from fastapi_cache.coder import JsonCoder, PickleCoder
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DCItem:
|
||||||
|
name: str
|
||||||
|
price: float
|
||||||
|
description: Optional[str] = None
|
||||||
|
tax: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class PDItem(BaseModel):
|
||||||
|
name: str
|
||||||
|
price: float
|
||||||
|
description: Optional[str] = None
|
||||||
|
tax: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -13,6 +30,8 @@ from fastapi_cache.coder import PickleCoder
|
|||||||
(1, 2),
|
(1, 2),
|
||||||
[1, 2, 3],
|
[1, 2, 3],
|
||||||
{"some_key": 1, "other_key": 2},
|
{"some_key": 1, "other_key": 2},
|
||||||
|
DCItem(name="foo", price=42.0, description="some dataclass item", tax=0.2),
|
||||||
|
PDItem(name="foo", price=42.0, description="some pydantic item", tax=0.2),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_pickle_coder(value: Any) -> None:
|
def test_pickle_coder(value: Any) -> None:
|
||||||
@@ -20,3 +39,28 @@ def test_pickle_coder(value: Any) -> None:
|
|||||||
assert isinstance(encoded_value, str)
|
assert isinstance(encoded_value, str)
|
||||||
decoded_value = PickleCoder.decode(encoded_value)
|
decoded_value = PickleCoder.decode(encoded_value)
|
||||||
assert decoded_value == value
|
assert decoded_value == value
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("value", "return_type"),
|
||||||
|
[
|
||||||
|
(1, None),
|
||||||
|
("some_string", None),
|
||||||
|
((1, 2), tuple[int, int]),
|
||||||
|
([1, 2, 3], None),
|
||||||
|
({"some_key": 1, "other_key": 2}, None),
|
||||||
|
(DCItem(name="foo", price=42.0, description="some dataclass item", tax=0.2), DCItem),
|
||||||
|
(PDItem(name="foo", price=42.0, description="some pydantic item", tax=0.2), PDItem),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_json_coder(value: Any, return_type) -> None:
|
||||||
|
encoded_value = JsonCoder.encode(value)
|
||||||
|
assert isinstance(encoded_value, str)
|
||||||
|
decoded_value = JsonCoder.decode_as_type(encoded_value, type_=return_type)
|
||||||
|
assert decoded_value == value
|
||||||
|
|
||||||
|
|
||||||
|
def test_json_coder_validation_error() -> None:
|
||||||
|
invalid = '{"name": "incomplete"}'
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
JsonCoder.decode_as_type(invalid, type_=PDItem)
|
||||||
|
|||||||
@@ -79,3 +79,10 @@ def test_method() -> None:
|
|||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
response = client.get("/method")
|
response = client.get("/method")
|
||||||
assert response.json() == 17
|
assert response.json() == 17
|
||||||
|
|
||||||
|
|
||||||
|
def test_pydantic_model() -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
r1 = client.get("/pydantic_instance").json()
|
||||||
|
r2 = client.get("/pydantic_instance").json()
|
||||||
|
assert r1 == r2
|
||||||
|
|||||||
Reference in New Issue
Block a user