mirror of
https://github.com/long2ice/fastapi-cache.git
synced 2026-03-24 20:47:54 +00:00
Merge pull request #134 from mjpieters/backend_coder_bytes
Make backends store bytes instead of strings
This commit is contained in:
@@ -78,7 +78,7 @@ async def index():
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
redis = aioredis.from_url("redis://localhost", encoding="utf8", decode_responses=True)
|
||||
redis = aioredis.from_url("redis://localhost")
|
||||
FastAPICache.init(RedisBackend(redis), prefix="fastapi-cache")
|
||||
|
||||
```
|
||||
@@ -180,6 +180,13 @@ async def index():
|
||||
`InMemoryBackend` store cache data in memory and use lazy delete, which mean if you don't access it after cached, it
|
||||
will not delete automatically.
|
||||
|
||||
|
||||
### RedisBackend
|
||||
|
||||
When using the redis backend, please make sure you pass in a redis client that does [_not_ decode responses][redis-decode] (`decode_responses` **must** be `False`, which is the default). Cached data is stored as `bytes` (binary), decoding these i the redis client would break caching.
|
||||
|
||||
[redis-decode]: https://redis-py.readthedocs.io/en/latest/examples/connection_examples.html#by-default-Redis-return-binary-responses,-to-decode-them-use-decode_responses=True
|
||||
|
||||
## Tests and coverage
|
||||
|
||||
```shell
|
||||
|
||||
@@ -4,15 +4,15 @@ from typing import Optional, Tuple
|
||||
|
||||
class Backend:
|
||||
@abc.abstractmethod
|
||||
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
|
||||
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
async def get(self, key: str) -> Optional[bytes]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
|
||||
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -39,11 +39,11 @@ class DynamoBackend(Backend):
|
||||
async def close(self) -> None:
|
||||
self.client = await self.client.__aexit__(None, None, None)
|
||||
|
||||
async def get_with_ttl(self, key: str) -> Tuple[int, str]:
|
||||
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
|
||||
|
||||
if "Item" in response:
|
||||
value = response["Item"].get("value", {}).get("S")
|
||||
value = response["Item"].get("value", {}).get("B")
|
||||
ttl = response["Item"].get("ttl", {}).get("N")
|
||||
|
||||
if not ttl:
|
||||
@@ -56,12 +56,12 @@ class DynamoBackend(Backend):
|
||||
|
||||
return 0, None
|
||||
|
||||
async def get(self, key: str) -> str:
|
||||
async def get(self, key: str) -> Optional[bytes]:
|
||||
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
|
||||
if "Item" in response:
|
||||
return response["Item"].get("value", {}).get("S")
|
||||
return response["Item"].get("value", {}).get("B")
|
||||
|
||||
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
|
||||
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||
ttl = (
|
||||
{
|
||||
"ttl": {
|
||||
@@ -83,7 +83,7 @@ class DynamoBackend(Backend):
|
||||
Item={
|
||||
**{
|
||||
"key": {"S": key},
|
||||
"value": {"S": value},
|
||||
"value": {"B": value},
|
||||
},
|
||||
**ttl,
|
||||
},
|
||||
|
||||
@@ -8,7 +8,7 @@ from fastapi_cache.backends import Backend
|
||||
|
||||
@dataclass
|
||||
class Value:
|
||||
data: str
|
||||
data: bytes
|
||||
ttl_ts: int
|
||||
|
||||
|
||||
@@ -29,21 +29,21 @@ class InMemoryBackend(Backend):
|
||||
return v
|
||||
return None
|
||||
|
||||
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
|
||||
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||
async with self._lock:
|
||||
v = self._get(key)
|
||||
if v:
|
||||
return v.ttl_ts - self._now, v.data
|
||||
return 0, None
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
async def get(self, key: str) -> Optional[bytes]:
|
||||
async with self._lock:
|
||||
v = self._get(key)
|
||||
if v:
|
||||
return v.data
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
|
||||
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||
async with self._lock:
|
||||
self._store[key] = Value(value, self._now + (expire or 0))
|
||||
|
||||
|
||||
@@ -9,14 +9,14 @@ class MemcachedBackend(Backend):
|
||||
def __init__(self, mcache: Client):
|
||||
self.mcache = mcache
|
||||
|
||||
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
|
||||
return 3600, await self.mcache.get(key.encode())
|
||||
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||
return 3600, await self.get(key)
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
return await self.mcache.get(key, key.encode())
|
||||
async def get(self, key: str) -> Optional[bytes]:
|
||||
return await self.mcache.get(key.encode())
|
||||
|
||||
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
|
||||
await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0)
|
||||
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||
await self.mcache.set(key.encode(), value, exptime=expire or 0)
|
||||
|
||||
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from redis.asyncio.client import Redis
|
||||
from redis.asyncio.cluster import RedisCluster
|
||||
@@ -7,18 +7,18 @@ from fastapi_cache.backends import Backend
|
||||
|
||||
|
||||
class RedisBackend(Backend):
|
||||
def __init__(self, redis: Redis[str] | RedisCluster[str]):
|
||||
def __init__(self, redis: Union[Redis[bytes], RedisCluster[bytes]]):
|
||||
self.redis = redis
|
||||
self.is_cluster: bool = isinstance(redis, RedisCluster)
|
||||
|
||||
async def get_with_ttl(self, key: str) -> Tuple[int, str]:
|
||||
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||
async with self.redis.pipeline(transaction=not self.is_cluster) as pipe:
|
||||
return await pipe.ttl(key).get(key).execute()
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
async def get(self, key: str) -> Optional[bytes]:
|
||||
return await self.redis.get(key)
|
||||
|
||||
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
|
||||
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||
return await self.redis.set(key, value, ex=expire)
|
||||
|
||||
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import codecs
|
||||
import datetime
|
||||
import json
|
||||
import pickle # nosec:B403
|
||||
from decimal import Decimal
|
||||
from typing import Any, Callable, ClassVar, Dict, TypeVar, overload
|
||||
from typing import Any, Callable, ClassVar, Dict, Optional, TypeVar, Union, overload
|
||||
|
||||
import pendulum
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
@@ -46,11 +45,11 @@ def object_hook(obj: Any) -> Any:
|
||||
|
||||
class Coder:
|
||||
@classmethod
|
||||
def encode(cls, value: Any) -> str:
|
||||
def encode(cls, value: Any) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def decode(cls, value: str) -> Any:
|
||||
def decode(cls, value: bytes) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
# (Shared) cache for endpoint return types to Pydantic model fields.
|
||||
@@ -62,16 +61,16 @@ class Coder:
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def decode_as_type(cls, value: str, type_: _T) -> _T:
|
||||
def decode_as_type(cls, value: bytes, *, type_: _T) -> _T:
|
||||
...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def decode_as_type(cls, value: str, *, type_: None) -> Any:
|
||||
def decode_as_type(cls, value: bytes, *, type_: None) -> Any:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def decode_as_type(cls, value: str, *, type_: _T | None) -> _T | Any:
|
||||
def decode_as_type(cls, value: bytes, *, type_: Optional[_T]) -> Union[_T, Any]:
|
||||
"""Decode value to the specific given type
|
||||
|
||||
The default implementation uses the Pydantic model system to convert the value.
|
||||
@@ -95,29 +94,32 @@ class Coder:
|
||||
|
||||
class JsonCoder(Coder):
|
||||
@classmethod
|
||||
def encode(cls, value: Any) -> str:
|
||||
def encode(cls, value: Any) -> bytes:
|
||||
if isinstance(value, JSONResponse):
|
||||
return value.body.decode()
|
||||
return json.dumps(value, cls=JsonEncoder)
|
||||
return value.body
|
||||
return json.dumps(value, cls=JsonEncoder).encode()
|
||||
|
||||
@classmethod
|
||||
def decode(cls, value: str) -> str:
|
||||
return json.loads(value, object_hook=object_hook)
|
||||
def decode(cls, value: bytes) -> Any:
|
||||
# explicitly decode from UTF-8 bytes first, as otherwise
|
||||
# json.loads() will first have to detect the correct UTF-
|
||||
# encoding used.
|
||||
return json.loads(value.decode(), object_hook=object_hook)
|
||||
|
||||
|
||||
class PickleCoder(Coder):
|
||||
@classmethod
|
||||
def encode(cls, value: Any) -> str:
|
||||
def encode(cls, value: Any) -> bytes:
|
||||
if isinstance(value, TemplateResponse):
|
||||
value = value.body
|
||||
return codecs.encode(pickle.dumps(value), "base64").decode()
|
||||
return pickle.dumps(value)
|
||||
|
||||
@classmethod
|
||||
def decode(cls, value: str) -> Any:
|
||||
return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301
|
||||
def decode(cls, value: bytes) -> Any:
|
||||
return pickle.loads(value) # nosec:B403,B301
|
||||
|
||||
@classmethod
|
||||
def decode_as_type(cls, value: str, *, type_: Any) -> Any:
|
||||
def decode_as_type(cls, value: bytes, *, type_: Optional[_T]) -> Any:
|
||||
# Pickle already produces the correct type on decoding, no point
|
||||
# in paying an extra performance penalty for pydantic to discover
|
||||
# the same.
|
||||
|
||||
@@ -36,7 +36,7 @@ class PDItem(BaseModel):
|
||||
)
|
||||
def test_pickle_coder(value: Any) -> None:
|
||||
encoded_value = PickleCoder.encode(value)
|
||||
assert isinstance(encoded_value, str)
|
||||
assert isinstance(encoded_value, bytes)
|
||||
decoded_value = PickleCoder.decode(encoded_value)
|
||||
assert decoded_value == value
|
||||
|
||||
@@ -55,12 +55,12 @@ def test_pickle_coder(value: Any) -> None:
|
||||
)
|
||||
def test_json_coder(value: Any, return_type) -> None:
|
||||
encoded_value = JsonCoder.encode(value)
|
||||
assert isinstance(encoded_value, str)
|
||||
assert isinstance(encoded_value, bytes)
|
||||
decoded_value = JsonCoder.decode_as_type(encoded_value, type_=return_type)
|
||||
assert decoded_value == value
|
||||
|
||||
|
||||
def test_json_coder_validation_error() -> None:
|
||||
invalid = '{"name": "incomplete"}'
|
||||
invalid = b'{"name": "incomplete"}'
|
||||
with pytest.raises(ValidationError):
|
||||
JsonCoder.decode_as_type(invalid, type_=PDItem)
|
||||
|
||||
Reference in New Issue
Block a user