mirror of
https://github.com/long2ice/fastapi-cache.git
synced 2026-03-25 04:57:54 +00:00
The header name is configurable, and defaults to `X-FastAPI-Cache`, the value is either `HIT` or `MISS`. Note that the header is not set at all when the cache is disabled.
213 lines
7.7 KiB
Python
213 lines
7.7 KiB
Python
import logging
|
|
import sys
|
|
from functools import wraps
|
|
from inspect import Parameter, Signature, isawaitable, iscoroutinefunction
|
|
from typing import Awaitable, Callable, List, Optional, Type, TypeVar, Union, cast
|
|
|
|
if sys.version_info >= (3, 10):
|
|
from typing import ParamSpec
|
|
else:
|
|
from typing_extensions import ParamSpec
|
|
|
|
from fastapi.concurrency import run_in_threadpool
|
|
from fastapi.dependencies.utils import get_typed_return_annotation, get_typed_signature
|
|
from starlette.requests import Request
|
|
from starlette.responses import Response
|
|
from starlette.status import HTTP_304_NOT_MODIFIED
|
|
|
|
from fastapi_cache import FastAPICache
|
|
from fastapi_cache.coder import Coder
|
|
from fastapi_cache.types import KeyBuilder
|
|
|
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
logger.addHandler(logging.NullHandler())
|
|
P = ParamSpec("P")
|
|
R = TypeVar("R")
|
|
|
|
|
|
def _augment_signature(signature: Signature, *extra: Parameter) -> Signature:
|
|
if not extra:
|
|
return signature
|
|
|
|
parameters = list(signature.parameters.values())
|
|
variadic_keyword_params: List[Parameter] = []
|
|
while parameters and parameters[-1].kind is Parameter.VAR_KEYWORD:
|
|
variadic_keyword_params.append(parameters.pop())
|
|
|
|
return signature.replace(parameters=[*parameters, *extra, *variadic_keyword_params])
|
|
|
|
|
|
def _locate_param(sig: Signature, dep: Parameter, to_inject: List[Parameter]) -> Parameter:
|
|
"""Locate an existing parameter in the decorated endpoint
|
|
|
|
If not found, returns the injectable parameter, and adds it to the to_inject list.
|
|
|
|
"""
|
|
param = next(
|
|
(param for param in sig.parameters.values() if param.annotation is dep.annotation),
|
|
None,
|
|
)
|
|
if param is None:
|
|
to_inject.append(dep)
|
|
param = dep
|
|
return param
|
|
|
|
|
|
def _uncacheable(request: Optional[Request]) -> bool:
|
|
"""Determine if this request should not be cached
|
|
|
|
Returns true if:
|
|
- Caching has been disabled globally
|
|
- This is not a GET request
|
|
- The request has a Cache-Control header with a value of "no-store" or "no-cache"
|
|
|
|
"""
|
|
if not FastAPICache.get_enable():
|
|
return True
|
|
if request is None:
|
|
return False
|
|
if request.method != "GET":
|
|
return True
|
|
return request.headers.get("Cache-Control") in ("no-store", "no-cache")
|
|
|
|
|
|
def cache(
|
|
expire: Optional[int] = None,
|
|
coder: Optional[Type[Coder]] = None,
|
|
key_builder: Optional[KeyBuilder] = None,
|
|
namespace: str = "",
|
|
injected_dependency_namespace: str = "__fastapi_cache",
|
|
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[Union[R, Response]]]]:
|
|
"""
|
|
cache all function
|
|
:param namespace:
|
|
:param expire:
|
|
:param coder:
|
|
:param key_builder:
|
|
|
|
:return:
|
|
"""
|
|
|
|
injected_request = Parameter(
|
|
name=f"{injected_dependency_namespace}_request",
|
|
annotation=Request,
|
|
kind=Parameter.KEYWORD_ONLY,
|
|
)
|
|
injected_response = Parameter(
|
|
name=f"{injected_dependency_namespace}_response",
|
|
annotation=Response,
|
|
kind=Parameter.KEYWORD_ONLY,
|
|
)
|
|
|
|
def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[Union[R, Response]]]:
|
|
# get_typed_signature ensures that any forward references are resolved first
|
|
wrapped_signature = get_typed_signature(func)
|
|
to_inject: List[Parameter] = []
|
|
request_param = _locate_param(wrapped_signature, injected_request, to_inject)
|
|
response_param = _locate_param(wrapped_signature, injected_response, to_inject)
|
|
return_type = get_typed_return_annotation(func)
|
|
|
|
@wraps(func)
|
|
async def inner(*args: P.args, **kwargs: P.kwargs) -> Union[R, Response]:
|
|
nonlocal coder
|
|
nonlocal expire
|
|
nonlocal key_builder
|
|
|
|
async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R:
|
|
"""Run cached sync functions in thread pool just like FastAPI."""
|
|
# if the wrapped function does NOT have request or response in
|
|
# its function signature, make sure we don't pass them in as
|
|
# keyword arguments
|
|
kwargs.pop(injected_request.name, None)
|
|
kwargs.pop(injected_response.name, None)
|
|
|
|
if iscoroutinefunction(func):
|
|
# async, return as is.
|
|
# unintuitively, we have to await once here, so that caller
|
|
# does not have to await twice. See
|
|
# https://stackoverflow.com/a/59268198/532513
|
|
return await func(*args, **kwargs)
|
|
else:
|
|
# sync, wrap in thread and return async
|
|
# see above why we have to await even although caller also awaits.
|
|
return await run_in_threadpool(func, *args, **kwargs) # type: ignore[arg-type]
|
|
|
|
copy_kwargs = kwargs.copy()
|
|
request: Optional[Request] = copy_kwargs.pop(request_param.name, None) # type: ignore[assignment]
|
|
response: Optional[Response] = copy_kwargs.pop(response_param.name, None) # type: ignore[assignment]
|
|
|
|
if _uncacheable(request):
|
|
return await ensure_async_func(*args, **kwargs)
|
|
|
|
prefix = FastAPICache.get_prefix()
|
|
coder = coder or FastAPICache.get_coder()
|
|
expire = expire or FastAPICache.get_expire()
|
|
key_builder = key_builder or FastAPICache.get_key_builder()
|
|
backend = FastAPICache.get_backend()
|
|
cache_status_header = FastAPICache.get_cache_status_header()
|
|
|
|
cache_key = key_builder(
|
|
func,
|
|
f"{prefix}:{namespace}",
|
|
request=request,
|
|
response=response,
|
|
args=args,
|
|
kwargs=copy_kwargs,
|
|
)
|
|
if isawaitable(cache_key):
|
|
cache_key = await cache_key
|
|
assert isinstance(cache_key, str)
|
|
|
|
try:
|
|
ttl, cached = await backend.get_with_ttl(cache_key)
|
|
except Exception:
|
|
logger.warning(
|
|
f"Error retrieving cache key '{cache_key}' from backend:", exc_info=True
|
|
)
|
|
ttl, cached = 0, None
|
|
|
|
if cached is None: # cache miss
|
|
result = await ensure_async_func(*args, **kwargs)
|
|
to_cache = coder.encode(result)
|
|
|
|
try:
|
|
await backend.set(cache_key, to_cache, expire)
|
|
except Exception:
|
|
logger.warning(
|
|
f"Error setting cache key '{cache_key}' in backend:", exc_info=True
|
|
)
|
|
|
|
if response:
|
|
response.headers.update(
|
|
{
|
|
"Cache-Control": f"max-age={expire}",
|
|
"ETag": f"W/{hash(to_cache)}",
|
|
cache_status_header: "MISS",
|
|
}
|
|
)
|
|
|
|
else: # cache hit
|
|
if response:
|
|
etag = f"W/{hash(cached)}"
|
|
response.headers.update(
|
|
{
|
|
"Cache-Control": f"max-age={ttl}",
|
|
"ETag": etag,
|
|
cache_status_header: "HIT",
|
|
}
|
|
)
|
|
|
|
if_none_match = request and request.headers.get("if-none-match")
|
|
if if_none_match == etag:
|
|
response.status_code = HTTP_304_NOT_MODIFIED
|
|
return response
|
|
|
|
result = cast(R, coder.decode_as_type(cached, type_=return_type))
|
|
|
|
return result
|
|
|
|
setattr(inner, "__signature__", _augment_signature(wrapped_signature, *to_inject))
|
|
return inner
|
|
|
|
return wrapper
|