diff --git a/examples/in_memory/main.py b/examples/in_memory/main.py index bbf0910..2306472 100644 --- a/examples/in_memory/main.py +++ b/examples/in_memory/main.py @@ -67,6 +67,19 @@ async def cache_response_obj(): return JSONResponse({"a": 1}) +class SomeClass: + def __init__(self, value): + self.value = value + + async def handler_method(self): + return self.value + + +# register an instance method as a handler +instance = SomeClass(17) +app.get("/method")(cache(namespace="test")(instance.handler_method)) + + @app.on_event("startup") async def startup(): FastAPICache.init(InMemoryBackend()) diff --git a/fastapi_cache/decorator.py b/fastapi_cache/decorator.py index 2c06ad2..65fc02b 100644 --- a/fastapi_cache/decorator.py +++ b/fastapi_cache/decorator.py @@ -23,6 +23,36 @@ P = ParamSpec("P") R = TypeVar("R") +def _augment_signature( + signature: inspect.Signature, add_request: bool, add_response: bool +) -> inspect.Signature: + if not (add_request or add_response): + return signature + + parameters = list(signature.parameters.values()) + variadic_keyword_params = [] + while parameters and parameters[-1].kind is inspect.Parameter.VAR_KEYWORD: + variadic_keyword_params.append(parameters.pop()) + + if add_request: + parameters.append( + inspect.Parameter( + name="request", + annotation=Request, + kind=inspect.Parameter.KEYWORD_ONLY, + ), + ) + if add_response: + parameters.append( + inspect.Parameter( + name="response", + annotation=Response, + kind=inspect.Parameter.KEYWORD_ONLY, + ), + ) + return signature.replace(parameters=[*parameters, *variadic_keyword_params]) + + def cache( expire: Optional[int] = None, coder: Optional[Type[Coder]] = None, @@ -49,33 +79,6 @@ def cache( (param for param in signature.parameters.values() if param.annotation is Response), None, ) - parameters = [] - extra_params = [] - for p in signature.parameters.values(): - if p.kind <= inspect.Parameter.KEYWORD_ONLY: - parameters.append(p) - else: - extra_params.append(p) - if not request_param: - parameters.append( - inspect.Parameter( - name="request", - annotation=Request, - kind=inspect.Parameter.KEYWORD_ONLY, - ), - ) - if not response_param: - parameters.append( - inspect.Parameter( - name="response", - annotation=Response, - kind=inspect.Parameter.KEYWORD_ONLY, - ), - ) - parameters.extend(extra_params) - if parameters: - signature = signature.replace(parameters=parameters) - func.__signature__ = signature @wraps(func) async def inner(*args: P.args, **kwargs: P.kwargs) -> R: @@ -173,6 +176,9 @@ def cache( response.headers["ETag"] = etag return ret + inner.__signature__ = _augment_signature( + signature, request_param is None, response_param is None + ) return inner return wrapper diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 752a0d2..9fd8aa8 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -73,3 +73,9 @@ def test_kwargs() -> None: name = "Jon" response = client.get("/kwargs", params={"name": name}) assert response.json() == {"name": name} + + +def test_method() -> None: + with TestClient(app) as client: + response = client.get("/method") + assert response.json() == 17