mirror of
https://github.com/long2ice/fastapi-cache.git
synced 2026-03-24 20:47:54 +00:00
Full mypy --strict type checking pass
This commit is contained in:
@@ -1,11 +1,16 @@
|
||||
import datetime
|
||||
from typing import Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from aiobotocore.client import AioBaseClient
|
||||
from aiobotocore.session import get_session, AioSession
|
||||
from aiobotocore.session import AioSession, get_session
|
||||
|
||||
from fastapi_cache.backends import Backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types_aiobotocore_dynamodb import DynamoDBClient
|
||||
else:
|
||||
DynamoDBClient = AioBaseClient
|
||||
|
||||
|
||||
class DynamoBackend(Backend):
|
||||
"""
|
||||
@@ -25,9 +30,13 @@ class DynamoBackend(Backend):
|
||||
>> FastAPICache.init(dynamodb)
|
||||
"""
|
||||
|
||||
client: DynamoDBClient
|
||||
session: AioSession
|
||||
table_name: str
|
||||
region: Optional[str]
|
||||
|
||||
def __init__(self, table_name: str, region: Optional[str] = None) -> None:
|
||||
self.session: AioSession = get_session()
|
||||
self.client: Optional[AioBaseClient] = None # Needs async init
|
||||
self.table_name = table_name
|
||||
self.region = region
|
||||
|
||||
@@ -60,6 +69,7 @@ class DynamoBackend(Backend):
|
||||
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
|
||||
if "Item" in response:
|
||||
return response["Item"].get("value", {}).get("B")
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||
ttl = (
|
||||
|
||||
@@ -13,18 +13,18 @@ class RedisBackend(Backend):
|
||||
|
||||
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||
async with self.redis.pipeline(transaction=not self.is_cluster) as pipe:
|
||||
return await pipe.ttl(key).get(key).execute()
|
||||
return await pipe.ttl(key).get(key).execute() # type: ignore[union-attr,no-any-return]
|
||||
|
||||
async def get(self, key: str) -> Optional[bytes]:
|
||||
return await self.redis.get(key)
|
||||
return await self.redis.get(key) # type: ignore[union-attr]
|
||||
|
||||
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||
return await self.redis.set(key, value, ex=expire)
|
||||
await self.redis.set(key, value, ex=expire) # type: ignore[union-attr]
|
||||
|
||||
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
|
||||
if namespace:
|
||||
lua = f"for i, name in ipairs(redis.call('KEYS', '{namespace}:*')) do redis.call('DEL', name); end"
|
||||
return await self.redis.eval(lua, numkeys=0)
|
||||
return await self.redis.eval(lua, numkeys=0) # type: ignore[union-attr,no-any-return]
|
||||
elif key:
|
||||
return await self.redis.delete(key)
|
||||
return await self.redis.delete(key) # type: ignore[union-attr]
|
||||
return 0
|
||||
|
||||
@@ -14,8 +14,9 @@ _T = TypeVar("_T", bound=type)
|
||||
|
||||
|
||||
CONVERTERS: Dict[str, Callable[[str], Any]] = {
|
||||
"date": lambda x: pendulum.parse(x, exact=True),
|
||||
"datetime": lambda x: pendulum.parse(x, exact=True),
|
||||
# Pendulum 3.0.0 adds parse to __all__, at which point these ignores can be removed
|
||||
"date": lambda x: pendulum.parse(x, exact=True), # type: ignore[attr-defined]
|
||||
"datetime": lambda x: pendulum.parse(x, exact=True), # type: ignore[attr-defined]
|
||||
"decimal": Decimal,
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import sys
|
||||
from functools import wraps
|
||||
from inspect import Parameter, Signature, isawaitable, iscoroutinefunction
|
||||
from typing import Awaitable, Callable, List, Optional, Type, TypeVar, Union
|
||||
from typing import Awaitable, Callable, List, Optional, Type, TypeVar, Union, cast
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import ParamSpec
|
||||
@@ -111,11 +111,11 @@ def cache(
|
||||
else:
|
||||
# sync, wrap in thread and return async
|
||||
# see above why we have to await even although caller also awaits.
|
||||
return await run_in_threadpool(func, *args, **kwargs)
|
||||
return await run_in_threadpool(func, *args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
copy_kwargs = kwargs.copy()
|
||||
request: Optional[Request] = copy_kwargs.pop(request_param.name, None)
|
||||
response: Optional[Response] = copy_kwargs.pop(response_param.name, None)
|
||||
request: Optional[Request] = copy_kwargs.pop(request_param.name, None) # type: ignore[assignment]
|
||||
response: Optional[Response] = copy_kwargs.pop(response_param.name, None) # type: ignore[assignment]
|
||||
if (
|
||||
request and request.headers.get("Cache-Control") in ("no-store", "no-cache")
|
||||
) or not FastAPICache.get_enable():
|
||||
@@ -137,6 +137,7 @@ def cache(
|
||||
)
|
||||
if isawaitable(cache_key):
|
||||
cache_key = await cache_key
|
||||
assert isinstance(cache_key, str)
|
||||
|
||||
try:
|
||||
ttl, cached = await backend.get_with_ttl(cache_key)
|
||||
@@ -147,7 +148,7 @@ def cache(
|
||||
ttl, cached = 0, None
|
||||
if not request:
|
||||
if cached is not None:
|
||||
return coder.decode_as_type(cached, type_=return_type)
|
||||
return cast(R, coder.decode_as_type(cached, type_=return_type))
|
||||
ret = await ensure_async_func(*args, **kwargs)
|
||||
try:
|
||||
await backend.set(cache_key, coder.encode(ret), expire)
|
||||
@@ -169,7 +170,7 @@ def cache(
|
||||
response.status_code = 304
|
||||
return response
|
||||
response.headers["ETag"] = etag
|
||||
return coder.decode_as_type(cached, type_=return_type)
|
||||
return cast(R, coder.decode_as_type(cached, type_=return_type))
|
||||
|
||||
ret = await ensure_async_func(*args, **kwargs)
|
||||
encoded_ret = coder.encode(ret)
|
||||
@@ -185,7 +186,7 @@ def cache(
|
||||
response.headers["ETag"] = etag
|
||||
return ret
|
||||
|
||||
inner.__signature__ = _augment_signature(wrapped_signature, *to_inject)
|
||||
inner.__signature__ = _augment_signature(wrapped_signature, *to_inject) # type: ignore[attr-defined]
|
||||
return inner
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -8,10 +8,11 @@ from starlette.responses import Response
|
||||
def default_key_builder(
|
||||
func: Callable[..., Any],
|
||||
namespace: str = "",
|
||||
*,
|
||||
request: Optional[Request] = None,
|
||||
response: Optional[Response] = None,
|
||||
args: Optional[Tuple[Any, ...]] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> str:
|
||||
cache_key = hashlib.md5( # nosec:B303
|
||||
f"{func.__module__}:{func.__name__}:{args}:{kwargs}".encode()
|
||||
|
||||
@@ -11,12 +11,12 @@ _Func = Callable[..., Any]
|
||||
class KeyBuilder(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
_function: _Func,
|
||||
_namespace: str = ...,
|
||||
__function: _Func,
|
||||
__namespace: str = ...,
|
||||
*,
|
||||
request: Optional[Request] = ...,
|
||||
response: Optional[Response] = ...,
|
||||
args: Tuple[Any, ...] = ...,
|
||||
kwargs: Dict[str, Any] = ...,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Union[Awaitable[str], str]:
|
||||
...
|
||||
|
||||
Reference in New Issue
Block a user