-
Hello, I am considering NNX for an upcoming R&D project, for Linen I used CommonLoopUtilities for my looping and metrics, is it still the suggested approach for NNX? Best |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
Hi @jhn-nt, in NNX we have the CLU-inspired |
Beta Was this translation helpful? Give feedback.
-
I guess that you can create your custom metric if necessary via For example regarding regression tasks this is how I implemented the mean absolute error metric (I am still not sure that it is the correct way to do it but I share it as a solution to be tested/validated): class MeanAbsoluteError(nnx.metrics.Average):
def __init__(self, argname: str = 'values'):
super().__init__(argname=argname)
def update(self, *, predictions: jax.Array, targets: jax.Array, **_) -> None: # type: ignore[override]
if predictions.shape != targets.shape:
raise ValueError(
f'Expected predictions.shape==labels.shape, '
f'got {predictions.shape} and {targets.shape}'
)
super().update(values=jnp.abs(predictions - targets).mean())
metrics = nnx.MultiMetric(
mae=MeanAbsoluteError(),
loss=nnx.metrics.Average('loss')
) |
Beta Was this translation helpful? Give feedback.
Hi @jhn-nt, in NNX we have the CLU-inspired
nnx.metrics
module (see MNIST Tutorial), its a little more ergonomic than CLU in my experience but you can you either with no problem.