-
Hi, I want to create multiple instances of a nnx.module (each initialized with a different key). def make(rng):
m = my_module.init(rng, dummy_input)
return ...
rngs = jax.random.split(jax.random.PRNGKey(0), num=5)
models = jax.vmap(make)(rngs) How can I achieve the same with nnx? def make_model(rngs):
return nnx.Sequential(
nnx.Linear(..., rngs=rngs),
...
)
init_keys = jax.random.split(jax.random.PRNGKey(0), num=5)
rngs = nnx.Rngs(init_keys)
model = jax.vmap(task.make_model)(rngs) But I get the error |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 15 replies
-
EDIT: Updating to use the new APIs. Hey @JeyRunner! You can use @nnx.split_rngs(splits=5)
@nnx.vmap
def make_model(rngs):
return nnx.Linear(2, 3, rngs=rngs)
model = make_model(nnx.Rngs(0))
print(model) Output: Linear(
bias=Param(
value=Array(shape=(5, 3), dtype=float32)
),
bias_init=<function zeros at 0x11ee95f30>,
dot_general=<function dot_general at 0x11e933910>,
dtype=None,
in_features=2,
kernel=Param(
value=Array(shape=(5, 2, 3), dtype=float32)
),
kernel_init=<function variance_scaling.<locals>.init at 0x11fa8fe20>,
out_features=3,
param_dtype=<class 'jax.numpy.float32'>,
precision=None,
use_bias=True
) |
Beta Was this translation helpful? Give feedback.
-
Hi @cgarciae. In the example that you gave above, namely def make_model(rngs): rngs = nnx.Rngs(0) print(model) how would you do it if you wanted the dimensions of Linear in make_model to be input? |
Beta Was this translation helpful? Give feedback.
-
@cgarciae Im trying to understand how to use vmap such that I have multiple models for more than just one dimensions axis. Lets say that we have data that we have restructured to have the shape: where: where we want to parallelize over both the batch dimension (B) and also the head dimension (H). Is this a question of nesting vmaps such that: @nnx.split_rngs(splits=batch_size)
@nnx.vmap(in_axes=0)
def make_batch_vmap_model(rngs):
return nnx.Linear(din, dout, rngs=rngs)
@nnx.split_rngs(splits=d_head)
@nnx.vmap(in_axes=0)
def make_d_head_vmap_model(rngs):
return make_batch_vmap_model(rngs) perhaps this is completely incorrect and that I actually should be using Thanks in advance for the help |
Beta Was this translation helpful? Give feedback.
EDIT: Updating to use the new APIs.
Hey @JeyRunner! You can use
nnx.split_rngs
to automatically split theRngs
before going intonnx.vmap
.Output: