Typing cleanup

- Compatibility with older Python versions
  - use `Optional` and `Union` instead of `... | None` and `a | b`
  - use `typing_extensions.Protocol` instead of `typing.Protocol`
  - use `typing.Dict`, `typing.List`, etc. instead of the concrete types.

- Fix backend `.get()` annotations; not all were marked as `Optional[str]`
- Don't return anything from `Backend.set()` methods.
- The `Coder.decode_as_type()` type parameter must be a type to be
  compatible with `ModelField(..., type_=...)`.
- Clean up `Optional[]` use, remove where it is not needed.
- Clean up variable use in decorator, keeping the raw cached value
  separate from the return value from the wrapped endpoint.
- Annotate the wrapper as returning either the original type _or_ a
  Response (returning a 304 Not Modified response).
- Clean up small edge-case where `response` could be `None`.
- Correct type annotation on `JsonCoder.decode()` to match `Coder.decode()`.
This commit is contained in:
Martijn Pieters
2023-05-09 15:30:46 +01:00
parent 564026e189
commit 013be85f97
8 changed files with 46 additions and 42 deletions

View File

@@ -2,7 +2,7 @@ import logging
import sys
from functools import wraps
from inspect import Parameter, Signature, isawaitable, iscoroutinefunction
from typing import Awaitable, Callable, Optional, Type, TypeVar
from typing import Awaitable, Callable, List, Optional, Type, TypeVar, Union
if sys.version_info >= (3, 10):
from typing import ParamSpec
@@ -36,7 +36,7 @@ def _augment_signature(signature: Signature, *extra: Parameter) -> Signature:
return signature.replace(parameters=[*parameters, *extra, *variadic_keyword_params])
def _locate_param(sig: Signature, dep: Parameter, to_inject: list[Parameter]) -> Parameter:
def _locate_param(sig: Signature, dep: Parameter, to_inject: List[Parameter]) -> Parameter:
"""Locate an existing parameter in the decorated endpoint
If not found, returns the injectable parameter, and adds it to the to_inject list.
@@ -56,9 +56,9 @@ def cache(
expire: Optional[int] = None,
coder: Optional[Type[Coder]] = None,
key_builder: Optional[KeyBuilder] = None,
namespace: Optional[str] = "",
namespace: str = "",
injected_dependency_namespace: str = "__fastapi_cache",
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[Union[R, Response]]]]:
"""
cache all function
:param namespace:
@@ -80,7 +80,7 @@ def cache(
kind=Parameter.KEYWORD_ONLY,
)
def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[Union[R, Response]]]:
# get_typed_signature ensures that any forward references are resolved first
wrapped_signature = get_typed_signature(func)
to_inject: list[Parameter] = []
@@ -89,7 +89,7 @@ def cache(
return_type = get_typed_return_annotation(func)
@wraps(func)
async def inner(*args: P.args, **kwargs: P.kwargs) -> R:
async def inner(*args: P.args, **kwargs: P.kwargs) -> Union[R, Response]:
nonlocal coder
nonlocal expire
nonlocal key_builder
@@ -139,15 +139,15 @@ def cache(
cache_key = await cache_key
try:
ttl, ret = await backend.get_with_ttl(cache_key)
ttl, cached = await backend.get_with_ttl(cache_key)
except Exception:
logger.warning(
f"Error retrieving cache key '{cache_key}' from backend:", exc_info=True
)
ttl, ret = 0, None
ttl, cached = 0, None
if not request:
if ret is not None:
return coder.decode_as_type(ret, type_=return_type)
if cached is not None:
return coder.decode_as_type(cached, type_=return_type)
ret = await ensure_async_func(*args, **kwargs)
try:
await backend.set(cache_key, coder.encode(ret), expire)
@@ -161,15 +161,15 @@ def cache(
return await ensure_async_func(*args, **kwargs)
if_none_match = request.headers.get("if-none-match")
if ret is not None:
if cached is not None:
if response:
response.headers["Cache-Control"] = f"max-age={ttl}"
etag = f"W/{hash(ret)}"
etag = f"W/{hash(cached)}"
if if_none_match == etag:
response.status_code = 304
return response
response.headers["ETag"] = etag
return coder.decode_as_type(ret, type_=return_type)
return coder.decode_as_type(cached, type_=return_type)
ret = await ensure_async_func(*args, **kwargs)
encoded_ret = coder.encode(ret)
@@ -179,9 +179,10 @@ def cache(
except Exception:
logger.warning(f"Error setting cache key '{cache_key}' in backend:", exc_info=True)
response.headers["Cache-Control"] = f"max-age={expire}"
etag = f"W/{hash(encoded_ret)}"
response.headers["ETag"] = etag
if response:
response.headers["Cache-Control"] = f"max-age={expire}"
etag = f"W/{hash(encoded_ret)}"
response.headers["ETag"] = etag
return ret
inner.__signature__ = _augment_signature(wrapped_signature, *to_inject)