feat: add more type hints

This commit is contained in:
Ivan Moiseev
2022-10-22 20:59:37 +04:00
parent 1ef80ff457
commit 4c6abcf786
9 changed files with 71 additions and 67 deletions

View File

@@ -1,28 +1,29 @@
from typing import Callable from typing import Callable, Optional, Type
from fastapi_cache.backends import Backend
from fastapi_cache.coder import Coder, JsonCoder from fastapi_cache.coder import Coder, JsonCoder
from fastapi_cache.key_builder import default_key_builder from fastapi_cache.key_builder import default_key_builder
class FastAPICache: class FastAPICache:
_backend = None _backend: Optional[Backend] = None
_prefix = None _prefix: Optional[str] = None
_expire = None _expire: Optional[int] = None
_init = False _init = False
_coder = None _coder: Optional[Type[Coder]] = None
_key_builder = None _key_builder: Optional[Callable] = None
_enable = True _enable = True
@classmethod @classmethod
def init( def init(
cls, cls,
backend, backend: Backend,
prefix: str = "", prefix: str = "",
expire: int = None, expire: Optional[int] = None,
coder: Coder = JsonCoder, coder: Type[Coder] = JsonCoder,
key_builder: Callable = default_key_builder, key_builder: Callable = default_key_builder,
enable: bool = True, enable: bool = True,
): ) -> None:
if cls._init: if cls._init:
return return
cls._init = True cls._init = True
@@ -34,31 +35,31 @@ class FastAPICache:
cls._enable = enable cls._enable = enable
@classmethod @classmethod
def get_backend(cls): def get_backend(cls) -> Backend:
assert cls._backend, "You must call init first!" # nosec: B101 assert cls._backend, "You must call init first!" # nosec: B101
return cls._backend return cls._backend
@classmethod @classmethod
def get_prefix(cls): def get_prefix(cls) -> Optional[str]:
return cls._prefix return cls._prefix
@classmethod @classmethod
def get_expire(cls): def get_expire(cls) -> Optional[int]:
return cls._expire return cls._expire
@classmethod @classmethod
def get_coder(cls): def get_coder(cls) -> Optional[Type[Coder]]:
return cls._coder return cls._coder
@classmethod @classmethod
def get_key_builder(cls): def get_key_builder(cls) -> Optional[Callable]:
return cls._key_builder return cls._key_builder
@classmethod @classmethod
def get_enable(cls): def get_enable(cls) -> bool:
return cls._enable return cls._enable
@classmethod @classmethod
async def clear(cls, namespace: str = None, key: str = None): async def clear(cls, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
namespace = cls._prefix + (":" + namespace if namespace else "") namespace = cls._prefix + (":" + namespace if namespace else "")
return await cls._backend.clear(namespace, key) return await cls._backend.clear(namespace, key)

View File

@@ -1,20 +1,20 @@
import abc import abc
from typing import Tuple from typing import Tuple, Optional
class Backend: class Backend:
@abc.abstractmethod @abc.abstractmethod
async def get_with_ttl(self, key: str) -> Tuple[int, str]: async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def get(self, key: str) -> str: async def get(self, key: str) -> Optional[str]:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def set(self, key: str, value: str, expire: int = None): async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def clear(self, namespace: str = None, key: str = None) -> int: async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
raise NotImplementedError raise NotImplementedError

View File

@@ -1,6 +1,7 @@
import datetime import datetime
from typing import Tuple from typing import Tuple, Optional
from aiobotocore.client import AioBaseClient
from aiobotocore.session import get_session from aiobotocore.session import get_session
from fastapi_cache.backends import Backend from fastapi_cache.backends import Backend
@@ -24,18 +25,18 @@ class DynamoBackend(Backend):
>> FastAPICache.init(dynamodb) >> FastAPICache.init(dynamodb)
""" """
def __init__(self, table_name, region=None): def __init__(self, table_name: str, region: Optional[str] = None) -> None:
self.session = get_session() self.session = get_session()
self.client = None # Needs async init self.client: Optional[AioBaseClient] = None # Needs async init
self.table_name = table_name self.table_name = table_name
self.region = region self.region = region
async def init(self): async def init(self) -> None:
self.client = await self.session.create_client( self.client = await self.session.create_client(
"dynamodb", region_name=self.region "dynamodb", region_name=self.region
).__aenter__() ).__aenter__()
async def close(self): async def close(self) -> None:
self.client = await self.client.__aexit__(None, None, None) self.client = await self.client.__aexit__(None, None, None)
async def get_with_ttl(self, key: str) -> Tuple[int, str]: async def get_with_ttl(self, key: str) -> Tuple[int, str]:
@@ -55,12 +56,12 @@ class DynamoBackend(Backend):
return 0, None return 0, None
async def get(self, key) -> str: async def get(self, key: str) -> str:
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}}) response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
if "Item" in response: if "Item" in response:
return response["Item"].get("value", {}).get("S") return response["Item"].get("value", {}).get("S")
async def set(self, key: str, value: str, expire: int = None): async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
ttl = ( ttl = (
{ {
"ttl": { "ttl": {
@@ -88,5 +89,5 @@ class DynamoBackend(Backend):
}, },
) )
async def clear(self, namespace: str = None, key: str = None) -> int: async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
raise NotImplementedError raise NotImplementedError

View File

@@ -20,13 +20,14 @@ class InMemoryBackend(Backend):
def _now(self) -> int: def _now(self) -> int:
return int(time.time()) return int(time.time())
def _get(self, key: str): def _get(self, key: str) -> Value | None:
v = self._store.get(key) v = self._store.get(key)
if v: if v:
if v.ttl_ts < self._now: if v.ttl_ts < self._now:
del self._store[key] del self._store[key]
else: else:
return v return v
return None
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]: async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
async with self._lock: async with self._lock:
@@ -35,17 +36,18 @@ class InMemoryBackend(Backend):
return v.ttl_ts - self._now, v.data return v.ttl_ts - self._now, v.data
return 0, None return 0, None
async def get(self, key: str) -> str: async def get(self, key: str) -> Optional[str]:
async with self._lock: async with self._lock:
v = self._get(key) v = self._get(key)
if v: if v:
return v.data return v.data
return None
async def set(self, key: str, value: str, expire: int = None): async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
async with self._lock: async with self._lock:
self._store[key] = Value(value, self._now + (expire or 0)) self._store[key] = Value(value, self._now + (expire or 0))
async def clear(self, namespace: str = None, key: str = None) -> int: async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
count = 0 count = 0
if namespace: if namespace:
keys = list(self._store.keys()) keys = list(self._store.keys())

View File

@@ -1,4 +1,4 @@
from typing import Tuple from typing import Tuple, Optional
from aiomcache import Client from aiomcache import Client
@@ -9,14 +9,14 @@ class MemcachedBackend(Backend):
def __init__(self, mcache: Client): def __init__(self, mcache: Client):
self.mcache = mcache self.mcache = mcache
async def get_with_ttl(self, key: str) -> Tuple[int, str]: async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
return 3600, await self.mcache.get(key.encode()) return 3600, await self.mcache.get(key.encode())
async def get(self, key: str): async def get(self, key: str) -> Optional[str]:
return await self.mcache.get(key, key.encode()) return await self.mcache.get(key, key.encode())
async def set(self, key: str, value: str, expire: int = None): async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
return await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0) await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0)
async def clear(self, namespace: str = None, key: str = None): async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
raise NotImplementedError raise NotImplementedError

View File

@@ -1,4 +1,4 @@
from typing import Tuple from typing import Tuple, Optional
from redis.asyncio.client import Redis from redis.asyncio.client import Redis
@@ -13,15 +13,16 @@ class RedisBackend(Backend):
async with self.redis.pipeline(transaction=True) as pipe: async with self.redis.pipeline(transaction=True) as pipe:
return await (pipe.ttl(key).get(key).execute()) return await (pipe.ttl(key).get(key).execute())
async def get(self, key) -> str: async def get(self, key: str) -> Optional[str]:
return await self.redis.get(key) return await self.redis.get(key)
async def set(self, key: str, value: str, expire: int = None): async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
return await self.redis.set(key, value, ex=expire) return await self.redis.set(key, value, ex=expire)
async def clear(self, namespace: str = None, key: str = None) -> int: async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
if namespace: if namespace:
lua = f"for i, name in ipairs(redis.call('KEYS', '{namespace}:*')) do redis.call('DEL', name); end" lua = f"for i, name in ipairs(redis.call('KEYS', '{namespace}:*')) do redis.call('DEL', name); end"
return await self.redis.eval(lua, numkeys=0) return await self.redis.eval(lua, numkeys=0)
elif key: elif key:
return await self.redis.delete(key) return await self.redis.delete(key)
return 0

View File

@@ -2,7 +2,7 @@ import datetime
import json import json
import pickle # nosec:B403 import pickle # nosec:B403
from decimal import Decimal from decimal import Decimal
from typing import Any from typing import Any, Dict, Union
import pendulum import pendulum
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
@@ -15,7 +15,7 @@ CONVERTERS = {
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj: Any) -> Any:
if isinstance(obj, datetime.datetime): if isinstance(obj, datetime.datetime):
return {"val": str(obj), "_spec_type": "datetime"} return {"val": str(obj), "_spec_type": "datetime"}
elif isinstance(obj, datetime.date): elif isinstance(obj, datetime.date):
@@ -26,42 +26,42 @@ class JsonEncoder(json.JSONEncoder):
return jsonable_encoder(obj) return jsonable_encoder(obj)
def object_hook(obj): def object_hook(obj: Any) -> Any:
_spec_type = obj.get("_spec_type") _spec_type = obj.get("_spec_type")
if not _spec_type: if not _spec_type:
return obj return obj
if _spec_type in CONVERTERS: if _spec_type in CONVERTERS:
return CONVERTERS[_spec_type](obj["val"]) return CONVERTERS[_spec_type](obj["val"]) # type: ignore
else: else:
raise TypeError("Unknown {}".format(_spec_type)) raise TypeError("Unknown {}".format(_spec_type))
class Coder: class Coder:
@classmethod @classmethod
def encode(cls, value: Any): def encode(cls, value: Any) -> Union[str, bytes]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def decode(cls, value: Any): def decode(cls, value: Any) -> Any:
raise NotImplementedError raise NotImplementedError
class JsonCoder(Coder): class JsonCoder(Coder):
@classmethod @classmethod
def encode(cls, value: Any): def encode(cls, value: Any) -> str:
return json.dumps(value, cls=JsonEncoder) return json.dumps(value, cls=JsonEncoder)
@classmethod @classmethod
def decode(cls, value: Any): def decode(cls, value: Any) -> str:
return json.loads(value, object_hook=object_hook) return json.loads(value, object_hook=object_hook)
class PickleCoder(Coder): class PickleCoder(Coder):
@classmethod @classmethod
def encode(cls, value: Any): def encode(cls, value: Any) -> Union[str, bytes]:
return pickle.dumps(value) return pickle.dumps(value)
@classmethod @classmethod
def decode(cls, value: Any): def decode(cls, value: Any) -> Any:
return pickle.loads(value) # nosec:B403,B301 return pickle.loads(value) # nosec:B403,B301

View File

@@ -1,6 +1,6 @@
import inspect import inspect
from functools import wraps from functools import wraps
from typing import Callable, Optional, Type from typing import Callable, Optional, Type, Any
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
from starlette.requests import Request from starlette.requests import Request
@@ -11,11 +11,11 @@ from fastapi_cache.coder import Coder
def cache( def cache(
expire: int = None, expire: Optional[int] = None,
coder: Type[Coder] = None, coder: Optional[Type[Coder]] = None,
key_builder: Callable = None, key_builder: Optional[Callable] = None,
namespace: Optional[str] = "", namespace: Optional[str] = "",
): ) -> Callable:
""" """
cache all function cache all function
:param namespace: :param namespace:
@@ -26,7 +26,7 @@ def cache(
:return: :return:
""" """
def wrapper(func): def wrapper(func: Callable) -> Callable:
signature = inspect.signature(func) signature = inspect.signature(func)
request_param = next( request_param = next(
(param for param in signature.parameters.values() if param.annotation is Request), (param for param in signature.parameters.values() if param.annotation is Request),
@@ -55,15 +55,15 @@ def cache(
) )
if parameters: if parameters:
signature = signature.replace(parameters=parameters) signature = signature.replace(parameters=parameters)
func.__signature__ = signature func.__signature__ = signature # type: ignore
@wraps(func) @wraps(func)
async def inner(*args, **kwargs): async def inner(*args: Any, **kwargs: Any) -> Any:
nonlocal coder nonlocal coder
nonlocal expire nonlocal expire
nonlocal key_builder nonlocal key_builder
async def ensure_async_func(*args, **kwargs): async def ensure_async_func(*args: Any, **kwargs: Any) -> Any:
"""Run cached sync functions in thread pool just like FastAPI.""" """Run cached sync functions in thread pool just like FastAPI."""
# if the wrapped function does NOT have request or response in its function signature, # if the wrapped function does NOT have request or response in its function signature,
# make sure we don't pass them in as keyword arguments # make sure we don't pass them in as keyword arguments
@@ -83,7 +83,6 @@ def cache(
# see above why we have to await even although caller also awaits. # see above why we have to await even although caller also awaits.
return await run_in_threadpool(func, *args, **kwargs) return await run_in_threadpool(func, *args, **kwargs)
copy_kwargs = kwargs.copy() copy_kwargs = kwargs.copy()
request = copy_kwargs.pop("request", None) request = copy_kwargs.pop("request", None)
response = copy_kwargs.pop("response", None) response = copy_kwargs.pop("response", None)

View File

@@ -1,18 +1,18 @@
import hashlib import hashlib
from typing import Optional from typing import Optional, Callable
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
def default_key_builder( def default_key_builder(
func, func: Callable,
namespace: Optional[str] = "", namespace: Optional[str] = "",
request: Optional[Request] = None, request: Optional[Request] = None,
response: Optional[Response] = None, response: Optional[Response] = None,
args: Optional[tuple] = None, args: Optional[tuple] = None,
kwargs: Optional[dict] = None, kwargs: Optional[dict] = None,
): ) -> str:
from fastapi_cache import FastAPICache from fastapi_cache import FastAPICache
prefix = f"{FastAPICache.get_prefix()}:{namespace}:" prefix = f"{FastAPICache.get_prefix()}:{namespace}:"