From 5247c0d9ca7ebe933924c2f59126c300cce995b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maciej=20Wilczy=C5=84ski?= Date: Fri, 6 Oct 2023 16:26:59 +0200 Subject: [PATCH] feat: Cache JIT methods between executions This is quite a dangerous change, but it's pretty much the only thing we can do to speedup serialization of long lists of nested Schemas without modifying Marshmallow's core code. We don't want to do that. --- deepfriedmarshmallow/log.py | 31 +++++++++++++++++++----------- deepfriedmarshmallow/serializer.py | 24 ++++++++++++++++++----- 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/deepfriedmarshmallow/log.py b/deepfriedmarshmallow/log.py index 7e5de69..1dbd849 100644 --- a/deepfriedmarshmallow/log.py +++ b/deepfriedmarshmallow/log.py @@ -1,17 +1,26 @@ import logging import os +if os.getenv("DFM_LOG_LEVEL", None) is not None: + logger = logging.getLogger("DeepFriedMarshmallow") -logger = logging.getLogger("DeepFriedMarshmallow") -sh = logging.StreamHandler() -sh.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s] [Deep-Fried Marshmallow] %(message)s")) + if logger.level == logging.NOTSET: + try: + logger.setLevel(os.getenv("DFM_LOG_LEVEL", logging.WARN)) + except ValueError: + logger.setLevel(logging.WARN) -if logger.level == logging.NOTSET: - try: - logger.setLevel(os.getenv("DFM_LOG_LEVEL", logging.WARN)) - sh.setLevel(os.getenv("DFM_LOG_LEVEL", logging.WARN)) - except ValueError: - logger.setLevel(logging.WARN) - sh.setLevel(logging.WARN) + sh = logging.StreamHandler() + sh.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s] [Deep-Fried Marshmallow] %(message)s")) + sh.setLevel(logging.DEBUG) + logger.addHandler(sh) +else: -logger.addHandler(sh) + class DummyLogger: + def __getattr__(self, item): + def noop(*args, **kwargs): + pass + + return noop + + logger = DummyLogger() diff --git a/deepfriedmarshmallow/serializer.py b/deepfriedmarshmallow/serializer.py index 1c9d518..7d0d750 100644 --- a/deepfriedmarshmallow/serializer.py +++ b/deepfriedmarshmallow/serializer.py @@ -17,21 +17,35 @@ def __init__(self, schema, method): def __call__(self, obj, many=False, **kwargs): # noqa: FBT002 self._ensure_jit_method() - logger.debug(f"JIT method called with {obj=}") + logger.debug(f"JIT method called with {obj.__class__.__name__}") try: result = self._jit_method(obj, many=many) - logger.debug(f"JIT method succeeded for {obj=} with {result=}") + logger.debug(f"JIT method succeeded for {obj.__class__.__name__}") except Exception as e: logger.warning(f"JIT method failed, falling back to non-JIT method: {e}", exc_info=e) result = self._method(obj, many=many, **kwargs) - logger.debug(f"Fallback method succeeded for {obj=} with {result=}") + logger.debug(f"Fallback method succeeded for {obj.__class__.__name__}") return result def _ensure_jit_method(self): if self._jit_method is None: - logger.debug(f"Generating JIT method {self._method} for {self._schema}") - self._jit_method = self.generate_jit_method(self._schema, JitContext()) + if "_dfm_jit_cache" not in globals(): + globals()["_dfm_jit_cache"] = {} + + cache_key = ( + self._schema.__class__.__name__, + self._schema.many, + id(self._schema.__class__), + self._method.__name__, + ) + if cache_key not in globals()["_dfm_jit_cache"]: + logger.debug(f"Generating JIT method {cache_key=}") + globals()["_dfm_jit_cache"][cache_key] = self.generate_jit_method(self._schema, JitContext()) + else: + logger.debug(f"Using cached JIT method {cache_key}") + + self._jit_method = globals()["_dfm_jit_cache"][cache_key] def generate_jit_method(self, schema, context): raise NotImplementedError