From 29426de95f95834b83b208105902449d5edcee99 Mon Sep 17 00:00:00 2001 From: Martijn Pieters Date: Fri, 12 May 2023 13:46:26 +0100 Subject: [PATCH] Refactor decorator, consolidate miss / hit paths Use just three code paths: uncacheable, cache miss and cache hit. This makes it much easier to follow what happens for each case. the only places where the inner function now exits early are when the call is uncacheable, or when there is a cache hit and the request included a matching If-Not-Modified header. - Use a utility function to capture when a request should not use the cache - Use the starlette.status constant for the 'not modified' status for code clarity. - Use `setattr()` for the inner function signature, avoiding the need for a type checker override comment. --- fastapi_cache/decorator.py | 79 +++++++++++++++++++++++--------------- 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/fastapi_cache/decorator.py b/fastapi_cache/decorator.py index 0089dbc..209064f 100644 --- a/fastapi_cache/decorator.py +++ b/fastapi_cache/decorator.py @@ -13,6 +13,7 @@ from fastapi.concurrency import run_in_threadpool from fastapi.dependencies.utils import get_typed_return_annotation, get_typed_signature from starlette.requests import Request from starlette.responses import Response +from starlette.status import HTTP_304_NOT_MODIFIED from fastapi_cache import FastAPICache from fastapi_cache.coder import Coder @@ -52,6 +53,24 @@ def _locate_param(sig: Signature, dep: Parameter, to_inject: List[Parameter]) -> return param +def _uncacheable(request: Optional[Request]) -> bool: + """Determine if this request should not be cached + + Returns true if: + - Caching has been disabled globally + - This is not a GET request + - The request has a Cache-Control header with a value of "no-store" or "no-cache" + + """ + if not FastAPICache.get_enable(): + return True + if request is None: + return False + if request.method != "GET": + return True + return request.headers.get("Cache-Control") in ("no-store", "no-cache") + + def cache( expire: Optional[int] = None, coder: Optional[Type[Coder]] = None, @@ -116,9 +135,8 @@ def cache( copy_kwargs = kwargs.copy() request: Optional[Request] = copy_kwargs.pop(request_param.name, None) # type: ignore[assignment] response: Optional[Response] = copy_kwargs.pop(response_param.name, None) # type: ignore[assignment] - if ( - request and request.headers.get("Cache-Control") in ("no-store", "no-cache") - ) or not FastAPICache.get_enable(): + + if _uncacheable(request): return await ensure_async_func(*args, **kwargs) prefix = FastAPICache.get_prefix() @@ -146,47 +164,46 @@ def cache( f"Error retrieving cache key '{cache_key}' from backend:", exc_info=True ) ttl, cached = 0, None - if not request: - if cached is not None: - return cast(R, coder.decode_as_type(cached, type_=return_type)) - ret = await ensure_async_func(*args, **kwargs) + + if cached is None: # cache miss + result = await ensure_async_func(*args, **kwargs) + to_cache = coder.encode(result) + try: - await backend.set(cache_key, coder.encode(ret), expire) + await backend.set(cache_key, to_cache, expire) except Exception: logger.warning( f"Error setting cache key '{cache_key}' in backend:", exc_info=True ) - return ret - if request.method != "GET": - return await ensure_async_func(*args, **kwargs) - - if_none_match = request.headers.get("if-none-match") - if cached is not None: if response: - response.headers["Cache-Control"] = f"max-age={ttl}" + response.headers.update( + { + "Cache-Control": f"max-age={expire}", + "ETag": f"W/{hash(to_cache)}", + } + ) + + else: # cache hit + if response: etag = f"W/{hash(cached)}" + response.headers.update( + { + "Cache-Control": f"max-age={ttl}", + "ETag": etag, + } + ) + + if_none_match = request and request.headers.get("if-none-match") if if_none_match == etag: - response.status_code = 304 + response.status_code = HTTP_304_NOT_MODIFIED return response - response.headers["ETag"] = etag - return cast(R, coder.decode_as_type(cached, type_=return_type)) - ret = await ensure_async_func(*args, **kwargs) - encoded_ret = coder.encode(ret) + result = cast(R, coder.decode_as_type(cached, type_=return_type)) - try: - await backend.set(cache_key, encoded_ret, expire) - except Exception: - logger.warning(f"Error setting cache key '{cache_key}' in backend:", exc_info=True) + return result - if response: - response.headers["Cache-Control"] = f"max-age={expire}" - etag = f"W/{hash(encoded_ret)}" - response.headers["ETag"] = etag - return ret - - inner.__signature__ = _augment_signature(wrapped_signature, *to_inject) # type: ignore[attr-defined] + setattr(inner, "__signature__", _augment_signature(wrapped_signature, *to_inject)) return inner return wrapper