Full mypy --strict type checking pass

This commit is contained in:
Martijn Pieters
2023-05-09 17:08:32 +01:00
parent e92604802e
commit 941cd044c7
10 changed files with 693 additions and 40 deletions

View File

@@ -1,11 +1,16 @@
import datetime
from typing import Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple
from aiobotocore.client import AioBaseClient
from aiobotocore.session import get_session, AioSession
from aiobotocore.session import AioSession, get_session
from fastapi_cache.backends import Backend
if TYPE_CHECKING:
from types_aiobotocore_dynamodb import DynamoDBClient
else:
DynamoDBClient = AioBaseClient
class DynamoBackend(Backend):
"""
@@ -25,9 +30,13 @@ class DynamoBackend(Backend):
>> FastAPICache.init(dynamodb)
"""
client: DynamoDBClient
session: AioSession
table_name: str
region: Optional[str]
def __init__(self, table_name: str, region: Optional[str] = None) -> None:
self.session: AioSession = get_session()
self.client: Optional[AioBaseClient] = None # Needs async init
self.table_name = table_name
self.region = region
@@ -60,6 +69,7 @@ class DynamoBackend(Backend):
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
if "Item" in response:
return response["Item"].get("value", {}).get("B")
return None
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
ttl = (

View File

@@ -13,18 +13,18 @@ class RedisBackend(Backend):
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
async with self.redis.pipeline(transaction=not self.is_cluster) as pipe:
return await pipe.ttl(key).get(key).execute()
return await pipe.ttl(key).get(key).execute() # type: ignore[union-attr,no-any-return]
async def get(self, key: str) -> Optional[bytes]:
return await self.redis.get(key)
return await self.redis.get(key) # type: ignore[union-attr]
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
return await self.redis.set(key, value, ex=expire)
await self.redis.set(key, value, ex=expire) # type: ignore[union-attr]
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
if namespace:
lua = f"for i, name in ipairs(redis.call('KEYS', '{namespace}:*')) do redis.call('DEL', name); end"
return await self.redis.eval(lua, numkeys=0)
return await self.redis.eval(lua, numkeys=0) # type: ignore[union-attr,no-any-return]
elif key:
return await self.redis.delete(key)
return await self.redis.delete(key) # type: ignore[union-attr]
return 0

View File

@@ -14,8 +14,9 @@ _T = TypeVar("_T", bound=type)
CONVERTERS: Dict[str, Callable[[str], Any]] = {
"date": lambda x: pendulum.parse(x, exact=True),
"datetime": lambda x: pendulum.parse(x, exact=True),
# Pendulum 3.0.0 adds parse to __all__, at which point these ignores can be removed
"date": lambda x: pendulum.parse(x, exact=True), # type: ignore[attr-defined]
"datetime": lambda x: pendulum.parse(x, exact=True), # type: ignore[attr-defined]
"decimal": Decimal,
}

View File

@@ -2,7 +2,7 @@ import logging
import sys
from functools import wraps
from inspect import Parameter, Signature, isawaitable, iscoroutinefunction
from typing import Awaitable, Callable, List, Optional, Type, TypeVar, Union
from typing import Awaitable, Callable, List, Optional, Type, TypeVar, Union, cast
if sys.version_info >= (3, 10):
from typing import ParamSpec
@@ -111,11 +111,11 @@ def cache(
else:
# sync, wrap in thread and return async
# see above why we have to await even although caller also awaits.
return await run_in_threadpool(func, *args, **kwargs)
return await run_in_threadpool(func, *args, **kwargs) # type: ignore[arg-type]
copy_kwargs = kwargs.copy()
request: Optional[Request] = copy_kwargs.pop(request_param.name, None)
response: Optional[Response] = copy_kwargs.pop(response_param.name, None)
request: Optional[Request] = copy_kwargs.pop(request_param.name, None) # type: ignore[assignment]
response: Optional[Response] = copy_kwargs.pop(response_param.name, None) # type: ignore[assignment]
if (
request and request.headers.get("Cache-Control") in ("no-store", "no-cache")
) or not FastAPICache.get_enable():
@@ -137,6 +137,7 @@ def cache(
)
if isawaitable(cache_key):
cache_key = await cache_key
assert isinstance(cache_key, str)
try:
ttl, cached = await backend.get_with_ttl(cache_key)
@@ -147,7 +148,7 @@ def cache(
ttl, cached = 0, None
if not request:
if cached is not None:
return coder.decode_as_type(cached, type_=return_type)
return cast(R, coder.decode_as_type(cached, type_=return_type))
ret = await ensure_async_func(*args, **kwargs)
try:
await backend.set(cache_key, coder.encode(ret), expire)
@@ -169,7 +170,7 @@ def cache(
response.status_code = 304
return response
response.headers["ETag"] = etag
return coder.decode_as_type(cached, type_=return_type)
return cast(R, coder.decode_as_type(cached, type_=return_type))
ret = await ensure_async_func(*args, **kwargs)
encoded_ret = coder.encode(ret)
@@ -185,7 +186,7 @@ def cache(
response.headers["ETag"] = etag
return ret
inner.__signature__ = _augment_signature(wrapped_signature, *to_inject)
inner.__signature__ = _augment_signature(wrapped_signature, *to_inject) # type: ignore[attr-defined]
return inner
return wrapper

View File

@@ -8,10 +8,11 @@ from starlette.responses import Response
def default_key_builder(
func: Callable[..., Any],
namespace: str = "",
*,
request: Optional[Request] = None,
response: Optional[Response] = None,
args: Optional[Tuple[Any, ...]] = None,
kwargs: Optional[Dict[str, Any]] = None,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> str:
cache_key = hashlib.md5( # nosec:B303
f"{func.__module__}:{func.__name__}:{args}:{kwargs}".encode()

View File

@@ -11,12 +11,12 @@ _Func = Callable[..., Any]
class KeyBuilder(Protocol):
def __call__(
self,
_function: _Func,
_namespace: str = ...,
__function: _Func,
__namespace: str = ...,
*,
request: Optional[Request] = ...,
response: Optional[Response] = ...,
args: Tuple[Any, ...] = ...,
kwargs: Dict[str, Any] = ...,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Union[Awaitable[str], str]:
...

621
poetry.lock generated

File diff suppressed because one or more lines are too long

View File

@@ -22,6 +22,7 @@ redis = { version = "^4.2.0rc1", optional = true }
aiomcache = { version = "*", optional = true }
pendulum = "*"
aiobotocore = { version = "^1.4.1", optional = true }
types-aiobotocore = { extras = ["dynamodb"], version = "^2.5.0.post2", optional = true }
typing-extensions = { version = ">=4.1.0" }
aiohttp = { version = ">=3.8.3", markers = "python_version >= \"3.11\"" }
@@ -33,6 +34,8 @@ pytest = "*"
requests = "*"
coverage = "^6.5.0"
httpx = "*"
mypy = "^1.2.0"
types-redis = "^4.5.4.2"
[build-system]
requires = ["poetry>=0.12"]
@@ -41,9 +44,31 @@ build-backend = "poetry.masonry.api"
[tool.poetry.extras]
redis = ["redis"]
memcache = ["aiomcache"]
dynamodb = ["aiobotocore"]
all = ["redis", "aiomcache", "aiobotocore"]
dynamodb = ["aiobotocore", "types-aiobotocore"]
all = ["redis", "aiomcache", "aiobotocore", "types-aiobotocore"]
[tool.black]
line-length = 100
target-version = ['py36', 'py37', 'py38', 'py39']
[tool.mypy]
files = ["fastapi_cache", "examples", "tests"]
# equivalent of --strict
warn_unused_configs = true
disallow_any_generics = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
check_untyped_defs = true
disallow_untyped_decorators = true
warn_redundant_casts = true
warn_unused_ignores = true
warn_return_any = true
no_implicit_reexport = true
strict_equality = true
strict_concatenate = true
[[tool.mypy.overrides]]
module = "examples.*.main"
ignore_errors = true

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Optional, Tuple
from typing import Any, Optional, Tuple, Type
import pytest
from pydantic import BaseModel, ValidationError
@@ -53,7 +53,7 @@ def test_pickle_coder(value: Any) -> None:
(PDItem(name="foo", price=42.0, description="some pydantic item", tax=0.2), PDItem),
],
)
def test_json_coder(value: Any, return_type) -> None:
def test_json_coder(value: Any, return_type: Type[Any]) -> None:
encoded_value = JsonCoder.encode(value)
assert isinstance(encoded_value, bytes)
decoded_value = JsonCoder.decode_as_type(encoded_value, type_=return_type)

View File

@@ -1,5 +1,5 @@
import time
from typing import Generator
from typing import Generator, Any
import pendulum
import pytest
@@ -11,7 +11,7 @@ from fastapi_cache.backends.inmemory import InMemoryBackend
@pytest.fixture(autouse=True)
def init_cache() -> Generator:
def init_cache() -> Generator[Any, Any, None]:
FastAPICache.init(InMemoryBackend())
yield
FastAPICache.reset()
@@ -21,33 +21,33 @@ def test_datetime() -> None:
with TestClient(app) as client:
response = client.get("/datetime")
now = response.json().get("now")
now_ = pendulum.now().replace(microsecond=0)
assert pendulum.parse(now).replace(microsecond=0) == now_
now_ = pendulum.now().replace(microsecond=0) # type: ignore[no-untyped-call]
assert pendulum.parse(now).replace(microsecond=0) == now_ # type: ignore[attr-defined]
response = client.get("/datetime")
now = response.json().get("now")
assert pendulum.parse(now).replace(microsecond=0) == now_
assert pendulum.parse(now).replace(microsecond=0) == now_ # type: ignore[attr-defined]
time.sleep(3)
response = client.get("/datetime")
now = response.json().get("now")
now = pendulum.parse(now).replace(microsecond=0)
now = pendulum.parse(now).replace(microsecond=0) # type: ignore[attr-defined]
assert now != now_
assert now == pendulum.now().replace(microsecond=0)
assert now == pendulum.now().replace(microsecond=0) # type: ignore[no-untyped-call]
def test_date() -> None:
"""Test path function without request or response arguments."""
with TestClient(app) as client:
response = client.get("/date")
assert pendulum.parse(response.json()) == pendulum.today()
assert pendulum.parse(response.json()) == pendulum.today() # type: ignore[attr-defined]
# do it again to test cache
response = client.get("/date")
assert pendulum.parse(response.json()) == pendulum.today()
assert pendulum.parse(response.json()) == pendulum.today() # type: ignore[attr-defined]
# now test with cache disabled, as that's a separate code path
FastAPICache._enable = False
response = client.get("/date")
assert pendulum.parse(response.json()) == pendulum.today()
assert pendulum.parse(response.json()) == pendulum.today() # type: ignore[attr-defined]
FastAPICache._enable = True