diff --git a/fastapi_cache/__init__.py b/fastapi_cache/__init__.py index 4248d22..eaa65f4 100644 --- a/fastapi_cache/__init__.py +++ b/fastapi_cache/__init__.py @@ -1,28 +1,29 @@ -from typing import Callable +from typing import Callable, Optional, Type +from fastapi_cache.backends import Backend from fastapi_cache.coder import Coder, JsonCoder from fastapi_cache.key_builder import default_key_builder class FastAPICache: - _backend = None - _prefix = None - _expire = None + _backend: Optional[Backend] = None + _prefix: Optional[str] = None + _expire: Optional[int] = None _init = False - _coder = None - _key_builder = None + _coder: Optional[Type[Coder]] = None + _key_builder: Optional[Callable] = None _enable = True @classmethod def init( cls, - backend, + backend: Backend, prefix: str = "", - expire: int = None, - coder: Coder = JsonCoder, + expire: Optional[int] = None, + coder: Type[Coder] = JsonCoder, key_builder: Callable = default_key_builder, enable: bool = True, - ): + ) -> None: if cls._init: return cls._init = True @@ -34,31 +35,31 @@ class FastAPICache: cls._enable = enable @classmethod - def get_backend(cls): + def get_backend(cls) -> Backend: assert cls._backend, "You must call init first!" # nosec: B101 return cls._backend @classmethod - def get_prefix(cls): + def get_prefix(cls) -> Optional[str]: return cls._prefix @classmethod - def get_expire(cls): + def get_expire(cls) -> Optional[int]: return cls._expire @classmethod - def get_coder(cls): + def get_coder(cls) -> Optional[Type[Coder]]: return cls._coder @classmethod - def get_key_builder(cls): + def get_key_builder(cls) -> Optional[Callable]: return cls._key_builder @classmethod - def get_enable(cls): + def get_enable(cls) -> bool: return cls._enable @classmethod - async def clear(cls, namespace: str = None, key: str = None): + async def clear(cls, namespace: Optional[str] = None, key: Optional[str] = None) -> int: namespace = cls._prefix + (":" + namespace if namespace else "") return await cls._backend.clear(namespace, key) diff --git a/fastapi_cache/backends/__init__.py b/fastapi_cache/backends/__init__.py index 8adaa31..7b3d070 100644 --- a/fastapi_cache/backends/__init__.py +++ b/fastapi_cache/backends/__init__.py @@ -1,20 +1,20 @@ import abc -from typing import Tuple +from typing import Tuple, Optional class Backend: @abc.abstractmethod - async def get_with_ttl(self, key: str) -> Tuple[int, str]: + async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]: raise NotImplementedError @abc.abstractmethod - async def get(self, key: str) -> str: + async def get(self, key: str) -> Optional[str]: raise NotImplementedError @abc.abstractmethod - async def set(self, key: str, value: str, expire: int = None): + async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: raise NotImplementedError @abc.abstractmethod - async def clear(self, namespace: str = None, key: str = None) -> int: + async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: raise NotImplementedError diff --git a/fastapi_cache/backends/dynamodb.py b/fastapi_cache/backends/dynamodb.py index 1a6aa2c..fc3650b 100644 --- a/fastapi_cache/backends/dynamodb.py +++ b/fastapi_cache/backends/dynamodb.py @@ -1,6 +1,7 @@ import datetime -from typing import Tuple +from typing import Tuple, Optional +from aiobotocore.client import AioBaseClient from aiobotocore.session import get_session from fastapi_cache.backends import Backend @@ -24,18 +25,18 @@ class DynamoBackend(Backend): >> FastAPICache.init(dynamodb) """ - def __init__(self, table_name, region=None): + def __init__(self, table_name: str, region: Optional[str] = None) -> None: self.session = get_session() - self.client = None # Needs async init + self.client: Optional[AioBaseClient] = None # Needs async init self.table_name = table_name self.region = region - async def init(self): + async def init(self) -> None: self.client = await self.session.create_client( "dynamodb", region_name=self.region ).__aenter__() - async def close(self): + async def close(self) -> None: self.client = await self.client.__aexit__(None, None, None) async def get_with_ttl(self, key: str) -> Tuple[int, str]: @@ -55,12 +56,12 @@ class DynamoBackend(Backend): return 0, None - async def get(self, key) -> str: + async def get(self, key: str) -> str: response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}}) if "Item" in response: return response["Item"].get("value", {}).get("S") - async def set(self, key: str, value: str, expire: int = None): + async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: ttl = ( { "ttl": { @@ -88,5 +89,5 @@ class DynamoBackend(Backend): }, ) - async def clear(self, namespace: str = None, key: str = None) -> int: + async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: raise NotImplementedError diff --git a/fastapi_cache/backends/inmemory.py b/fastapi_cache/backends/inmemory.py index 4238b24..94a4e09 100644 --- a/fastapi_cache/backends/inmemory.py +++ b/fastapi_cache/backends/inmemory.py @@ -20,13 +20,14 @@ class InMemoryBackend(Backend): def _now(self) -> int: return int(time.time()) - def _get(self, key: str): + def _get(self, key: str) -> Value | None: v = self._store.get(key) if v: if v.ttl_ts < self._now: del self._store[key] else: return v + return None async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]: async with self._lock: @@ -35,17 +36,18 @@ class InMemoryBackend(Backend): return v.ttl_ts - self._now, v.data return 0, None - async def get(self, key: str) -> str: + async def get(self, key: str) -> Optional[str]: async with self._lock: v = self._get(key) if v: return v.data + return None - async def set(self, key: str, value: str, expire: int = None): + async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: async with self._lock: self._store[key] = Value(value, self._now + (expire or 0)) - async def clear(self, namespace: str = None, key: str = None) -> int: + async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: count = 0 if namespace: keys = list(self._store.keys()) diff --git a/fastapi_cache/backends/memcached.py b/fastapi_cache/backends/memcached.py index 8702f70..3c96ef4 100644 --- a/fastapi_cache/backends/memcached.py +++ b/fastapi_cache/backends/memcached.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional from aiomcache import Client @@ -9,14 +9,14 @@ class MemcachedBackend(Backend): def __init__(self, mcache: Client): self.mcache = mcache - async def get_with_ttl(self, key: str) -> Tuple[int, str]: + async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]: return 3600, await self.mcache.get(key.encode()) - async def get(self, key: str): + async def get(self, key: str) -> Optional[str]: return await self.mcache.get(key, key.encode()) - async def set(self, key: str, value: str, expire: int = None): - return await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0) + async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: + await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0) - async def clear(self, namespace: str = None, key: str = None): + async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: raise NotImplementedError diff --git a/fastapi_cache/backends/redis.py b/fastapi_cache/backends/redis.py index b5147dc..20b3538 100644 --- a/fastapi_cache/backends/redis.py +++ b/fastapi_cache/backends/redis.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional from redis.asyncio.client import Redis @@ -13,15 +13,16 @@ class RedisBackend(Backend): async with self.redis.pipeline(transaction=True) as pipe: return await (pipe.ttl(key).get(key).execute()) - async def get(self, key) -> str: + async def get(self, key: str) -> Optional[str]: return await self.redis.get(key) - async def set(self, key: str, value: str, expire: int = None): + async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: return await self.redis.set(key, value, ex=expire) - async def clear(self, namespace: str = None, key: str = None) -> int: + async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: if namespace: lua = f"for i, name in ipairs(redis.call('KEYS', '{namespace}:*')) do redis.call('DEL', name); end" return await self.redis.eval(lua, numkeys=0) elif key: return await self.redis.delete(key) + return 0 diff --git a/fastapi_cache/coder.py b/fastapi_cache/coder.py index 0d45079..5d7e4a5 100644 --- a/fastapi_cache/coder.py +++ b/fastapi_cache/coder.py @@ -2,7 +2,7 @@ import datetime import json import pickle # nosec:B403 from decimal import Decimal -from typing import Any +from typing import Any, Dict, Union import pendulum from fastapi.encoders import jsonable_encoder @@ -15,7 +15,7 @@ CONVERTERS = { class JsonEncoder(json.JSONEncoder): - def default(self, obj): + def default(self, obj: Any) -> Any: if isinstance(obj, datetime.datetime): return {"val": str(obj), "_spec_type": "datetime"} elif isinstance(obj, datetime.date): @@ -26,42 +26,42 @@ class JsonEncoder(json.JSONEncoder): return jsonable_encoder(obj) -def object_hook(obj): +def object_hook(obj: Any) -> Any: _spec_type = obj.get("_spec_type") if not _spec_type: return obj if _spec_type in CONVERTERS: - return CONVERTERS[_spec_type](obj["val"]) + return CONVERTERS[_spec_type](obj["val"]) # type: ignore else: raise TypeError("Unknown {}".format(_spec_type)) class Coder: @classmethod - def encode(cls, value: Any): + def encode(cls, value: Any) -> Union[str, bytes]: raise NotImplementedError @classmethod - def decode(cls, value: Any): + def decode(cls, value: Any) -> Any: raise NotImplementedError class JsonCoder(Coder): @classmethod - def encode(cls, value: Any): + def encode(cls, value: Any) -> str: return json.dumps(value, cls=JsonEncoder) @classmethod - def decode(cls, value: Any): + def decode(cls, value: Any) -> str: return json.loads(value, object_hook=object_hook) class PickleCoder(Coder): @classmethod - def encode(cls, value: Any): + def encode(cls, value: Any) -> Union[str, bytes]: return pickle.dumps(value) @classmethod - def decode(cls, value: Any): + def decode(cls, value: Any) -> Any: return pickle.loads(value) # nosec:B403,B301 diff --git a/fastapi_cache/decorator.py b/fastapi_cache/decorator.py index d185e6e..1ac5f13 100644 --- a/fastapi_cache/decorator.py +++ b/fastapi_cache/decorator.py @@ -1,6 +1,6 @@ import inspect from functools import wraps -from typing import Callable, Optional, Type +from typing import Callable, Optional, Type, Any from fastapi.concurrency import run_in_threadpool from starlette.requests import Request @@ -11,11 +11,11 @@ from fastapi_cache.coder import Coder def cache( - expire: int = None, - coder: Type[Coder] = None, - key_builder: Callable = None, + expire: Optional[int] = None, + coder: Optional[Type[Coder]] = None, + key_builder: Optional[Callable] = None, namespace: Optional[str] = "", -): +) -> Callable: """ cache all function :param namespace: @@ -26,7 +26,7 @@ def cache( :return: """ - def wrapper(func): + def wrapper(func: Callable) -> Callable: signature = inspect.signature(func) request_param = next( (param for param in signature.parameters.values() if param.annotation is Request), @@ -55,15 +55,15 @@ def cache( ) if parameters: signature = signature.replace(parameters=parameters) - func.__signature__ = signature + func.__signature__ = signature # type: ignore @wraps(func) - async def inner(*args, **kwargs): + async def inner(*args: Any, **kwargs: Any) -> Any: nonlocal coder nonlocal expire nonlocal key_builder - async def ensure_async_func(*args, **kwargs): + async def ensure_async_func(*args: Any, **kwargs: Any) -> Any: """Run cached sync functions in thread pool just like FastAPI.""" # if the wrapped function does NOT have request or response in its function signature, # make sure we don't pass them in as keyword arguments @@ -83,7 +83,6 @@ def cache( # see above why we have to await even although caller also awaits. return await run_in_threadpool(func, *args, **kwargs) - copy_kwargs = kwargs.copy() request = copy_kwargs.pop("request", None) response = copy_kwargs.pop("response", None) diff --git a/fastapi_cache/key_builder.py b/fastapi_cache/key_builder.py index 112f2dc..e751e0b 100644 --- a/fastapi_cache/key_builder.py +++ b/fastapi_cache/key_builder.py @@ -1,18 +1,18 @@ import hashlib -from typing import Optional +from typing import Optional, Callable from starlette.requests import Request from starlette.responses import Response def default_key_builder( - func, + func: Callable, namespace: Optional[str] = "", request: Optional[Request] = None, response: Optional[Response] = None, args: Optional[tuple] = None, kwargs: Optional[dict] = None, -): +) -> str: from fastapi_cache import FastAPICache prefix = f"{FastAPICache.get_prefix()}:{namespace}:"