Typing cleanup

- Compatibility with older Python versions
  - use `Optional` and `Union` instead of `... | None` and `a | b`
  - use `typing_extensions.Protocol` instead of `typing.Protocol`
  - use `typing.Dict`, `typing.List`, etc. instead of the concrete types.

- Fix backend `.get()` annotations; not all were marked as `Optional[str]`
- Don't return anything from `Backend.set()` methods.
- The `Coder.decode_as_type()` type parameter must be a type to be
  compatible with `ModelField(..., type_=...)`.
- Clean up `Optional[]` use, remove where it is not needed.
- Clean up variable use in decorator, keeping the raw cached value
  separate from the return value from the wrapped endpoint.
- Annotate the wrapper as returning either the original type _or_ a
  Response (returning a 304 Not Modified response).
- Clean up small edge-case where `response` could be `None`.
- Correct type annotation on `JsonCoder.decode()` to match `Coder.decode()`.
This commit is contained in:
Martijn Pieters
2023-05-09 15:30:46 +01:00
parent 564026e189
commit 013be85f97
8 changed files with 46 additions and 42 deletions

View File

@@ -1,3 +1,5 @@
from typing import Dict, Optional
import pendulum import pendulum
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
@@ -84,9 +86,9 @@ app.get("/method")(cache(namespace="test")(instance.handler_method))
# cache a Pydantic model instance; the return type annotation is required in this case # cache a Pydantic model instance; the return type annotation is required in this case
class Item(BaseModel): class Item(BaseModel):
name: str name: str
description: str | None = None description: Optional[str] = None
price: float price: float
tax: float | None = None tax: Optional[float] = None
@app.get("/pydantic_instance") @app.get("/pydantic_instance")
@@ -110,7 +112,7 @@ async def uncached_put():
@cache(namespace="test", expire=5, injected_dependency_namespace="monty_python") @cache(namespace="test", expire=5, injected_dependency_namespace="monty_python")
def namespaced_injection( def namespaced_injection(
__fastapi_cache_request: int = 42, __fastapi_cache_response: int = 17 __fastapi_cache_request: int = 42, __fastapi_cache_response: int = 17
) -> dict[str, int]: ) -> Dict[str, int]:
return { return {
"__fastapi_cache_request": __fastapi_cache_request, "__fastapi_cache_request": __fastapi_cache_request,
"__fastapi_cache_response": __fastapi_cache_response, "__fastapi_cache_response": __fastapi_cache_response,

View File

@@ -10,10 +10,10 @@ from pydantic import BaseConfig, ValidationError, fields
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from starlette.templating import _TemplateResponse as TemplateResponse from starlette.templating import _TemplateResponse as TemplateResponse
_T = TypeVar("_T") _T = TypeVar("_T", bound=type)
CONVERTERS: dict[str, Callable[[str], Any]] = { CONVERTERS: Dict[str, Callable[[str], Any]] = {
"date": lambda x: pendulum.parse(x, exact=True), "date": lambda x: pendulum.parse(x, exact=True),
"datetime": lambda x: pendulum.parse(x, exact=True), "datetime": lambda x: pendulum.parse(x, exact=True),
"decimal": Decimal, "decimal": Decimal,

View File

@@ -2,7 +2,7 @@ import logging
import sys import sys
from functools import wraps from functools import wraps
from inspect import Parameter, Signature, isawaitable, iscoroutinefunction from inspect import Parameter, Signature, isawaitable, iscoroutinefunction
from typing import Awaitable, Callable, Optional, Type, TypeVar from typing import Awaitable, Callable, List, Optional, Type, TypeVar, Union
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
from typing import ParamSpec from typing import ParamSpec
@@ -36,7 +36,7 @@ def _augment_signature(signature: Signature, *extra: Parameter) -> Signature:
return signature.replace(parameters=[*parameters, *extra, *variadic_keyword_params]) return signature.replace(parameters=[*parameters, *extra, *variadic_keyword_params])
def _locate_param(sig: Signature, dep: Parameter, to_inject: list[Parameter]) -> Parameter: def _locate_param(sig: Signature, dep: Parameter, to_inject: List[Parameter]) -> Parameter:
"""Locate an existing parameter in the decorated endpoint """Locate an existing parameter in the decorated endpoint
If not found, returns the injectable parameter, and adds it to the to_inject list. If not found, returns the injectable parameter, and adds it to the to_inject list.
@@ -56,9 +56,9 @@ def cache(
expire: Optional[int] = None, expire: Optional[int] = None,
coder: Optional[Type[Coder]] = None, coder: Optional[Type[Coder]] = None,
key_builder: Optional[KeyBuilder] = None, key_builder: Optional[KeyBuilder] = None,
namespace: Optional[str] = "", namespace: str = "",
injected_dependency_namespace: str = "__fastapi_cache", injected_dependency_namespace: str = "__fastapi_cache",
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[Union[R, Response]]]]:
""" """
cache all function cache all function
:param namespace: :param namespace:
@@ -80,7 +80,7 @@ def cache(
kind=Parameter.KEYWORD_ONLY, kind=Parameter.KEYWORD_ONLY,
) )
def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[Union[R, Response]]]:
# get_typed_signature ensures that any forward references are resolved first # get_typed_signature ensures that any forward references are resolved first
wrapped_signature = get_typed_signature(func) wrapped_signature = get_typed_signature(func)
to_inject: list[Parameter] = [] to_inject: list[Parameter] = []
@@ -89,7 +89,7 @@ def cache(
return_type = get_typed_return_annotation(func) return_type = get_typed_return_annotation(func)
@wraps(func) @wraps(func)
async def inner(*args: P.args, **kwargs: P.kwargs) -> R: async def inner(*args: P.args, **kwargs: P.kwargs) -> Union[R, Response]:
nonlocal coder nonlocal coder
nonlocal expire nonlocal expire
nonlocal key_builder nonlocal key_builder
@@ -139,15 +139,15 @@ def cache(
cache_key = await cache_key cache_key = await cache_key
try: try:
ttl, ret = await backend.get_with_ttl(cache_key) ttl, cached = await backend.get_with_ttl(cache_key)
except Exception: except Exception:
logger.warning( logger.warning(
f"Error retrieving cache key '{cache_key}' from backend:", exc_info=True f"Error retrieving cache key '{cache_key}' from backend:", exc_info=True
) )
ttl, ret = 0, None ttl, cached = 0, None
if not request: if not request:
if ret is not None: if cached is not None:
return coder.decode_as_type(ret, type_=return_type) return coder.decode_as_type(cached, type_=return_type)
ret = await ensure_async_func(*args, **kwargs) ret = await ensure_async_func(*args, **kwargs)
try: try:
await backend.set(cache_key, coder.encode(ret), expire) await backend.set(cache_key, coder.encode(ret), expire)
@@ -161,15 +161,15 @@ def cache(
return await ensure_async_func(*args, **kwargs) return await ensure_async_func(*args, **kwargs)
if_none_match = request.headers.get("if-none-match") if_none_match = request.headers.get("if-none-match")
if ret is not None: if cached is not None:
if response: if response:
response.headers["Cache-Control"] = f"max-age={ttl}" response.headers["Cache-Control"] = f"max-age={ttl}"
etag = f"W/{hash(ret)}" etag = f"W/{hash(cached)}"
if if_none_match == etag: if if_none_match == etag:
response.status_code = 304 response.status_code = 304
return response return response
response.headers["ETag"] = etag response.headers["ETag"] = etag
return coder.decode_as_type(ret, type_=return_type) return coder.decode_as_type(cached, type_=return_type)
ret = await ensure_async_func(*args, **kwargs) ret = await ensure_async_func(*args, **kwargs)
encoded_ret = coder.encode(ret) encoded_ret = coder.encode(ret)
@@ -179,6 +179,7 @@ def cache(
except Exception: except Exception:
logger.warning(f"Error setting cache key '{cache_key}' in backend:", exc_info=True) logger.warning(f"Error setting cache key '{cache_key}' in backend:", exc_info=True)
if response:
response.headers["Cache-Control"] = f"max-age={expire}" response.headers["Cache-Control"] = f"max-age={expire}"
etag = f"W/{hash(encoded_ret)}" etag = f"W/{hash(encoded_ret)}"
response.headers["ETag"] = etag response.headers["ETag"] = etag

View File

@@ -1,5 +1,5 @@
import hashlib import hashlib
from typing import Any, Callable, Optional from typing import Any, Callable, Dict, Optional, Tuple
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
@@ -10,8 +10,8 @@ def default_key_builder(
namespace: str = "", namespace: str = "",
request: Optional[Request] = None, request: Optional[Request] = None,
response: Optional[Response] = None, response: Optional[Response] = None,
args: Optional[tuple[Any, ...]] = None, args: Optional[Tuple[Any, ...]] = None,
kwargs: Optional[dict[str, Any]] = None, kwargs: Optional[Dict[str, Any]] = None,
) -> str: ) -> str:
cache_key = hashlib.md5( # nosec:B303 cache_key = hashlib.md5( # nosec:B303
f"{func.__module__}:{func.__name__}:{args}:{kwargs}".encode() f"{func.__module__}:{func.__name__}:{args}:{kwargs}".encode()

View File

@@ -1,7 +1,8 @@
from typing import Any, Awaitable, Callable, Optional, Protocol, Union from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Union
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
from typing_extensions import Protocol
_Func = Callable[..., Any] _Func = Callable[..., Any]
@@ -15,7 +16,7 @@ class KeyBuilder(Protocol):
*, *,
request: Optional[Request] = ..., request: Optional[Request] = ...,
response: Optional[Response] = ..., response: Optional[Response] = ...,
args: tuple[Any, ...] = ..., args: Tuple[Any, ...] = ...,
kwargs: dict[str, Any] = ..., kwargs: Dict[str, Any] = ...,
) -> Union[Awaitable[str], str]: ) -> Union[Awaitable[str], str]:
... ...

24
poetry.lock generated
View File

@@ -304,14 +304,14 @@ crt = ["awscrt (==0.11.24)"]
[[package]] [[package]]
name = "certifi" name = "certifi"
version = "2022.12.7" version = "2023.5.7"
description = "Python package for providing Mozilla's CA Bundle." description = "Python package for providing Mozilla's CA Bundle."
category = "dev" category = "dev"
optional = false optional = false
python-versions = ">=3.6" python-versions = ">=3.6"
files = [ files = [
{file = "certifi-2022.12.7-py3-none-any.whl", hash = "sha256:4ad3232f5e926d6718ec31cfc1fcadfde020920e278684144551c91769c7bc18"}, {file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"},
{file = "certifi-2022.12.7.tar.gz", hash = "sha256:35824b4c3a97115964b408844d64aa14db1cc518f6562e8d7261699d1350a9e3"}, {file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"},
] ]
[[package]] [[package]]
@@ -1099,18 +1099,18 @@ files = [
[[package]] [[package]]
name = "redis" name = "redis"
version = "4.5.4" version = "4.5.5"
description = "Python client for Redis database and key-value store" description = "Python client for Redis database and key-value store"
category = "main" category = "main"
optional = true optional = true
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "redis-4.5.4-py3-none-any.whl", hash = "sha256:2c19e6767c474f2e85167909061d525ed65bea9301c0770bb151e041b7ac89a2"}, {file = "redis-4.5.5-py3-none-any.whl", hash = "sha256:77929bc7f5dab9adf3acba2d3bb7d7658f1e0c2f1cafe7eb36434e751c471119"},
{file = "redis-4.5.4.tar.gz", hash = "sha256:73ec35da4da267d6847e47f68730fdd5f62e2ca69e3ef5885c6a78a9374c3893"}, {file = "redis-4.5.5.tar.gz", hash = "sha256:dc87a0bdef6c8bfe1ef1e1c40be7034390c2ae02d92dcd0c7ca1729443899880"},
] ]
[package.dependencies] [package.dependencies]
async-timeout = {version = ">=4.0.2", markers = "python_version <= \"3.11.2\""} async-timeout = {version = ">=4.0.2", markers = "python_full_version <= \"3.11.2\""}
importlib-metadata = {version = ">=1.0", markers = "python_version < \"3.8\""} importlib-metadata = {version = ">=1.0", markers = "python_version < \"3.8\""}
typing-extensions = {version = "*", markers = "python_version < \"3.8\""} typing-extensions = {version = "*", markers = "python_version < \"3.8\""}
@@ -1120,21 +1120,21 @@ ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"
[[package]] [[package]]
name = "requests" name = "requests"
version = "2.29.0" version = "2.30.0"
description = "Python HTTP for Humans." description = "Python HTTP for Humans."
category = "dev" category = "dev"
optional = false optional = false
python-versions = ">=3.7" python-versions = ">=3.7"
files = [ files = [
{file = "requests-2.29.0-py3-none-any.whl", hash = "sha256:e8f3c9be120d3333921d213eef078af392fba3933ab7ed2d1cba3b56f2568c3b"}, {file = "requests-2.30.0-py3-none-any.whl", hash = "sha256:10e94cc4f3121ee6da529d358cdaeaff2f1c409cd377dbc72b825852f2f7e294"},
{file = "requests-2.29.0.tar.gz", hash = "sha256:f2e34a75f4749019bb0e3effb66683630e4ffeaf75819fb51bebef1bf5aef059"}, {file = "requests-2.30.0.tar.gz", hash = "sha256:239d7d4458afcb28a692cdd298d87542235f4ca8d36d03a15bfc128a6559a2f4"},
] ]
[package.dependencies] [package.dependencies]
certifi = ">=2017.4.17" certifi = ">=2017.4.17"
charset-normalizer = ">=2,<4" charset-normalizer = ">=2,<4"
idna = ">=2.5,<4" idna = ">=2.5,<4"
urllib3 = ">=1.21.1,<1.27" urllib3 = ">=1.21.1,<3"
[package.extras] [package.extras]
socks = ["PySocks (>=1.5.6,!=1.5.7)"] socks = ["PySocks (>=1.5.6,!=1.5.7)"]
@@ -1476,4 +1476,4 @@ redis = ["redis"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.7" python-versions = "^3.7"
content-hash = "4325f045144b309e8378c02465b4e66544d28fc74751883f9a50ff391df08dbe" content-hash = "479b55889016688ab9b82e0a1c998ac2faaddf0807d0174b99289d02613387a8"

View File

@@ -22,7 +22,7 @@ redis = { version = "^4.2.0rc1", optional = true }
aiomcache = { version = "*", optional = true } aiomcache = { version = "*", optional = true }
pendulum = "*" pendulum = "*"
aiobotocore = { version = "^1.4.1", optional = true } aiobotocore = { version = "^1.4.1", optional = true }
typing-extensions = { version = ">=4.1.0", markers = "python_version < \"3.10\"" } typing-extensions = { version = ">=4.1.0" }
aiohttp = { version = ">=3.8.3", markers = "python_version >= \"3.11\"" } aiohttp = { version = ">=3.8.3", markers = "python_version >= \"3.11\"" }
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional, Tuple
import pytest import pytest
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
@@ -46,7 +46,7 @@ def test_pickle_coder(value: Any) -> None:
[ [
(1, None), (1, None),
("some_string", None), ("some_string", None),
((1, 2), tuple[int, int]), ((1, 2), Tuple[int, int]),
([1, 2, 3], None), ([1, 2, 3], None),
({"some_key": 1, "other_key": 2}, None), ({"some_key": 1, "other_key": 2}, None),
(DCItem(name="foo", price=42.0, description="some dataclass item", tax=0.2), DCItem), (DCItem(name="foo", price=42.0, description="some dataclass item", tax=0.2), DCItem),