Merge pull request #122 from mjpieters/type-hinting

Complete type hints
This commit is contained in:
long2ice
2023-04-28 14:28:49 +08:00
committed by GitHub
8 changed files with 63 additions and 47 deletions

View File

@@ -121,7 +121,7 @@ take effect globally.
```python ```python
def my_key_builder( def my_key_builder(
func, func,
namespace: Optional[str] = "", namespace: str = "",
request: Request = None, request: Request = None,
response: Response = None, response: Response = None,
*args, *args,

View File

@@ -1,18 +1,19 @@
from typing import Callable, Optional, Type from typing import ClassVar, Optional, Type
from fastapi_cache.backends import Backend from fastapi_cache.backends import Backend
from fastapi_cache.coder import Coder, JsonCoder from fastapi_cache.coder import Coder, JsonCoder
from fastapi_cache.key_builder import default_key_builder from fastapi_cache.key_builder import default_key_builder
from fastapi_cache.types import KeyBuilder
class FastAPICache: class FastAPICache:
_backend: Optional[Backend] = None _backend: ClassVar[Optional[Backend]] = None
_prefix: Optional[str] = None _prefix: ClassVar[Optional[str]] = None
_expire: Optional[int] = None _expire: ClassVar[Optional[int]] = None
_init = False _init: ClassVar[bool] = False
_coder: Optional[Type[Coder]] = None _coder: ClassVar[Optional[Type[Coder]]] = None
_key_builder: Optional[Callable] = None _key_builder: ClassVar[Optional[KeyBuilder]] = None
_enable = True _enable: ClassVar[bool] = True
@classmethod @classmethod
def init( def init(
@@ -21,7 +22,7 @@ class FastAPICache:
prefix: str = "", prefix: str = "",
expire: Optional[int] = None, expire: Optional[int] = None,
coder: Type[Coder] = JsonCoder, coder: Type[Coder] = JsonCoder,
key_builder: Callable = default_key_builder, key_builder: KeyBuilder = default_key_builder,
enable: bool = True, enable: bool = True,
) -> None: ) -> None:
if cls._init: if cls._init:
@@ -64,7 +65,7 @@ class FastAPICache:
return cls._coder return cls._coder
@classmethod @classmethod
def get_key_builder(cls) -> Callable: def get_key_builder(cls) -> KeyBuilder:
assert cls._key_builder, "You must call init first!" # nosec: B101 assert cls._key_builder, "You must call init first!" # nosec: B101
return cls._key_builder return cls._key_builder

View File

@@ -2,7 +2,7 @@ import datetime
from typing import Optional, Tuple from typing import Optional, Tuple
from aiobotocore.client import AioBaseClient from aiobotocore.client import AioBaseClient
from aiobotocore.session import get_session from aiobotocore.session import get_session, AioSession
from fastapi_cache.backends import Backend from fastapi_cache.backends import Backend
@@ -26,7 +26,7 @@ class DynamoBackend(Backend):
""" """
def __init__(self, table_name: str, region: Optional[str] = None) -> None: def __init__(self, table_name: str, region: Optional[str] = None) -> None:
self.session = get_session() self.session: AioSession = get_session()
self.client: Optional[AioBaseClient] = None # Needs async init self.client: Optional[AioBaseClient] = None # Needs async init
self.table_name = table_name self.table_name = table_name
self.region = region self.region = region

View File

@@ -1,15 +1,15 @@
from typing import Optional, Tuple from typing import Optional, Tuple
from redis.asyncio.client import AbstractRedis from redis.asyncio.client import Redis
from redis.asyncio.cluster import AbstractRedisCluster from redis.asyncio.cluster import RedisCluster
from fastapi_cache.backends import Backend from fastapi_cache.backends import Backend
class RedisBackend(Backend): class RedisBackend(Backend):
def __init__(self, redis: AbstractRedis): def __init__(self, redis: Redis[str] | RedisCluster[str]):
self.redis = redis self.redis = redis
self.is_cluster = isinstance(redis, AbstractRedisCluster) self.is_cluster: bool = isinstance(redis, RedisCluster)
async def get_with_ttl(self, key: str) -> Tuple[int, str]: async def get_with_ttl(self, key: str) -> Tuple[int, str]:
async with self.redis.pipeline(transaction=not self.is_cluster) as pipe: async with self.redis.pipeline(transaction=not self.is_cluster) as pipe:

View File

@@ -3,14 +3,14 @@ import datetime
import json import json
import pickle # nosec:B403 import pickle # nosec:B403
from decimal import Decimal from decimal import Decimal
from typing import Any from typing import Any, Callable
import pendulum import pendulum
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from starlette.responses import JSONResponse from starlette.responses import JSONResponse
from starlette.templating import _TemplateResponse as TemplateResponse from starlette.templating import _TemplateResponse as TemplateResponse
CONVERTERS = { CONVERTERS: dict[str, Callable[[str], Any]] = {
"date": lambda x: pendulum.parse(x, exact=True), "date": lambda x: pendulum.parse(x, exact=True),
"datetime": lambda x: pendulum.parse(x, exact=True), "datetime": lambda x: pendulum.parse(x, exact=True),
"decimal": Decimal, "decimal": Decimal,
@@ -35,7 +35,7 @@ def object_hook(obj: Any) -> Any:
return obj return obj
if _spec_type in CONVERTERS: if _spec_type in CONVERTERS:
return CONVERTERS[_spec_type](obj["val"]) # type: ignore return CONVERTERS[_spec_type](obj["val"])
else: else:
raise TypeError("Unknown {}".format(_spec_type)) raise TypeError("Unknown {}".format(_spec_type))
@@ -54,7 +54,7 @@ class JsonCoder(Coder):
@classmethod @classmethod
def encode(cls, value: Any) -> str: def encode(cls, value: Any) -> str:
if isinstance(value, JSONResponse): if isinstance(value, JSONResponse):
return value.body return value.body.decode()
return json.dumps(value, cls=JsonEncoder) return json.dumps(value, cls=JsonEncoder)
@classmethod @classmethod

View File

@@ -2,7 +2,7 @@ import inspect
import logging import logging
import sys import sys
from functools import wraps from functools import wraps
from typing import Any, Awaitable, Callable, Optional, Type, TypeVar from typing import Awaitable, Callable, Optional, Type, TypeVar
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
from typing import ParamSpec from typing import ParamSpec
@@ -15,8 +15,9 @@ from starlette.responses import Response
from fastapi_cache import FastAPICache from fastapi_cache import FastAPICache
from fastapi_cache.coder import Coder from fastapi_cache.coder import Coder
from fastapi_cache.types import KeyBuilder
logger = logging.getLogger(__name__) logger: logging.Logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler()) logger.addHandler(logging.NullHandler())
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
@@ -25,7 +26,7 @@ R = TypeVar("R")
def cache( def cache(
expire: Optional[int] = None, expire: Optional[int] = None,
coder: Optional[Type[Coder]] = None, coder: Optional[Type[Coder]] = None,
key_builder: Optional[Callable[..., Any]] = None, key_builder: Optional[KeyBuilder] = None,
namespace: Optional[str] = "", namespace: Optional[str] = "",
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
""" """
@@ -115,24 +116,17 @@ def cache(
key_builder = key_builder or FastAPICache.get_key_builder() key_builder = key_builder or FastAPICache.get_key_builder()
backend = FastAPICache.get_backend() backend = FastAPICache.get_backend()
if inspect.iscoroutinefunction(key_builder): cache_key = key_builder(
cache_key = await key_builder( func,
func, namespace,
namespace, request=request,
request=request, response=response,
response=response, args=args,
args=args, kwargs=copy_kwargs,
kwargs=copy_kwargs, )
) if inspect.isawaitable(cache_key):
else: cache_key = await cache_key
cache_key = key_builder(
func,
namespace,
request=request,
response=response,
args=args,
kwargs=copy_kwargs,
)
try: try:
ttl, ret = await backend.get_with_ttl(cache_key) ttl, ret = await backend.get_with_ttl(cache_key)
except Exception: except Exception:

View File

@@ -1,17 +1,17 @@
import hashlib import hashlib
from typing import Callable, Optional from typing import Any, Callable, Optional
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
def default_key_builder( def default_key_builder(
func: Callable, func: Callable[..., Any],
namespace: Optional[str] = "", namespace: str = "",
request: Optional[Request] = None, request: Optional[Request] = None,
response: Optional[Response] = None, response: Optional[Response] = None,
args: Optional[tuple] = None, args: Optional[tuple[Any, ...]] = None,
kwargs: Optional[dict] = None, kwargs: Optional[dict[str, Any]] = None,
) -> str: ) -> str:
from fastapi_cache import FastAPICache from fastapi_cache import FastAPICache

21
fastapi_cache/types.py Normal file
View File

@@ -0,0 +1,21 @@
from typing import Any, Awaitable, Callable, Optional, Protocol, Union
from starlette.requests import Request
from starlette.responses import Response
_Func = Callable[..., Any]
class KeyBuilder(Protocol):
def __call__(
self,
_function: _Func,
_namespace: str = ...,
*,
request: Optional[Request] = ...,
response: Optional[Response] = ...,
args: Optional[tuple[Any, ...]] = ...,
kwargs: Optional[dict[str, Any]] = ...,
) -> Union[Awaitable[str], str]:
...