Merge pull request #128 from mjpieters/namespaced_injection

Inject dependencies using a namespace
This commit is contained in:
long2ice
2023-05-09 18:15:32 +08:00
committed by GitHub
4 changed files with 79 additions and 46 deletions

View File

@@ -98,9 +98,22 @@ expire | int, states a caching time in seconds
namespace | str, namespace to use to store certain cache items
coder | which coder to use, e.g. JsonCoder
key_builder | which key builder to use, default to builtin
injected_dependency_namespace | prefix for injected dependency keywords, defaults to `__fastapi_cache`.
You can also use `cache` as decorator like other cache tools to cache common function result.
### Injected Request and Response dependencies
The `cache` decorator adds dependencies for the `Request` and `Response` objects, so that it can
add cache control headers to the outgoing response, and return a 304 Not Modified response when
the incoming request has a matching If-Non-Match header. This only happens if the decorated
endpoint doesn't already list these objects directly.
The keyword arguments for these extra dependencies are named
`__fastapi_cache_request` and `__fastapi_cache_response` to minimize collisions.
Use the `injected_dependency_namespace` argument to `@cache()` to change the
prefix used if those names would clash anyway.
### Supported data types

View File

@@ -106,6 +106,17 @@ async def uncached_put():
return {"value": put_ret}
@app.get("/namespaced_injection")
@cache(namespace="test", expire=5, injected_dependency_namespace="monty_python")
def namespaced_injection(
__fastapi_cache_request: int = 42, __fastapi_cache_response: int = 17
) -> dict[str, int]:
return {
"__fastapi_cache_request": __fastapi_cache_request,
"__fastapi_cache_response": __fastapi_cache_response,
}
@app.on_event("startup")
async def startup():
FastAPICache.init(InMemoryBackend())

View File

@@ -1,7 +1,7 @@
import inspect
import logging
import sys
from functools import wraps
from inspect import Parameter, Signature, isawaitable, iscoroutinefunction
from typing import Awaitable, Callable, Optional, Type, TypeVar
if sys.version_info >= (3, 10):
@@ -10,7 +10,7 @@ else:
from typing_extensions import ParamSpec
from fastapi.concurrency import run_in_threadpool
from fastapi.dependencies.utils import get_typed_return_annotation
from fastapi.dependencies.utils import get_typed_return_annotation, get_typed_signature
from starlette.requests import Request
from starlette.responses import Response
@@ -24,34 +24,32 @@ P = ParamSpec("P")
R = TypeVar("R")
def _augment_signature(
signature: inspect.Signature, add_request: bool, add_response: bool
) -> inspect.Signature:
if not (add_request or add_response):
def _augment_signature(signature: Signature, *extra: Parameter) -> Signature:
if not extra:
return signature
parameters = list(signature.parameters.values())
variadic_keyword_params = []
while parameters and parameters[-1].kind is inspect.Parameter.VAR_KEYWORD:
while parameters and parameters[-1].kind is Parameter.VAR_KEYWORD:
variadic_keyword_params.append(parameters.pop())
if add_request:
parameters.append(
inspect.Parameter(
name="request",
annotation=Request,
kind=inspect.Parameter.KEYWORD_ONLY,
),
)
if add_response:
parameters.append(
inspect.Parameter(
name="response",
annotation=Response,
kind=inspect.Parameter.KEYWORD_ONLY,
),
)
return signature.replace(parameters=[*parameters, *variadic_keyword_params])
return signature.replace(parameters=[*parameters, *extra, *variadic_keyword_params])
def _locate_param(sig: Signature, dep: Parameter, to_inject: list[Parameter]) -> Parameter:
"""Locate an existing parameter in the decorated endpoint
If not found, returns the injectable parameter, and adds it to the to_inject list.
"""
param = next(
(param for param in sig.parameters.values() if param.annotation is dep.annotation),
None,
)
if param is None:
to_inject.append(dep)
param = dep
return param
def cache(
@@ -59,6 +57,7 @@ def cache(
coder: Optional[Type[Coder]] = None,
key_builder: Optional[KeyBuilder] = None,
namespace: Optional[str] = "",
injected_dependency_namespace: str = "__fastapi_cache",
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
"""
cache all function
@@ -70,16 +69,23 @@ def cache(
:return:
"""
injected_request = Parameter(
name=f"{injected_dependency_namespace}_request",
annotation=Request,
kind=Parameter.KEYWORD_ONLY,
)
injected_response = Parameter(
name=f"{injected_dependency_namespace}_response",
annotation=Response,
kind=Parameter.KEYWORD_ONLY,
)
def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
signature = inspect.signature(func)
request_param = next(
(param for param in signature.parameters.values() if param.annotation is Request),
None,
)
response_param = next(
(param for param in signature.parameters.values() if param.annotation is Response),
None,
)
# get_typed_signature ensures that any forward references are resolved first
wrapped_signature = get_typed_signature(func)
to_inject: list[Parameter] = []
request_param = _locate_param(wrapped_signature, injected_request, to_inject)
response_param = _locate_param(wrapped_signature, injected_response, to_inject)
return_type = get_typed_return_annotation(func)
@wraps(func)
@@ -90,14 +96,13 @@ def cache(
async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R:
"""Run cached sync functions in thread pool just like FastAPI."""
# if the wrapped function does NOT have request or response in its function signature,
# make sure we don't pass them in as keyword arguments
if not request_param:
kwargs.pop("request", None)
if not response_param:
kwargs.pop("response", None)
# if the wrapped function does NOT have request or response in
# its function signature, make sure we don't pass them in as
# keyword arguments
kwargs.pop(injected_request.name, None)
kwargs.pop(injected_response.name, None)
if inspect.iscoroutinefunction(func):
if iscoroutinefunction(func):
# async, return as is.
# unintuitively, we have to await once here, so that caller
# does not have to await twice. See
@@ -109,8 +114,8 @@ def cache(
return await run_in_threadpool(func, *args, **kwargs)
copy_kwargs = kwargs.copy()
request: Optional[Request] = copy_kwargs.pop("request", None)
response: Optional[Response] = copy_kwargs.pop("response", None)
request: Optional[Request] = copy_kwargs.pop(request_param.name, None)
response: Optional[Response] = copy_kwargs.pop(response_param.name, None)
if (
request and request.headers.get("Cache-Control") in ("no-store", "no-cache")
) or not FastAPICache.get_enable():
@@ -129,7 +134,7 @@ def cache(
args=args,
kwargs=copy_kwargs,
)
if inspect.isawaitable(cache_key):
if isawaitable(cache_key):
cache_key = await cache_key
try:
@@ -178,9 +183,7 @@ def cache(
response.headers["ETag"] = etag
return ret
inner.__signature__ = _augment_signature(
signature, request_param is None, response_param is None
)
inner.__signature__ = _augment_signature(wrapped_signature, *to_inject)
return inner
return wrapper

View File

@@ -94,3 +94,9 @@ def test_non_get() -> None:
assert response.json() == {"value": 1}
response = client.put("/uncached_put")
assert response.json() == {"value": 2}
def test_alternate_injected_namespace() -> None:
with TestClient(app) as client:
response = client.get("/namespaced_injection")
assert response.json() == {"__fastapi_cache_request": 42, "__fastapi_cache_response": 17}