Skip to content

Commit

Permalink
feat: Cache JIT methods between executions
Browse files Browse the repository at this point in the history
This is quite a dangerous change, but it's pretty much the only thing we can do to speedup serialization of long lists of nested Schemas without modifying Marshmallow's core code. We don't want to do that.
  • Loading branch information
mLupine committed Oct 6, 2023
1 parent fb23b22 commit 5247c0d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
31 changes: 20 additions & 11 deletions deepfriedmarshmallow/log.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
import logging
import os

if os.getenv("DFM_LOG_LEVEL", None) is not None:
logger = logging.getLogger("DeepFriedMarshmallow")

logger = logging.getLogger("DeepFriedMarshmallow")
sh = logging.StreamHandler()
sh.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s] [Deep-Fried Marshmallow] %(message)s"))
if logger.level == logging.NOTSET:
try:
logger.setLevel(os.getenv("DFM_LOG_LEVEL", logging.WARN))
except ValueError:
logger.setLevel(logging.WARN)

if logger.level == logging.NOTSET:
try:
logger.setLevel(os.getenv("DFM_LOG_LEVEL", logging.WARN))
sh.setLevel(os.getenv("DFM_LOG_LEVEL", logging.WARN))
except ValueError:
logger.setLevel(logging.WARN)
sh.setLevel(logging.WARN)
sh = logging.StreamHandler()
sh.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s] [Deep-Fried Marshmallow] %(message)s"))
sh.setLevel(logging.DEBUG)
logger.addHandler(sh)
else:

logger.addHandler(sh)
class DummyLogger:
def __getattr__(self, item):
def noop(*args, **kwargs):
pass

return noop

logger = DummyLogger()
24 changes: 19 additions & 5 deletions deepfriedmarshmallow/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,35 @@ def __init__(self, schema, method):
def __call__(self, obj, many=False, **kwargs): # noqa: FBT002
self._ensure_jit_method()

logger.debug(f"JIT method called with {obj=}")
logger.debug(f"JIT method called with {obj.__class__.__name__}")
try:
result = self._jit_method(obj, many=many)
logger.debug(f"JIT method succeeded for {obj=} with {result=}")
logger.debug(f"JIT method succeeded for {obj.__class__.__name__}")
except Exception as e:
logger.warning(f"JIT method failed, falling back to non-JIT method: {e}", exc_info=e)
result = self._method(obj, many=many, **kwargs)
logger.debug(f"Fallback method succeeded for {obj=} with {result=}")
logger.debug(f"Fallback method succeeded for {obj.__class__.__name__}")

return result

def _ensure_jit_method(self):
if self._jit_method is None:
logger.debug(f"Generating JIT method {self._method} for {self._schema}")
self._jit_method = self.generate_jit_method(self._schema, JitContext())
if "_dfm_jit_cache" not in globals():
globals()["_dfm_jit_cache"] = {}

cache_key = (
self._schema.__class__.__name__,
self._schema.many,
id(self._schema.__class__),
self._method.__name__,
)
if cache_key not in globals()["_dfm_jit_cache"]:
logger.debug(f"Generating JIT method {cache_key=}")
globals()["_dfm_jit_cache"][cache_key] = self.generate_jit_method(self._schema, JitContext())
else:
logger.debug(f"Using cached JIT method {cache_key}")

self._jit_method = globals()["_dfm_jit_cache"][cache_key]

def generate_jit_method(self, schema, context):
raise NotImplementedError
Expand Down

0 comments on commit 5247c0d

Please sign in to comment.