Replies: 1 comment
-
Just for what it's worth, I think this would work too, would it be the recommended approach? from functools import partial
import jax
import jax.numpy as jnp
from penzai import pz
@pz.pytree_dataclass
class ThreaderLayer(pz.nn.Layer):
layer: pz.nn.Layer # the subject layer
side_layer: pz.nn.Layer # the layer updating the threaded values
def __call__(self, argument, **side_inputs):
arg, side_arg = argument
out = self.layer(arg, **side_inputs)
side_out = self.side_layer(argument, **side_inputs)
return out, side_out
@pz.pytree_dataclass
class Identity(pz.nn.Layer):
def __call__(self, argument, **side_Inputs):
return argument
@pz.pytree_dataclass
class AddPair(pz.nn.Layer):
def __call__(self, argument, /, **side_inputs):
arg, side_arg = argument
return arg + side_arg
@pz.variable_jit
def call_model(model, x):
return model(x)
if __name__ == "__main__":
model = pz.nn.Sequential([Identity() for _ in range(5)])
thread_model = (
pz.select(model)
.at_instances_of(Identity)
.apply(lambda i: ThreaderLayer(i, AddPair()))
)
x = jnp.ones((2,3))
z = jnp.zeros_like(x)
x_out, z_out = call_model(thread_model, (x, z))
print(x_out, z_out) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Conceptually,
**side_inputs
are like a global read-only dictionary of inputs, available to every layer's__call__
function. The global availability is great, but is there any way to combine this idea with a layer that could also modify these? I know this is not possible with the current design, but it would be really useful.Something like:
I guess there is some jax'y reason why this sort of design wouldn't jit-compile?
The use case for this is that I want to create a wrapper that can efficiently record and concatenate select activations from an underlying model. That is, this wrapper:
One approach could be to insert a "SaveActivation" layer just after each Transformer block which saves the output in a StateVariable, and then have a function which extracts all of these and concatenates them. But, to do this requires memory to store the StateVariables, and the same amount of memory to perform the concatenation. I was hoping to find a method that can:
Unfortunately, this seems to mean I would have to wrap every layer, converting it to one that would accept the tuple (argument, buffer, offsets) and either ignore buffer and offsets (if it's a non-recording layer).
Beta Was this translation helpful? Give feedback.
All reactions