fix #509: fix up linting and tests due to aiobotocore 2.18.0 changes

This commit is contained in:
Gary Gale
2025-01-18 09:52:34 +00:00
parent 7dade61a49
commit 6f4876ff7d

View File

@@ -1,5 +1,5 @@
import datetime import datetime
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Optional, Tuple, Union
from aiobotocore.client import AioBaseClient from aiobotocore.client import AioBaseClient
from aiobotocore.session import AioSession, get_session from aiobotocore.session import AioSession, get_session
@@ -30,7 +30,7 @@ class DynamoBackend(Backend):
>> FastAPICache.init(dynamodb) >> FastAPICache.init(dynamodb)
""" """
client: DynamoDBClient client: Union[DynamoDBClient, None]
session: AioSession session: AioSession
table_name: str table_name: str
region: Optional[str] region: Optional[str]
@@ -46,58 +46,63 @@ class DynamoBackend(Backend):
).__aenter__() ).__aenter__()
async def close(self) -> None: async def close(self) -> None:
self.client = await self.client.__aexit__(None, None, None) if self.client:
await self.client.__aexit__(None, None, None)
self.client = None
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]: async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}}) if self.client:
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
if "Item" in response: if "Item" in response:
value = response["Item"].get("value", {}).get("B") value = response["Item"].get("value", {}).get("B")
ttl = response["Item"].get("ttl", {}).get("N") ttl = response["Item"].get("ttl", {}).get("N")
if not ttl: if not ttl:
return -1, value return -1, value
# It's only eventually consistent so we need to check ourselves # It's only eventually consistent so we need to check ourselves
expire = int(ttl) - int(datetime.datetime.now().timestamp()) expire = int(ttl) - int(datetime.datetime.now().timestamp())
if expire > 0: if expire > 0:
return expire, value return expire, value
return 0, None return 0, None
async def get(self, key: str) -> Optional[bytes]: async def get(self, key: str) -> Optional[bytes]:
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}}) if self.client:
if "Item" in response: response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
return response["Item"].get("value", {}).get("B") if "Item" in response:
return response["Item"].get("value", {}).get("B")
return None return None
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None: async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
ttl = ( if self.client:
{ ttl = (
"ttl": { {
"N": str( "ttl": {
int( "N": str(
( int(
datetime.datetime.now() + datetime.timedelta(seconds=expire) (
).timestamp() datetime.datetime.now() + datetime.timedelta(seconds=expire)
).timestamp()
)
) )
) }
} }
} if expire
if expire else {}
else {} )
)
await self.client.put_item( await self.client.put_item(
TableName=self.table_name, TableName=self.table_name,
Item={ Item={
**{ **{
"key": {"S": key}, "key": {"S": key},
"value": {"B": value}, "value": {"B": value},
},
**ttl,
}, },
**ttl, )
},
)
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
raise NotImplementedError raise NotImplementedError