-
Hi, I am currently working with I would appreciate any help and assistance that can be provided. Thank you. import jax
from jax import numpy as jnp
from loguru import logger
def set_log_level(log_level: str):
logger.remove()
logger.add(jax.debug.print, level=log_level)
logger.debug(f"Setting LogLevel to {log_level}")
@jax.jit
def add(a: int, b: int) -> int:
r = a + b
logger.debug("Adding {a} and {b} = {r}", a=a, b=b, r=r)
return r
if __name__ == "__main__":
set_log_level("DEBUG")
add(jnp.ones((1, 2)), jnp.zeros((2, 1)))
add(jnp.ones(2), jnp.zeros(1))
add(jnp.ones(2), jnp.zeros(1)) Output: 2024-12-25 00:33:24.977 | DEBUG | __main__:set_log_level:9 - Setting LogLevel to DEBUG
2024-12-25 00:33:25.385 | DEBUG | __main__:add:15 - Adding Traced<ShapedArray(float32[1,2])>with<DynamicJaxprTrace(level=1/0)> and Traced<ShapedArray(float32[2,1])>with<DynamicJaxprTrace(level=1/0)> = Traced<ShapedArray(float32[2,2])>with<DynamicJaxprTrace(level=1/0)>
2024-12-25 00:33:25.485 | DEBUG | __main__:add:15 - Adding Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)> and Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)> = Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>
2024-12-25 00:33:25.485 | DEBUG | __main__:add:15 - Adding Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)> and Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)> = Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)> Expectation: 2024-12-25 00:33:44.540 | DEBUG | __main__:set_log_level:9 - Setting LogLevel to DEBUG
2024-12-25 00:33:44.636 | DEBUG | __main__:add:15 - Adding [[1. 1.]] and [[0.]
[0.]] = [[1. 1.]
[1. 1.]]
2024-12-25 00:33:44.682 | DEBUG | __main__:add:15 - Adding [1. 1.] and [0.] = [1. 1.]
2024-12-25 00:33:44.683 | DEBUG | __main__:add:15 - Adding [1. 1.] and [0.] = [1. 1.] |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi - the issue here is that logging is a side-effect, and in general functions used with If you want to log runtime values within your JIT-compiled function, you can do so via Modifying your example, it might look like this: import jax
from jax import numpy as jnp
from loguru import logger
def set_log_level(log_level: str):
logger.remove()
logger.add(jax.debug.print, level=log_level)
logger.debug(f"Setting LogLevel to {log_level}")
def log_callback(a, b, r):
logger.debug("Adding {a} and {b} = {r}", a=a, b=b, r=r)
@jax.jit
def add(a: int, b: int) -> int:
r = a + b
jax.debug.callback(log_callback, a, b, r)
return r
if __name__ == "__main__":
set_log_level("DEBUG")
add(jnp.ones((1, 2)), jnp.zeros((2, 1)))
add(jnp.ones(2), jnp.zeros(1))
add(jnp.ones(2), jnp.zeros(1)) Output:
|
Beta Was this translation helpful? Give feedback.
Hi - the issue here is that logging is a side-effect, and in general functions used with
jax.jit
must be pure.If you want to log runtime values within your JIT-compiled function, you can do so via
jax.debug.callback
orjax.experimental.io_callback
depending on the intent; see External Callbacks for a discussion of these.Modifying your example, it might look like this: