Files
fastapi-cache/fastapi_cache/decorator.py
Martijn Pieters 1d9e126037 Switch to ruff to handle linting and formatting
Ruff handles black, flake8 and isort in one package, and is way faster.
The isort rules had not been enforced, so this commit includes a lot
of import resorting changes.

I switched to flake8-bugbear and the standard black-compatible line
length of 80 + 10% (so max 88 characters), so some line reflowing is
included too.

Finally, because bugbear rightly points out that `setattr()` is less
performant, I've switched the `__signature__` assigment back to using
a direct assignment with type ignore comment.
2023-05-16 12:18:24 +01:00

231 lines
7.8 KiB
Python

import logging
import sys
from functools import wraps
from inspect import Parameter, Signature, isawaitable, iscoroutinefunction
from typing import (
Awaitable,
Callable,
List,
Optional,
Type,
TypeVar,
Union,
cast,
)
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
from fastapi.concurrency import run_in_threadpool
from fastapi.dependencies.utils import (
get_typed_return_annotation,
get_typed_signature,
)
from starlette.requests import Request
from starlette.responses import Response
from starlette.status import HTTP_304_NOT_MODIFIED
from fastapi_cache import FastAPICache
from fastapi_cache.coder import Coder
from fastapi_cache.types import KeyBuilder
logger: logging.Logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
P = ParamSpec("P")
R = TypeVar("R")
def _augment_signature(signature: Signature, *extra: Parameter) -> Signature:
if not extra:
return signature
parameters = list(signature.parameters.values())
variadic_keyword_params: List[Parameter] = []
while parameters and parameters[-1].kind is Parameter.VAR_KEYWORD:
variadic_keyword_params.append(parameters.pop())
return signature.replace(parameters=[*parameters, *extra, *variadic_keyword_params])
def _locate_param(
sig: Signature, dep: Parameter, to_inject: List[Parameter]
) -> Parameter:
"""Locate an existing parameter in the decorated endpoint
If not found, returns the injectable parameter, and adds it to the to_inject list.
"""
param = next(
(p for p in sig.parameters.values() if p.annotation is dep.annotation), None
)
if param is None:
to_inject.append(dep)
param = dep
return param
def _uncacheable(request: Optional[Request]) -> bool:
"""Determine if this request should not be cached
Returns true if:
- Caching has been disabled globally
- This is not a GET request
- The request has a Cache-Control header with a value of "no-store" or "no-cache"
"""
if not FastAPICache.get_enable():
return True
if request is None:
return False
if request.method != "GET":
return True
return request.headers.get("Cache-Control") in ("no-store", "no-cache")
def cache(
expire: Optional[int] = None,
coder: Optional[Type[Coder]] = None,
key_builder: Optional[KeyBuilder] = None,
namespace: str = "",
injected_dependency_namespace: str = "__fastapi_cache",
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[Union[R, Response]]]]:
"""
cache all function
:param namespace:
:param expire:
:param coder:
:param key_builder:
:return:
"""
injected_request = Parameter(
name=f"{injected_dependency_namespace}_request",
annotation=Request,
kind=Parameter.KEYWORD_ONLY,
)
injected_response = Parameter(
name=f"{injected_dependency_namespace}_response",
annotation=Response,
kind=Parameter.KEYWORD_ONLY,
)
def wrapper(
func: Callable[P, Awaitable[R]]
) -> Callable[P, Awaitable[Union[R, Response]]]:
# get_typed_signature ensures that any forward references are resolved first
wrapped_signature = get_typed_signature(func)
to_inject: List[Parameter] = []
request_param = _locate_param(wrapped_signature, injected_request, to_inject)
response_param = _locate_param(wrapped_signature, injected_response, to_inject)
return_type = get_typed_return_annotation(func)
@wraps(func)
async def inner(*args: P.args, **kwargs: P.kwargs) -> Union[R, Response]:
nonlocal coder
nonlocal expire
nonlocal key_builder
async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R:
"""Run cached sync functions in thread pool just like FastAPI."""
# if the wrapped function does NOT have request or response in
# its function signature, make sure we don't pass them in as
# keyword arguments
kwargs.pop(injected_request.name, None)
kwargs.pop(injected_response.name, None)
if iscoroutinefunction(func):
# async, return as is.
# unintuitively, we have to await once here, so that caller
# does not have to await twice. See
# https://stackoverflow.com/a/59268198/532513
return await func(*args, **kwargs)
else:
# sync, wrap in thread and return async
# see above why we have to await even although caller also awaits.
return await run_in_threadpool(func, *args, **kwargs) # type: ignore[arg-type]
copy_kwargs = kwargs.copy()
request: Optional[Request] = copy_kwargs.pop(request_param.name, None) # type: ignore[assignment]
response: Optional[Response] = copy_kwargs.pop(response_param.name, None) # type: ignore[assignment]
if _uncacheable(request):
return await ensure_async_func(*args, **kwargs)
prefix = FastAPICache.get_prefix()
coder = coder or FastAPICache.get_coder()
expire = expire or FastAPICache.get_expire()
key_builder = key_builder or FastAPICache.get_key_builder()
backend = FastAPICache.get_backend()
cache_status_header = FastAPICache.get_cache_status_header()
cache_key = key_builder(
func,
f"{prefix}:{namespace}",
request=request,
response=response,
args=args,
kwargs=copy_kwargs,
)
if isawaitable(cache_key):
cache_key = await cache_key
assert isinstance(cache_key, str)
try:
ttl, cached = await backend.get_with_ttl(cache_key)
except Exception:
logger.warning(
f"Error retrieving cache key '{cache_key}' from backend:",
exc_info=True,
)
ttl, cached = 0, None
if cached is None: # cache miss
result = await ensure_async_func(*args, **kwargs)
to_cache = coder.encode(result)
try:
await backend.set(cache_key, to_cache, expire)
except Exception:
logger.warning(
f"Error setting cache key '{cache_key}' in backend:",
exc_info=True,
)
if response:
response.headers.update(
{
"Cache-Control": f"max-age={expire}",
"ETag": f"W/{hash(to_cache)}",
cache_status_header: "MISS",
}
)
else: # cache hit
if response:
etag = f"W/{hash(cached)}"
response.headers.update(
{
"Cache-Control": f"max-age={ttl}",
"ETag": etag,
cache_status_header: "HIT",
}
)
if_none_match = request and request.headers.get("if-none-match")
if if_none_match == etag:
response.status_code = HTTP_304_NOT_MODIFIED
return response
result = cast(R, coder.decode_as_type(cached, type_=return_type))
return result
inner.__signature__ = _augment_signature(wrapped_signature, *to_inject) # type: ignore[attr-defined]
return inner
return wrapper