diff --git a/fastapi_cache/backends/dynamodb.py b/fastapi_cache/backends/dynamodb.py index ec639d1..7a32e89 100644 --- a/fastapi_cache/backends/dynamodb.py +++ b/fastapi_cache/backends/dynamodb.py @@ -1,5 +1,5 @@ import datetime -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple, Union from aiobotocore.client import AioBaseClient from aiobotocore.session import AioSession, get_session @@ -30,7 +30,7 @@ class DynamoBackend(Backend): >> FastAPICache.init(dynamodb) """ - client: DynamoDBClient + client: Union[DynamoDBClient, None] session: AioSession table_name: str region: Optional[str] @@ -46,58 +46,63 @@ class DynamoBackend(Backend): ).__aenter__() async def close(self) -> None: - self.client = await self.client.__aexit__(None, None, None) + if self.client: + await self.client.__aexit__(None, None, None) + self.client = None async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]: - response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}}) + if self.client: + response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}}) - if "Item" in response: - value = response["Item"].get("value", {}).get("B") - ttl = response["Item"].get("ttl", {}).get("N") + if "Item" in response: + value = response["Item"].get("value", {}).get("B") + ttl = response["Item"].get("ttl", {}).get("N") - if not ttl: - return -1, value + if not ttl: + return -1, value - # It's only eventually consistent so we need to check ourselves - expire = int(ttl) - int(datetime.datetime.now().timestamp()) - if expire > 0: - return expire, value + # It's only eventually consistent so we need to check ourselves + expire = int(ttl) - int(datetime.datetime.now().timestamp()) + if expire > 0: + return expire, value return 0, None async def get(self, key: str) -> Optional[bytes]: - response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}}) - if "Item" in response: - return response["Item"].get("value", {}).get("B") + if self.client: + response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}}) + if "Item" in response: + return response["Item"].get("value", {}).get("B") return None async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None: - ttl = ( - { - "ttl": { - "N": str( - int( - ( - datetime.datetime.now() + datetime.timedelta(seconds=expire) - ).timestamp() + if self.client: + ttl = ( + { + "ttl": { + "N": str( + int( + ( + datetime.datetime.now() + datetime.timedelta(seconds=expire) + ).timestamp() + ) ) - ) + } } - } - if expire - else {} - ) + if expire + else {} + ) - await self.client.put_item( - TableName=self.table_name, - Item={ - **{ - "key": {"S": key}, - "value": {"B": value}, + await self.client.put_item( + TableName=self.table_name, + Item={ + **{ + "key": {"S": key}, + "value": {"B": value}, + }, + **ttl, }, - **ttl, - }, - ) + ) async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: raise NotImplementedError