-
Notifications
You must be signed in to change notification settings - Fork 658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Slow training step occasionally due to slow graph flatten #4336
Comments
Thanks for posting this @kriscao-cohere! We do have a global context that keeps track of graph references during Line 740 in e4dad9c
Hopefully the Line 1046 in e4dad9c
I'll look into it. But if this is blocking you consider using regular |
Just as a sanity check I ran this simple training code performing a training codeimport jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from flax import nnx
X = np.linspace(0, 1, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)
def dataset(batch_size):
while True:
idx = np.random.choice(len(X), size=batch_size)
yield X[idx], Y[idx]
class Linear(nnx.Module):
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
self.b = nnx.Param(jnp.zeros((dout,)))
def __call__(self, x):
return x @ self.w.value + self.b.value
class Count(nnx.Variable):
pass
class MLP(nnx.Module):
def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
self.count = Count(jnp.array(0))
self.linear1 = Linear(din, dhidden, rngs=rngs)
self.linear2 = Linear(dhidden, dout, rngs=rngs)
def __call__(self, x):
self.count.value += 1
x = self.linear1(x)
x = jax.nn.relu(x)
x = self.linear2(x)
return x
model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)
@nnx.jit
def train_step(model: MLP, optimizer: nnx.Optimizer, batch):
x, y = batch
def loss_fn(model: MLP):
y_pred = model(x)
return jnp.mean((y - y_pred) ** 2)
grads: nnx.State = nnx.grad(loss_fn)(model)
optimizer.update(grads)
@nnx.jit
def test_step(model: MLP, batch):
x, y = batch
y_pred = model(x)
loss = jnp.mean((y - y_pred) ** 2)
return {'loss': loss}
total_steps = 10_000
for step, batch in enumerate(dataset(32)):
train_step(model, optimizer, batch)
print(nnx.graph.GRAPH_CONTEXT)
if step % 1000 == 0:
logs = test_step(model, (X, Y))
print(f"step: {step}, loss: {logs['loss']}")
if step >= total_steps - 1:
break
print('times called:', model.count.value)
y_pred = model(X)
plt.scatter(X, Y, color='blue')
plt.plot(X, y_pred, color='black')
plt.show() |
Hi @cgarciae, thanks for the updates! And thanks for pointing me to the jax-only pattern for using NNX, I tried it and it eliminated the wait time (and also sped up my experiment loop a lot, but I am doing small model experiments). As for repro, I only saw it with certain model architectures (mainly Transformers past a certain layer depth), and I tried to repro the slow |
I'm using NNX for a toy transformer on Wikitext-103, and I'm observing that one in every ~100 steps there's a step that takes much much longer (on the order of 2 seconds vs 0.02 seconds). I'm managed to trakc down the culprit with a profile, and it seems that there's sone NNX internal machinery in
nnx.split
that's taking the bulk of the time:Is there anything NNX-related that could be causing this to take a long time?
The text was updated successfully, but these errors were encountered: