Skip to content

logging with loguru inside jitted function #25685

Answered by jakevdp
Qazalbash asked this question in Q&A
Discussion options

You must be logged in to vote

Hi - the issue here is that logging is a side-effect, and in general functions used with jax.jit must be pure.

If you want to log runtime values within your JIT-compiled function, you can do so via jax.debug.callback or jax.experimental.io_callback depending on the intent; see External Callbacks for a discussion of these.

Modifying your example, it might look like this:

import jax
from jax import numpy as jnp
from loguru import logger

def set_log_level(log_level: str):
    logger.remove()
    logger.add(jax.debug.print, level=log_level)
    logger.debug(f"Setting LogLevel to {log_level}")

def log_callback(a, b, r):
  logger.debug("Adding {a} and {b} = {r}", a=a, b=b, r=r)

@jax.jit
def add

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Qazalbash
Comment options

Answer selected by Qazalbash
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants