2022-08-01 00:04:10 +02:00
|
|
|
import inspect
|
2022-10-25 08:52:59 +07:00
|
|
|
import sys
|
2022-09-09 19:35:48 +03:00
|
|
|
from functools import wraps
|
2022-10-25 08:52:59 +07:00
|
|
|
from typing import Any, Awaitable, Callable, Optional, TypeVar
|
|
|
|
|
if sys.version_info >= (3, 10):
|
|
|
|
|
from typing import ParamSpec
|
|
|
|
|
else:
|
|
|
|
|
from typing_extensions import ParamSpec
|
2022-09-09 19:35:48 +03:00
|
|
|
|
|
|
|
|
from fastapi.concurrency import run_in_threadpool
|
2022-09-10 20:06:37 +08:00
|
|
|
from starlette.requests import Request
|
|
|
|
|
from starlette.responses import Response
|
2020-08-26 18:04:57 +08:00
|
|
|
|
|
|
|
|
from fastapi_cache import FastAPICache
|
2020-10-16 16:55:33 +08:00
|
|
|
from fastapi_cache.coder import Coder
|
2020-08-26 18:04:57 +08:00
|
|
|
|
|
|
|
|
|
2022-10-25 08:52:59 +07:00
|
|
|
P = ParamSpec("P")
|
|
|
|
|
R = TypeVar("R")
|
|
|
|
|
|
|
|
|
|
|
2020-08-26 18:04:57 +08:00
|
|
|
def cache(
|
2022-10-25 08:52:59 +07:00
|
|
|
expire: Optional[int] = None,
|
|
|
|
|
coder: Optional[Coder] = None,
|
|
|
|
|
key_builder: Optional[Callable[..., Any]] = None,
|
2020-08-26 18:04:57 +08:00
|
|
|
namespace: Optional[str] = "",
|
2022-10-25 08:52:59 +07:00
|
|
|
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
|
2020-08-26 18:04:57 +08:00
|
|
|
"""
|
|
|
|
|
cache all function
|
|
|
|
|
:param namespace:
|
|
|
|
|
:param expire:
|
|
|
|
|
:param coder:
|
|
|
|
|
:param key_builder:
|
2022-08-01 00:04:10 +02:00
|
|
|
|
2020-08-26 18:04:57 +08:00
|
|
|
:return:
|
|
|
|
|
"""
|
|
|
|
|
|
2022-10-25 08:52:59 +07:00
|
|
|
def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
2022-09-10 20:06:37 +08:00
|
|
|
signature = inspect.signature(func)
|
|
|
|
|
request_param = next(
|
|
|
|
|
(param for param in signature.parameters.values() if param.annotation is Request),
|
|
|
|
|
None,
|
|
|
|
|
)
|
|
|
|
|
response_param = next(
|
|
|
|
|
(param for param in signature.parameters.values() if param.annotation is Response),
|
|
|
|
|
None,
|
|
|
|
|
)
|
|
|
|
|
parameters = [*signature.parameters.values()]
|
|
|
|
|
if not request_param:
|
|
|
|
|
parameters.append(
|
|
|
|
|
inspect.Parameter(
|
|
|
|
|
name="request",
|
|
|
|
|
annotation=Request,
|
|
|
|
|
kind=inspect.Parameter.KEYWORD_ONLY,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
if not response_param:
|
|
|
|
|
parameters.append(
|
|
|
|
|
inspect.Parameter(
|
|
|
|
|
name="response",
|
|
|
|
|
annotation=Response,
|
|
|
|
|
kind=inspect.Parameter.KEYWORD_ONLY,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
if parameters:
|
|
|
|
|
signature = signature.replace(parameters=parameters)
|
|
|
|
|
func.__signature__ = signature
|
|
|
|
|
|
2020-08-26 18:04:57 +08:00
|
|
|
@wraps(func)
|
2022-10-25 08:52:59 +07:00
|
|
|
async def inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
2020-10-16 16:55:33 +08:00
|
|
|
nonlocal coder
|
|
|
|
|
nonlocal expire
|
|
|
|
|
nonlocal key_builder
|
2022-10-14 13:44:49 +02:00
|
|
|
|
2022-10-25 08:52:59 +07:00
|
|
|
async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R:
|
2022-10-14 13:44:49 +02:00
|
|
|
"""Run cached sync functions in thread pool just like FastAPI."""
|
2022-10-14 21:58:34 +02:00
|
|
|
# if the wrapped function does NOT have request or response in its function signature,
|
|
|
|
|
# make sure we don't pass them in as keyword arguments
|
|
|
|
|
if not request_param:
|
|
|
|
|
kwargs.pop("request")
|
|
|
|
|
if not response_param:
|
|
|
|
|
kwargs.pop("response")
|
|
|
|
|
|
2022-10-14 13:44:49 +02:00
|
|
|
if inspect.iscoroutinefunction(func):
|
2022-10-14 14:09:02 +02:00
|
|
|
# async, return as is.
|
|
|
|
|
# unintuitively, we have to await once here, so that caller
|
|
|
|
|
# does not have to await twice. See
|
|
|
|
|
# https://stackoverflow.com/a/59268198/532513
|
|
|
|
|
return await func(*args, **kwargs)
|
2022-10-14 13:44:49 +02:00
|
|
|
else:
|
|
|
|
|
# sync, wrap in thread and return async
|
2022-10-14 14:09:02 +02:00
|
|
|
# see above why we have to await even although caller also awaits.
|
|
|
|
|
return await run_in_threadpool(func, *args, **kwargs)
|
2022-10-14 13:44:49 +02:00
|
|
|
|
|
|
|
|
|
2021-01-06 20:00:58 +08:00
|
|
|
copy_kwargs = kwargs.copy()
|
|
|
|
|
request = copy_kwargs.pop("request", None)
|
|
|
|
|
response = copy_kwargs.pop("response", None)
|
2022-08-08 12:40:33 -03:00
|
|
|
|
2021-10-28 15:52:21 +08:00
|
|
|
if (
|
2022-08-08 12:40:33 -03:00
|
|
|
request and request.headers.get("Cache-Control") in ("no-store", "no-cache")
|
2021-10-28 15:52:21 +08:00
|
|
|
) or not FastAPICache.get_enable():
|
2022-10-14 13:44:49 +02:00
|
|
|
return await ensure_async_func(*args, **kwargs)
|
2020-10-16 16:55:33 +08:00
|
|
|
|
|
|
|
|
coder = coder or FastAPICache.get_coder()
|
|
|
|
|
expire = expire or FastAPICache.get_expire()
|
|
|
|
|
key_builder = key_builder or FastAPICache.get_key_builder()
|
2020-08-26 18:04:57 +08:00
|
|
|
backend = FastAPICache.get_backend()
|
2021-01-06 20:00:58 +08:00
|
|
|
|
|
|
|
|
cache_key = key_builder(
|
|
|
|
|
func, namespace, request=request, response=response, args=args, kwargs=copy_kwargs
|
|
|
|
|
)
|
2020-10-08 15:10:34 +08:00
|
|
|
ttl, ret = await backend.get_with_ttl(cache_key)
|
|
|
|
|
if not request:
|
|
|
|
|
if ret is not None:
|
|
|
|
|
return coder.decode(ret)
|
2022-10-14 13:44:49 +02:00
|
|
|
ret = await ensure_async_func(*args, **kwargs)
|
2020-10-08 15:10:34 +08:00
|
|
|
await backend.set(cache_key, coder.encode(ret), expire or FastAPICache.get_expire())
|
|
|
|
|
return ret
|
2020-08-26 18:04:57 +08:00
|
|
|
|
|
|
|
|
if request.method != "GET":
|
2022-10-14 13:44:49 +02:00
|
|
|
return await ensure_async_func(request, *args, **kwargs)
|
|
|
|
|
|
2020-08-26 18:04:57 +08:00
|
|
|
if_none_match = request.headers.get("if-none-match")
|
|
|
|
|
if ret is not None:
|
|
|
|
|
if response:
|
|
|
|
|
response.headers["Cache-Control"] = f"max-age={ttl}"
|
|
|
|
|
etag = f"W/{hash(ret)}"
|
|
|
|
|
if if_none_match == etag:
|
|
|
|
|
response.status_code = 304
|
|
|
|
|
return response
|
|
|
|
|
response.headers["ETag"] = etag
|
|
|
|
|
return coder.decode(ret)
|
2022-10-14 13:44:49 +02:00
|
|
|
|
|
|
|
|
ret = await ensure_async_func(*args, **kwargs)
|
2022-08-01 00:04:10 +02:00
|
|
|
|
2020-10-08 15:10:34 +08:00
|
|
|
await backend.set(cache_key, coder.encode(ret), expire or FastAPICache.get_expire())
|
2020-08-26 18:04:57 +08:00
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
return inner
|
|
|
|
|
|
|
|
|
|
return wrapper
|