Merge pull request #122 from mjpieters/type-hinting

Complete type hints
This commit is contained in:
long2ice
2023-04-28 14:28:49 +08:00
committed by GitHub
8 changed files with 63 additions and 47 deletions

View File

@@ -121,7 +121,7 @@ take effect globally.
```python
def my_key_builder(
func,
namespace: Optional[str] = "",
namespace: str = "",
request: Request = None,
response: Response = None,
*args,

View File

@@ -1,18 +1,19 @@
from typing import Callable, Optional, Type
from typing import ClassVar, Optional, Type
from fastapi_cache.backends import Backend
from fastapi_cache.coder import Coder, JsonCoder
from fastapi_cache.key_builder import default_key_builder
from fastapi_cache.types import KeyBuilder
class FastAPICache:
_backend: Optional[Backend] = None
_prefix: Optional[str] = None
_expire: Optional[int] = None
_init = False
_coder: Optional[Type[Coder]] = None
_key_builder: Optional[Callable] = None
_enable = True
_backend: ClassVar[Optional[Backend]] = None
_prefix: ClassVar[Optional[str]] = None
_expire: ClassVar[Optional[int]] = None
_init: ClassVar[bool] = False
_coder: ClassVar[Optional[Type[Coder]]] = None
_key_builder: ClassVar[Optional[KeyBuilder]] = None
_enable: ClassVar[bool] = True
@classmethod
def init(
@@ -21,7 +22,7 @@ class FastAPICache:
prefix: str = "",
expire: Optional[int] = None,
coder: Type[Coder] = JsonCoder,
key_builder: Callable = default_key_builder,
key_builder: KeyBuilder = default_key_builder,
enable: bool = True,
) -> None:
if cls._init:
@@ -64,7 +65,7 @@ class FastAPICache:
return cls._coder
@classmethod
def get_key_builder(cls) -> Callable:
def get_key_builder(cls) -> KeyBuilder:
assert cls._key_builder, "You must call init first!" # nosec: B101
return cls._key_builder

View File

@@ -2,7 +2,7 @@ import datetime
from typing import Optional, Tuple
from aiobotocore.client import AioBaseClient
from aiobotocore.session import get_session
from aiobotocore.session import get_session, AioSession
from fastapi_cache.backends import Backend
@@ -26,7 +26,7 @@ class DynamoBackend(Backend):
"""
def __init__(self, table_name: str, region: Optional[str] = None) -> None:
self.session = get_session()
self.session: AioSession = get_session()
self.client: Optional[AioBaseClient] = None # Needs async init
self.table_name = table_name
self.region = region

View File

@@ -1,15 +1,15 @@
from typing import Optional, Tuple
from redis.asyncio.client import AbstractRedis
from redis.asyncio.cluster import AbstractRedisCluster
from redis.asyncio.client import Redis
from redis.asyncio.cluster import RedisCluster
from fastapi_cache.backends import Backend
class RedisBackend(Backend):
def __init__(self, redis: AbstractRedis):
def __init__(self, redis: Redis[str] | RedisCluster[str]):
self.redis = redis
self.is_cluster = isinstance(redis, AbstractRedisCluster)
self.is_cluster: bool = isinstance(redis, RedisCluster)
async def get_with_ttl(self, key: str) -> Tuple[int, str]:
async with self.redis.pipeline(transaction=not self.is_cluster) as pipe:

View File

@@ -3,14 +3,14 @@ import datetime
import json
import pickle # nosec:B403
from decimal import Decimal
from typing import Any
from typing import Any, Callable
import pendulum
from fastapi.encoders import jsonable_encoder
from starlette.responses import JSONResponse
from starlette.templating import _TemplateResponse as TemplateResponse
CONVERTERS = {
CONVERTERS: dict[str, Callable[[str], Any]] = {
"date": lambda x: pendulum.parse(x, exact=True),
"datetime": lambda x: pendulum.parse(x, exact=True),
"decimal": Decimal,
@@ -35,7 +35,7 @@ def object_hook(obj: Any) -> Any:
return obj
if _spec_type in CONVERTERS:
return CONVERTERS[_spec_type](obj["val"]) # type: ignore
return CONVERTERS[_spec_type](obj["val"])
else:
raise TypeError("Unknown {}".format(_spec_type))
@@ -54,7 +54,7 @@ class JsonCoder(Coder):
@classmethod
def encode(cls, value: Any) -> str:
if isinstance(value, JSONResponse):
return value.body
return value.body.decode()
return json.dumps(value, cls=JsonEncoder)
@classmethod

View File

@@ -2,7 +2,7 @@ import inspect
import logging
import sys
from functools import wraps
from typing import Any, Awaitable, Callable, Optional, Type, TypeVar
from typing import Awaitable, Callable, Optional, Type, TypeVar
if sys.version_info >= (3, 10):
from typing import ParamSpec
@@ -15,8 +15,9 @@ from starlette.responses import Response
from fastapi_cache import FastAPICache
from fastapi_cache.coder import Coder
from fastapi_cache.types import KeyBuilder
logger = logging.getLogger(__name__)
logger: logging.Logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
P = ParamSpec("P")
R = TypeVar("R")
@@ -25,7 +26,7 @@ R = TypeVar("R")
def cache(
expire: Optional[int] = None,
coder: Optional[Type[Coder]] = None,
key_builder: Optional[Callable[..., Any]] = None,
key_builder: Optional[KeyBuilder] = None,
namespace: Optional[str] = "",
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
"""
@@ -115,24 +116,17 @@ def cache(
key_builder = key_builder or FastAPICache.get_key_builder()
backend = FastAPICache.get_backend()
if inspect.iscoroutinefunction(key_builder):
cache_key = await key_builder(
func,
namespace,
request=request,
response=response,
args=args,
kwargs=copy_kwargs,
)
else:
cache_key = key_builder(
func,
namespace,
request=request,
response=response,
args=args,
kwargs=copy_kwargs,
)
cache_key = key_builder(
func,
namespace,
request=request,
response=response,
args=args,
kwargs=copy_kwargs,
)
if inspect.isawaitable(cache_key):
cache_key = await cache_key
try:
ttl, ret = await backend.get_with_ttl(cache_key)
except Exception:

View File

@@ -1,17 +1,17 @@
import hashlib
from typing import Callable, Optional
from typing import Any, Callable, Optional
from starlette.requests import Request
from starlette.responses import Response
def default_key_builder(
func: Callable,
namespace: Optional[str] = "",
func: Callable[..., Any],
namespace: str = "",
request: Optional[Request] = None,
response: Optional[Response] = None,
args: Optional[tuple] = None,
kwargs: Optional[dict] = None,
args: Optional[tuple[Any, ...]] = None,
kwargs: Optional[dict[str, Any]] = None,
) -> str:
from fastapi_cache import FastAPICache

21
fastapi_cache/types.py Normal file
View File

@@ -0,0 +1,21 @@
from typing import Any, Awaitable, Callable, Optional, Protocol, Union
from starlette.requests import Request
from starlette.responses import Response
_Func = Callable[..., Any]
class KeyBuilder(Protocol):
def __call__(
self,
_function: _Func,
_namespace: str = ...,
*,
request: Optional[Request] = ...,
response: Optional[Response] = ...,
args: Optional[tuple[Any, ...]] = ...,
kwargs: Optional[dict[str, Any]] = ...,
) -> Union[Awaitable[str], str]:
...