diff --git a/README.md b/README.md index f0a9b6d..b445a5d 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ async def index(): @app.on_event("startup") async def startup(): - redis = aioredis.from_url("redis://localhost", encoding="utf8", decode_responses=True) + redis = aioredis.from_url("redis://localhost") FastAPICache.init(RedisBackend(redis), prefix="fastapi-cache") ``` @@ -180,6 +180,13 @@ async def index(): `InMemoryBackend` store cache data in memory and use lazy delete, which mean if you don't access it after cached, it will not delete automatically. + +### RedisBackend + +When using the redis backend, please make sure you pass in a redis client that does [_not_ decode responses][redis-decode] (`decode_responses` **must** be `False`, which is the default). Cached data is stored as `bytes` (binary), decoding these i the redis client would break caching. + +[redis-decode]: https://redis-py.readthedocs.io/en/latest/examples/connection_examples.html#by-default-Redis-return-binary-responses,-to-decode-them-use-decode_responses=True + ## Tests and coverage ```shell diff --git a/fastapi_cache/backends/__init__.py b/fastapi_cache/backends/__init__.py index 2aa3433..4519aa9 100644 --- a/fastapi_cache/backends/__init__.py +++ b/fastapi_cache/backends/__init__.py @@ -4,15 +4,15 @@ from typing import Optional, Tuple class Backend: @abc.abstractmethod - async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]: + async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]: raise NotImplementedError @abc.abstractmethod - async def get(self, key: str) -> Optional[str]: + async def get(self, key: str) -> Optional[bytes]: raise NotImplementedError @abc.abstractmethod - async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: + async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None: raise NotImplementedError @abc.abstractmethod diff --git a/fastapi_cache/backends/dynamodb.py b/fastapi_cache/backends/dynamodb.py index 34cc366..547cfb8 100644 --- a/fastapi_cache/backends/dynamodb.py +++ b/fastapi_cache/backends/dynamodb.py @@ -39,11 +39,11 @@ class DynamoBackend(Backend): async def close(self) -> None: self.client = await self.client.__aexit__(None, None, None) - async def get_with_ttl(self, key: str) -> Tuple[int, str]: + async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]: response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}}) if "Item" in response: - value = response["Item"].get("value", {}).get("S") + value = response["Item"].get("value", {}).get("B") ttl = response["Item"].get("ttl", {}).get("N") if not ttl: @@ -56,12 +56,12 @@ class DynamoBackend(Backend): return 0, None - async def get(self, key: str) -> str: + async def get(self, key: str) -> Optional[bytes]: response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}}) if "Item" in response: - return response["Item"].get("value", {}).get("S") + return response["Item"].get("value", {}).get("B") - async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: + async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None: ttl = ( { "ttl": { @@ -83,7 +83,7 @@ class DynamoBackend(Backend): Item={ **{ "key": {"S": key}, - "value": {"S": value}, + "value": {"B": value}, }, **ttl, }, diff --git a/fastapi_cache/backends/inmemory.py b/fastapi_cache/backends/inmemory.py index 4b05f06..219dd25 100644 --- a/fastapi_cache/backends/inmemory.py +++ b/fastapi_cache/backends/inmemory.py @@ -8,7 +8,7 @@ from fastapi_cache.backends import Backend @dataclass class Value: - data: str + data: bytes ttl_ts: int @@ -29,21 +29,21 @@ class InMemoryBackend(Backend): return v return None - async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]: + async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]: async with self._lock: v = self._get(key) if v: return v.ttl_ts - self._now, v.data return 0, None - async def get(self, key: str) -> Optional[str]: + async def get(self, key: str) -> Optional[bytes]: async with self._lock: v = self._get(key) if v: return v.data return None - async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: + async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None: async with self._lock: self._store[key] = Value(value, self._now + (expire or 0)) diff --git a/fastapi_cache/backends/memcached.py b/fastapi_cache/backends/memcached.py index 31f1a35..22e201f 100644 --- a/fastapi_cache/backends/memcached.py +++ b/fastapi_cache/backends/memcached.py @@ -9,14 +9,14 @@ class MemcachedBackend(Backend): def __init__(self, mcache: Client): self.mcache = mcache - async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]: - return 3600, await self.mcache.get(key.encode()) + async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]: + return 3600, await self.get(key) - async def get(self, key: str) -> Optional[str]: - return await self.mcache.get(key, key.encode()) + async def get(self, key: str) -> Optional[bytes]: + return await self.mcache.get(key.encode()) - async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: - await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0) + async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None: + await self.mcache.set(key.encode(), value, exptime=expire or 0) async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: raise NotImplementedError diff --git a/fastapi_cache/backends/redis.py b/fastapi_cache/backends/redis.py index a40bde4..928af88 100644 --- a/fastapi_cache/backends/redis.py +++ b/fastapi_cache/backends/redis.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, Union from redis.asyncio.client import Redis from redis.asyncio.cluster import RedisCluster @@ -7,18 +7,18 @@ from fastapi_cache.backends import Backend class RedisBackend(Backend): - def __init__(self, redis: Redis[str] | RedisCluster[str]): + def __init__(self, redis: Union[Redis[bytes], RedisCluster[bytes]]): self.redis = redis self.is_cluster: bool = isinstance(redis, RedisCluster) - async def get_with_ttl(self, key: str) -> Tuple[int, str]: + async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]: async with self.redis.pipeline(transaction=not self.is_cluster) as pipe: return await pipe.ttl(key).get(key).execute() - async def get(self, key: str) -> Optional[str]: + async def get(self, key: str) -> Optional[bytes]: return await self.redis.get(key) - async def set(self, key: str, value: str, expire: Optional[int] = None) -> None: + async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None: return await self.redis.set(key, value, ex=expire) async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int: diff --git a/fastapi_cache/coder.py b/fastapi_cache/coder.py index 6581967..30c0ee2 100644 --- a/fastapi_cache/coder.py +++ b/fastapi_cache/coder.py @@ -1,9 +1,8 @@ -import codecs import datetime import json import pickle # nosec:B403 from decimal import Decimal -from typing import Any, Callable, ClassVar, Dict, TypeVar, overload +from typing import Any, Callable, ClassVar, Dict, Optional, TypeVar, Union, overload import pendulum from fastapi.encoders import jsonable_encoder @@ -46,11 +45,11 @@ def object_hook(obj: Any) -> Any: class Coder: @classmethod - def encode(cls, value: Any) -> str: + def encode(cls, value: Any) -> bytes: raise NotImplementedError @classmethod - def decode(cls, value: str) -> Any: + def decode(cls, value: bytes) -> Any: raise NotImplementedError # (Shared) cache for endpoint return types to Pydantic model fields. @@ -62,16 +61,16 @@ class Coder: @overload @classmethod - def decode_as_type(cls, value: str, type_: _T) -> _T: + def decode_as_type(cls, value: bytes, *, type_: _T) -> _T: ... @overload @classmethod - def decode_as_type(cls, value: str, *, type_: None) -> Any: + def decode_as_type(cls, value: bytes, *, type_: None) -> Any: ... @classmethod - def decode_as_type(cls, value: str, *, type_: _T | None) -> _T | Any: + def decode_as_type(cls, value: bytes, *, type_: Optional[_T]) -> Union[_T, Any]: """Decode value to the specific given type The default implementation uses the Pydantic model system to convert the value. @@ -95,29 +94,32 @@ class Coder: class JsonCoder(Coder): @classmethod - def encode(cls, value: Any) -> str: + def encode(cls, value: Any) -> bytes: if isinstance(value, JSONResponse): - return value.body.decode() - return json.dumps(value, cls=JsonEncoder) + return value.body + return json.dumps(value, cls=JsonEncoder).encode() @classmethod - def decode(cls, value: str) -> str: - return json.loads(value, object_hook=object_hook) + def decode(cls, value: bytes) -> Any: + # explicitly decode from UTF-8 bytes first, as otherwise + # json.loads() will first have to detect the correct UTF- + # encoding used. + return json.loads(value.decode(), object_hook=object_hook) class PickleCoder(Coder): @classmethod - def encode(cls, value: Any) -> str: + def encode(cls, value: Any) -> bytes: if isinstance(value, TemplateResponse): value = value.body - return codecs.encode(pickle.dumps(value), "base64").decode() + return pickle.dumps(value) @classmethod - def decode(cls, value: str) -> Any: - return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301 + def decode(cls, value: bytes) -> Any: + return pickle.loads(value) # nosec:B403,B301 @classmethod - def decode_as_type(cls, value: str, *, type_: Any) -> Any: + def decode_as_type(cls, value: bytes, *, type_: Optional[_T]) -> Any: # Pickle already produces the correct type on decoding, no point # in paying an extra performance penalty for pydantic to discover # the same. diff --git a/tests/test_codecs.py b/tests/test_codecs.py index 8120977..ff1dc23 100644 --- a/tests/test_codecs.py +++ b/tests/test_codecs.py @@ -36,7 +36,7 @@ class PDItem(BaseModel): ) def test_pickle_coder(value: Any) -> None: encoded_value = PickleCoder.encode(value) - assert isinstance(encoded_value, str) + assert isinstance(encoded_value, bytes) decoded_value = PickleCoder.decode(encoded_value) assert decoded_value == value @@ -55,12 +55,12 @@ def test_pickle_coder(value: Any) -> None: ) def test_json_coder(value: Any, return_type) -> None: encoded_value = JsonCoder.encode(value) - assert isinstance(encoded_value, str) + assert isinstance(encoded_value, bytes) decoded_value = JsonCoder.decode_as_type(encoded_value, type_=return_type) assert decoded_value == value def test_json_coder_validation_error() -> None: - invalid = '{"name": "incomplete"}' + invalid = b'{"name": "incomplete"}' with pytest.raises(ValidationError): JsonCoder.decode_as_type(invalid, type_=PDItem)