diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a15f9d..0cf54bd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### 0.2.0 - Make `request` and `response` optional. +- Add typing info to the `cache` decorator. ## 0.1 diff --git a/fastapi_cache/decorator.py b/fastapi_cache/decorator.py index 1ac5f13..be50d30 100644 --- a/fastapi_cache/decorator.py +++ b/fastapi_cache/decorator.py @@ -1,6 +1,11 @@ import inspect +import sys from functools import wraps -from typing import Callable, Optional, Type, Any +from typing import Any, Awaitable, Callable, Optional, TypeVar +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec from fastapi.concurrency import run_in_threadpool from starlette.requests import Request @@ -10,12 +15,16 @@ from fastapi_cache import FastAPICache from fastapi_cache.coder import Coder +P = ParamSpec("P") +R = TypeVar("R") + + def cache( expire: Optional[int] = None, - coder: Optional[Type[Coder]] = None, - key_builder: Optional[Callable] = None, + coder: Optional[Coder] = None, + key_builder: Optional[Callable[..., Any]] = None, namespace: Optional[str] = "", -) -> Callable: +) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: """ cache all function :param namespace: @@ -26,7 +35,7 @@ def cache( :return: """ - def wrapper(func: Callable) -> Callable: + def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: signature = inspect.signature(func) request_param = next( (param for param in signature.parameters.values() if param.annotation is Request), @@ -55,15 +64,15 @@ def cache( ) if parameters: signature = signature.replace(parameters=parameters) - func.__signature__ = signature # type: ignore + func.__signature__ = signature @wraps(func) - async def inner(*args: Any, **kwargs: Any) -> Any: + async def inner(*args: P.args, **kwargs: P.kwargs) -> R: nonlocal coder nonlocal expire nonlocal key_builder - async def ensure_async_func(*args: Any, **kwargs: Any) -> Any: + async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R: """Run cached sync functions in thread pool just like FastAPI.""" # if the wrapped function does NOT have request or response in its function signature, # make sure we don't pass them in as keyword arguments @@ -83,6 +92,7 @@ def cache( # see above why we have to await even although caller also awaits. return await run_in_threadpool(func, *args, **kwargs) + copy_kwargs = kwargs.copy() request = copy_kwargs.pop("request", None) response = copy_kwargs.pop("response", None) diff --git a/fastapi_cache/py.typed b/fastapi_cache/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index 2f56f04..9eafb92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ redis = { version = "^4.2.0rc1", optional = true } aiomcache = { version = "*", optional = true } pendulum = "*" aiobotocore = { version = "^1.4.1", optional = true } +typing-extensions = { version = ">=4.1.0", markers = "python_version < \"3.10\"" } [tool.poetry.dev-dependencies] flake8 = "*"