Merge pull request #93 from Mrreadiness/feat/type-hints-covering

Feat/type hints covering
This commit is contained in:
long2ice
2022-11-04 08:51:21 +08:00
committed by GitHub
10 changed files with 99 additions and 68 deletions

View File

@@ -1,28 +1,29 @@
from typing import Callable from typing import Callable, Optional, Type
from fastapi_cache.backends import Backend
from fastapi_cache.coder import Coder, JsonCoder from fastapi_cache.coder import Coder, JsonCoder
from fastapi_cache.key_builder import default_key_builder from fastapi_cache.key_builder import default_key_builder
class FastAPICache: class FastAPICache:
_backend = None _backend: Optional[Backend] = None
_prefix = None _prefix: Optional[str] = None
_expire = None _expire: Optional[int] = None
_init = False _init = False
_coder = None _coder: Optional[Type[Coder]] = None
_key_builder = None _key_builder: Optional[Callable] = None
_enable = True _enable = True
@classmethod @classmethod
def init( def init(
cls, cls,
backend, backend: Backend,
prefix: str = "", prefix: str = "",
expire: int = None, expire: Optional[int] = None,
coder: Coder = JsonCoder, coder: Type[Coder] = JsonCoder,
key_builder: Callable = default_key_builder, key_builder: Callable = default_key_builder,
enable: bool = True, enable: bool = True,
): ) -> None:
if cls._init: if cls._init:
return return
cls._init = True cls._init = True
@@ -34,31 +35,45 @@ class FastAPICache:
cls._enable = enable cls._enable = enable
@classmethod @classmethod
def get_backend(cls): def reset(cls) -> None:
cls._init = False
cls._backend = None
cls._prefix = None
cls._expire = None
cls._coder = None
cls._key_builder = None
cls._enable = True
@classmethod
def get_backend(cls) -> Backend:
assert cls._backend, "You must call init first!" # nosec: B101 assert cls._backend, "You must call init first!" # nosec: B101
return cls._backend return cls._backend
@classmethod @classmethod
def get_prefix(cls): def get_prefix(cls) -> str:
assert cls._prefix is not None, "You must call init first!" # nosec: B101
return cls._prefix return cls._prefix
@classmethod @classmethod
def get_expire(cls): def get_expire(cls) -> Optional[int]:
return cls._expire return cls._expire
@classmethod @classmethod
def get_coder(cls): def get_coder(cls) -> Type[Coder]:
assert cls._coder, "You must call init first!" # nosec: B101
return cls._coder return cls._coder
@classmethod @classmethod
def get_key_builder(cls): def get_key_builder(cls) -> Callable:
assert cls._key_builder, "You must call init first!" # nosec: B101
return cls._key_builder return cls._key_builder
@classmethod @classmethod
def get_enable(cls): def get_enable(cls) -> bool:
return cls._enable return cls._enable
@classmethod @classmethod
async def clear(cls, namespace: str = None, key: str = None): async def clear(cls, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
assert cls._backend and cls._prefix is not None, "You must call init first!" # nosec: B101
namespace = cls._prefix + (":" + namespace if namespace else "") namespace = cls._prefix + (":" + namespace if namespace else "")
return await cls._backend.clear(namespace, key) return await cls._backend.clear(namespace, key)

View File

@@ -1,20 +1,20 @@
import abc import abc
from typing import Tuple from typing import Tuple, Optional
class Backend: class Backend:
@abc.abstractmethod @abc.abstractmethod
async def get_with_ttl(self, key: str) -> Tuple[int, str]: async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def get(self, key: str) -> str: async def get(self, key: str) -> Optional[str]:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def set(self, key: str, value: str, expire: int = None): async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
async def clear(self, namespace: str = None, key: str = None) -> int: async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
raise NotImplementedError raise NotImplementedError

View File

@@ -1,6 +1,7 @@
import datetime import datetime
from typing import Tuple from typing import Tuple, Optional
from aiobotocore.client import AioBaseClient
from aiobotocore.session import get_session from aiobotocore.session import get_session
from fastapi_cache.backends import Backend from fastapi_cache.backends import Backend
@@ -24,18 +25,18 @@ class DynamoBackend(Backend):
>> FastAPICache.init(dynamodb) >> FastAPICache.init(dynamodb)
""" """
def __init__(self, table_name, region=None): def __init__(self, table_name: str, region: Optional[str] = None) -> None:
self.session = get_session() self.session = get_session()
self.client = None # Needs async init self.client: Optional[AioBaseClient] = None # Needs async init
self.table_name = table_name self.table_name = table_name
self.region = region self.region = region
async def init(self): async def init(self) -> None:
self.client = await self.session.create_client( self.client = await self.session.create_client(
"dynamodb", region_name=self.region "dynamodb", region_name=self.region
).__aenter__() ).__aenter__()
async def close(self): async def close(self) -> None:
self.client = await self.client.__aexit__(None, None, None) self.client = await self.client.__aexit__(None, None, None)
async def get_with_ttl(self, key: str) -> Tuple[int, str]: async def get_with_ttl(self, key: str) -> Tuple[int, str]:
@@ -55,12 +56,12 @@ class DynamoBackend(Backend):
return 0, None return 0, None
async def get(self, key) -> str: async def get(self, key: str) -> str:
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}}) response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
if "Item" in response: if "Item" in response:
return response["Item"].get("value", {}).get("S") return response["Item"].get("value", {}).get("S")
async def set(self, key: str, value: str, expire: int = None): async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
ttl = ( ttl = (
{ {
"ttl": { "ttl": {
@@ -88,5 +89,5 @@ class DynamoBackend(Backend):
}, },
) )
async def clear(self, namespace: str = None, key: str = None) -> int: async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
raise NotImplementedError raise NotImplementedError

View File

@@ -20,13 +20,14 @@ class InMemoryBackend(Backend):
def _now(self) -> int: def _now(self) -> int:
return int(time.time()) return int(time.time())
def _get(self, key: str): def _get(self, key: str) -> Optional[Value]:
v = self._store.get(key) v = self._store.get(key)
if v: if v:
if v.ttl_ts < self._now: if v.ttl_ts < self._now:
del self._store[key] del self._store[key]
else: else:
return v return v
return None
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]: async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
async with self._lock: async with self._lock:
@@ -35,17 +36,18 @@ class InMemoryBackend(Backend):
return v.ttl_ts - self._now, v.data return v.ttl_ts - self._now, v.data
return 0, None return 0, None
async def get(self, key: str) -> str: async def get(self, key: str) -> Optional[str]:
async with self._lock: async with self._lock:
v = self._get(key) v = self._get(key)
if v: if v:
return v.data return v.data
return None
async def set(self, key: str, value: str, expire: int = None): async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
async with self._lock: async with self._lock:
self._store[key] = Value(value, self._now + (expire or 0)) self._store[key] = Value(value, self._now + (expire or 0))
async def clear(self, namespace: str = None, key: str = None) -> int: async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
count = 0 count = 0
if namespace: if namespace:
keys = list(self._store.keys()) keys = list(self._store.keys())

View File

@@ -1,4 +1,4 @@
from typing import Tuple from typing import Tuple, Optional
from aiomcache import Client from aiomcache import Client
@@ -9,14 +9,14 @@ class MemcachedBackend(Backend):
def __init__(self, mcache: Client): def __init__(self, mcache: Client):
self.mcache = mcache self.mcache = mcache
async def get_with_ttl(self, key: str) -> Tuple[int, str]: async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
return 3600, await self.mcache.get(key.encode()) return 3600, await self.mcache.get(key.encode())
async def get(self, key: str): async def get(self, key: str) -> Optional[str]:
return await self.mcache.get(key, key.encode()) return await self.mcache.get(key, key.encode())
async def set(self, key: str, value: str, expire: int = None): async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
return await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0) await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0)
async def clear(self, namespace: str = None, key: str = None): async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
raise NotImplementedError raise NotImplementedError

View File

@@ -1,4 +1,4 @@
from typing import Tuple from typing import Tuple, Optional
from redis.asyncio.client import Redis from redis.asyncio.client import Redis
@@ -13,15 +13,16 @@ class RedisBackend(Backend):
async with self.redis.pipeline(transaction=True) as pipe: async with self.redis.pipeline(transaction=True) as pipe:
return await (pipe.ttl(key).get(key).execute()) return await (pipe.ttl(key).get(key).execute())
async def get(self, key) -> str: async def get(self, key: str) -> Optional[str]:
return await self.redis.get(key) return await self.redis.get(key)
async def set(self, key: str, value: str, expire: int = None): async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
return await self.redis.set(key, value, ex=expire) return await self.redis.set(key, value, ex=expire)
async def clear(self, namespace: str = None, key: str = None) -> int: async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
if namespace: if namespace:
lua = f"for i, name in ipairs(redis.call('KEYS', '{namespace}:*')) do redis.call('DEL', name); end" lua = f"for i, name in ipairs(redis.call('KEYS', '{namespace}:*')) do redis.call('DEL', name); end"
return await self.redis.eval(lua, numkeys=0) return await self.redis.eval(lua, numkeys=0)
elif key: elif key:
return await self.redis.delete(key) return await self.redis.delete(key)
return 0

View File

@@ -2,7 +2,7 @@ import datetime
import json import json
import pickle # nosec:B403 import pickle # nosec:B403
from decimal import Decimal from decimal import Decimal
from typing import Any from typing import Any, Dict, Union
import pendulum import pendulum
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
@@ -16,7 +16,7 @@ CONVERTERS = {
class JsonEncoder(json.JSONEncoder): class JsonEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj: Any) -> Any:
if isinstance(obj, datetime.datetime): if isinstance(obj, datetime.datetime):
return {"val": str(obj), "_spec_type": "datetime"} return {"val": str(obj), "_spec_type": "datetime"}
elif isinstance(obj, datetime.date): elif isinstance(obj, datetime.date):
@@ -27,44 +27,44 @@ class JsonEncoder(json.JSONEncoder):
return jsonable_encoder(obj) return jsonable_encoder(obj)
def object_hook(obj): def object_hook(obj: Any) -> Any:
_spec_type = obj.get("_spec_type") _spec_type = obj.get("_spec_type")
if not _spec_type: if not _spec_type:
return obj return obj
if _spec_type in CONVERTERS: if _spec_type in CONVERTERS:
return CONVERTERS[_spec_type](obj["val"]) return CONVERTERS[_spec_type](obj["val"]) # type: ignore
else: else:
raise TypeError("Unknown {}".format(_spec_type)) raise TypeError("Unknown {}".format(_spec_type))
class Coder: class Coder:
@classmethod @classmethod
def encode(cls, value: Any): def encode(cls, value: Any) -> str:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def decode(cls, value: Any): def decode(cls, value: Any) -> Any:
raise NotImplementedError raise NotImplementedError
class JsonCoder(Coder): class JsonCoder(Coder):
@classmethod @classmethod
def encode(cls, value: Any): def encode(cls, value: Any) -> str:
return json.dumps(value, cls=JsonEncoder) return json.dumps(value, cls=JsonEncoder)
@classmethod @classmethod
def decode(cls, value: Any): def decode(cls, value: Any) -> str:
return json.loads(value, object_hook=object_hook) return json.loads(value, object_hook=object_hook)
class PickleCoder(Coder): class PickleCoder(Coder):
@classmethod @classmethod
def encode(cls, value: Any): def encode(cls, value: Any) -> str:
if isinstance(value, TemplateResponse): if isinstance(value, TemplateResponse):
value = value.body value = value.body
return pickle.dumps(value) return str(pickle.dumps(value))
@classmethod @classmethod
def decode(cls, value: Any): def decode(cls, value: Any) -> Any:
return pickle.loads(value) # nosec:B403,B301 return pickle.loads(bytes(value)) # nosec:B403,B301

View File

@@ -1,7 +1,8 @@
import inspect import inspect
import sys import sys
from functools import wraps from functools import wraps
from typing import Any, Awaitable, Callable, Optional, TypeVar from typing import Any, Awaitable, Callable, Optional, TypeVar, Type
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
from typing import ParamSpec from typing import ParamSpec
else: else:
@@ -21,7 +22,7 @@ R = TypeVar("R")
def cache( def cache(
expire: Optional[int] = None, expire: Optional[int] = None,
coder: Optional[Coder] = None, coder: Optional[Type[Coder]] = None,
key_builder: Optional[Callable[..., Any]] = None, key_builder: Optional[Callable[..., Any]] = None,
namespace: Optional[str] = "", namespace: Optional[str] = "",
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: ) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
@@ -92,11 +93,9 @@ def cache(
# see above why we have to await even although caller also awaits. # see above why we have to await even although caller also awaits.
return await run_in_threadpool(func, *args, **kwargs) return await run_in_threadpool(func, *args, **kwargs)
copy_kwargs = kwargs.copy() copy_kwargs = kwargs.copy()
request = copy_kwargs.pop("request", None) request: Optional[Request] = copy_kwargs.pop("request", None)
response = copy_kwargs.pop("response", None) response: Optional[Response] = copy_kwargs.pop("response", None)
if ( if (
request and request.headers.get("Cache-Control") in ("no-store", "no-cache") request and request.headers.get("Cache-Control") in ("no-store", "no-cache")
) or not FastAPICache.get_enable(): ) or not FastAPICache.get_enable():

View File

@@ -1,18 +1,18 @@
import hashlib import hashlib
from typing import Optional from typing import Optional, Callable
from starlette.requests import Request from starlette.requests import Request
from starlette.responses import Response from starlette.responses import Response
def default_key_builder( def default_key_builder(
func, func: Callable,
namespace: Optional[str] = "", namespace: Optional[str] = "",
request: Optional[Request] = None, request: Optional[Request] = None,
response: Optional[Response] = None, response: Optional[Response] = None,
args: Optional[tuple] = None, args: Optional[tuple] = None,
kwargs: Optional[dict] = None, kwargs: Optional[dict] = None,
): ) -> str:
from fastapi_cache import FastAPICache from fastapi_cache import FastAPICache
prefix = f"{FastAPICache.get_prefix()}:{namespace}:" prefix = f"{FastAPICache.get_prefix()}:{namespace}:"

View File

@@ -1,13 +1,24 @@
import time import time
from typing import Generator
import pendulum import pendulum
import pytest
from fastapi_cache import FastAPICache from fastapi_cache import FastAPICache
from starlette.testclient import TestClient from starlette.testclient import TestClient
from examples.in_memory.main import app from examples.in_memory.main import app
from fastapi_cache.backends.inmemory import InMemoryBackend
def test_datetime(): @pytest.fixture(autouse=True)
def init_cache() -> Generator:
FastAPICache.init(InMemoryBackend())
yield
FastAPICache.reset()
def test_datetime() -> None:
with TestClient(app) as client: with TestClient(app) as client:
response = client.get("/datetime") response = client.get("/datetime")
now = response.json().get("now") now = response.json().get("now")
@@ -23,7 +34,8 @@ def test_datetime():
assert now != now_ assert now != now_
assert now == pendulum.now().replace(microsecond=0) assert now == pendulum.now().replace(microsecond=0)
def test_date():
def test_date() -> None:
"""Test path function without request or response arguments.""" """Test path function without request or response arguments."""
with TestClient(app) as client: with TestClient(app) as client:
@@ -40,7 +52,8 @@ def test_date():
assert pendulum.parse(response.json()) == pendulum.today() assert pendulum.parse(response.json()) == pendulum.today()
FastAPICache._enable = True FastAPICache._enable = True
def test_sync():
def test_sync() -> None:
"""Ensure that sync function support works.""" """Ensure that sync function support works."""
with TestClient(app) as client: with TestClient(app) as client:
response = client.get("/sync-me") response = client.get("/sync-me")