You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I would like to record some model activations in an architecture-invariant way.
In PyTorch, we can use forward hooks to do this, by registering a hook on modules that match some criteria (maybe all modules that are an MLP class, for example).
Is there an equivalent strategy in Equinox?
One idea is to create a class Wrapper(eqx.Module) that simply wraps a module and calls some callback in __call__ with the underlying module's activations, then somehow replace modules in an equinox module.
classWrapper(eqx.Module):
wrapped: eqx.Modulecallback: ...
def__init__(self, module, callback):
self.wrapped=moduleself.callback=callbackdef__calll__(self, *args, **kwargs):
outs=self.wrapped(*args, **kwargs)
self.callback(outs) # this would save to disk or something
Then in the main script, I could do something like:
I would like to record some model activations in an architecture-invariant way.
In PyTorch, we can use forward hooks to do this, by registering a hook on modules that match some criteria (maybe all modules that are an MLP class, for example).
Is there an equivalent strategy in Equinox?
One idea is to create a
class Wrapper(eqx.Module)
that simply wraps a module and calls some callback in__call__
with the underlying module's activations, then somehow replace modules in an equinox module.Then in the main script, I could do something like:
Is there a better/more obvious way to do this?
The text was updated successfully, but these errors were encountered: