Files
fastapi-cache/fastapi_cache/coder.py

75 lines
2.0 KiB
Python
Raw Normal View History

2022-11-05 13:45:16 +04:00
import codecs
import datetime
2020-08-26 18:04:57 +08:00
import json
import pickle # nosec:B403
from decimal import Decimal
2022-11-05 13:45:16 +04:00
from typing import Any
2020-08-26 18:04:57 +08:00
2021-09-17 10:19:56 +08:00
import pendulum
from fastapi.encoders import jsonable_encoder
2022-11-04 17:31:37 +08:00
from starlette.responses import JSONResponse
from starlette.templating import _TemplateResponse as TemplateResponse
CONVERTERS = {
2021-10-09 16:51:05 +08:00
"date": lambda x: pendulum.parse(x, exact=True),
2021-09-17 10:19:56 +08:00
"datetime": lambda x: pendulum.parse(x, exact=True),
"decimal": Decimal,
}
class JsonEncoder(json.JSONEncoder):
2022-10-22 20:59:37 +04:00
def default(self, obj: Any) -> Any:
if isinstance(obj, datetime.datetime):
2021-10-09 16:51:05 +08:00
return {"val": str(obj), "_spec_type": "datetime"}
elif isinstance(obj, datetime.date):
2021-10-09 16:51:05 +08:00
return {"val": str(obj), "_spec_type": "date"}
elif isinstance(obj, Decimal):
return {"val": str(obj), "_spec_type": "decimal"}
else:
return jsonable_encoder(obj)
2022-10-22 20:59:37 +04:00
def object_hook(obj: Any) -> Any:
_spec_type = obj.get("_spec_type")
if not _spec_type:
return obj
if _spec_type in CONVERTERS:
2022-10-22 20:59:37 +04:00
return CONVERTERS[_spec_type](obj["val"]) # type: ignore
else:
raise TypeError("Unknown {}".format(_spec_type))
2020-08-26 18:04:57 +08:00
class Coder:
@classmethod
def encode(cls, value: Any) -> str:
2020-08-26 18:04:57 +08:00
raise NotImplementedError
@classmethod
2022-11-05 13:45:16 +04:00
def decode(cls, value: str) -> Any:
2020-08-26 18:04:57 +08:00
raise NotImplementedError
class JsonCoder(Coder):
@classmethod
2022-10-22 20:59:37 +04:00
def encode(cls, value: Any) -> str:
2022-11-04 17:31:37 +08:00
if isinstance(value, JSONResponse):
return value.body
return json.dumps(value, cls=JsonEncoder)
2020-08-26 18:04:57 +08:00
@classmethod
2022-11-05 13:45:16 +04:00
def decode(cls, value: str) -> str:
return json.loads(value, object_hook=object_hook)
2020-08-26 18:04:57 +08:00
class PickleCoder(Coder):
@classmethod
def encode(cls, value: Any) -> str:
if isinstance(value, TemplateResponse):
value = value.body
2022-11-05 13:45:16 +04:00
return codecs.encode(pickle.dumps(value), "base64").decode()
2020-08-26 18:04:57 +08:00
@classmethod
2022-11-05 13:45:16 +04:00
def decode(cls, value: str) -> Any:
return pickle.loads(codecs.decode(value.encode(), "base64")) # nosec:B403,B301