Files
fastapi-cache/fastapi_cache/decorator.py
Martijn Pieters 29426de95f Refactor decorator, consolidate miss / hit paths
Use just three code paths: uncacheable, cache miss and cache hit. This
makes it much easier to follow what happens for each case. the only
places where the inner function now exits early are when the call is
uncacheable, or when there is a cache hit and the request included a
matching If-Not-Modified header.

- Use a utility function to capture when a request should not use the
  cache
- Use the starlette.status constant for the 'not modified' status for
  code clarity.
- Use `setattr()` for the inner function signature, avoiding the need
  for a type checker override comment.
2023-05-14 17:03:57 +01:00

210 lines
7.5 KiB
Python

import logging
import sys
from functools import wraps
from inspect import Parameter, Signature, isawaitable, iscoroutinefunction
from typing import Awaitable, Callable, List, Optional, Type, TypeVar, Union, cast
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
from fastapi.concurrency import run_in_threadpool
from fastapi.dependencies.utils import get_typed_return_annotation, get_typed_signature
from starlette.requests import Request
from starlette.responses import Response
from starlette.status import HTTP_304_NOT_MODIFIED
from fastapi_cache import FastAPICache
from fastapi_cache.coder import Coder
from fastapi_cache.types import KeyBuilder
logger: logging.Logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
P = ParamSpec("P")
R = TypeVar("R")
def _augment_signature(signature: Signature, *extra: Parameter) -> Signature:
if not extra:
return signature
parameters = list(signature.parameters.values())
variadic_keyword_params: List[Parameter] = []
while parameters and parameters[-1].kind is Parameter.VAR_KEYWORD:
variadic_keyword_params.append(parameters.pop())
return signature.replace(parameters=[*parameters, *extra, *variadic_keyword_params])
def _locate_param(sig: Signature, dep: Parameter, to_inject: List[Parameter]) -> Parameter:
"""Locate an existing parameter in the decorated endpoint
If not found, returns the injectable parameter, and adds it to the to_inject list.
"""
param = next(
(param for param in sig.parameters.values() if param.annotation is dep.annotation),
None,
)
if param is None:
to_inject.append(dep)
param = dep
return param
def _uncacheable(request: Optional[Request]) -> bool:
"""Determine if this request should not be cached
Returns true if:
- Caching has been disabled globally
- This is not a GET request
- The request has a Cache-Control header with a value of "no-store" or "no-cache"
"""
if not FastAPICache.get_enable():
return True
if request is None:
return False
if request.method != "GET":
return True
return request.headers.get("Cache-Control") in ("no-store", "no-cache")
def cache(
expire: Optional[int] = None,
coder: Optional[Type[Coder]] = None,
key_builder: Optional[KeyBuilder] = None,
namespace: str = "",
injected_dependency_namespace: str = "__fastapi_cache",
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[Union[R, Response]]]]:
"""
cache all function
:param namespace:
:param expire:
:param coder:
:param key_builder:
:return:
"""
injected_request = Parameter(
name=f"{injected_dependency_namespace}_request",
annotation=Request,
kind=Parameter.KEYWORD_ONLY,
)
injected_response = Parameter(
name=f"{injected_dependency_namespace}_response",
annotation=Response,
kind=Parameter.KEYWORD_ONLY,
)
def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[Union[R, Response]]]:
# get_typed_signature ensures that any forward references are resolved first
wrapped_signature = get_typed_signature(func)
to_inject: List[Parameter] = []
request_param = _locate_param(wrapped_signature, injected_request, to_inject)
response_param = _locate_param(wrapped_signature, injected_response, to_inject)
return_type = get_typed_return_annotation(func)
@wraps(func)
async def inner(*args: P.args, **kwargs: P.kwargs) -> Union[R, Response]:
nonlocal coder
nonlocal expire
nonlocal key_builder
async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R:
"""Run cached sync functions in thread pool just like FastAPI."""
# if the wrapped function does NOT have request or response in
# its function signature, make sure we don't pass them in as
# keyword arguments
kwargs.pop(injected_request.name, None)
kwargs.pop(injected_response.name, None)
if iscoroutinefunction(func):
# async, return as is.
# unintuitively, we have to await once here, so that caller
# does not have to await twice. See
# https://stackoverflow.com/a/59268198/532513
return await func(*args, **kwargs)
else:
# sync, wrap in thread and return async
# see above why we have to await even although caller also awaits.
return await run_in_threadpool(func, *args, **kwargs) # type: ignore[arg-type]
copy_kwargs = kwargs.copy()
request: Optional[Request] = copy_kwargs.pop(request_param.name, None) # type: ignore[assignment]
response: Optional[Response] = copy_kwargs.pop(response_param.name, None) # type: ignore[assignment]
if _uncacheable(request):
return await ensure_async_func(*args, **kwargs)
prefix = FastAPICache.get_prefix()
coder = coder or FastAPICache.get_coder()
expire = expire or FastAPICache.get_expire()
key_builder = key_builder or FastAPICache.get_key_builder()
backend = FastAPICache.get_backend()
cache_key = key_builder(
func,
f"{prefix}:{namespace}",
request=request,
response=response,
args=args,
kwargs=copy_kwargs,
)
if isawaitable(cache_key):
cache_key = await cache_key
assert isinstance(cache_key, str)
try:
ttl, cached = await backend.get_with_ttl(cache_key)
except Exception:
logger.warning(
f"Error retrieving cache key '{cache_key}' from backend:", exc_info=True
)
ttl, cached = 0, None
if cached is None: # cache miss
result = await ensure_async_func(*args, **kwargs)
to_cache = coder.encode(result)
try:
await backend.set(cache_key, to_cache, expire)
except Exception:
logger.warning(
f"Error setting cache key '{cache_key}' in backend:", exc_info=True
)
if response:
response.headers.update(
{
"Cache-Control": f"max-age={expire}",
"ETag": f"W/{hash(to_cache)}",
}
)
else: # cache hit
if response:
etag = f"W/{hash(cached)}"
response.headers.update(
{
"Cache-Control": f"max-age={ttl}",
"ETag": etag,
}
)
if_none_match = request and request.headers.get("if-none-match")
if if_none_match == etag:
response.status_code = HTTP_304_NOT_MODIFIED
return response
result = cast(R, coder.decode_as_type(cached, type_=return_type))
return result
setattr(inner, "__signature__", _augment_signature(wrapped_signature, *to_inject))
return inner
return wrapper