mirror of
https://github.com/long2ice/fastapi-cache.git
synced 2026-03-25 04:57:54 +00:00
Make backends store bytes instead of strings
This is, for the majority of backends, the native format anyway, and so we save encoding and decoding when using the PickleCodec or if (in future) a orjson Coder was to be added. For the JsonCodec, the only thing that changed is the location where the JSON data is encoded to bytes and decoded back again to a string.
This commit is contained in:
@@ -78,7 +78,7 @@ async def index():
|
|||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup():
|
async def startup():
|
||||||
redis = aioredis.from_url("redis://localhost", encoding="utf8", decode_responses=True)
|
redis = aioredis.from_url("redis://localhost")
|
||||||
FastAPICache.init(RedisBackend(redis), prefix="fastapi-cache")
|
FastAPICache.init(RedisBackend(redis), prefix="fastapi-cache")
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -180,6 +180,13 @@ async def index():
|
|||||||
`InMemoryBackend` store cache data in memory and use lazy delete, which mean if you don't access it after cached, it
|
`InMemoryBackend` store cache data in memory and use lazy delete, which mean if you don't access it after cached, it
|
||||||
will not delete automatically.
|
will not delete automatically.
|
||||||
|
|
||||||
|
|
||||||
|
### RedisBackend
|
||||||
|
|
||||||
|
When using the redis backend, please make sure you pass in a redis client that does [_not_ decode responses][redis-decode] (`decode_responses` **must** be `False`, which is the default). Cached data is stored as `bytes` (binary), decoding these i the redis client would break caching.
|
||||||
|
|
||||||
|
[redis-decode]: https://redis-py.readthedocs.io/en/latest/examples/connection_examples.html#by-default-Redis-return-binary-responses,-to-decode-them-use-decode_responses=True
|
||||||
|
|
||||||
## Tests and coverage
|
## Tests and coverage
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
|||||||
@@ -4,15 +4,15 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
class Backend:
|
class Backend:
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
|
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get(self, key: str) -> Optional[str]:
|
async def get(self, key: str) -> Optional[bytes]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
|
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
|||||||
@@ -39,11 +39,11 @@ class DynamoBackend(Backend):
|
|||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
self.client = await self.client.__aexit__(None, None, None)
|
self.client = await self.client.__aexit__(None, None, None)
|
||||||
|
|
||||||
async def get_with_ttl(self, key: str) -> Tuple[int, str]:
|
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||||
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
|
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
|
||||||
|
|
||||||
if "Item" in response:
|
if "Item" in response:
|
||||||
value = response["Item"].get("value", {}).get("S")
|
value = response["Item"].get("value", {}).get("B")
|
||||||
ttl = response["Item"].get("ttl", {}).get("N")
|
ttl = response["Item"].get("ttl", {}).get("N")
|
||||||
|
|
||||||
if not ttl:
|
if not ttl:
|
||||||
@@ -56,12 +56,12 @@ class DynamoBackend(Backend):
|
|||||||
|
|
||||||
return 0, None
|
return 0, None
|
||||||
|
|
||||||
async def get(self, key: str) -> str:
|
async def get(self, key: str) -> Optional[bytes]:
|
||||||
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
|
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
|
||||||
if "Item" in response:
|
if "Item" in response:
|
||||||
return response["Item"].get("value", {}).get("S")
|
return response["Item"].get("value", {}).get("B")
|
||||||
|
|
||||||
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
|
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||||
ttl = (
|
ttl = (
|
||||||
{
|
{
|
||||||
"ttl": {
|
"ttl": {
|
||||||
@@ -83,7 +83,7 @@ class DynamoBackend(Backend):
|
|||||||
Item={
|
Item={
|
||||||
**{
|
**{
|
||||||
"key": {"S": key},
|
"key": {"S": key},
|
||||||
"value": {"S": value},
|
"value": {"B": value},
|
||||||
},
|
},
|
||||||
**ttl,
|
**ttl,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from fastapi_cache.backends import Backend
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Value:
|
class Value:
|
||||||
data: str
|
data: bytes
|
||||||
ttl_ts: int
|
ttl_ts: int
|
||||||
|
|
||||||
|
|
||||||
@@ -29,21 +29,21 @@ class InMemoryBackend(Backend):
|
|||||||
return v
|
return v
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
|
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
v = self._get(key)
|
v = self._get(key)
|
||||||
if v:
|
if v:
|
||||||
return v.ttl_ts - self._now, v.data
|
return v.ttl_ts - self._now, v.data
|
||||||
return 0, None
|
return 0, None
|
||||||
|
|
||||||
async def get(self, key: str) -> Optional[str]:
|
async def get(self, key: str) -> Optional[bytes]:
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
v = self._get(key)
|
v = self._get(key)
|
||||||
if v:
|
if v:
|
||||||
return v.data
|
return v.data
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
|
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._store[key] = Value(value, self._now + (expire or 0))
|
self._store[key] = Value(value, self._now + (expire or 0))
|
||||||
|
|
||||||
|
|||||||
@@ -9,14 +9,14 @@ class MemcachedBackend(Backend):
|
|||||||
def __init__(self, mcache: Client):
|
def __init__(self, mcache: Client):
|
||||||
self.mcache = mcache
|
self.mcache = mcache
|
||||||
|
|
||||||
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
|
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||||
return 3600, await self.mcache.get(key.encode())
|
return 3600, await self.get(key)
|
||||||
|
|
||||||
async def get(self, key: str) -> Optional[str]:
|
async def get(self, key: str) -> Optional[bytes]:
|
||||||
return await self.mcache.get(key, key.encode())
|
return await self.mcache.get(key.encode())
|
||||||
|
|
||||||
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
|
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||||
await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0)
|
await self.mcache.set(key.encode(), value, exptime=expire or 0)
|
||||||
|
|
||||||
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
|
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
from redis.asyncio.client import Redis
|
from redis.asyncio.client import Redis
|
||||||
from redis.asyncio.cluster import RedisCluster
|
from redis.asyncio.cluster import RedisCluster
|
||||||
@@ -7,18 +7,18 @@ from fastapi_cache.backends import Backend
|
|||||||
|
|
||||||
|
|
||||||
class RedisBackend(Backend):
|
class RedisBackend(Backend):
|
||||||
def __init__(self, redis: Redis[str] | RedisCluster[str]):
|
def __init__(self, redis: Union[Redis[bytes], RedisCluster[bytes]]):
|
||||||
self.redis = redis
|
self.redis = redis
|
||||||
self.is_cluster: bool = isinstance(redis, RedisCluster)
|
self.is_cluster: bool = isinstance(redis, RedisCluster)
|
||||||
|
|
||||||
async def get_with_ttl(self, key: str) -> Tuple[int, str]:
|
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||||
async with self.redis.pipeline(transaction=not self.is_cluster) as pipe:
|
async with self.redis.pipeline(transaction=not self.is_cluster) as pipe:
|
||||||
return await pipe.ttl(key).get(key).execute()
|
return await pipe.ttl(key).get(key).execute()
|
||||||
|
|
||||||
async def get(self, key: str) -> Optional[str]:
|
async def get(self, key: str) -> Optional[bytes]:
|
||||||
return await self.redis.get(key)
|
return await self.redis.get(key)
|
||||||
|
|
||||||
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
|
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||||
return await self.redis.set(key, value, ex=expire)
|
return await self.redis.set(key, value, ex=expire)
|
||||||
|
|
||||||
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
|
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import codecs
|
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import pickle # nosec:B403
|
import pickle # nosec:B403
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Any, Callable, ClassVar, Dict, TypeVar, overload
|
from typing import Any, Callable, ClassVar, Dict, Optional, TypeVar, Union, overload
|
||||||
|
|
||||||
import pendulum
|
import pendulum
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
@@ -46,11 +45,11 @@ def object_hook(obj: Any) -> Any:
|
|||||||
|
|
||||||
class Coder:
|
class Coder:
|
||||||
@classmethod
|
@classmethod
|
||||||
def encode(cls, value: Any) -> str:
|
def encode(cls, value: Any) -> bytes:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode(cls, value: str) -> Any:
|
def decode(cls, value: bytes) -> Any:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# (Shared) cache for endpoint return types to Pydantic model fields.
|
# (Shared) cache for endpoint return types to Pydantic model fields.
|
||||||
@@ -62,16 +61,16 @@ class Coder:
|
|||||||
|
|
||||||
@overload
|
@overload
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode_as_type(cls, value: str, type_: _T) -> _T:
|
def decode_as_type(cls, value: bytes, *, type_: _T) -> _T:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode_as_type(cls, value: str, *, type_: None) -> Any:
|
def decode_as_type(cls, value: bytes, *, type_: None) -> Any:
|
||||||
...
|
...
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode_as_type(cls, value: str, *, type_: _T | None) -> _T | Any:
|
def decode_as_type(cls, value: bytes, *, type_: Optional[_T]) -> Union[_T, Any]:
|
||||||
"""Decode value to the specific given type
|
"""Decode value to the specific given type
|
||||||
|
|
||||||
The default implementation uses the Pydantic model system to convert the value.
|
The default implementation uses the Pydantic model system to convert the value.
|
||||||
@@ -95,29 +94,32 @@ class Coder:
|
|||||||
|
|
||||||
class JsonCoder(Coder):
|
class JsonCoder(Coder):
|
||||||
@classmethod
|
@classmethod
|
||||||
def encode(cls, value: Any) -> str:
|
def encode(cls, value: Any) -> bytes:
|
||||||
if isinstance(value, JSONResponse):
|
if isinstance(value, JSONResponse):
|
||||||
return value.body.decode()
|
return value.body
|
||||||
return json.dumps(value, cls=JsonEncoder)
|
return json.dumps(value, cls=JsonEncoder).encode()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode(cls, value: str) -> str:
|
def decode(cls, value: bytes) -> Any:
|
||||||
return json.loads(value, object_hook=object_hook)
|
# explicitly decode from UTF-8 bytes first, as otherwise
|
||||||
|
# json.loads() will first have to detect the correct UTF-
|
||||||
|
# encoding used.
|
||||||
|
return json.loads(value.decode(), object_hook=object_hook)
|
||||||
|
|
||||||
|
|
||||||
class PickleCoder(Coder):
|
class PickleCoder(Coder):
|
||||||
@classmethod
|
@classmethod
|
||||||
def encode(cls, value: Any) -> str:
|
def encode(cls, value: Any) -> bytes:
|
||||||
if isinstance(value, TemplateResponse):
|
if isinstance(value, TemplateResponse):
|
||||||
value = value.body
|
value = value.body
|
||||||
return codecs.encode(pickle.dumps(value), "base64").decode()
|
return pickle.dumps(value)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode(cls, value: str) -> Any:
|
def decode(cls, value: bytes) -> Any:
|
||||||
return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301
|
return pickle.loads(value) # nosec:B403,B301
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode_as_type(cls, value: str, *, type_: Any) -> Any:
|
def decode_as_type(cls, value: bytes, *, type_: Optional[_T]) -> Any:
|
||||||
# Pickle already produces the correct type on decoding, no point
|
# Pickle already produces the correct type on decoding, no point
|
||||||
# in paying an extra performance penalty for pydantic to discover
|
# in paying an extra performance penalty for pydantic to discover
|
||||||
# the same.
|
# the same.
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class PDItem(BaseModel):
|
|||||||
)
|
)
|
||||||
def test_pickle_coder(value: Any) -> None:
|
def test_pickle_coder(value: Any) -> None:
|
||||||
encoded_value = PickleCoder.encode(value)
|
encoded_value = PickleCoder.encode(value)
|
||||||
assert isinstance(encoded_value, str)
|
assert isinstance(encoded_value, bytes)
|
||||||
decoded_value = PickleCoder.decode(encoded_value)
|
decoded_value = PickleCoder.decode(encoded_value)
|
||||||
assert decoded_value == value
|
assert decoded_value == value
|
||||||
|
|
||||||
@@ -55,12 +55,12 @@ def test_pickle_coder(value: Any) -> None:
|
|||||||
)
|
)
|
||||||
def test_json_coder(value: Any, return_type) -> None:
|
def test_json_coder(value: Any, return_type) -> None:
|
||||||
encoded_value = JsonCoder.encode(value)
|
encoded_value = JsonCoder.encode(value)
|
||||||
assert isinstance(encoded_value, str)
|
assert isinstance(encoded_value, bytes)
|
||||||
decoded_value = JsonCoder.decode_as_type(encoded_value, type_=return_type)
|
decoded_value = JsonCoder.decode_as_type(encoded_value, type_=return_type)
|
||||||
assert decoded_value == value
|
assert decoded_value == value
|
||||||
|
|
||||||
|
|
||||||
def test_json_coder_validation_error() -> None:
|
def test_json_coder_validation_error() -> None:
|
||||||
invalid = '{"name": "incomplete"}'
|
invalid = b'{"name": "incomplete"}'
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
JsonCoder.decode_as_type(invalid, type_=PDItem)
|
JsonCoder.decode_as_type(invalid, type_=PDItem)
|
||||||
|
|||||||
Reference in New Issue
Block a user