From 4c6abcf7868a0e0bf4d6b4a67a491ebb35954a61 Mon Sep 17 00:00:00 2001 From: Ivan Moiseev Date: Sat, 22 Oct 2022 20:59:37 +0400 Subject: [PATCH 1/6] feat: add more type hints --- fastapi_cache/__init__.py | 35 +++++++++++++++-------------- fastapi_cache/backends/__init__.py | 10 ++++----- fastapi_cache/backends/dynamodb.py | 17 +++++++------- fastapi_cache/backends/inmemory.py | 10 +++++---- fastapi_cache/backends/memcached.py | 12 +++++----- fastapi_cache/backends/redis.py | 9 ++++---- fastapi_cache/coder.py | 20 ++++++++--------- fastapi_cache/decorator.py | 19 ++++++++-------- fastapi_cache/key_builder.py | 6 ++--- 9 files changed, 71 insertions(+), 67 deletions(-) diff --git a/fastapi_cache/__init__.py b/fastapi_cache/__init__.py index 4248d22..eaa65f4 100644 --- a/fastapi_cache/__init__.py +++ b/fastapi_cache/__init__.py @@ -1,28 +1,29 @@ -from typing import Callable +from typing import Callable, Optional, Type +from fastapi_cache.backends import Backend from fastapi_cache.coder import Coder, JsonCoder from fastapi_cache.key_builder import default_key_builder class FastAPICache: - _backend = None - _prefix = None - _expire = None + _backend: Optional[Backend] = None + _prefix: Optional[str] = None + _expire: Optional[int] = None _init = False - _coder = None - _key_builder = None + _coder: Optional[Type[Coder]] = None + _key_builder: Optional[Callable] = None _enable = True @classmethod def init( cls, - backend, + backend: Backend, prefix: str = "", - expire: int = None, - coder: Coder = JsonCoder, + expire: Optional[int] = None, + coder: Type[Coder] = JsonCoder, key_builder: Callable = default_key_builder, enable: bool = True, - ): + ) -> None: if cls._init: return cls._init = True @@ -34,31 +35,31 @@ class FastAPICache: cls._enable = enable @classmethod - def get_backend(cls): + def get_backend(cls) -> Backend: assert cls._backend, "You must call init first!" # nosec: B101 return cls._backend @classmethod - def get_prefix(cls): + def get_prefix(cls) -> Optional[str]: return cls._prefix @classmethod - def get_expire(cls): + def get_expire(cls) -> Optional[int]: return cls._expire @classmethod - def get_coder(cls): + def get_coder(cls) -> Optional[Type[Coder]]: return cls._coder @classmethod - def get_key_builder(cls): + def get_key_builder(cls) -> Optional[Callable]: return cls._key_builder @classmethod - def get_enable(cls): + def get_enable(cls) -> bool: return cls._enable @classmethod - async def clear(cls, namespace: str = None, key: str = None): + async def clear(cls, namespace: Optional[str] = None, key: Optional[str] = None) -> int: namespace = cls._prefix + (":" + namespace if namespace else "") return await cls._backend.clear(namespace, key) diff --git a/fastapi_cache/backends/__init__.py b/fastapi_cache/backends/__init__.py index 8adaa31..7b3d070 100644 --- a/fastapi_cache/backends/__init__.py +++ b/fastapi_cache/backends/__init__.py @@ -1,20 +1,20 @@ import abc -from typing import Tuple +from typing import Tuple, Optional class Backend: @abc.abstractmethod - async def get_with_ttl(self, key: str) -> Tuple[int, str]: + async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]: raise NotImplementedError @abc.abstractmethod - async def get(self, key: str) -> str: + async def get(self, key: str) -> Optional[str]: raise NotImplementedError @abc.abstractmethod - async def set(self, key: str, value: str, expire: int = None): + async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: raise NotImplementedError @abc.abstractmethod - async def clear(self, namespace: str = None, key: str = None) -> int: + async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: raise NotImplementedError diff --git a/fastapi_cache/backends/dynamodb.py b/fastapi_cache/backends/dynamodb.py index 1a6aa2c..fc3650b 100644 --- a/fastapi_cache/backends/dynamodb.py +++ b/fastapi_cache/backends/dynamodb.py @@ -1,6 +1,7 @@ import datetime -from typing import Tuple +from typing import Tuple, Optional +from aiobotocore.client import AioBaseClient from aiobotocore.session import get_session from fastapi_cache.backends import Backend @@ -24,18 +25,18 @@ class DynamoBackend(Backend): >> FastAPICache.init(dynamodb) """ - def __init__(self, table_name, region=None): + def __init__(self, table_name: str, region: Optional[str] = None) -> None: self.session = get_session() - self.client = None # Needs async init + self.client: Optional[AioBaseClient] = None # Needs async init self.table_name = table_name self.region = region - async def init(self): + async def init(self) -> None: self.client = await self.session.create_client( "dynamodb", region_name=self.region ).__aenter__() - async def close(self): + async def close(self) -> None: self.client = await self.client.__aexit__(None, None, None) async def get_with_ttl(self, key: str) -> Tuple[int, str]: @@ -55,12 +56,12 @@ class DynamoBackend(Backend): return 0, None - async def get(self, key) -> str: + async def get(self, key: str) -> str: response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}}) if "Item" in response: return response["Item"].get("value", {}).get("S") - async def set(self, key: str, value: str, expire: int = None): + async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: ttl = ( { "ttl": { @@ -88,5 +89,5 @@ class DynamoBackend(Backend): }, ) - async def clear(self, namespace: str = None, key: str = None) -> int: + async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: raise NotImplementedError diff --git a/fastapi_cache/backends/inmemory.py b/fastapi_cache/backends/inmemory.py index 4238b24..94a4e09 100644 --- a/fastapi_cache/backends/inmemory.py +++ b/fastapi_cache/backends/inmemory.py @@ -20,13 +20,14 @@ class InMemoryBackend(Backend): def _now(self) -> int: return int(time.time()) - def _get(self, key: str): + def _get(self, key: str) -> Value | None: v = self._store.get(key) if v: if v.ttl_ts < self._now: del self._store[key] else: return v + return None async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]: async with self._lock: @@ -35,17 +36,18 @@ class InMemoryBackend(Backend): return v.ttl_ts - self._now, v.data return 0, None - async def get(self, key: str) -> str: + async def get(self, key: str) -> Optional[str]: async with self._lock: v = self._get(key) if v: return v.data + return None - async def set(self, key: str, value: str, expire: int = None): + async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: async with self._lock: self._store[key] = Value(value, self._now + (expire or 0)) - async def clear(self, namespace: str = None, key: str = None) -> int: + async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: count = 0 if namespace: keys = list(self._store.keys()) diff --git a/fastapi_cache/backends/memcached.py b/fastapi_cache/backends/memcached.py index 8702f70..3c96ef4 100644 --- a/fastapi_cache/backends/memcached.py +++ b/fastapi_cache/backends/memcached.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional from aiomcache import Client @@ -9,14 +9,14 @@ class MemcachedBackend(Backend): def __init__(self, mcache: Client): self.mcache = mcache - async def get_with_ttl(self, key: str) -> Tuple[int, str]: + async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]: return 3600, await self.mcache.get(key.encode()) - async def get(self, key: str): + async def get(self, key: str) -> Optional[str]: return await self.mcache.get(key, key.encode()) - async def set(self, key: str, value: str, expire: int = None): - return await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0) + async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: + await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0) - async def clear(self, namespace: str = None, key: str = None): + async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: raise NotImplementedError diff --git a/fastapi_cache/backends/redis.py b/fastapi_cache/backends/redis.py index b5147dc..20b3538 100644 --- a/fastapi_cache/backends/redis.py +++ b/fastapi_cache/backends/redis.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional from redis.asyncio.client import Redis @@ -13,15 +13,16 @@ class RedisBackend(Backend): async with self.redis.pipeline(transaction=True) as pipe: return await (pipe.ttl(key).get(key).execute()) - async def get(self, key) -> str: + async def get(self, key: str) -> Optional[str]: return await self.redis.get(key) - async def set(self, key: str, value: str, expire: int = None): + async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: return await self.redis.set(key, value, ex=expire) - async def clear(self, namespace: str = None, key: str = None) -> int: + async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: if namespace: lua = f"for i, name in ipairs(redis.call('KEYS', '{namespace}:*')) do redis.call('DEL', name); end" return await self.redis.eval(lua, numkeys=0) elif key: return await self.redis.delete(key) + return 0 diff --git a/fastapi_cache/coder.py b/fastapi_cache/coder.py index 0d45079..5d7e4a5 100644 --- a/fastapi_cache/coder.py +++ b/fastapi_cache/coder.py @@ -2,7 +2,7 @@ import datetime import json import pickle # nosec:B403 from decimal import Decimal -from typing import Any +from typing import Any, Dict, Union import pendulum from fastapi.encoders import jsonable_encoder @@ -15,7 +15,7 @@ CONVERTERS = { class JsonEncoder(json.JSONEncoder): - def default(self, obj): + def default(self, obj: Any) -> Any: if isinstance(obj, datetime.datetime): return {"val": str(obj), "_spec_type": "datetime"} elif isinstance(obj, datetime.date): @@ -26,42 +26,42 @@ class JsonEncoder(json.JSONEncoder): return jsonable_encoder(obj) -def object_hook(obj): +def object_hook(obj: Any) -> Any: _spec_type = obj.get("_spec_type") if not _spec_type: return obj if _spec_type in CONVERTERS: - return CONVERTERS[_spec_type](obj["val"]) + return CONVERTERS[_spec_type](obj["val"]) # type: ignore else: raise TypeError("Unknown {}".format(_spec_type)) class Coder: @classmethod - def encode(cls, value: Any): + def encode(cls, value: Any) -> Union[str, bytes]: raise NotImplementedError @classmethod - def decode(cls, value: Any): + def decode(cls, value: Any) -> Any: raise NotImplementedError class JsonCoder(Coder): @classmethod - def encode(cls, value: Any): + def encode(cls, value: Any) -> str: return json.dumps(value, cls=JsonEncoder) @classmethod - def decode(cls, value: Any): + def decode(cls, value: Any) -> str: return json.loads(value, object_hook=object_hook) class PickleCoder(Coder): @classmethod - def encode(cls, value: Any): + def encode(cls, value: Any) -> Union[str, bytes]: return pickle.dumps(value) @classmethod - def decode(cls, value: Any): + def decode(cls, value: Any) -> Any: return pickle.loads(value) # nosec:B403,B301 diff --git a/fastapi_cache/decorator.py b/fastapi_cache/decorator.py index d185e6e..1ac5f13 100644 --- a/fastapi_cache/decorator.py +++ b/fastapi_cache/decorator.py @@ -1,6 +1,6 @@ import inspect from functools import wraps -from typing import Callable, Optional, Type +from typing import Callable, Optional, Type, Any from fastapi.concurrency import run_in_threadpool from starlette.requests import Request @@ -11,11 +11,11 @@ from fastapi_cache.coder import Coder def cache( - expire: int = None, - coder: Type[Coder] = None, - key_builder: Callable = None, + expire: Optional[int] = None, + coder: Optional[Type[Coder]] = None, + key_builder: Optional[Callable] = None, namespace: Optional[str] = "", -): +) -> Callable: """ cache all function :param namespace: @@ -26,7 +26,7 @@ def cache( :return: """ - def wrapper(func): + def wrapper(func: Callable) -> Callable: signature = inspect.signature(func) request_param = next( (param for param in signature.parameters.values() if param.annotation is Request), @@ -55,15 +55,15 @@ def cache( ) if parameters: signature = signature.replace(parameters=parameters) - func.__signature__ = signature + func.__signature__ = signature # type: ignore @wraps(func) - async def inner(*args, **kwargs): + async def inner(*args: Any, **kwargs: Any) -> Any: nonlocal coder nonlocal expire nonlocal key_builder - async def ensure_async_func(*args, **kwargs): + async def ensure_async_func(*args: Any, **kwargs: Any) -> Any: """Run cached sync functions in thread pool just like FastAPI.""" # if the wrapped function does NOT have request or response in its function signature, # make sure we don't pass them in as keyword arguments @@ -83,7 +83,6 @@ def cache( # see above why we have to await even although caller also awaits. return await run_in_threadpool(func, *args, **kwargs) - copy_kwargs = kwargs.copy() request = copy_kwargs.pop("request", None) response = copy_kwargs.pop("response", None) diff --git a/fastapi_cache/key_builder.py b/fastapi_cache/key_builder.py index 112f2dc..e751e0b 100644 --- a/fastapi_cache/key_builder.py +++ b/fastapi_cache/key_builder.py @@ -1,18 +1,18 @@ import hashlib -from typing import Optional +from typing import Optional, Callable from starlette.requests import Request from starlette.responses import Response def default_key_builder( - func, + func: Callable, namespace: Optional[str] = "", request: Optional[Request] = None, response: Optional[Response] = None, args: Optional[tuple] = None, kwargs: Optional[dict] = None, -): +) -> str: from fastapi_cache import FastAPICache prefix = f"{FastAPICache.get_prefix()}:{namespace}:" From 68ef94f2dbe03802af791040095b36f1b7f2ae68 Mon Sep 17 00:00:00 2001 From: Ivan Moiseev Date: Sat, 22 Oct 2022 21:05:43 +0400 Subject: [PATCH 2/6] feat: add more asserts for FastAPICache init --- fastapi_cache/__init__.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/fastapi_cache/__init__.py b/fastapi_cache/__init__.py index eaa65f4..d9acd9e 100644 --- a/fastapi_cache/__init__.py +++ b/fastapi_cache/__init__.py @@ -40,7 +40,8 @@ class FastAPICache: return cls._backend @classmethod - def get_prefix(cls) -> Optional[str]: + def get_prefix(cls) -> str: + assert cls._prefix, "You must call init first!" # nosec: B101 return cls._prefix @classmethod @@ -48,11 +49,13 @@ class FastAPICache: return cls._expire @classmethod - def get_coder(cls) -> Optional[Type[Coder]]: + def get_coder(cls) -> Type[Coder]: + assert cls._coder, "You must call init first!" # nosec: B101 return cls._coder @classmethod - def get_key_builder(cls) -> Optional[Callable]: + def get_key_builder(cls) -> Callable: + assert cls._key_builder, "You must call init first!" # nosec: B101 return cls._key_builder @classmethod @@ -61,5 +64,6 @@ class FastAPICache: @classmethod async def clear(cls, namespace: Optional[str] = None, key: Optional[str] = None) -> int: + assert cls._backend and cls._prefix, "You must call init first!" # nosec: B101 namespace = cls._prefix + (":" + namespace if namespace else "") return await cls._backend.clear(namespace, key) From e842d6408e78e17f7ae96b075e3039dd3315fcf1 Mon Sep 17 00:00:00 2001 From: Ivan Moiseev Date: Sat, 22 Oct 2022 21:06:38 +0400 Subject: [PATCH 3/6] feat: make PickleCoder compatible with backends --- fastapi_cache/coder.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fastapi_cache/coder.py b/fastapi_cache/coder.py index 5d7e4a5..3bc2aa0 100644 --- a/fastapi_cache/coder.py +++ b/fastapi_cache/coder.py @@ -39,7 +39,7 @@ def object_hook(obj: Any) -> Any: class Coder: @classmethod - def encode(cls, value: Any) -> Union[str, bytes]: + def encode(cls, value: Any) -> str: raise NotImplementedError @classmethod @@ -59,9 +59,9 @@ class JsonCoder(Coder): class PickleCoder(Coder): @classmethod - def encode(cls, value: Any) -> Union[str, bytes]: - return pickle.dumps(value) + def encode(cls, value: Any) -> str: + return str(pickle.dumps(value)) @classmethod def decode(cls, value: Any) -> Any: - return pickle.loads(value) # nosec:B403,B301 + return pickle.loads(bytes(value)) # nosec:B403,B301 From c6bd8483a402af75d25b840f53534574d5df2acd Mon Sep 17 00:00:00 2001 From: Ivan Moiseev Date: Sat, 22 Oct 2022 21:12:04 +0400 Subject: [PATCH 4/6] feat: fix tests and add FastAPICache init in tests. --- fastapi_cache/__init__.py | 14 ++++++++++++-- tests/test_decorator.py | 19 ++++++++++++++++--- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/fastapi_cache/__init__.py b/fastapi_cache/__init__.py index d9acd9e..b64eefe 100644 --- a/fastapi_cache/__init__.py +++ b/fastapi_cache/__init__.py @@ -34,6 +34,16 @@ class FastAPICache: cls._key_builder = key_builder cls._enable = enable + @classmethod + def reset(cls) -> None: + cls._init = False + cls._backend = None + cls._prefix = None + cls._expire = None + cls._coder = None + cls._key_builder = None + cls._enable = True + @classmethod def get_backend(cls) -> Backend: assert cls._backend, "You must call init first!" # nosec: B101 @@ -41,7 +51,7 @@ class FastAPICache: @classmethod def get_prefix(cls) -> str: - assert cls._prefix, "You must call init first!" # nosec: B101 + assert cls._prefix is not None, "You must call init first!" # nosec: B101 return cls._prefix @classmethod @@ -64,6 +74,6 @@ class FastAPICache: @classmethod async def clear(cls, namespace: Optional[str] = None, key: Optional[str] = None) -> int: - assert cls._backend and cls._prefix, "You must call init first!" # nosec: B101 + assert cls._backend and cls._prefix is not None, "You must call init first!" # nosec: B101 namespace = cls._prefix + (":" + namespace if namespace else "") return await cls._backend.clear(namespace, key) diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 29db05d..6d54aef 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -1,13 +1,24 @@ import time +from typing import Generator import pendulum +import pytest + from fastapi_cache import FastAPICache from starlette.testclient import TestClient from examples.in_memory.main import app +from fastapi_cache.backends.inmemory import InMemoryBackend -def test_datetime(): +@pytest.fixture(autouse=True) +def init_cache() -> Generator: + FastAPICache.init(InMemoryBackend()) + yield + FastAPICache.reset() + + +def test_datetime() -> None: with TestClient(app) as client: response = client.get("/datetime") now = response.json().get("now") @@ -23,7 +34,8 @@ def test_datetime(): assert now != now_ assert now == pendulum.now().replace(microsecond=0) -def test_date(): + +def test_date() -> None: """Test path function without request or response arguments.""" with TestClient(app) as client: @@ -40,7 +52,8 @@ def test_date(): assert pendulum.parse(response.json()) == pendulum.today() FastAPICache._enable = True -def test_sync(): + +def test_sync() -> None: """Ensure that sync function support works.""" with TestClient(app) as client: response = client.get("/sync-me") From 71a77f6b39b06d341afb6f6e31c53b4ab4ad0e34 Mon Sep 17 00:00:00 2001 From: Ivan Moiseev Date: Sun, 30 Oct 2022 11:03:16 +0400 Subject: [PATCH 5/6] fix: request and response type hints --- fastapi_cache/decorator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fastapi_cache/decorator.py b/fastapi_cache/decorator.py index be50d30..8a5e0a6 100644 --- a/fastapi_cache/decorator.py +++ b/fastapi_cache/decorator.py @@ -1,7 +1,8 @@ import inspect import sys from functools import wraps -from typing import Any, Awaitable, Callable, Optional, TypeVar +from typing import Any, Awaitable, Callable, Optional, TypeVar, Type + if sys.version_info >= (3, 10): from typing import ParamSpec else: @@ -21,7 +22,7 @@ R = TypeVar("R") def cache( expire: Optional[int] = None, - coder: Optional[Coder] = None, + coder: Optional[Type[Coder]] = None, key_builder: Optional[Callable[..., Any]] = None, namespace: Optional[str] = "", ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: @@ -92,10 +93,9 @@ def cache( # see above why we have to await even although caller also awaits. return await run_in_threadpool(func, *args, **kwargs) - copy_kwargs = kwargs.copy() - request = copy_kwargs.pop("request", None) - response = copy_kwargs.pop("response", None) + request: Optional[Request] = copy_kwargs.pop("request", None) + response: Optional[Response] = copy_kwargs.pop("response", None) if ( request and request.headers.get("Cache-Control") == "no-store" ) or not FastAPICache.get_enable(): From 10f819483c4e25dc01a699c6f1c053bf2c5752f8 Mon Sep 17 00:00:00 2001 From: Ivan Moiseev Date: Thu, 3 Nov 2022 15:49:58 +0400 Subject: [PATCH 6/6] fix: replace pipe for Optional --- fastapi_cache/backends/inmemory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastapi_cache/backends/inmemory.py b/fastapi_cache/backends/inmemory.py index 94a4e09..4b05f06 100644 --- a/fastapi_cache/backends/inmemory.py +++ b/fastapi_cache/backends/inmemory.py @@ -20,7 +20,7 @@ class InMemoryBackend(Backend): def _now(self) -> int: return int(time.time()) - def _get(self, key: str) -> Value | None: + def _get(self, key: str) -> Optional[Value]: v = self._store.get(key) if v: if v.ttl_ts < self._now: