diff --git a/examples/in_memory/main.py b/examples/in_memory/main.py index 1df6bc2..471899e 100644 --- a/examples/in_memory/main.py +++ b/examples/in_memory/main.py @@ -2,7 +2,7 @@ import pendulum import uvicorn from fastapi import FastAPI from starlette.requests import Request -from starlette.responses import Response +from starlette.responses import Response, JSONResponse from fastapi_cache import FastAPICache from fastapi_cache.backends.inmemory import InMemoryBackend @@ -42,6 +42,7 @@ async def get_date(): async def get_datetime(request: Request, response: Response): return {"now": pendulum.now()} + @app.get("/sync-me") @cache(namespace="test") def sync_me(): @@ -50,6 +51,12 @@ def sync_me(): return 42 +@app.get("/cache_response_obj") +@cache(namespace="test", expire=5) +async def cache_response_obj(): + return JSONResponse({"a": 1}) + + @app.on_event("startup") async def startup(): FastAPICache.init(InMemoryBackend()) diff --git a/examples/redis/main.py b/examples/redis/main.py index 773416e..98d1401 100644 --- a/examples/redis/main.py +++ b/examples/redis/main.py @@ -6,7 +6,7 @@ import uvicorn from fastapi import FastAPI from redis.asyncio.connection import ConnectionPool from starlette.requests import Request -from starlette.responses import Response +from starlette.responses import Response, JSONResponse from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates @@ -73,6 +73,12 @@ async def cache_html(request: Request): }) +@app.get("/cache_response_obj") +@cache(namespace="test", expire=5) +async def cache_response_obj(): + return JSONResponse({"a": 1}) + + @app.on_event("startup") async def startup(): pool = ConnectionPool.from_url(url="redis://redis") diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 6d54aef..9ebb1ac 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -58,3 +58,13 @@ def test_sync() -> None: with TestClient(app) as client: response = client.get("/sync-me") assert response.json() == 42 + + +def test_cache_response_obj() -> None: + with TestClient(app) as client: + cache_response = client.get("cache_response_obj") + assert cache_response.json() == {"a": 1} + get_cache_response = client.get("cache_response_obj") + assert get_cache_response.json() == {"a": 1} + assert get_cache_response.headers.get("cache-control") + assert get_cache_response.headers.get("etag")