feat: add more type hints

This commit is contained in:
Ivan Moiseev
2022-10-22 20:59:37 +04:00
parent 1ef80ff457
commit 4c6abcf786
9 changed files with 71 additions and 67 deletions

View File

@@ -1,28 +1,29 @@
from typing import Callable
from typing import Callable, Optional, Type
from fastapi_cache.backends import Backend
from fastapi_cache.coder import Coder, JsonCoder
from fastapi_cache.key_builder import default_key_builder
class FastAPICache:
_backend = None
_prefix = None
_expire = None
_backend: Optional[Backend] = None
_prefix: Optional[str] = None
_expire: Optional[int] = None
_init = False
_coder = None
_key_builder = None
_coder: Optional[Type[Coder]] = None
_key_builder: Optional[Callable] = None
_enable = True
@classmethod
def init(
cls,
backend,
backend: Backend,
prefix: str = "",
expire: int = None,
coder: Coder = JsonCoder,
expire: Optional[int] = None,
coder: Type[Coder] = JsonCoder,
key_builder: Callable = default_key_builder,
enable: bool = True,
):
) -> None:
if cls._init:
return
cls._init = True
@@ -34,31 +35,31 @@ class FastAPICache:
cls._enable = enable
@classmethod
def get_backend(cls):
def get_backend(cls) -> Backend:
assert cls._backend, "You must call init first!" # nosec: B101
return cls._backend
@classmethod
def get_prefix(cls):
def get_prefix(cls) -> Optional[str]:
return cls._prefix
@classmethod
def get_expire(cls):
def get_expire(cls) -> Optional[int]:
return cls._expire
@classmethod
def get_coder(cls):
def get_coder(cls) -> Optional[Type[Coder]]:
return cls._coder
@classmethod
def get_key_builder(cls):
def get_key_builder(cls) -> Optional[Callable]:
return cls._key_builder
@classmethod
def get_enable(cls):
def get_enable(cls) -> bool:
return cls._enable
@classmethod
async def clear(cls, namespace: str = None, key: str = None):
async def clear(cls, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
namespace = cls._prefix + (":" + namespace if namespace else "")
return await cls._backend.clear(namespace, key)

View File

@@ -1,20 +1,20 @@
import abc
from typing import Tuple
from typing import Tuple, Optional
class Backend:
@abc.abstractmethod
async def get_with_ttl(self, key: str) -> Tuple[int, str]:
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
raise NotImplementedError
@abc.abstractmethod
async def get(self, key: str) -> str:
async def get(self, key: str) -> Optional[str]:
raise NotImplementedError
@abc.abstractmethod
async def set(self, key: str, value: str, expire: int = None):
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
raise NotImplementedError
@abc.abstractmethod
async def clear(self, namespace: str = None, key: str = None) -> int:
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
raise NotImplementedError

View File

@@ -1,6 +1,7 @@
import datetime
from typing import Tuple
from typing import Tuple, Optional
from aiobotocore.client import AioBaseClient
from aiobotocore.session import get_session
from fastapi_cache.backends import Backend
@@ -24,18 +25,18 @@ class DynamoBackend(Backend):
>> FastAPICache.init(dynamodb)
"""
def __init__(self, table_name, region=None):
def __init__(self, table_name: str, region: Optional[str] = None) -> None:
self.session = get_session()
self.client = None # Needs async init
self.client: Optional[AioBaseClient] = None # Needs async init
self.table_name = table_name
self.region = region
async def init(self):
async def init(self) -> None:
self.client = await self.session.create_client(
"dynamodb", region_name=self.region
).__aenter__()
async def close(self):
async def close(self) -> None:
self.client = await self.client.__aexit__(None, None, None)
async def get_with_ttl(self, key: str) -> Tuple[int, str]:
@@ -55,12 +56,12 @@ class DynamoBackend(Backend):
return 0, None
async def get(self, key) -> str:
async def get(self, key: str) -> str:
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
if "Item" in response:
return response["Item"].get("value", {}).get("S")
async def set(self, key: str, value: str, expire: int = None):
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
ttl = (
{
"ttl": {
@@ -88,5 +89,5 @@ class DynamoBackend(Backend):
},
)
async def clear(self, namespace: str = None, key: str = None) -> int:
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
raise NotImplementedError

View File

@@ -20,13 +20,14 @@ class InMemoryBackend(Backend):
def _now(self) -> int:
return int(time.time())
def _get(self, key: str):
def _get(self, key: str) -> Value | None:
v = self._store.get(key)
if v:
if v.ttl_ts < self._now:
del self._store[key]
else:
return v
return None
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
async with self._lock:
@@ -35,17 +36,18 @@ class InMemoryBackend(Backend):
return v.ttl_ts - self._now, v.data
return 0, None
async def get(self, key: str) -> str:
async def get(self, key: str) -> Optional[str]:
async with self._lock:
v = self._get(key)
if v:
return v.data
return None
async def set(self, key: str, value: str, expire: int = None):
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
async with self._lock:
self._store[key] = Value(value, self._now + (expire or 0))
async def clear(self, namespace: str = None, key: str = None) -> int:
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
count = 0
if namespace:
keys = list(self._store.keys())

View File

@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Tuple, Optional
from aiomcache import Client
@@ -9,14 +9,14 @@ class MemcachedBackend(Backend):
def __init__(self, mcache: Client):
self.mcache = mcache
async def get_with_ttl(self, key: str) -> Tuple[int, str]:
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
return 3600, await self.mcache.get(key.encode())
async def get(self, key: str):
async def get(self, key: str) -> Optional[str]:
return await self.mcache.get(key, key.encode())
async def set(self, key: str, value: str, expire: int = None):
return await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0)
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0)
async def clear(self, namespace: str = None, key: str = None):
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
raise NotImplementedError

View File

@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Tuple, Optional
from redis.asyncio.client import Redis
@@ -13,15 +13,16 @@ class RedisBackend(Backend):
async with self.redis.pipeline(transaction=True) as pipe:
return await (pipe.ttl(key).get(key).execute())
async def get(self, key) -> str:
async def get(self, key: str) -> Optional[str]:
return await self.redis.get(key)
async def set(self, key: str, value: str, expire: int = None):
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
return await self.redis.set(key, value, ex=expire)
async def clear(self, namespace: str = None, key: str = None) -> int:
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
if namespace:
lua = f"for i, name in ipairs(redis.call('KEYS', '{namespace}:*')) do redis.call('DEL', name); end"
return await self.redis.eval(lua, numkeys=0)
elif key:
return await self.redis.delete(key)
return 0

View File

@@ -2,7 +2,7 @@ import datetime
import json
import pickle # nosec:B403
from decimal import Decimal
from typing import Any
from typing import Any, Dict, Union
import pendulum
from fastapi.encoders import jsonable_encoder
@@ -15,7 +15,7 @@ CONVERTERS = {
class JsonEncoder(json.JSONEncoder):
def default(self, obj):
def default(self, obj: Any) -> Any:
if isinstance(obj, datetime.datetime):
return {"val": str(obj), "_spec_type": "datetime"}
elif isinstance(obj, datetime.date):
@@ -26,42 +26,42 @@ class JsonEncoder(json.JSONEncoder):
return jsonable_encoder(obj)
def object_hook(obj):
def object_hook(obj: Any) -> Any:
_spec_type = obj.get("_spec_type")
if not _spec_type:
return obj
if _spec_type in CONVERTERS:
return CONVERTERS[_spec_type](obj["val"])
return CONVERTERS[_spec_type](obj["val"]) # type: ignore
else:
raise TypeError("Unknown {}".format(_spec_type))
class Coder:
@classmethod
def encode(cls, value: Any):
def encode(cls, value: Any) -> Union[str, bytes]:
raise NotImplementedError
@classmethod
def decode(cls, value: Any):
def decode(cls, value: Any) -> Any:
raise NotImplementedError
class JsonCoder(Coder):
@classmethod
def encode(cls, value: Any):
def encode(cls, value: Any) -> str:
return json.dumps(value, cls=JsonEncoder)
@classmethod
def decode(cls, value: Any):
def decode(cls, value: Any) -> str:
return json.loads(value, object_hook=object_hook)
class PickleCoder(Coder):
@classmethod
def encode(cls, value: Any):
def encode(cls, value: Any) -> Union[str, bytes]:
return pickle.dumps(value)
@classmethod
def decode(cls, value: Any):
def decode(cls, value: Any) -> Any:
return pickle.loads(value) # nosec:B403,B301

View File

@@ -1,6 +1,6 @@
import inspect
from functools import wraps
from typing import Callable, Optional, Type
from typing import Callable, Optional, Type, Any
from fastapi.concurrency import run_in_threadpool
from starlette.requests import Request
@@ -11,11 +11,11 @@ from fastapi_cache.coder import Coder
def cache(
expire: int = None,
coder: Type[Coder] = None,
key_builder: Callable = None,
expire: Optional[int] = None,
coder: Optional[Type[Coder]] = None,
key_builder: Optional[Callable] = None,
namespace: Optional[str] = "",
):
) -> Callable:
"""
cache all function
:param namespace:
@@ -26,7 +26,7 @@ def cache(
:return:
"""
def wrapper(func):
def wrapper(func: Callable) -> Callable:
signature = inspect.signature(func)
request_param = next(
(param for param in signature.parameters.values() if param.annotation is Request),
@@ -55,15 +55,15 @@ def cache(
)
if parameters:
signature = signature.replace(parameters=parameters)
func.__signature__ = signature
func.__signature__ = signature # type: ignore
@wraps(func)
async def inner(*args, **kwargs):
async def inner(*args: Any, **kwargs: Any) -> Any:
nonlocal coder
nonlocal expire
nonlocal key_builder
async def ensure_async_func(*args, **kwargs):
async def ensure_async_func(*args: Any, **kwargs: Any) -> Any:
"""Run cached sync functions in thread pool just like FastAPI."""
# if the wrapped function does NOT have request or response in its function signature,
# make sure we don't pass them in as keyword arguments
@@ -83,7 +83,6 @@ def cache(
# see above why we have to await even although caller also awaits.
return await run_in_threadpool(func, *args, **kwargs)
copy_kwargs = kwargs.copy()
request = copy_kwargs.pop("request", None)
response = copy_kwargs.pop("response", None)

View File

@@ -1,18 +1,18 @@
import hashlib
from typing import Optional
from typing import Optional, Callable
from starlette.requests import Request
from starlette.responses import Response
def default_key_builder(
func,
func: Callable,
namespace: Optional[str] = "",
request: Optional[Request] = None,
response: Optional[Response] = None,
args: Optional[tuple] = None,
kwargs: Optional[dict] = None,
):
) -> str:
from fastapi_cache import FastAPICache
prefix = f"{FastAPICache.get_prefix()}:{namespace}:"