Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow training step occasionally due to slow graph flatten #4336

Open
kriscao-cohere opened this issue Oct 25, 2024 · 3 comments
Open

Slow training step occasionally due to slow graph flatten #4336

kriscao-cohere opened this issue Oct 25, 2024 · 3 comments
Assignees
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)

Comments

@kriscao-cohere
Copy link

I'm using NNX for a toy transformer on Wikitext-103, and I'm observing that one in every ~100 steps there's a step that takes much much longer (on the order of 2 seconds vs 0.02 seconds). I'm managed to trakc down the culprit with a profile, and it seems that there's sone NNX internal machinery in nnx.split that's taking the bulk of the time:

Image

Is there anything NNX-related that could be causing this to take a long time?

@cgarciae
Copy link
Collaborator

Thanks for posting this @kriscao-cohere! We do have a global context that keeps track of graph references during jit and all other transforms. That might be the first place I would look.

GRAPH_CONTEXT = GraphContext()

Hopefully the update_context context manager is not messing anything:

def update_context(tag: str):

I'll look into it. But if this is blocking you consider using regular jax.jit as detailed in Performance Considerations for NNX.

@cgarciae
Copy link
Collaborator

Just as a sanity check I ran this simple training code performing a print(nnx.graph.GRAPH_CONTEXT) after each step but all the context stacks are empty. Maybe the python garbage collector is hitting a spike?

training code
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax

from flax import nnx

X = np.linspace(0, 1, 100)[:, None]
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)


def dataset(batch_size):
  while True:
    idx = np.random.choice(len(X), size=batch_size)
    yield X[idx], Y[idx]


class Linear(nnx.Module):
  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
    self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
    self.b = nnx.Param(jnp.zeros((dout,)))

  def __call__(self, x):
    return x @ self.w.value + self.b.value


class Count(nnx.Variable):
  pass


class MLP(nnx.Module):
  def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
    self.count = Count(jnp.array(0))
    self.linear1 = Linear(din, dhidden, rngs=rngs)
    self.linear2 = Linear(dhidden, dout, rngs=rngs)

  def __call__(self, x):
    self.count.value += 1
    x = self.linear1(x)
    x = jax.nn.relu(x)
    x = self.linear2(x)
    return x


model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0))
tx = optax.sgd(1e-3)
optimizer = nnx.Optimizer(model, tx)


@nnx.jit
def train_step(model: MLP, optimizer: nnx.Optimizer, batch):
  x, y = batch

  def loss_fn(model: MLP):
    y_pred = model(x)
    return jnp.mean((y - y_pred) ** 2)

  grads: nnx.State = nnx.grad(loss_fn)(model)
  optimizer.update(grads)


@nnx.jit
def test_step(model: MLP, batch):
  x, y = batch
  y_pred = model(x)
  loss = jnp.mean((y - y_pred) ** 2)
  return {'loss': loss}


total_steps = 10_000
for step, batch in enumerate(dataset(32)):
  train_step(model, optimizer, batch)
  print(nnx.graph.GRAPH_CONTEXT)

  if step % 1000 == 0:
    logs = test_step(model, (X, Y))
    print(f"step: {step}, loss: {logs['loss']}")

  if step >= total_steps - 1:
    break

print('times called:', model.count.value)

y_pred = model(X)

plt.scatter(X, Y, color='blue')
plt.plot(X, y_pred, color='black')
plt.show()

@kriscao-cohere
Copy link
Author

Hi @cgarciae, thanks for the updates! And thanks for pointing me to the jax-only pattern for using NNX, I tried it and it eliminated the wait time (and also sped up my experiment loop a lot, but I am doing small model experiments).

As for repro, I only saw it with certain model architectures (mainly Transformers past a certain layer depth), and I tried to repro the slow nnx.split only by running nnx.split by itself 100 times, but there were no unexpected slowdowns. It was only when I did 100 model forward passes (even on dummy data) that I would see the occasional slow step caused by nnx graph traversal.

@cgarciae cgarciae added the Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required) label Nov 1, 2024
@cgarciae cgarciae self-assigned this Nov 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Priority: P1 - soon Response within 5 business days. Resolution within 30 days. (Assignee required)
Projects
None yet
Development

No branches or pull requests

2 participants