From 832650347b05311004b21ecf0ee586b7736f6f59 Mon Sep 17 00:00:00 2001 From: Martijn Pieters Date: Thu, 27 Apr 2023 18:14:59 +0100 Subject: [PATCH] Attach updated endpoint signature to inner Not all endpoints accept a __signature__ attribute, nor should the cache decorator modify the decorated endpoint. Attach the signature to the returned inner function instead. While here, refactor the signature updating code, and extract it to a separate function. --- examples/in_memory/main.py | 13 +++++++++ fastapi_cache/decorator.py | 60 +++++++++++++++++++++----------------- tests/test_decorator.py | 6 ++++ 3 files changed, 52 insertions(+), 27 deletions(-) diff --git a/examples/in_memory/main.py b/examples/in_memory/main.py index bbf0910..2306472 100644 --- a/examples/in_memory/main.py +++ b/examples/in_memory/main.py @@ -67,6 +67,19 @@ async def cache_response_obj(): return JSONResponse({"a": 1}) +class SomeClass: + def __init__(self, value): + self.value = value + + async def handler_method(self): + return self.value + + +# register an instance method as a handler +instance = SomeClass(17) +app.get("/method")(cache(namespace="test")(instance.handler_method)) + + @app.on_event("startup") async def startup(): FastAPICache.init(InMemoryBackend()) diff --git a/fastapi_cache/decorator.py b/fastapi_cache/decorator.py index b0601aa..27feadb 100644 --- a/fastapi_cache/decorator.py +++ b/fastapi_cache/decorator.py @@ -22,6 +22,36 @@ P = ParamSpec("P") R = TypeVar("R") +def _augment_signature( + signature: inspect.Signature, add_request: bool, add_response: bool +) -> inspect.Signature: + if not (add_request or add_response): + return signature + + parameters = list(signature.parameters.values()) + variadic_keyword_params = [] + while parameters and parameters[-1].kind is inspect.Parameter.VAR_KEYWORD: + variadic_keyword_params.append(parameters.pop()) + + if add_request: + parameters.append( + inspect.Parameter( + name="request", + annotation=Request, + kind=inspect.Parameter.KEYWORD_ONLY, + ), + ) + if add_response: + parameters.append( + inspect.Parameter( + name="response", + annotation=Response, + kind=inspect.Parameter.KEYWORD_ONLY, + ), + ) + return signature.replace(parameters=[*parameters, *variadic_keyword_params]) + + def cache( expire: Optional[int] = None, coder: Optional[Type[Coder]] = None, @@ -48,33 +78,6 @@ def cache( (param for param in signature.parameters.values() if param.annotation is Response), None, ) - parameters = [] - extra_params = [] - for p in signature.parameters.values(): - if p.kind <= inspect.Parameter.KEYWORD_ONLY: - parameters.append(p) - else: - extra_params.append(p) - if not request_param: - parameters.append( - inspect.Parameter( - name="request", - annotation=Request, - kind=inspect.Parameter.KEYWORD_ONLY, - ), - ) - if not response_param: - parameters.append( - inspect.Parameter( - name="response", - annotation=Response, - kind=inspect.Parameter.KEYWORD_ONLY, - ), - ) - parameters.extend(extra_params) - if parameters: - signature = signature.replace(parameters=parameters) - func.__signature__ = signature @wraps(func) async def inner(*args: P.args, **kwargs: P.kwargs) -> R: @@ -179,6 +182,9 @@ def cache( response.headers["ETag"] = etag return ret + inner.__signature__ = _augment_signature( + signature, request_param is None, response_param is None + ) return inner return wrapper diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 752a0d2..9fd8aa8 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -73,3 +73,9 @@ def test_kwargs() -> None: name = "Jon" response = client.get("/kwargs", params={"name": name}) assert response.json() == {"name": name} + + +def test_method() -> None: + with TestClient(app) as client: + response = client.get("/method") + assert response.json() == 17