Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError from nemo_run/core/runners/fdl_runner.py when using PreTrainingDataModule #104

Open
RachitBansal opened this issue Nov 9, 2024 · 1 comment

Comments

@RachitBansal
Copy link

I am trying to pre-train a Mixtral-style model using NeMo but facing issues when trying to use a custom pre-training dataset using the PreTrainingDataModule.

My training recipe looks like the following:

recipe = run.Partial(
        fn,
        model=model(seq_length=seq_length),
        trainer=trainer_cfg,
        data=run.Config(MockDataModule, seq_length=seq_length, global_batch_size=global_batch_size, micro_batch_size=1),
        log=default_log(
            dir=dir,
            name=name,
            wandb_logger=wandb_logger(project="moes_optimization", name=name),
        ),
        optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4),
        resume=default_resume(),
    )

And this works fine, however when I change data to:

data=run.Config(
      PreTrainingDataModule,
      paths=data_path,
      tokenizer=get_tokenizer(vocab_path, merges_path, tokenizer),
      seq_length=seq_length,
      global_batch_size=global_batch_size,
      micro_batch_size=1,
      rampup_batch_size=None,
      num_workers=8,
      split='90,5,5',
),

It gives the following error:

Traceback (most recent call last):
  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/opt/NeMo-Run/src/nemo_run/core/runners/fdl_runner.py", line 73, in <module>
    fdl_runner_app()
  File "/usr/local/lib/python3.10/dist-packages/typer/main.py", line 326, in __call__
    raise e
  File "/usr/local/lib/python3.10/dist-packages/typer/main.py", line 309, in __call__
    return get_command(self)(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 1157, in __call__
    return self.main(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/typer/core.py", line 661, in main
    return _main(
  File "/usr/local/lib/python3.10/dist-packages/typer/core.py", line 193, in _main
    rv = self.invoke(ctx)
  File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 1434, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/usr/local/lib/python3.10/dist-packages/click/core.py", line 783, in invoke
    return __callback(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/typer/main.py", line 692, in wrapper
    return callback(**use_params)
  File "/opt/NeMo-Run/src/nemo_run/core/runners/fdl_runner.py", line 62, in fdl_direct_run
    fdl_buildable: fdl.Buildable = ZlibJSONSerializer().deserialize(fdl_config)
  File "/opt/NeMo-Run/src/nemo_run/core/serialization/zlib_json.py", line 41, in deserialize
    return serialization.load_json(
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 850, in load_json
    return Deserialization(json.loads(serialized_value), pyref_policy).result
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 736, in __init__
    self._result = self._deserialize(self._root)
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 776, in _deserialize
    return self._deserialize_ref(serialized_object)
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 749, in _deserialize_ref
    deserialized_object = self._deserialize(self._serialized_objects[key])
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 797, in _deserialize
    values = [value for _, value in self._deserialize(serialized_items)]
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 766, in _deserialize
    return [self._deserialize(x) for x in serialized_object]
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 766, in <listcomp>
    return [self._deserialize(x) for x in serialized_object]
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 766, in _deserialize
    return [self._deserialize(x) for x in serialized_object]
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 766, in <listcomp>
    return [self._deserialize(x) for x in serialized_object]
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 776, in _deserialize
    return self._deserialize_ref(serialized_object)
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 749, in _deserialize_ref
    deserialized_object = self._deserialize(self._serialized_objects[key])
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 797, in _deserialize
    values = [value for _, value in self._deserialize(serialized_items)]
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 766, in _deserialize
    return [self._deserialize(x) for x in serialized_object]
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 766, in <listcomp>
    return [self._deserialize(x) for x in serialized_object]
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 766, in _deserialize
    return [self._deserialize(x) for x in serialized_object]
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 766, in <listcomp>
    return [self._deserialize(x) for x in serialized_object]
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 776, in _deserialize
    return self._deserialize_ref(serialized_object)
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 749, in _deserialize_ref
    deserialized_object = self._deserialize(self._serialized_objects[key])
  File "/usr/local/lib/python3.10/dist-packages/fiddle/_src/experimental/serialization.py", line 799, in _deserialize
    return traverser.unflatten(values, metadata)
  File "/opt/NeMo/nemo/lightning/io/mixin.py", line 564, in _io_unflatten_object
    assert hasattr(_thread_local, "output_dir")
AssertionError
@hemildesai
Copy link
Collaborator

Thanks for the issue, the error message should be more informative on the latest main of NeMo.

Also, the fix would be to wrap the tokenizer as follows:

run.Config(get_tokenizer, vocab_path=vocab_path, merges_path=merges_path, tokenizer=tokenizer)

Can you give it a try?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants