Refactor decorator, consolidate miss / hit paths

Use just three code paths: uncacheable, cache miss and cache hit. This
makes it much easier to follow what happens for each case. the only
places where the inner function now exits early are when the call is
uncacheable, or when there is a cache hit and the request included a
matching If-Not-Modified header.

- Use a utility function to capture when a request should not use the
  cache
- Use the starlette.status constant for the 'not modified' status for
  code clarity.
- Use `setattr()` for the inner function signature, avoiding the need
  for a type checker override comment.
This commit is contained in:
Martijn Pieters
2023-05-12 13:46:26 +01:00
committed by Martijn Pieters
parent d10f4af6d6
commit 29426de95f

View File

@@ -13,6 +13,7 @@ from fastapi.concurrency import run_in_threadpool
from fastapi.dependencies.utils import get_typed_return_annotation, get_typed_signature
from starlette.requests import Request
from starlette.responses import Response
from starlette.status import HTTP_304_NOT_MODIFIED
from fastapi_cache import FastAPICache
from fastapi_cache.coder import Coder
@@ -52,6 +53,24 @@ def _locate_param(sig: Signature, dep: Parameter, to_inject: List[Parameter]) ->
return param
def _uncacheable(request: Optional[Request]) -> bool:
"""Determine if this request should not be cached
Returns true if:
- Caching has been disabled globally
- This is not a GET request
- The request has a Cache-Control header with a value of "no-store" or "no-cache"
"""
if not FastAPICache.get_enable():
return True
if request is None:
return False
if request.method != "GET":
return True
return request.headers.get("Cache-Control") in ("no-store", "no-cache")
def cache(
expire: Optional[int] = None,
coder: Optional[Type[Coder]] = None,
@@ -116,9 +135,8 @@ def cache(
copy_kwargs = kwargs.copy()
request: Optional[Request] = copy_kwargs.pop(request_param.name, None) # type: ignore[assignment]
response: Optional[Response] = copy_kwargs.pop(response_param.name, None) # type: ignore[assignment]
if (
request and request.headers.get("Cache-Control") in ("no-store", "no-cache")
) or not FastAPICache.get_enable():
if _uncacheable(request):
return await ensure_async_func(*args, **kwargs)
prefix = FastAPICache.get_prefix()
@@ -146,47 +164,46 @@ def cache(
f"Error retrieving cache key '{cache_key}' from backend:", exc_info=True
)
ttl, cached = 0, None
if not request:
if cached is not None:
return cast(R, coder.decode_as_type(cached, type_=return_type))
ret = await ensure_async_func(*args, **kwargs)
if cached is None: # cache miss
result = await ensure_async_func(*args, **kwargs)
to_cache = coder.encode(result)
try:
await backend.set(cache_key, coder.encode(ret), expire)
await backend.set(cache_key, to_cache, expire)
except Exception:
logger.warning(
f"Error setting cache key '{cache_key}' in backend:", exc_info=True
)
return ret
if request.method != "GET":
return await ensure_async_func(*args, **kwargs)
if_none_match = request.headers.get("if-none-match")
if cached is not None:
if response:
response.headers["Cache-Control"] = f"max-age={ttl}"
response.headers.update(
{
"Cache-Control": f"max-age={expire}",
"ETag": f"W/{hash(to_cache)}",
}
)
else: # cache hit
if response:
etag = f"W/{hash(cached)}"
response.headers.update(
{
"Cache-Control": f"max-age={ttl}",
"ETag": etag,
}
)
if_none_match = request and request.headers.get("if-none-match")
if if_none_match == etag:
response.status_code = 304
response.status_code = HTTP_304_NOT_MODIFIED
return response
response.headers["ETag"] = etag
return cast(R, coder.decode_as_type(cached, type_=return_type))
ret = await ensure_async_func(*args, **kwargs)
encoded_ret = coder.encode(ret)
result = cast(R, coder.decode_as_type(cached, type_=return_type))
try:
await backend.set(cache_key, encoded_ret, expire)
except Exception:
logger.warning(f"Error setting cache key '{cache_key}' in backend:", exc_info=True)
return result
if response:
response.headers["Cache-Control"] = f"max-age={expire}"
etag = f"W/{hash(encoded_ret)}"
response.headers["ETag"] = etag
return ret
inner.__signature__ = _augment_signature(wrapped_signature, *to_inject) # type: ignore[attr-defined]
setattr(inner, "__signature__", _augment_signature(wrapped_signature, *to_inject))
return inner
return wrapper