Factor out sync handling and use everywhere

This commit is contained in:
Charl P. Botha
2022-10-14 13:44:49 +02:00
parent 34415ad50a
commit 2822ab5d71

View File

@@ -62,13 +62,24 @@ def cache(
nonlocal coder nonlocal coder
nonlocal expire nonlocal expire
nonlocal key_builder nonlocal key_builder
async def ensure_async_func(*args, **kwargs):
"""Run cached sync functions in thread pool just like FastAPI."""
if inspect.iscoroutinefunction(func):
# async, return as is
return func(*args, **kwargs)
else:
# sync, wrap in thread and return async
return run_in_threadpool(func, *args, **kwargs)
copy_kwargs = kwargs.copy() copy_kwargs = kwargs.copy()
request = copy_kwargs.pop("request", None) request = copy_kwargs.pop("request", None)
response = copy_kwargs.pop("response", None) response = copy_kwargs.pop("response", None)
if ( if (
request and request.headers.get("Cache-Control") == "no-store" request and request.headers.get("Cache-Control") == "no-store"
) or not FastAPICache.get_enable(): ) or not FastAPICache.get_enable():
return await func(*args, **kwargs) return await ensure_async_func(*args, **kwargs)
coder = coder or FastAPICache.get_coder() coder = coder or FastAPICache.get_coder()
expire = expire or FastAPICache.get_expire() expire = expire or FastAPICache.get_expire()
@@ -82,12 +93,13 @@ def cache(
if not request: if not request:
if ret is not None: if ret is not None:
return coder.decode(ret) return coder.decode(ret)
ret = await func(*args, **kwargs) ret = await ensure_async_func(*args, **kwargs)
await backend.set(cache_key, coder.encode(ret), expire or FastAPICache.get_expire()) await backend.set(cache_key, coder.encode(ret), expire or FastAPICache.get_expire())
return ret return ret
if request.method != "GET": if request.method != "GET":
return await func(request, *args, **kwargs) return await ensure_async_func(request, *args, **kwargs)
if_none_match = request.headers.get("if-none-match") if_none_match = request.headers.get("if-none-match")
if ret is not None: if ret is not None:
if response: if response:
@@ -102,10 +114,8 @@ def cache(
kwargs.pop("request") kwargs.pop("request")
if not response_param: if not response_param:
kwargs.pop("response") kwargs.pop("response")
if inspect.iscoroutinefunction(func):
ret = await func(*args, **kwargs) ret = await ensure_async_func(*args, **kwargs)
else:
ret = await run_in_threadpool(func, *args, **kwargs)
await backend.set(cache_key, coder.encode(ret), expire or FastAPICache.get_expire()) await backend.set(cache_key, coder.encode(ret), expire or FastAPICache.get_expire())
return ret return ret