diff --git a/README.md b/README.md index 4f0b31d..ce038b0 100644 --- a/README.md +++ b/README.md @@ -121,7 +121,7 @@ take effect globally. ```python def my_key_builder( func, - namespace: Optional[str] = "", + namespace: str = "", request: Request = None, response: Response = None, *args, diff --git a/fastapi_cache/__init__.py b/fastapi_cache/__init__.py index b64eefe..2393916 100644 --- a/fastapi_cache/__init__.py +++ b/fastapi_cache/__init__.py @@ -1,18 +1,19 @@ -from typing import Callable, Optional, Type +from typing import ClassVar, Optional, Type from fastapi_cache.backends import Backend from fastapi_cache.coder import Coder, JsonCoder from fastapi_cache.key_builder import default_key_builder +from fastapi_cache.types import KeyBuilder class FastAPICache: - _backend: Optional[Backend] = None - _prefix: Optional[str] = None - _expire: Optional[int] = None - _init = False - _coder: Optional[Type[Coder]] = None - _key_builder: Optional[Callable] = None - _enable = True + _backend: ClassVar[Optional[Backend]] = None + _prefix: ClassVar[Optional[str]] = None + _expire: ClassVar[Optional[int]] = None + _init: ClassVar[bool] = False + _coder: ClassVar[Optional[Type[Coder]]] = None + _key_builder: ClassVar[Optional[KeyBuilder]] = None + _enable: ClassVar[bool] = True @classmethod def init( @@ -21,7 +22,7 @@ class FastAPICache: prefix: str = "", expire: Optional[int] = None, coder: Type[Coder] = JsonCoder, - key_builder: Callable = default_key_builder, + key_builder: KeyBuilder = default_key_builder, enable: bool = True, ) -> None: if cls._init: @@ -64,7 +65,7 @@ class FastAPICache: return cls._coder @classmethod - def get_key_builder(cls) -> Callable: + def get_key_builder(cls) -> KeyBuilder: assert cls._key_builder, "You must call init first!" # nosec: B101 return cls._key_builder diff --git a/fastapi_cache/backends/dynamodb.py b/fastapi_cache/backends/dynamodb.py index 46df6d2..34cc366 100644 --- a/fastapi_cache/backends/dynamodb.py +++ b/fastapi_cache/backends/dynamodb.py @@ -2,7 +2,7 @@ import datetime from typing import Optional, Tuple from aiobotocore.client import AioBaseClient -from aiobotocore.session import get_session +from aiobotocore.session import get_session, AioSession from fastapi_cache.backends import Backend @@ -26,7 +26,7 @@ class DynamoBackend(Backend): """ def __init__(self, table_name: str, region: Optional[str] = None) -> None: - self.session = get_session() + self.session: AioSession = get_session() self.client: Optional[AioBaseClient] = None # Needs async init self.table_name = table_name self.region = region diff --git a/fastapi_cache/backends/redis.py b/fastapi_cache/backends/redis.py index 0bff5c7..a40bde4 100644 --- a/fastapi_cache/backends/redis.py +++ b/fastapi_cache/backends/redis.py @@ -1,15 +1,15 @@ from typing import Optional, Tuple -from redis.asyncio.client import AbstractRedis -from redis.asyncio.cluster import AbstractRedisCluster +from redis.asyncio.client import Redis +from redis.asyncio.cluster import RedisCluster from fastapi_cache.backends import Backend class RedisBackend(Backend): - def __init__(self, redis: AbstractRedis): + def __init__(self, redis: Redis[str] | RedisCluster[str]): self.redis = redis - self.is_cluster = isinstance(redis, AbstractRedisCluster) + self.is_cluster: bool = isinstance(redis, RedisCluster) async def get_with_ttl(self, key: str) -> Tuple[int, str]: async with self.redis.pipeline(transaction=not self.is_cluster) as pipe: diff --git a/fastapi_cache/coder.py b/fastapi_cache/coder.py index f3df837..d698db2 100644 --- a/fastapi_cache/coder.py +++ b/fastapi_cache/coder.py @@ -3,14 +3,14 @@ import datetime import json import pickle # nosec:B403 from decimal import Decimal -from typing import Any +from typing import Any, Callable import pendulum from fastapi.encoders import jsonable_encoder from starlette.responses import JSONResponse from starlette.templating import _TemplateResponse as TemplateResponse -CONVERTERS = { +CONVERTERS: dict[str, Callable[[str], Any]] = { "date": lambda x: pendulum.parse(x, exact=True), "datetime": lambda x: pendulum.parse(x, exact=True), "decimal": Decimal, @@ -35,7 +35,7 @@ def object_hook(obj: Any) -> Any: return obj if _spec_type in CONVERTERS: - return CONVERTERS[_spec_type](obj["val"]) # type: ignore + return CONVERTERS[_spec_type](obj["val"]) else: raise TypeError("Unknown {}".format(_spec_type)) @@ -54,7 +54,7 @@ class JsonCoder(Coder): @classmethod def encode(cls, value: Any) -> str: if isinstance(value, JSONResponse): - return value.body + return value.body.decode() return json.dumps(value, cls=JsonEncoder) @classmethod diff --git a/fastapi_cache/decorator.py b/fastapi_cache/decorator.py index b0601aa..2c06ad2 100644 --- a/fastapi_cache/decorator.py +++ b/fastapi_cache/decorator.py @@ -2,7 +2,7 @@ import inspect import logging import sys from functools import wraps -from typing import Any, Awaitable, Callable, Optional, Type, TypeVar +from typing import Awaitable, Callable, Optional, Type, TypeVar if sys.version_info >= (3, 10): from typing import ParamSpec @@ -15,8 +15,9 @@ from starlette.responses import Response from fastapi_cache import FastAPICache from fastapi_cache.coder import Coder +from fastapi_cache.types import KeyBuilder -logger = logging.getLogger(__name__) +logger: logging.Logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) P = ParamSpec("P") R = TypeVar("R") @@ -25,7 +26,7 @@ R = TypeVar("R") def cache( expire: Optional[int] = None, coder: Optional[Type[Coder]] = None, - key_builder: Optional[Callable[..., Any]] = None, + key_builder: Optional[KeyBuilder] = None, namespace: Optional[str] = "", ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: """ @@ -115,24 +116,17 @@ def cache( key_builder = key_builder or FastAPICache.get_key_builder() backend = FastAPICache.get_backend() - if inspect.iscoroutinefunction(key_builder): - cache_key = await key_builder( - func, - namespace, - request=request, - response=response, - args=args, - kwargs=copy_kwargs, - ) - else: - cache_key = key_builder( - func, - namespace, - request=request, - response=response, - args=args, - kwargs=copy_kwargs, - ) + cache_key = key_builder( + func, + namespace, + request=request, + response=response, + args=args, + kwargs=copy_kwargs, + ) + if inspect.isawaitable(cache_key): + cache_key = await cache_key + try: ttl, ret = await backend.get_with_ttl(cache_key) except Exception: diff --git a/fastapi_cache/key_builder.py b/fastapi_cache/key_builder.py index 852e632..c588889 100644 --- a/fastapi_cache/key_builder.py +++ b/fastapi_cache/key_builder.py @@ -1,17 +1,17 @@ import hashlib -from typing import Callable, Optional +from typing import Any, Callable, Optional from starlette.requests import Request from starlette.responses import Response def default_key_builder( - func: Callable, - namespace: Optional[str] = "", + func: Callable[..., Any], + namespace: str = "", request: Optional[Request] = None, response: Optional[Response] = None, - args: Optional[tuple] = None, - kwargs: Optional[dict] = None, + args: Optional[tuple[Any, ...]] = None, + kwargs: Optional[dict[str, Any]] = None, ) -> str: from fastapi_cache import FastAPICache diff --git a/fastapi_cache/types.py b/fastapi_cache/types.py new file mode 100644 index 0000000..a05a37e --- /dev/null +++ b/fastapi_cache/types.py @@ -0,0 +1,21 @@ +from typing import Any, Awaitable, Callable, Optional, Protocol, Union + +from starlette.requests import Request +from starlette.responses import Response + + +_Func = Callable[..., Any] + + +class KeyBuilder(Protocol): + def __call__( + self, + _function: _Func, + _namespace: str = ..., + *, + request: Optional[Request] = ..., + response: Optional[Response] = ..., + args: Optional[tuple[Any, ...]] = ..., + kwargs: Optional[dict[str, Any]] = ..., + ) -> Union[Awaitable[str], str]: + ...