Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax NNX checkpointing is not that straightforward #1105

Open
prabhudavidsheryl opened this issue Aug 27, 2024 · 9 comments
Open

Flax NNX checkpointing is not that straightforward #1105

prabhudavidsheryl opened this issue Aug 27, 2024 · 9 comments
Assignees
Labels
checkpoint type:feature New feature or request

Comments

@prabhudavidsheryl
Copy link

I have been trying to use Orbax for checkpointing Flax NNX models and getting checkpointing to work for models with Dropout layers which also hold JAX RNG keys is not very straight forward. After various attempts this was the only way I could get it to work.

import flax.nnx as nnx
import jax.numpy as jnp
import orbax.checkpoint as ocp


class Test(nnx.Module):
    def __init__(self, dim, rngs: nnx.Rngs):
        self.layer1 = nnx.Linear(dim, dim, rngs=rngs)
        self.layer2 = nnx.Dropout(0.1, rngs=rngs)
        self.layer3 = nnx.Linear(dim, dim, rngs=rngs)

    def __call__(self, x):
        x = self.layer3(self.layer2(self.layer1(x)))
        return x


checkpoint_manager = ocp.CheckpointManager(
    ocp.test_utils.erase_and_create_empty(
        "/var/local/ML/TRAIN/STAGE/troubleshoot_checkpointing/"
    ),
    options=ocp.CheckpointManagerOptions(
        max_to_keep=2,
        keep_checkpoints_without_metrics=False,
        create=True,
    ),
    item_names=("state", "layer2_dropout_key"),
)

model = Test(10, nnx.Rngs(0))
model_state = nnx.state(model).flat_state()
layer2_dropout_key = model_state[("layer2", "rngs", "default", "key")].value

# The Dropout layers RNG key had to be replaced with a dummy to
# allow checkpoint saving
# Error seen
# TypeError: Cannot interpret 'key<fry>' as a data type
model_state[("layer2", "rngs", "default", "key")] = nnx.VariableState(
    type=nnx.Param, value=jnp.array([])
)

# The RNG key had to be saved with its special checkpointer
checkpoint_manager.save(
    0,
    args=ocp.args.Composite(
        state=ocp.args.StandardSave(nnx.State.from_flat_path(model_state)),
        layer2_dropout_key=ocp.args.JaxRandomKeySave(layer2_dropout_key),
    ),
)
checkpoint_manager.wait_until_finished()

abs_model = nnx.eval_shape(lambda: nnx.State.from_flat_path(model_state))

# Checkpoint restoration also does not work
# The two items have to be restored separately
restored = checkpoint_manager.restore(
    0,
    args=ocp.args.Composite(
        state=ocp.args.StandardRestore(abs_model),
    ),
)

restored_key = checkpoint_manager.restore(
    0,
    args=ocp.args.Composite(
        # state=ocp.args.StandardRestore(abs_model),
        layer2_dropout_key=ocp.args.JaxRandomKeyRestore(),
    ),
)

# Model restoration is equally not straightforward
restored_model_state = restored["state"].flat_state()
restored_model_state[("layer2", "rngs", "default", "key")] = nnx.VariableState(
    type=nnx.Param, value=restored_key["layer2_dropout_key"]
)

abs_graph_def, abs_state = nnx.split(Test(10, nnx.Rngs(0)))
restored_model = nnx.merge(
    abs_graph_def, nnx.State.from_flat_path(restored_model_state)
)

print(nnx.state(restored_model).flat_state().keys())

The comments describe the issues faced.

It would be good to address ease of use for cases where the model has Dropout layers.

@cnguyen10
Copy link

I am having a similar issue with Orbax-checkpoint in flax.linen. I followed the tutorial in flax.linen for Dropout. The TrainState in that case included another attribute key, which is the key for dropout. When saving the TrainState with orbax_checkpoint, I got an error of TypeError: Cannot interpret 'key<fry>' as a data type.

Hence, it is not about flax.nnx, but also flax.linen as well.

@JulienRaynal
Copy link

I encouter the same problem on a dropout layer in flax nnx nnx.Dropout(rate=0.3, rngs=rngs)
It would really be nice to fix it, to facilitate greatly model savings and allow more users to really on jax and flax

@benstonezhang
Copy link

benstonezhang commented Dec 6, 2024

I meet the same issue too. Based on your code, I come out a slightly different approach which include a ADAM optimizer.
I meet infinite loop call when continue training from restored model. After copy metadata, the issue gone.

import flax.nnx as nnx
import jax
import jax.numpy as jnp
import numpy as np
import optax
import orbax.checkpoint as ocp

ckpt_dir = '/tmp/checkpoints'

class Test(nnx.Module):
    def __init__(self, din: int, dmid: int, dout: int, *, dropout: float, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
        self.dropout = nnx.Dropout(rate=dropout, rngs=rngs)
        self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

    def __call__(self, x: jax.Array):
        return self.linear2(nnx.gelu(self.dropout(self.linear1(x))))

def loss_fn(model, x, y):
    y_pred = model(x)
    return ((y_pred - y) ** 2).mean(), y_pred

sample = np.random.random(10)
target = np.random.random(2)

model = Test(10, 10, 2, dropout=0.5, rngs=nnx.Rngs(jax.random.key(1)))
optimizer = nnx.Optimizer(model, optax.adam(learning_rate=0.001))

grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
_, grads = grad_fn(model, sample, target)
optimizer.update(grads)

with ocp.CheckpointManager(
        ckpt_dir,
        item_names=('state', 'opt_state', 'dropout_key',),
) as ckpt_mgr:
    print('save checkpoint')

    _, state = nnx.split(model)
    # The Dropout layers RNG key had to be replaced with a dummy to allow checkpoint saving
    dropout_key = state['dropout']['rngs']['default']['key']
    state['dropout']['rngs']['default']['key'] = nnx.VariableState(
        type=nnx.Param,
        value=jnp.zeros(1),
        **dropout_key.get_metadata()
    )
    _, opt_state = nnx.split(optimizer.opt_state)

    ckpt_mgr.save(
        1,
        args=ocp.args.Composite(
            state=ocp.args.StandardSave(state),
            opt_state=ocp.args.StandardSave(opt_state),
            dropout_key=ocp.args.JaxRandomKeySave(dropout_key.value),
        ),
    )
    ckpt_mgr.wait_until_finished()

    print('restore from checkpoint')

    model = Test(10, 10, 2, dropout=0.5, rngs=nnx.Rngs(0))
    graph_def, state = nnx.split(model)
    # replace Dropout layers RNG key by a dummy to allow checkpoint restoration
    state['dropout']['rngs']['default']['key'] = nnx.VariableState(
        type=nnx.Param,
        value=jnp.zeros(1),
        **state['dropout']['rngs']['default']['key'].get_metadata()
    )

    opt_def, opt_state = nnx.split(optimizer.opt_state)

    restored = ckpt_mgr.restore(
        1,
        args=ocp.args.Composite(
            state=ocp.args.StandardRestore(nnx.eval_shape(lambda: state)),
            opt_state=ocp.args.StandardRestore(nnx.eval_shape(lambda: opt_state)),
            dropout_key=ocp.args.JaxRandomKeyRestore(),
        ),
    )
    # restore Dropout layers RNG key
    restored['state']['dropout']['rngs']['default']['key'] = nnx.VariableState(
        type=nnx.rnglib.RngKey,
        value=restored['dropout_key'],
        **restored['state']['dropout']['rngs']['default']['key'].get_metadata()
    )

    model = nnx.merge(graph_def, restored['state'])
    optimizer.opt_state = nnx.merge(opt_def, restored['opt_state'])

@cgarciae
Copy link

cgarciae commented Dec 6, 2024

I think orbax should support key<fry> Arrays but in the mean time you can manually convert the data back and forth to regular uint32 arrays using jax.random.key_data and jax.random.wrap_key_data:

class Model(nnx.Module):
  def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
    self.linear = nnx.Linear(din, dmid, rngs=rngs)
    self.bn = nnx.BatchNorm(dmid, rngs=rngs)
    self.dropout = nnx.Dropout(0.2, rngs=rngs)
    self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)

  def __call__(self, x):
    x = nnx.relu(self.dropout(self.bn(self.linear(x))))
    return self.linear_out(x)


# saving
model = Model(2, 64, 3, rngs=nnx.Rngs(0))  # eager initialization

keys, state = nnx.state(model, nnx.RngKey, ...)
keys = jax.tree.map(jax.random.key_data, keys)
...  # save keys and state checkpoints

# loading
...  # load keys and state checkpoints
keys = jax.tree_map(jax.random.wrap_key_data, keys)
nnx.update(model, keys, state)

@benstonezhang
Copy link

Looks code became slightly simple. But need a new function to merge the rngs key back to state.

import flax.nnx as nnx
import jax
import numpy as np
import optax
import orbax.checkpoint as ocp

ckpt_dir = '/tmp/checkpoints'

def merge_state(dst: nnx.State, src: nnx.State):
    for k, v in src.items():
        if isinstance(v, nnx.State):
            merge_state(dst[k], v)
        else:
            dst[k] = v

class Test(nnx.Module):
    def __init__(self, din: int, dmid: int, dout: int, *, dropout: float, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(din, dmid, rngs=rngs)
        self.dropout = nnx.Dropout(rate=dropout, rngs=rngs)
        self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)

    def __call__(self, x: jax.Array):
        return self.linear2(nnx.gelu(self.dropout(self.linear1(x))))

def loss_fn(model, x, y):
    y_pred = model(x)
    return ((y_pred - y) ** 2).mean(), y_pred

sample = np.random.random(10)
target = np.random.random(2)

model = Test(10, 10, 2, dropout=0.5, rngs=nnx.Rngs(jax.random.key(1)))
optimizer = nnx.Optimizer(model, optax.adam(learning_rate=0.001))

grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
_, grads = grad_fn(model, sample, target)
optimizer.update(grads)

with ocp.CheckpointManager(
        ckpt_dir,
        item_names=('state', 'opt_state'),
) as ckpt_mgr:
    print('save checkpoint')

    _, state = nnx.split(model)
    # The RNG key had to be convert to int to allow checkpoint saving
    rngs_key = jax.tree.map(jax.random.key_data, state.filter(nnx.RngKey))
    merge_state(state, rngs_key)

    _, opt_state = nnx.split(optimizer.opt_state)

    ckpt_mgr.save(
        1,
        args=ocp.args.Composite(
            state=ocp.args.StandardSave(state),
            opt_state=ocp.args.StandardSave(opt_state),
        ),
    )
    ckpt_mgr.wait_until_finished()

    print('restore from checkpoint')

    model = Test(10, 10, 2, dropout=0.5, rngs=nnx.Rngs(0))
    graph_def, state = nnx.split(model)
    # replace RNG key by a dummy to allow checkpoint restoration
    rngs_key = jax.tree.map(jax.random.key_data, state.filter(nnx.RngKey))
    merge_state(state, rngs_key)

    opt_def, opt_state = nnx.split(optimizer.opt_state)

    restored = ckpt_mgr.restore(
        1,
        args=ocp.args.Composite(
            state=ocp.args.StandardRestore(nnx.eval_shape(lambda: state)),
            opt_state=ocp.args.StandardRestore(nnx.eval_shape(lambda: opt_state)),
        ),
    )
    # restore RNG key
    rngs_key = jax.tree.map(jax.random.wrap_key_data, restored['state'].filter(nnx.RngKey))
    merge_state(restored['state'], rngs_key)

    model = nnx.merge(graph_def, restored['state'])
    optimizer.opt_state = nnx.merge(opt_def, restored['opt_state'])

@ChromeHearts
Copy link
Collaborator

Orbax team is looking into it and planing to add support for RNG in early Jan. Thanks!

@kailukowiak
Copy link

Thanks, for working on this. Looking forward to being able to easily save models with dropouts.

@tobiaswuerth
Copy link

Waiting for better support as well :)

Currently when trying to save/restore a simple module like:

import orbax.checkpoint as ocp
checkpoint = ocp.PyTreeCheckpointer()
checkpoint.save("models/renderer-nnx", nnx.state(renderer))

I get:
TypeError: JAX array with PRNGKey dtype cannot be converted to a NumPy array. Use jax.random.key_data(arr) if you wish to extract the underlying integer array.

I think this is related to this issue? Happy to see an update soon, thanks!

yongquan-qu added a commit to yongquan-qu/SciML-NNX that referenced this issue Feb 1, 2025
@rajasekharporeddy rajasekharporeddy added the type:feature New feature or request label Feb 3, 2025
@SteamedGit
Copy link

I am facing a related issue. Following the NNX scan over layers example, I have a simple MLP module containing other modules. Unfortunately, after attempting the @cgarciae workaround for checkpointing the model, keys, state=nnx.state(model, nnx.RngKey, ...), the keys for the middle_block are missing.

class LinearDropLayer(nnx.Module):
    def __init__(self, input_dim: int, output_dim: int, p_drop: float, rngs: nnx.Rngs):
        self.linear = nnx.Linear(input_dim, output_dim, rngs=rngs)
        self.dropout = nnx.Dropout(p_drop, rngs=rngs)

    def __call__(self, x: ArrayLike) -> jax.Array:
        return self.dropout(nnx.relu(self.linear(x)))


class MCMLP(nnx.Module):
    def __init__(
        self,
        input_dim: int,
        path_length: int,
        num_layers: int,
        hidden_dim: int,
        p_drop: float,
        rngs: nnx.Rngs,
    ):
        assert num_layers >= 3, "MCMLP must have at least 3 layers"
        self.input_layer = LinearDropLayer(input_dim, hidden_dim, p_drop, rngs)
        self.output_layer = nnx.Linear(hidden_dim, path_length, rngs=rngs)

        @nnx.split_rngs(splits=num_layers - 2)
        @nnx.vmap(in_axes=(0,), out_axes=0)
        def create_middle_block(rngs):
            return LinearDropLayer(hidden_dim, hidden_dim, p_drop, rngs=rngs)

        self.middle_block = create_middle_block(rngs)
        self.num_layers = num_layers

    def __call__(self, x: ArrayLike) -> jax.Array:
        @nnx.split_rngs(splits=self.num_layers - 2)
        @nnx.scan(in_axes=(nnx.Carry, 0), out_axes=nnx.Carry)
        def mid_forward(x, model):
            x = model(x)
            return x

        x = self.input_layer(x)
        x = mid_forward(x, self.middle_block)

        return nnx.relu(self.output_layer(x))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
checkpoint type:feature New feature or request
Projects
None yet
Development

No branches or pull requests

13 participants