mirror of
https://github.com/long2ice/fastapi-cache.git
synced 2026-03-25 04:57:54 +00:00
Merge pull request #134 from mjpieters/backend_coder_bytes
Make backends store bytes instead of strings
This commit is contained in:
@@ -78,7 +78,7 @@ async def index():
|
|||||||
|
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup():
|
async def startup():
|
||||||
redis = aioredis.from_url("redis://localhost", encoding="utf8", decode_responses=True)
|
redis = aioredis.from_url("redis://localhost")
|
||||||
FastAPICache.init(RedisBackend(redis), prefix="fastapi-cache")
|
FastAPICache.init(RedisBackend(redis), prefix="fastapi-cache")
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -180,6 +180,13 @@ async def index():
|
|||||||
`InMemoryBackend` store cache data in memory and use lazy delete, which mean if you don't access it after cached, it
|
`InMemoryBackend` store cache data in memory and use lazy delete, which mean if you don't access it after cached, it
|
||||||
will not delete automatically.
|
will not delete automatically.
|
||||||
|
|
||||||
|
|
||||||
|
### RedisBackend
|
||||||
|
|
||||||
|
When using the redis backend, please make sure you pass in a redis client that does [_not_ decode responses][redis-decode] (`decode_responses` **must** be `False`, which is the default). Cached data is stored as `bytes` (binary), decoding these i the redis client would break caching.
|
||||||
|
|
||||||
|
[redis-decode]: https://redis-py.readthedocs.io/en/latest/examples/connection_examples.html#by-default-Redis-return-binary-responses,-to-decode-them-use-decode_responses=True
|
||||||
|
|
||||||
## Tests and coverage
|
## Tests and coverage
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
|||||||
@@ -4,15 +4,15 @@ from typing import Optional, Tuple
|
|||||||
|
|
||||||
class Backend:
|
class Backend:
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
|
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def get(self, key: str) -> Optional[str]:
|
async def get(self, key: str) -> Optional[bytes]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
|
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
|||||||
@@ -39,11 +39,11 @@ class DynamoBackend(Backend):
|
|||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
self.client = await self.client.__aexit__(None, None, None)
|
self.client = await self.client.__aexit__(None, None, None)
|
||||||
|
|
||||||
async def get_with_ttl(self, key: str) -> Tuple[int, str]:
|
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||||
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
|
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
|
||||||
|
|
||||||
if "Item" in response:
|
if "Item" in response:
|
||||||
value = response["Item"].get("value", {}).get("S")
|
value = response["Item"].get("value", {}).get("B")
|
||||||
ttl = response["Item"].get("ttl", {}).get("N")
|
ttl = response["Item"].get("ttl", {}).get("N")
|
||||||
|
|
||||||
if not ttl:
|
if not ttl:
|
||||||
@@ -56,12 +56,12 @@ class DynamoBackend(Backend):
|
|||||||
|
|
||||||
return 0, None
|
return 0, None
|
||||||
|
|
||||||
async def get(self, key: str) -> str:
|
async def get(self, key: str) -> Optional[bytes]:
|
||||||
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
|
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
|
||||||
if "Item" in response:
|
if "Item" in response:
|
||||||
return response["Item"].get("value", {}).get("S")
|
return response["Item"].get("value", {}).get("B")
|
||||||
|
|
||||||
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
|
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||||
ttl = (
|
ttl = (
|
||||||
{
|
{
|
||||||
"ttl": {
|
"ttl": {
|
||||||
@@ -83,7 +83,7 @@ class DynamoBackend(Backend):
|
|||||||
Item={
|
Item={
|
||||||
**{
|
**{
|
||||||
"key": {"S": key},
|
"key": {"S": key},
|
||||||
"value": {"S": value},
|
"value": {"B": value},
|
||||||
},
|
},
|
||||||
**ttl,
|
**ttl,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from fastapi_cache.backends import Backend
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Value:
|
class Value:
|
||||||
data: str
|
data: bytes
|
||||||
ttl_ts: int
|
ttl_ts: int
|
||||||
|
|
||||||
|
|
||||||
@@ -29,21 +29,21 @@ class InMemoryBackend(Backend):
|
|||||||
return v
|
return v
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
|
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
v = self._get(key)
|
v = self._get(key)
|
||||||
if v:
|
if v:
|
||||||
return v.ttl_ts - self._now, v.data
|
return v.ttl_ts - self._now, v.data
|
||||||
return 0, None
|
return 0, None
|
||||||
|
|
||||||
async def get(self, key: str) -> Optional[str]:
|
async def get(self, key: str) -> Optional[bytes]:
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
v = self._get(key)
|
v = self._get(key)
|
||||||
if v:
|
if v:
|
||||||
return v.data
|
return v.data
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
|
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._store[key] = Value(value, self._now + (expire or 0))
|
self._store[key] = Value(value, self._now + (expire or 0))
|
||||||
|
|
||||||
|
|||||||
@@ -9,14 +9,14 @@ class MemcachedBackend(Backend):
|
|||||||
def __init__(self, mcache: Client):
|
def __init__(self, mcache: Client):
|
||||||
self.mcache = mcache
|
self.mcache = mcache
|
||||||
|
|
||||||
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[str]]:
|
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||||
return 3600, await self.mcache.get(key.encode())
|
return 3600, await self.get(key)
|
||||||
|
|
||||||
async def get(self, key: str) -> Optional[str]:
|
async def get(self, key: str) -> Optional[bytes]:
|
||||||
return await self.mcache.get(key, key.encode())
|
return await self.mcache.get(key.encode())
|
||||||
|
|
||||||
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
|
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||||
await self.mcache.set(key.encode(), value.encode(), exptime=expire or 0)
|
await self.mcache.set(key.encode(), value, exptime=expire or 0)
|
||||||
|
|
||||||
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
|
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
from redis.asyncio.client import Redis
|
from redis.asyncio.client import Redis
|
||||||
from redis.asyncio.cluster import RedisCluster
|
from redis.asyncio.cluster import RedisCluster
|
||||||
@@ -7,18 +7,18 @@ from fastapi_cache.backends import Backend
|
|||||||
|
|
||||||
|
|
||||||
class RedisBackend(Backend):
|
class RedisBackend(Backend):
|
||||||
def __init__(self, redis: Redis[str] | RedisCluster[str]):
|
def __init__(self, redis: Union[Redis[bytes], RedisCluster[bytes]]):
|
||||||
self.redis = redis
|
self.redis = redis
|
||||||
self.is_cluster: bool = isinstance(redis, RedisCluster)
|
self.is_cluster: bool = isinstance(redis, RedisCluster)
|
||||||
|
|
||||||
async def get_with_ttl(self, key: str) -> Tuple[int, str]:
|
async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
|
||||||
async with self.redis.pipeline(transaction=not self.is_cluster) as pipe:
|
async with self.redis.pipeline(transaction=not self.is_cluster) as pipe:
|
||||||
return await pipe.ttl(key).get(key).execute()
|
return await pipe.ttl(key).get(key).execute()
|
||||||
|
|
||||||
async def get(self, key: str) -> Optional[str]:
|
async def get(self, key: str) -> Optional[bytes]:
|
||||||
return await self.redis.get(key)
|
return await self.redis.get(key)
|
||||||
|
|
||||||
async def set(self, key: str, value: str, expire: Optional[int] = None) -> None:
|
async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
|
||||||
return await self.redis.set(key, value, ex=expire)
|
return await self.redis.set(key, value, ex=expire)
|
||||||
|
|
||||||
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
|
async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import codecs
|
|
||||||
import datetime
|
import datetime
|
||||||
import json
|
import json
|
||||||
import pickle # nosec:B403
|
import pickle # nosec:B403
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from typing import Any, Callable, ClassVar, Dict, TypeVar, overload
|
from typing import Any, Callable, ClassVar, Dict, Optional, TypeVar, Union, overload
|
||||||
|
|
||||||
import pendulum
|
import pendulum
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
@@ -46,11 +45,11 @@ def object_hook(obj: Any) -> Any:
|
|||||||
|
|
||||||
class Coder:
|
class Coder:
|
||||||
@classmethod
|
@classmethod
|
||||||
def encode(cls, value: Any) -> str:
|
def encode(cls, value: Any) -> bytes:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode(cls, value: str) -> Any:
|
def decode(cls, value: bytes) -> Any:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# (Shared) cache for endpoint return types to Pydantic model fields.
|
# (Shared) cache for endpoint return types to Pydantic model fields.
|
||||||
@@ -62,16 +61,16 @@ class Coder:
|
|||||||
|
|
||||||
@overload
|
@overload
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode_as_type(cls, value: str, type_: _T) -> _T:
|
def decode_as_type(cls, value: bytes, *, type_: _T) -> _T:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode_as_type(cls, value: str, *, type_: None) -> Any:
|
def decode_as_type(cls, value: bytes, *, type_: None) -> Any:
|
||||||
...
|
...
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode_as_type(cls, value: str, *, type_: _T | None) -> _T | Any:
|
def decode_as_type(cls, value: bytes, *, type_: Optional[_T]) -> Union[_T, Any]:
|
||||||
"""Decode value to the specific given type
|
"""Decode value to the specific given type
|
||||||
|
|
||||||
The default implementation uses the Pydantic model system to convert the value.
|
The default implementation uses the Pydantic model system to convert the value.
|
||||||
@@ -95,29 +94,32 @@ class Coder:
|
|||||||
|
|
||||||
class JsonCoder(Coder):
|
class JsonCoder(Coder):
|
||||||
@classmethod
|
@classmethod
|
||||||
def encode(cls, value: Any) -> str:
|
def encode(cls, value: Any) -> bytes:
|
||||||
if isinstance(value, JSONResponse):
|
if isinstance(value, JSONResponse):
|
||||||
return value.body.decode()
|
return value.body
|
||||||
return json.dumps(value, cls=JsonEncoder)
|
return json.dumps(value, cls=JsonEncoder).encode()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode(cls, value: str) -> str:
|
def decode(cls, value: bytes) -> Any:
|
||||||
return json.loads(value, object_hook=object_hook)
|
# explicitly decode from UTF-8 bytes first, as otherwise
|
||||||
|
# json.loads() will first have to detect the correct UTF-
|
||||||
|
# encoding used.
|
||||||
|
return json.loads(value.decode(), object_hook=object_hook)
|
||||||
|
|
||||||
|
|
||||||
class PickleCoder(Coder):
|
class PickleCoder(Coder):
|
||||||
@classmethod
|
@classmethod
|
||||||
def encode(cls, value: Any) -> str:
|
def encode(cls, value: Any) -> bytes:
|
||||||
if isinstance(value, TemplateResponse):
|
if isinstance(value, TemplateResponse):
|
||||||
value = value.body
|
value = value.body
|
||||||
return codecs.encode(pickle.dumps(value), "base64").decode()
|
return pickle.dumps(value)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode(cls, value: str) -> Any:
|
def decode(cls, value: bytes) -> Any:
|
||||||
return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301
|
return pickle.loads(value) # nosec:B403,B301
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decode_as_type(cls, value: str, *, type_: Any) -> Any:
|
def decode_as_type(cls, value: bytes, *, type_: Optional[_T]) -> Any:
|
||||||
# Pickle already produces the correct type on decoding, no point
|
# Pickle already produces the correct type on decoding, no point
|
||||||
# in paying an extra performance penalty for pydantic to discover
|
# in paying an extra performance penalty for pydantic to discover
|
||||||
# the same.
|
# the same.
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ class PDItem(BaseModel):
|
|||||||
)
|
)
|
||||||
def test_pickle_coder(value: Any) -> None:
|
def test_pickle_coder(value: Any) -> None:
|
||||||
encoded_value = PickleCoder.encode(value)
|
encoded_value = PickleCoder.encode(value)
|
||||||
assert isinstance(encoded_value, str)
|
assert isinstance(encoded_value, bytes)
|
||||||
decoded_value = PickleCoder.decode(encoded_value)
|
decoded_value = PickleCoder.decode(encoded_value)
|
||||||
assert decoded_value == value
|
assert decoded_value == value
|
||||||
|
|
||||||
@@ -55,12 +55,12 @@ def test_pickle_coder(value: Any) -> None:
|
|||||||
)
|
)
|
||||||
def test_json_coder(value: Any, return_type) -> None:
|
def test_json_coder(value: Any, return_type) -> None:
|
||||||
encoded_value = JsonCoder.encode(value)
|
encoded_value = JsonCoder.encode(value)
|
||||||
assert isinstance(encoded_value, str)
|
assert isinstance(encoded_value, bytes)
|
||||||
decoded_value = JsonCoder.decode_as_type(encoded_value, type_=return_type)
|
decoded_value = JsonCoder.decode_as_type(encoded_value, type_=return_type)
|
||||||
assert decoded_value == value
|
assert decoded_value == value
|
||||||
|
|
||||||
|
|
||||||
def test_json_coder_validation_error() -> None:
|
def test_json_coder_validation_error() -> None:
|
||||||
invalid = '{"name": "incomplete"}'
|
invalid = b'{"name": "incomplete"}'
|
||||||
with pytest.raises(ValidationError):
|
with pytest.raises(ValidationError):
|
||||||
JsonCoder.decode_as_type(invalid, type_=PDItem)
|
JsonCoder.decode_as_type(invalid, type_=PDItem)
|
||||||
|
|||||||
Reference in New Issue
Block a user