Merge branch 'master' into master

This commit is contained in:
Lucca Marques
2022-08-08 12:40:33 -03:00
committed by GitHub
18 changed files with 996 additions and 459 deletions

View File

@@ -11,6 +11,7 @@ class FastAPICache:
_init = False
_coder = None
_key_builder = None
_enable = True
@classmethod
def init(
@@ -20,6 +21,7 @@ class FastAPICache:
expire: int = None,
coder: Coder = JsonCoder,
key_builder: Callable = default_key_builder,
enable: bool = True,
):
if cls._init:
return
@@ -29,6 +31,7 @@ class FastAPICache:
cls._expire = expire
cls._coder = coder
cls._key_builder = key_builder
cls._enable = enable
@classmethod
def get_backend(cls):
@@ -51,6 +54,10 @@ class FastAPICache:
def get_key_builder(cls):
return cls._key_builder
@classmethod
def get_enable(cls):
return cls._enable
@classmethod
async def clear(cls, namespace: str = None, key: str = None):
namespace = cls._prefix + ":" + namespace if namespace else None

View File

@@ -0,0 +1,92 @@
import datetime
from typing import Tuple
from aiobotocore.session import get_session
from fastapi_cache.backends import Backend
class DynamoBackend(Backend):
"""
Amazon DynamoDB backend provider
This backend requires an existing table within your AWS environment to be passed during
backend init. If ttl is going to be used, this needs to be manually enabled on the table
using the `ttl` key. Dynamo will take care of deleting outdated objects, but this is not
instant so don't be alarmed when they linger around for a bit.
As with all AWS clients, credentials will be taken from the environment. Check the AWS SDK
for more information.
Usage:
>> dynamodb = DynamoBackend(table_name="your-cache", region="eu-west-1")
>> await dynamodb.init()
>> FastAPICache.init(dynamodb)
"""
def __init__(self, table_name, region=None):
self.session = get_session()
self.client = None # Needs async init
self.table_name = table_name
self.region = region
async def init(self):
self.client = await self.session.create_client(
"dynamodb", region_name=self.region
).__aenter__()
async def close(self):
self.client = await self.client.__aexit__(None, None, None)
async def get_with_ttl(self, key: str) -> Tuple[int, str]:
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
if "Item" in response:
value = response["Item"].get("value", {}).get("S")
ttl = response["Item"].get("ttl", {}).get("N")
if not ttl:
return -1, value
# It's only eventually consistent so we need to check ourselves
expire = int(ttl) - int(datetime.datetime.now().timestamp())
if expire > 0:
return expire, value
return 0, None
async def get(self, key) -> str:
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
if "Item" in response:
return response["Item"].get("value", {}).get("S")
async def set(self, key: str, value: str, expire: int = None):
ttl = (
{
"ttl": {
"N": str(
int(
(
datetime.datetime.now() + datetime.timedelta(seconds=expire)
).timestamp()
)
)
}
}
if expire
else {}
)
await self.client.put_item(
TableName=self.table_name,
Item={
**{
"key": {"S": key},
"value": {"S": value},
},
**ttl,
},
)
async def clear(self, namespace: str = None, key: str = None) -> int:
raise NotImplementedError

View File

@@ -43,7 +43,7 @@ class InMemoryBackend(Backend):
async def set(self, key: str, value: str, expire: int = None):
async with self._lock:
self._store[key] = Value(value, self._now + expire)
self._store[key] = Value(value, self._now + (expire or 0))
async def clear(self, namespace: str = None, key: str = None) -> int:
count = 0

View File

@@ -1,6 +1,6 @@
from typing import Tuple
from aioredis import Redis
from redis.asyncio.client import Redis
from fastapi_cache.backends import Backend

View File

@@ -4,12 +4,12 @@ import pickle # nosec:B403
from decimal import Decimal
from typing import Any
import dateutil.parser
import pendulum
from fastapi.encoders import jsonable_encoder
CONVERTERS = {
"date": dateutil.parser.parse,
"datetime": dateutil.parser.parse,
"date": lambda x: pendulum.parse(x, exact=True),
"datetime": lambda x: pendulum.parse(x, exact=True),
"decimal": Decimal,
}
@@ -17,12 +17,9 @@ CONVERTERS = {
class JsonEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, datetime.datetime):
if obj.tzinfo:
return {"val": obj.strftime("%Y-%m-%d %H:%M:%S%z"), "_spec_type": "datetime"}
else:
return {"val": obj.strftime("%Y-%m-%d %H:%M:%S"), "_spec_type": "datetime"}
return {"val": str(obj), "_spec_type": "datetime"}
elif isinstance(obj, datetime.date):
return {"val": obj.strftime("%Y-%m-%d"), "_spec_type": "date"}
return {"val": str(obj), "_spec_type": "date"}
elif isinstance(obj, Decimal):
return {"val": str(obj), "_spec_type": "decimal"}
else:
@@ -67,4 +64,4 @@ class PickleCoder(Coder):
@classmethod
def decode(cls, value: Any):
return pickle.loads(value) # nosec:B403
return pickle.loads(value) # nosec:B403,B301

View File

@@ -1,15 +1,21 @@
from functools import wraps
from typing import Callable, Optional, Type
import asyncio
from functools import wraps, partial
import inspect
from typing import TYPE_CHECKING, Callable, Optional, Type
from fastapi_cache import FastAPICache
from fastapi_cache.coder import Coder
if TYPE_CHECKING:
import concurrent.futures
def cache(
expire: int = None,
coder: Type[Coder] = None,
key_builder: Callable = None,
namespace: Optional[str] = "",
executor: Optional["concurrent.futures.Executor"] = None,
):
"""
cache all function
@@ -17,6 +23,8 @@ def cache(
:param expire:
:param coder:
:param key_builder:
:param executor:
:return:
"""
@@ -29,7 +37,10 @@ def cache(
copy_kwargs = kwargs.copy()
request = copy_kwargs.pop("request", None)
response = copy_kwargs.pop("response", None)
if request and request.headers.get("Cache-Control") in ("no-store", "no-cache"):
if (
request and request.headers.get("Cache-Control") in ("no-store", "no-cache")
) or not FastAPICache.get_enable():
return await func(*args, **kwargs)
coder = coder or FastAPICache.get_coder()
@@ -61,7 +72,12 @@ def cache(
response.headers["ETag"] = etag
return coder.decode(ret)
ret = await func(*args, **kwargs)
if inspect.iscoroutinefunction(func):
ret = await func(*args, **kwargs)
else:
loop = asyncio.get_event_loop()
ret = await loop.run_in_executor(executor, partial(func, *args, **kwargs))
await backend.set(cache_key, coder.encode(ret), expire or FastAPICache.get_expire())
return ret