diff --git a/fastapi_cache/decorator.py b/fastapi_cache/decorator.py index 0089dbc..209064f 100644 --- a/fastapi_cache/decorator.py +++ b/fastapi_cache/decorator.py @@ -13,6 +13,7 @@ from fastapi.concurrency import run_in_threadpool from fastapi.dependencies.utils import get_typed_return_annotation, get_typed_signature from starlette.requests import Request from starlette.responses import Response +from starlette.status import HTTP_304_NOT_MODIFIED from fastapi_cache import FastAPICache from fastapi_cache.coder import Coder @@ -52,6 +53,24 @@ def _locate_param(sig: Signature, dep: Parameter, to_inject: List[Parameter]) -> return param +def _uncacheable(request: Optional[Request]) -> bool: + """Determine if this request should not be cached + + Returns true if: + - Caching has been disabled globally + - This is not a GET request + - The request has a Cache-Control header with a value of "no-store" or "no-cache" + + """ + if not FastAPICache.get_enable(): + return True + if request is None: + return False + if request.method != "GET": + return True + return request.headers.get("Cache-Control") in ("no-store", "no-cache") + + def cache( expire: Optional[int] = None, coder: Optional[Type[Coder]] = None, @@ -116,9 +135,8 @@ def cache( copy_kwargs = kwargs.copy() request: Optional[Request] = copy_kwargs.pop(request_param.name, None) # type: ignore[assignment] response: Optional[Response] = copy_kwargs.pop(response_param.name, None) # type: ignore[assignment] - if ( - request and request.headers.get("Cache-Control") in ("no-store", "no-cache") - ) or not FastAPICache.get_enable(): + + if _uncacheable(request): return await ensure_async_func(*args, **kwargs) prefix = FastAPICache.get_prefix() @@ -146,47 +164,46 @@ def cache( f"Error retrieving cache key '{cache_key}' from backend:", exc_info=True ) ttl, cached = 0, None - if not request: - if cached is not None: - return cast(R, coder.decode_as_type(cached, type_=return_type)) - ret = await ensure_async_func(*args, **kwargs) + + if cached is None: # cache miss + result = await ensure_async_func(*args, **kwargs) + to_cache = coder.encode(result) + try: - await backend.set(cache_key, coder.encode(ret), expire) + await backend.set(cache_key, to_cache, expire) except Exception: logger.warning( f"Error setting cache key '{cache_key}' in backend:", exc_info=True ) - return ret - if request.method != "GET": - return await ensure_async_func(*args, **kwargs) - - if_none_match = request.headers.get("if-none-match") - if cached is not None: if response: - response.headers["Cache-Control"] = f"max-age={ttl}" + response.headers.update( + { + "Cache-Control": f"max-age={expire}", + "ETag": f"W/{hash(to_cache)}", + } + ) + + else: # cache hit + if response: etag = f"W/{hash(cached)}" + response.headers.update( + { + "Cache-Control": f"max-age={ttl}", + "ETag": etag, + } + ) + + if_none_match = request and request.headers.get("if-none-match") if if_none_match == etag: - response.status_code = 304 + response.status_code = HTTP_304_NOT_MODIFIED return response - response.headers["ETag"] = etag - return cast(R, coder.decode_as_type(cached, type_=return_type)) - ret = await ensure_async_func(*args, **kwargs) - encoded_ret = coder.encode(ret) + result = cast(R, coder.decode_as_type(cached, type_=return_type)) - try: - await backend.set(cache_key, encoded_ret, expire) - except Exception: - logger.warning(f"Error setting cache key '{cache_key}' in backend:", exc_info=True) + return result - if response: - response.headers["Cache-Control"] = f"max-age={expire}" - etag = f"W/{hash(encoded_ret)}" - response.headers["ETag"] = etag - return ret - - inner.__signature__ = _augment_signature(wrapped_signature, *to_inject) # type: ignore[attr-defined] + setattr(inner, "__signature__", _augment_signature(wrapped_signature, *to_inject)) return inner return wrapper