diff --git a/fastapi_cache/__init__.py b/fastapi_cache/__init__.py index 4248d22..b64eefe 100644 --- a/fastapi_cache/__init__.py +++ b/fastapi_cache/__init__.py @@ -1,28 +1,29 @@ -from typing import Callable +from typing import Callable, Optional, Type +from fastapi_cache.backends import Backend from fastapi_cache.coder import Coder, JsonCoder from fastapi_cache.key_builder import default_key_builder class FastAPICache: - _backend = None - _prefix = None - _expire = None + _backend: Optional[Backend] = None + _prefix: Optional[str] = None + _expire: Optional[int] = None _init = False - _coder = None - _key_builder = None + _coder: Optional[Type[Coder]] = None + _key_builder: Optional[Callable] = None _enable = True @classmethod def init( cls, - backend, + backend: Backend, prefix: str = "", - expire: int = None, - coder: Coder = JsonCoder, + expire: Optional[int] = None, + coder: Type[Coder] = JsonCoder, key_builder: Callable = default_key_builder, enable: bool = True, - ): + ) -> None: if cls._init: return cls._init = True @@ -34,31 +35,45 @@ class FastAPICache: cls._enable = enable @classmethod - def get_backend(cls): + def reset(cls) -> None: + cls._init = False + cls._backend = None + cls._prefix = None + cls._expire = None + cls._coder = None + cls._key_builder = None + cls._enable = True + + @classmethod + def get_backend(cls) -> Backend: assert cls._backend, "You must call init first!" # nosec: B101 return cls._backend @classmethod - def get_prefix(cls): + def get_prefix(cls) -> str: + assert cls._prefix is not None, "You must call init first!" # nosec: B101 return cls._prefix @classmethod - def get_expire(cls): + def get_expire(cls) -> Optional[int]: return cls._expire @classmethod - def get_coder(cls): + def get_coder(cls) -> Type[Coder]: + assert cls._coder, "You must call init first!" # nosec: B101 return cls._coder @classmethod - def get_key_builder(cls): + def get_key_builder(cls) -> Callable: + assert cls._key_builder, "You must call init first!" # nosec: B101 return cls._key_builder @classmethod - def get_enable(cls): + def get_enable(cls) -> bool: return cls._enable @classmethod - async def clear(cls, namespace: str = None, key: str = None): + async def clear(cls, namespace: Optional[str] = None, key: Optional[str] = None) -> int: + assert cls._backend and cls._prefix is not None, "You must call init first!" # nosec: B101 namespace = cls._prefix + (":" + namespace if namespace else "") return await cls._backend.clear(namespace, key) diff --git a/fastapi_cache/backends/__init__.py b/fastapi_cache/backends/__init__.py index 8adaa31..7b3d070 100644 --- a/fastapi_cache/backends/__init__.py +++ b/fastapi_cache/backends/__init__.py @@ -1,20 +1,20 @@ import abc -from typing import Tuple +from typing import Tuple, Optional class Backend: @abc.abstractmethod - async def get_with_ttl(self, key: str) -> Tuple[int, str]: + async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]: raise NotImplementedError @abc.abstractmethod - async def get(self, key: str) -> str: + async def get(self, key: str) -> Optional[str]: raise NotImplementedError @abc.abstractmethod - async def set(self, key: str, value: str, expire: int = None): + async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: raise NotImplementedError @abc.abstractmethod - async def clear(self, namespace: str = None, key: str = None) -> int: + async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: raise NotImplementedError diff --git a/fastapi_cache/backends/dynamodb.py b/fastapi_cache/backends/dynamodb.py index 1a6aa2c..fc3650b 100644 --- a/fastapi_cache/backends/dynamodb.py +++ b/fastapi_cache/backends/dynamodb.py @@ -1,6 +1,7 @@ import datetime -from typing import Tuple +from typing import Tuple, Optional +from aiobotocore.client import AioBaseClient from aiobotocore.session import get_session from fastapi_cache.backends import Backend @@ -24,18 +25,18 @@ class DynamoBackend(Backend): >> FastAPICache.init(dynamodb) """ - def __init__(self, table_name, region=None): + def __init__(self, table_name: str, region: Optional[str] = None) -> None: self.session = get_session() - self.client = None # Needs async init + self.client: Optional[AioBaseClient] = None # Needs async init self.table_name = table_name self.region = region - async def init(self): + async def init(self) -> None: self.client = await self.session.create_client( "dynamodb", region_name=self.region ).__aenter__() - async def close(self): + async def close(self) -> None: self.client = await self.client.__aexit__(None, None, None) async def get_with_ttl(self, key: str) -> Tuple[int, str]: @@ -55,12 +56,12 @@ class DynamoBackend(Backend): return 0, None - async def get(self, key) -> str: + async def get(self, key: str) -> str: response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}}) if "Item" in response: return response["Item"].get("value", {}).get("S") - async def set(self, key: str, value: str, expire: int = None): + async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: ttl = ( { "ttl": { @@ -88,5 +89,5 @@ class DynamoBackend(Backend): }, ) - async def clear(self, namespace: str = None, key: str = None) -> int: + async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: raise NotImplementedError diff --git a/fastapi_cache/backends/inmemory.py b/fastapi_cache/backends/inmemory.py index 4238b24..4b05f06 100644 --- a/fastapi_cache/backends/inmemory.py +++ b/fastapi_cache/backends/inmemory.py @@ -20,13 +20,14 @@ class InMemoryBackend(Backend): def _now(self) -> int: return int(time.time()) - def _get(self, key: str): + def _get(self, key: str) -> Optional[Value]: v = self._store.get(key) if v: if v.ttl_ts < self._now: del self._store[key] else: return v + return None async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]: async with self._lock: @@ -35,17 +36,18 @@ class InMemoryBackend(Backend): return v.ttl_ts - self._now, v.data return 0, None - async def get(self, key: str) -> str: + async def get(self, key: str) -> Optional[str]: async with self._lock: v = self._get(key) if v: return v.data + return None - async def set(self, key: str, value: str, expire: int = None): + async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: async with self._lock: self._store[key] = Value(value, self._now + (expire or 0)) - async def clear(self, namespace: str = None, key: str = None) -> int: + async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: count = 0 if namespace: keys = list(self._store.keys()) diff --git a/fastapi_cache/backends/memcached.py b/fastapi_cache/backends/memcached.py index 8702f70..3c96ef4 100644 --- a/fastapi_cache/backends/memcached.py +++ b/fastapi_cache/backends/memcached.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional from aiomcache import Client @@ -9,14 +9,14 @@ class MemcachedBackend(Backend): def __init__(self, mcache: Client): self.mcache = mcache - async def get_with_ttl(self, key: str) -> Tuple[int, str]: + async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]: return 3600, await self.mcache.get(key.encode()) - async def get(self, key: str): + async def get(self, key: str) -> Optional[str]: return await self.mcache.get(key, key.encode()) - async def set(self, key: str, value: str, expire: int = None): - return await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0) + async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: + await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0) - async def clear(self, namespace: str = None, key: str = None): + async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: raise NotImplementedError diff --git a/fastapi_cache/backends/redis.py b/fastapi_cache/backends/redis.py index b5147dc..20b3538 100644 --- a/fastapi_cache/backends/redis.py +++ b/fastapi_cache/backends/redis.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, Optional from redis.asyncio.client import Redis @@ -13,15 +13,16 @@ class RedisBackend(Backend): async with self.redis.pipeline(transaction=True) as pipe: return await (pipe.ttl(key).get(key).execute()) - async def get(self, key) -> str: + async def get(self, key: str) -> Optional[str]: return await self.redis.get(key) - async def set(self, key: str, value: str, expire: int = None): + async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: return await self.redis.set(key, value, ex=expire) - async def clear(self, namespace: str = None, key: str = None) -> int: + async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: if namespace: lua = f"for i, name in ipairs(redis.call('KEYS', '{namespace}:*')) do redis.call('DEL', name); end" return await self.redis.eval(lua, numkeys=0) elif key: return await self.redis.delete(key) + return 0 diff --git a/fastapi_cache/coder.py b/fastapi_cache/coder.py index daeac83..b6cd683 100644 --- a/fastapi_cache/coder.py +++ b/fastapi_cache/coder.py @@ -2,7 +2,7 @@ import datetime import json import pickle # nosec:B403 from decimal import Decimal -from typing import Any +from typing import Any, Dict, Union import pendulum from fastapi.encoders import jsonable_encoder @@ -16,7 +16,7 @@ CONVERTERS = { class JsonEncoder(json.JSONEncoder): - def default(self, obj): + def default(self, obj: Any) -> Any: if isinstance(obj, datetime.datetime): return {"val": str(obj), "_spec_type": "datetime"} elif isinstance(obj, datetime.date): @@ -27,44 +27,44 @@ class JsonEncoder(json.JSONEncoder): return jsonable_encoder(obj) -def object_hook(obj): +def object_hook(obj: Any) -> Any: _spec_type = obj.get("_spec_type") if not _spec_type: return obj if _spec_type in CONVERTERS: - return CONVERTERS[_spec_type](obj["val"]) + return CONVERTERS[_spec_type](obj["val"]) # type: ignore else: raise TypeError("Unknown {}".format(_spec_type)) class Coder: @classmethod - def encode(cls, value: Any): + def encode(cls, value: Any) -> str: raise NotImplementedError @classmethod - def decode(cls, value: Any): + def decode(cls, value: Any) -> Any: raise NotImplementedError class JsonCoder(Coder): @classmethod - def encode(cls, value: Any): + def encode(cls, value: Any) -> str: return json.dumps(value, cls=JsonEncoder) @classmethod - def decode(cls, value: Any): + def decode(cls, value: Any) -> str: return json.loads(value, object_hook=object_hook) class PickleCoder(Coder): @classmethod - def encode(cls, value: Any): + def encode(cls, value: Any) -> str: if isinstance(value, TemplateResponse): value = value.body - return pickle.dumps(value) + return str(pickle.dumps(value)) @classmethod - def decode(cls, value: Any): - return pickle.loads(value) # nosec:B403,B301 + def decode(cls, value: Any) -> Any: + return pickle.loads(bytes(value)) # nosec:B403,B301 diff --git a/fastapi_cache/decorator.py b/fastapi_cache/decorator.py index 0dd0bc7..0aea981 100644 --- a/fastapi_cache/decorator.py +++ b/fastapi_cache/decorator.py @@ -1,7 +1,8 @@ import inspect import sys from functools import wraps -from typing import Any, Awaitable, Callable, Optional, TypeVar +from typing import Any, Awaitable, Callable, Optional, TypeVar, Type + if sys.version_info >= (3, 10): from typing import ParamSpec else: @@ -21,7 +22,7 @@ R = TypeVar("R") def cache( expire: Optional[int] = None, - coder: Optional[Coder] = None, + coder: Optional[Type[Coder]] = None, key_builder: Optional[Callable[..., Any]] = None, namespace: Optional[str] = "", ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: @@ -92,11 +93,9 @@ def cache( # see above why we have to await even although caller also awaits. return await run_in_threadpool(func, *args, **kwargs) - copy_kwargs = kwargs.copy() - request = copy_kwargs.pop("request", None) - response = copy_kwargs.pop("response", None) - + request: Optional[Request] = copy_kwargs.pop("request", None) + response: Optional[Response] = copy_kwargs.pop("response", None) if ( request and request.headers.get("Cache-Control") in ("no-store", "no-cache") ) or not FastAPICache.get_enable(): diff --git a/fastapi_cache/key_builder.py b/fastapi_cache/key_builder.py index 112f2dc..e751e0b 100644 --- a/fastapi_cache/key_builder.py +++ b/fastapi_cache/key_builder.py @@ -1,18 +1,18 @@ import hashlib -from typing import Optional +from typing import Optional, Callable from starlette.requests import Request from starlette.responses import Response def default_key_builder( - func, + func: Callable, namespace: Optional[str] = "", request: Optional[Request] = None, response: Optional[Response] = None, args: Optional[tuple] = None, kwargs: Optional[dict] = None, -): +) -> str: from fastapi_cache import FastAPICache prefix = f"{FastAPICache.get_prefix()}:{namespace}:" diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 29db05d..6d54aef 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -1,13 +1,24 @@ import time +from typing import Generator import pendulum +import pytest + from fastapi_cache import FastAPICache from starlette.testclient import TestClient from examples.in_memory.main import app +from fastapi_cache.backends.inmemory import InMemoryBackend -def test_datetime(): +@pytest.fixture(autouse=True) +def init_cache() -> Generator: + FastAPICache.init(InMemoryBackend()) + yield + FastAPICache.reset() + + +def test_datetime() -> None: with TestClient(app) as client: response = client.get("/datetime") now = response.json().get("now") @@ -23,7 +34,8 @@ def test_datetime(): assert now != now_ assert now == pendulum.now().replace(microsecond=0) -def test_date(): + +def test_date() -> None: """Test path function without request or response arguments.""" with TestClient(app) as client: @@ -40,7 +52,8 @@ def test_date(): assert pendulum.parse(response.json()) == pendulum.today() FastAPICache._enable = True -def test_sync(): + +def test_sync() -> None: """Ensure that sync function support works.""" with TestClient(app) as client: response = client.get("/sync-me")