Added typing to the decorator

This commit is contained in:
squaresmile
2022-10-25 08:52:59 +07:00
parent 5781593829
commit f3f134a318
2 changed files with 18 additions and 8 deletions

View File

@@ -1,6 +1,11 @@
import inspect import inspect
import sys
from functools import wraps from functools import wraps
from typing import Callable, Optional, Type from typing import Any, Awaitable, Callable, Optional, TypeVar
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
from starlette.requests import Request from starlette.requests import Request
@@ -10,12 +15,16 @@ from fastapi_cache import FastAPICache
from fastapi_cache.coder import Coder from fastapi_cache.coder import Coder
P = ParamSpec("P")
R = TypeVar("R")
def cache( def cache(
expire: int = None, expire: Optional[int] = None,
coder: Type[Coder] = None, coder: Optional[Coder] = None,
key_builder: Callable = None, key_builder: Optional[Callable[..., Any]] = None,
namespace: Optional[str] = "", namespace: Optional[str] = "",
): ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
""" """
cache all function cache all function
:param namespace: :param namespace:
@@ -26,7 +35,7 @@ def cache(
:return: :return:
""" """
def wrapper(func): def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
signature = inspect.signature(func) signature = inspect.signature(func)
request_param = next( request_param = next(
(param for param in signature.parameters.values() if param.annotation is Request), (param for param in signature.parameters.values() if param.annotation is Request),
@@ -58,12 +67,12 @@ def cache(
func.__signature__ = signature func.__signature__ = signature
@wraps(func) @wraps(func)
async def inner(*args, **kwargs): async def inner(*args: P.args, **kwargs: P.kwargs) -> R:
nonlocal coder nonlocal coder
nonlocal expire nonlocal expire
nonlocal key_builder nonlocal key_builder
async def ensure_async_func(*args, **kwargs): async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R:
"""Run cached sync functions in thread pool just like FastAPI.""" """Run cached sync functions in thread pool just like FastAPI."""
# if the wrapped function does NOT have request or response in its function signature, # if the wrapped function does NOT have request or response in its function signature,
# make sure we don't pass them in as keyword arguments # make sure we don't pass them in as keyword arguments

View File

@@ -22,6 +22,7 @@ redis = { version = "^4.2.0rc1", optional = true }
aiomcache = { version = "*", optional = true } aiomcache = { version = "*", optional = true }
pendulum = "*" pendulum = "*"
aiobotocore = { version = "^1.4.1", optional = true } aiobotocore = { version = "^1.4.1", optional = true }
typing-extensions = { version = ">=4.1.0", markers = "python_version < \"3.10\"" }
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
flake8 = "*" flake8 = "*"