From 7c304029078a3ae8e222282d277ab0aceb459482 Mon Sep 17 00:00:00 2001 From: Martijn Pieters Date: Tue, 9 May 2023 12:31:19 +0100 Subject: [PATCH] Cache pydantic model fields for faster decoding In `timeit` tests, 10.000 calls to `ModelField()` could take up to half a second on my Macbook Pro M1, depending on the type annotation used. Given that the method is called for every cache hit, this can really add up. The number of different return types for endpoints is very much finite however, so caching is a definite win here. --- fastapi_cache/coder.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/fastapi_cache/coder.py b/fastapi_cache/coder.py index a7373da..6581967 100644 --- a/fastapi_cache/coder.py +++ b/fastapi_cache/coder.py @@ -3,7 +3,7 @@ import datetime import json import pickle # nosec:B403 from decimal import Decimal -from typing import Any, Callable, TypeVar, overload +from typing import Any, Callable, ClassVar, Dict, TypeVar, overload import pendulum from fastapi.encoders import jsonable_encoder @@ -53,6 +53,13 @@ class Coder: def decode(cls, value: str) -> Any: raise NotImplementedError + # (Shared) cache for endpoint return types to Pydantic model fields. + # Note that subclasses share this cache! If a subclass overrides the + # decode_as_type method and then stores a different kind of field for a + # given type, do make sure that the subclass provides its own class + # attribute for this cache. + _type_field_cache: ClassVar[Dict[Any, fields.ModelField]] = {} + @overload @classmethod def decode_as_type(cls, value: str, type_: _T) -> _T: @@ -72,9 +79,12 @@ class Coder: """ result = cls.decode(value) if type_ is not None: - field = fields.ModelField( - name="body", type_=type_, class_validators=None, model_config=BaseConfig - ) + try: + field = cls._type_field_cache[type_] + except KeyError: + field = cls._type_field_cache[type_] = fields.ModelField( + name="body", type_=type_, class_validators=None, model_config=BaseConfig + ) result, errors = field.validate(result, {}, loc=()) if errors is not None: if not isinstance(errors, list):