Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: get highest checkpoint instead of hard coded path #383

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

anhuong
Copy link
Collaborator

@anhuong anhuong commented Oct 31, 2024

Description of the change

Unit tests were failing on transformers v4.46 with errors:

  • Resume training tests failing with assert 9.666666666666666 == (3.6666666666666665 + 5) where the epochs being logged didn't match
  • Other tests failing because the checkpoint path that was being used was hardcoded to checkpoint-5 but the checkpoints were only saving 3 checkpoints, not matching the number of epochs.

Because of this I updated our trainingArguments parameters to use gradient_accumulation_steps=1 to create the number of checkpoints expected for the small dataset we are using and logging_strategy="epoch" so that the loss logs would print per each full epoch. When setting gradient_accumulation_steps>1 the loss values would be fractions of epochs. I suspect that because the logging was set to partial epochs that the checkpoints being saved may have also been off.

This is more consistent with save_strategy="epoch" that we have set and I removed logging_steps=1 since we are logging based on epochs.

The loss logging and then a print statement I added that for os.listdir(output_dir) showed:

{'loss': 20.4898, 'grad_norm': 23.858930587768555, 'learning_rate': 1e-05, 'epoch': 1.67}
{'loss': 20.9744, 'grad_norm': 23.15357780456543, 'learning_rate': 8.535533905932739e-06, 'epoch': 3.67}
{'train_runtime': 1.649, 'train_samples_per_second': 30.322, 'train_steps_per_second': 3.032, 'train_tokens_per_second': 1825.376, 'train_loss': 25.917576789855957, 'epoch': 3.67}
['checkpoint-1', 'training_logs.jsonl', 'checkpoint-0', 'checkpoint-2']

As you can see only 3 checkpoints are saved including one for checkpoint-0, likely because of the partial epochs being logged.

With these changes, the loss logging and os.listdir(output_dir) shows:

{'loss': 10.5682, 'grad_norm': 41.685245513916016, 'learning_rate': 9.504844339512096e-06, 'epoch': 1.0}
{'loss': 10.3005, 'grad_norm': 43.35952377319336, 'learning_rate': 7.169418695587791e-06, 'epoch': 2.0}
{'loss': 9.8799, 'grad_norm': 35.97984313964844, 'learning_rate': 3.887395330218429e-06, 'epoch': 3.0}
{'loss': 9.8383, 'grad_norm': 40.168182373046875, 'learning_rate': 1.0908425876598516e-06, 'epoch': 4.0}
{'loss': 9.8219, 'grad_norm': 35.47007751464844, 'learning_rate': 0.0, 'epoch': 5.0}
{'train_runtime': 3.0789, 'train_samples_per_second': 16.24, 'train_steps_per_second': 4.872, 'train_tokens_per_second': 977.63, 'train_loss': 10.081750996907552, 'epoch': 5.0}
['checkpoint-6', 'training_logs.jsonl', 'checkpoint-9', 'checkpoint-3', 'checkpoint-12', 'checkpoint-15']

I'm not sure how the number is determined for checkpoint-<number>, but it's better to not hardcode checkpoint-5 so instead I refactored our existing code to get the highest checkpoint and return it.

Related issue number

How to verify the PR

I verified that unit tests pass with transformers v4.45 and v4.46.

Was the PR tested

  • I have added >=1 unit test(s) for every new method I have added.
  • I have ensured all unit tests pass

Copy link

Thanks for making a pull request! 😃
One of the maintainers will review and advise on the next steps.

@github-actions github-actions bot added the test label Oct 31, 2024
@Abhishek-TAMU
Copy link
Collaborator

Abhishek-TAMU commented Oct 31, 2024

Thank you @anhuong for this PR to fix the unit tests. Definitely good to have get checkpoint functions refactored!

EDIT: gradient_accumulation_steps = 1 looks good to me unless gradient_accumulation_steps>1 is specifically tested.

@fabianlim
Copy link
Collaborator

@anhuong my understanding is that we if you set save_strategy="epoch", then every epoch you are gauranteed to call a save. this is due to this line here.

And as far as I know logging_strategy does not affect save_strategy. So I cant really undertstand why you have missing checkpoints

@Abhishek-TAMU
Copy link
Collaborator

Abhishek-TAMU commented Nov 1, 2024

@fabianlim was curious does number of checkpoints saved gets affected when save_strategy="epoch" and gradient_accumulation_steps > micro-batch per epoch (Sample size is 10 and micro-batches is 3). Does checkpoints still gets saved after each epoch performing Optimizer Step (weight update) ?

For example: (Testing log logging each step)

python tuning/sft_trainer.py  \
--model_name_or_path Maykeye/TinyLLama-v0  \
--training_data_path tests/data/twitter_complaints_small.jsonl  \
--output_dir outputs/full-tuning  \
--num_train_epochs 5  \
--per_device_train_batch_size 4  \
--gradient_accumulation_steps 4  \
--learning_rate 1e-5  \
--response_template "\n### Label:"  \
--dataset_text_field "output" \
--torch_dtype "float32" \
--logging_strategy "steps" \
--logging_steps 1 \
--save_strategy "epoch"

Sample dataset size is 10 and per_device_train_batch_size is 4 hence micro-batch per epoch is ~3. With gradient_accumulation_steps = 4, the logs looks like below and hence only 3 checkpoint is saved:

{"data": {"epoch": 1.0, "step": 1, "timestamp": "2024-10-31T21:44:08.164493", "value": 7.909}, "name": "training_loss"}
{"data": {"epoch": 1.33, "step": 2, "timestamp": "2024-10-31T21:44:08.529254", "value": 2.5304}, "name": "training_loss"}
{"data": {"epoch": 2.0, "step": 3, "timestamp": "2024-10-31T21:44:08.854672", "value": 5.1889}, "name": "training_loss"}
{"data": {"epoch": 2.67, "step": 4, "timestamp": "2024-10-31T21:44:09.377428", "value": 5.0882}, "name": "training_loss"}
{"data": {"epoch": 3.0, "step": 5, "timestamp": "2024-10-31T21:44:09.556685", "value": 2.3755}, "name": "training_loss"}

os.listdir(output): ['checkpoint-1', 'training_logs.jsonl', 'checkpoint-3', 'checkpoint-5']

Log with gradient_accumulation_steps = 1, and rest of parameters same. It saves 5 checkpoints:

{"data": {"epoch": 0.33, "step": 1, "timestamp": "2024-10-31T21:50:10.247606", "value": 10.5645}, "name": "training_loss"}
{"data": {"epoch": 0.67, "step": 2, "timestamp": "2024-10-31T21:50:10.418498", "value": 9.8469}, "name": "training_loss"}
{"data": {"epoch": 1.0, "step": 3, "timestamp": "2024-10-31T21:50:10.589923", "value": 10.9906}, "name": "training_loss"}
{"data": {"epoch": 1.33, "step": 4, "timestamp": "2024-10-31T21:50:10.945404", "value": 9.9601}, "name": "training_loss"}
{"data": {"epoch": 1.67, "step": 5, "timestamp": "2024-10-31T21:50:11.114054", "value": 9.721}, "name": "training_loss"}
{"data": {"epoch": 2.0, "step": 6, "timestamp": "2024-10-31T21:50:11.286315", "value": 10.6751}, "name": "training_loss"}
{"data": {"epoch": 2.33, "step": 7, "timestamp": "2024-10-31T21:50:11.661657", "value": 10.2685}, "name": "training_loss"}
{"data": {"epoch": 2.67, "step": 8, "timestamp": "2024-10-31T21:50:11.837999", "value": 9.573}, "name": "training_loss"}
{"data": {"epoch": 3.0, "step": 9, "timestamp": "2024-10-31T21:50:12.014647", "value": 9.2539}, "name": "training_loss"}
{"data": {"epoch": 3.33, "step": 10, "timestamp": "2024-10-31T21:50:12.397118", "value": 9.781}, "name": "training_loss"}
{"data": {"epoch": 3.67, "step": 11, "timestamp": "2024-10-31T21:50:12.445898", "value": 9.574}, "name": "training_loss"}
{"data": {"epoch": 4.0, "step": 12, "timestamp": "2024-10-31T21:50:12.625669", "value": 9.8564}, "name": "training_loss"}
{"data": {"epoch": 4.33, "step": 13, "timestamp": "2024-10-31T21:50:12.993576", "value": 9.9461}, "name": "training_loss"}
{"data": {"epoch": 4.67, "step": 14, "timestamp": "2024-10-31T21:50:13.177089", "value": 9.5162}, "name": "training_loss"}
{"data": {"epoch": 5.0, "step": 15, "timestamp": "2024-10-31T21:50:13.375828", "value": 9.9178}, "name": "training_loss"}

os.listdir(output): ['checkpoint-3', 'checkpoint-6', 'checkpoint-9', 'training_logs.jsonl', 'checkpoint-12', 'checkpoint-15']

@fabianlim
Copy link
Collaborator

fabianlim commented Nov 1, 2024

@Abhishek-TAMU @anhuong I cannot reproduce the problem. I run the exact same command that was provided above, and I have 5 checkpoints. In accordance to my reading of the [code] for transformers==4.52.2(https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/trainer.py#L2406)

  • 10 samples, GA=4, so steps_in_epoch=2
  • for each epoch, on the second step (i.e. step=1, will set is_last_step_and_steps_less_than_grad_acc=True. Thus will go into the if block and optimizer will step. In other words, GA will only be for 2 steps
  • after the two steps of each epoch, it will coem to _maybe_log_save_evaluate and in there the checkpoint will be saved per epoch.

Also I cannot repro your loss, for small models, I need to use a much higher learning rate like 1e-3 to go down to 2ish on 5 epochs

you can turn on TRANSFORMERS_VERBOSITY=info to see more verbose logs.

$ python tuning/sft_trainer.py      --model_name_or_path Maykeye/TinyLLama-v0      --training_data_path tests/data/twitter_complaints_small.jsonl      --output_dir outputs/full-tuning      --num_train_epochs 5      --per_device_train_batch_size 4      --gradient_accumulation_steps 4      --learning_rate 1e-3      --response_template "\n### Label:"      --dataset_text_field "output"     --torch_dtype "float32"     --use_flash_attn False     --logging_strategy "steps"     --logging_steps 1     --save_strategy "epoch"

....

***** Running training *****
  Num examples = 10
  Num Epochs = 5
  Instantaneous batch size per device = 4
  Training with DataParallel so batch size has been adjusted to: 8
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 4
  Total optimization steps = 5
  Number of trainable parameters = 4,621,504
  0%|                                                                                                                                                                                                   | 0/5 [00:00<?, ?it/s]/workspace/fms/run-benches/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '
{'loss': 5.3779, 'grad_norm': 16.19341468811035, 'learning_rate': 0.0008, 'epoch': 1.0}
 20%|█████████████████████████████████████▍                                                                                                                                                     | 1/5 [00:01<00:05,  1.35s/it]Saving model checkpoint to outputs/full-tuning/checkpoint-1
Configuration saved in outputs/full-tuning/checkpoint-1/config.json
Configuration saved in outputs/full-tuning/checkpoint-1/generation_config.json
Model weights saved in outputs/full-tuning/checkpoint-1/model.safetensors
tokenizer config file saved in outputs/full-tuning/checkpoint-1/tokenizer_config.json
Special tokens file saved in outputs/full-tuning/checkpoint-1/special_tokens_map.json
/workspace/fms/run-benches/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '
{'loss': 4.3423, 'grad_norm': 8.393450736999512, 'learning_rate': 0.0006, 'epoch': 2.0}
 40%|██████████████████████████████████████████████████████████████████████████▊                                                                                                                | 2/5 [00:01<00:02,  1.48it/s]Saving model checkpoint to outputs/full-tuning/checkpoint-2
Configuration saved in outputs/full-tuning/checkpoint-2/config.json
Configuration saved in outputs/full-tuning/checkpoint-2/generation_config.json
Model weights saved in outputs/full-tuning/checkpoint-2/model.safetensors
tokenizer config file saved in outputs/full-tuning/checkpoint-2/tokenizer_config.json
Special tokens file saved in outputs/full-tuning/checkpoint-2/special_tokens_map.json
/workspace/fms/run-benches/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '
{'loss': 3.2777, 'grad_norm': 9.944791793823242, 'learning_rate': 0.0004, 'epoch': 3.0}
 60%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                          | 3/5 [00:01<00:00,  2.16it/s]Saving model checkpoint to outputs/full-tuning/checkpoint-3
Configuration saved in outputs/full-tuning/checkpoint-3/config.json
Configuration saved in outputs/full-tuning/checkpoint-3/generation_config.json
Model weights saved in outputs/full-tuning/checkpoint-3/model.safetensors
tokenizer config file saved in outputs/full-tuning/checkpoint-3/tokenizer_config.json
Special tokens file saved in outputs/full-tuning/checkpoint-3/special_tokens_map.json
/workspace/fms/run-benches/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '
{'loss': 3.0204, 'grad_norm': 7.56746244430542, 'learning_rate': 0.0002, 'epoch': 4.0}
 80%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                     | 4/5 [00:01<00:00,  2.81it/s]Saving model checkpoint to outputs/full-tuning/checkpoint-4
Configuration saved in outputs/full-tuning/checkpoint-4/config.json
Configuration saved in outputs/full-tuning/checkpoint-4/generation_config.json
Model weights saved in outputs/full-tuning/checkpoint-4/model.safetensors
tokenizer config file saved in outputs/full-tuning/checkpoint-4/tokenizer_config.json
Special tokens file saved in outputs/full-tuning/checkpoint-4/special_tokens_map.json
/workspace/fms/run-benches/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
  warnings.warn('Was asked to gather along dimension 0, but all '
{'loss': 2.7926, 'grad_norm': 4.190311908721924, 'learning_rate': 0.0, 'epoch': 5.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:02<00:00,  3.36it/s]Saving model checkpoint to outputs/full-tuning/checkpoint-5
Configuration saved in outputs/full-tuning/checkpoint-5/config.json
Configuration saved in outputs/full-tuning/checkpoint-5/generation_config.json
Model weights saved in outputs/full-tuning/checkpoint-5/model.safetensors
tokenizer config file saved in outputs/full-tuning/checkpoint-5/tokenizer_config.json
Special tokens file saved in outputs/full-tuning/checkpoint-5/special_tokens_map.json
Saving model checkpoint to outputs/full-tuning/checkpoint-5
Configuration saved in outputs/full-tuning/checkpoint-5/config.json
Configuration saved in outputs/full-tuning/checkpoint-5/generation_config.json
Model weights saved in outputs/full-tuning/checkpoint-5/model.safetensors
tokenizer config file saved in outputs/full-tuning/checkpoint-5/tokenizer_config.json
Special tokens file saved in outputs/full-tuning/checkpoint-5/special_tokens_map.json


Training completed. Do not forget to share your model on huggingface.co/models =)


{'train_runtime': 2.3713, 'train_samples_per_second': 21.085, 'train_steps_per_second': 2.109, 'train_loss': 3.7621720314025877, 'epoch': 5.0}

@anhuong
Copy link
Collaborator Author

anhuong commented Nov 1, 2024

@fabianlim the error only occurs on transformers v4.46 so agreed not sure how our configuration would cause the number of checkpoints to save when transformers upgrades....That makes sense that logging wouldn't affect the change, I thought it would be easier for us to read and makes more sense to set logging_strategy and save_strategy to the same. But I agree I suspected it could have been something with the gradient_accumulation that caused the issue but I see Fabian you're saying that save_strategy=epoch should still save on each epoch....hmmm

@anhuong
Copy link
Collaborator Author

anhuong commented Nov 1, 2024

Our unit tests passed with transformesr v4.45 but only started failing with transformers v4.46 because it looks like something changed with how checkpoints are being saved now as described. You can see this recent run of unit tests that ran 8 hours ago: https://github.com/foundation-model-stack/fms-hf-tuning/actions/runs/11619798675/job/32360230590

I can recreate this unit test failure locally when running tox but in a cluster I am not able to recreate the error. Even when specifying no GPUs on a pod, I do see 5 checkpoints getting created instead of less that is seen when running unit tests.

This solution is better to not hard-code in checkpoint-5 but would still be good to check that 5 checkpoints exist, which is why only certain tests fail and not all of them, like the FT ones don't fail because they just check that a checkpoint exists, not looking for a specific one.

@anhuong
Copy link
Collaborator Author

anhuong commented Nov 1, 2024

I verified when running unit tests that it is all in the gradient_accumulation setting. When this is set to 1, the expected number of checkpoints is created. When GA>1, the number of checkpoints will be less than the number of epochs even though save_strategy="epoch" is set. But I could only reproduce this when running unit tests and indeed when running on the cluster I get the expected number of checkpoints which is very odd....

@fabianlim
Copy link
Collaborator

fabianlim commented Nov 1, 2024

@anhuong I cannot reproduce this on transformers=4.46.1 too in the cluster. Can you try to do source /path/to/tox/env/bin/activate and then pip freeze inside to check the versions

NOTE: when I run tox i see an upper bound for the transformers version, so a little confused why this can be due to the new version.

py: install_package_deps> python -I -m pip install 'accelerate!=0.34,<1.1,>=0.20.3' 'datasets<3.0,>=2.15.0' 'numpy<2.0,>=1.26.4' 'peft<0.14,>=0.8.0' 'protobuf<6.0.0,>=5.28.0' 'sentencepiece<0.3,>=0.1.99' 'simpleeval<1.0,>=0.9.13' 'tokenizers<1.0,>=0.13.3' 'torch<2.5,>=2.2.0' 'tqdm<5.0,>=4.66.2' 'transformers<4.50,>4.41' 'trl<1.0,>=0.9.3'

Update: my test passed in .tox

$ tox --workdir $HOME/fms -e py \
    -- tests/test_sft_trainer.py::test_resume_training_from_checkpoint
.pkg: _optional_hooks> python /workspace/.local/lib/python3.10/site-packages/pyproject_api/_backend.py True setuptools.build_meta
.pkg: get_requires_for_build_sdist> python /workspace/.local/lib/python3.10/site-packages/pyproject_api/_backend.py True setuptools.build_meta
.pkg: build_sdist> python /workspace/.local/lib/python3.10/site-packages/pyproject_api/_backend.py True setuptools.build_meta
py: install_package> python -I -m pip install --force-reinstall --no-deps /workspace/fms/.tmp/package/4/fms_hf_tuning-2.1.1rc1.tar.gz
py: commands[0]> pytest tests/test_sft_trainer.py::test_resume_training_from_checkpoint
==================================================================================================== test session starts =====================================================================================================
platform linux -- Python 3.10.12, pytest-8.3.3, pluggy-1.5.0
cachedir: /workspace/fms/py/.pytest_cache
rootdir: /data/flim/fms-hf-tuning
configfile: pytest.ini
collected 1 item

tests/test_sft_trainer.py .                                                                                                                                                                                            [100%]

====================================================================================================== warnings summary ======================================================================================================
tests/test_sft_trainer.py::test_resume_training_from_checkpoint
tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /workspace/fms/py/lib/python3.10/site-packages/transformers/training_args.py:2041: FutureWarning: `--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--hub_token` instead.
    warnings.warn(

tests/test_sft_trainer.py::test_resume_training_from_checkpoint
tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /workspace/fms/py/lib/python3.10/site-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': dataset_text_field, max_seq_length, dataset_kwargs. Will not be supported from version '1.0.0'.

  Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
    warnings.warn(message, FutureWarning)

tests/test_sft_trainer.py::test_resume_training_from_checkpoint
tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /workspace/fms/py/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:283: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.
    warnings.warn(

tests/test_sft_trainer.py::test_resume_training_from_checkpoint
tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /workspace/fms/py/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:321: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.
    warnings.warn(

tests/test_sft_trainer.py::test_resume_training_from_checkpoint
tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /workspace/fms/py/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:327: UserWarning: You passed a `dataset_kwargs` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.
    warnings.warn(

tests/test_sft_trainer.py::test_resume_training_from_checkpoint
tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /workspace/fms/py/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:396: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
    warnings.warn(

tests/test_sft_trainer.py::test_resume_training_from_checkpoint
tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /workspace/fms/py/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:401: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `SFTTrainer.__init__`. Use `processing_class` instead.
    super().__init__(

tests/test_sft_trainer.py: 10 warnings
  /workspace/fms/py/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.
    warnings.warn('Was asked to gather along dimension 0, but all '

tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /workspace/fms/py/lib/python3.10/site-packages/transformers/trainer.py:3347: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
    torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)

tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /workspace/fms/py/lib/python3.10/site-packages/transformers/trainer.py:3026: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
    checkpoint_rng_state = torch.load(rng_file)

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================================================================== 1 passed, 26 warnings in 10.10s ===============================================================================================
  py: OK (17.06=setup[5.74]+cmd[11.32] seconds)
  congratulations :) (17.13 seconds)

@Abhishek-TAMU
Copy link
Collaborator

Abhishek-TAMU commented Nov 1, 2024

Sharing my test log with library version of accelerate, transformers, trl in test environment :

tox -e py -- tests/test_sft_trainer.py::test_resume_training_from_checkpoint

LOG
.pkg: _optional_hooks> python /usr/local/lib/python3.11/site-packages/pyproject_api/_backend.py True setuptools.build_meta
.pkg: get_requires_for_build_sdist> python /usr/local/lib/python3.11/site-packages/pyproject_api/_backend.py True setuptools.build_meta
.pkg: build_sdist> python /usr/local/lib/python3.11/site-packages/pyproject_api/_backend.py True setuptools.build_meta
py: install_package> python -I -m pip install --force-reinstall --no-deps /data/abhishek/compile_test/fms-hf-tuning/.tox/.tmp/package/5/fms_hf_tuning-0.1.dev368+g1f109fb.d20241101.tar.gz
py: commands[0]> pip show transformers
Name: transformers
Version: 4.46.1
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: [email protected]
License: Apache 2.0 License
Location: /data/abhishek/compile_test/fms-hf-tuning/.tox/py/lib/python3.11/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: fms-hf-tuning, peft, trl
py: commands[1]> pip show trl
Name: trl
Version: 0.11.4
Summary: Train transformer language models with reinforcement learning.
Home-page: https://github.com/huggingface/trl
Author: Leandro von Werra
Author-email: [email protected]
License: Apache 2.0
Location: /data/abhishek/compile_test/fms-hf-tuning/.tox/py/lib/python3.11/site-packages
Requires: accelerate, datasets, numpy, torch, transformers, tyro
Required-by: fms-hf-tuning
py: commands[2]> pip show accelerate
Name: accelerate
Version: 1.0.1
Summary: Accelerate
Home-page: https://github.com/huggingface/accelerate
Author: The HuggingFace team
Author-email: [email protected]
License: Apache
Location: /data/abhishek/compile_test/fms-hf-tuning/.tox/py/lib/python3.11/site-packages
Requires: huggingface-hub, numpy, packaging, psutil, pyyaml, safetensors, torch
Required-by: fms-hf-tuning, peft, trl
py: commands[3]> pytest tests/test_sft_trainer.py::test_resume_training_from_checkpoint
============================================================================================================================= test session starts ==============================================================================================================================
platform linux -- Python 3.11.7, pytest-8.3.3, pluggy-1.5.0
cachedir: .tox/py/.pytest_cache
rootdir: /data/abhishek/compile_test/fms-hf-tuning
configfile: pytest.ini
collected 1 item                                                                                                                                                                                                                                                               

tests/test_sft_trainer.py F                                                                                                                                                                                                                                              [100%]

=================================================================================================================================== FAILURES ===================================================================================================================================
_____________________________________________________________________________________________________________________ test_resume_training_from_checkpoint _____________________________________________________________________________________________________________________

    def test_resume_training_from_checkpoint():
        """
        Test tuning resumes from the latest checkpoint, creating new checkpoints and the
        checkpoints created before resuming tuning is not affected.
        """
        with tempfile.TemporaryDirectory() as tempdir:
            train_args = copy.deepcopy(TRAIN_ARGS)
            train_args.output_dir = tempdir
            train_args.gradient_accumulation_steps=4
            sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
            _validate_training(tempdir)
>           _validate_num_checkpoints(tempdir, train_args.num_train_epochs)

tests/test_sft_trainer.py:94: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

dir_path = '/tmp/tmpyyxkle2s', expected_num = 5

    def _validate_num_checkpoints(dir_path, expected_num):
        checkpoints = [d for d in os.listdir(dir_path) if d.startswith("checkpoint")]
>       assert len(checkpoints) == expected_num
E       AssertionError: assert 3 == 5
E        +  where 3 = len(['checkpoint-0', 'checkpoint-1', 'checkpoint-2'])

tests/test_sft_trainer.py:824: AssertionError
----------------------------------------------------------------------------------------------------------------------------- Captured stdout call -----------------------------------------------------------------------------------------------------------------------------
{'loss': 20.4898, 'grad_norm': 67.86002349853516, 'learning_rate': 1e-05, 'epoch': 1.67}
{'loss': 20.9744, 'grad_norm': 64.71253967285156, 'learning_rate': 8.535533905932739e-06, 'epoch': 3.67}
{'train_runtime': 1.782, 'train_samples_per_second': 28.058, 'train_steps_per_second': 2.806, 'train_tokens_per_second': 1689.106, 'train_loss': 25.864973068237305, 'epoch': 3.67}
----------------------------------------------------------------------------------------------------------------------------- Captured stderr call -----------------------------------------------------------------------------------------------------------------------------
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.
WARNING:sft_trainer.py:max_seq_length 4096 exceeds tokenizer.model_max_length             2048, using tokenizer.model_max_length 2048
WARNING:sft_trainer.py:PAD token set to default, missing in tokenizer
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
 40%|██████████████████████████████████████████████████████████████████████████████████████████████▍                                                                                                                                             | 2/5 [00:01<00:02,  1.12it/s]
=============================================================================================================================== warnings summary ===============================================================================================================================
tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /data/abhishek/compile_test/fms-hf-tuning/.tox/py/lib/python3.11/site-packages/transformers/training_args.py:2041: FutureWarning: `--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `--hub_token` instead.
    warnings.warn(

tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /data/abhishek/compile_test/fms-hf-tuning/.tox/py/lib/python3.11/site-packages/huggingface_hub/utils/_deprecation.py:100: FutureWarning: Deprecated argument(s) used in '__init__': dataset_text_field, max_seq_length, dataset_kwargs. Will not be supported from version '1.0.0'.
  
  Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
    warnings.warn(message, FutureWarning)

tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /data/abhishek/compile_test/fms-hf-tuning/.tox/py/lib/python3.11/site-packages/trl/trainer/sft_trainer.py:283: UserWarning: You passed a `max_seq_length` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.
    warnings.warn(

tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /data/abhishek/compile_test/fms-hf-tuning/.tox/py/lib/python3.11/site-packages/trl/trainer/sft_trainer.py:321: UserWarning: You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.
    warnings.warn(

tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /data/abhishek/compile_test/fms-hf-tuning/.tox/py/lib/python3.11/site-packages/trl/trainer/sft_trainer.py:327: UserWarning: You passed a `dataset_kwargs` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`.
    warnings.warn(

tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /data/abhishek/compile_test/fms-hf-tuning/.tox/py/lib/python3.11/site-packages/trl/trainer/sft_trainer.py:396: UserWarning: You passed a tokenizer with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `tokenizer.padding_side = 'right'` to your code.
    warnings.warn(

tests/test_sft_trainer.py::test_resume_training_from_checkpoint
  /data/abhishek/compile_test/fms-hf-tuning/.tox/py/lib/python3.11/site-packages/trl/trainer/sft_trainer.py:401: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `SFTTrainer.__init__`. Use `processing_class` instead.
    super().__init__(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================================================================================================================== short test summary info ============================================================================================================================
FAILED tests/test_sft_trainer.py::test_resume_training_from_checkpoint - AssertionError: assert 3 == 5
======================================================================================================================== 1 failed, 7 warnings in 9.36s =========================================================================================================================
py: exit 1 (10.41 seconds) /data/abhishek/compile_test/fms-hf-tuning> pytest tests/test_sft_trainer.py::test_resume_training_from_checkpoint pid=2753
  py: FAIL code 1 (15.87=setup[4.39]+cmd[0.37,0.35,0.34,10.41] seconds)
  evaluation failed :( (15.93 seconds)

@Abhishek-TAMU
Copy link
Collaborator

Abhishek-TAMU commented Nov 1, 2024

@fabianlim

Sharing pip freeze o/p
accelerate==1.0.1
aiohappyeyeballs==2.4.3
aiohttp==3.10.8
aiosignal==1.3.1
attrs==24.2.0
bitsandbytes==0.44.1
cachetools==5.5.0
certifi==2024.8.30
chardet==5.2.0
charset-normalizer==3.3.2
colorama==0.4.6
datasets==2.21.0
dill==0.3.8
distlib==0.3.9
docstring_parser==0.16
einops==0.8.0
filelock==3.16.1
flash-attn==2.6.3
fms-acceleration==0.4.0
fms-acceleration-aadp==0.1.1
fms-acceleration-foak==0.3.0
fms-acceleration-peft==0.3.0
-e git+https://github.com/anhuong/fms-hf-tuning.git@1f109fb6f893ab20a381929bf0ab74d08fc36bd6#egg=fms_hf_tuning
frozenlist==1.4.1
fsspec==2024.6.1
huggingface-hub==0.25.1
idna==3.10
Jinja2==3.1.4
llvmlite==0.43.0
markdown-it-py==3.0.0
MarkupSafe==2.1.5
mdurl==0.1.2
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
networkx==3.3
numba==0.60.0
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.77
nvidia-nvtx-cu12==12.1.105
packaging==24.1
pandas==2.2.3
peft==0.12.0
platformdirs==4.3.6
pluggy==1.5.0
protobuf==5.28.2
psutil==6.0.0
pyarrow==17.0.0
Pygments==2.18.0
pyproject-api==1.8.0
pyproject_hooks==1.2.0
python-dateutil==2.9.0.post0
pytz==2024.2
PyYAML==6.0.2
regex==2024.9.11
requests==2.32.3
rich==13.9.1
safetensors==0.4.5
sentencepiece==0.2.0
shtab==1.7.1
simpleeval==0.9.13
six==1.16.0
sympy==1.13.3
threadpoolctl==3.5.0
tokenizers==0.20.1
torch==2.4.1
tox==4.23.2
tqdm==4.66.5
transformers==4.46.1
triton==3.0.0
trl==0.11.4
typing_extensions==4.12.2
tyro==0.8.11
tzdata==2024.2
urllib3==2.2.3
virtualenv==20.27.1
xxhash==3.5.0
yarl==1.13.1

@anhuong
Copy link
Collaborator Author

anhuong commented Nov 1, 2024

Running tox -e py with GA=4, logging_steps=1, logging_strategy not set, running test_sft_trainer.py only I get 12 failed unit tests where the number of checkpoints does not match the number of epochs as expected:

======================================== short test summary info ========================================
FAILED tests/test_sft_trainer.py::test_resume_training_from_checkpoint - AssertionError: assert 3 == 5
FAILED tests/test_sft_trainer.py::test_resume_training_from_checkpoint_with_flag_true - assert 9.666666666666666 == (3.6666666666666665 + 5)
FAILED tests/test_sft_trainer.py::test_resume_training_from_checkpoint_with_flag_false - assert 0 == 1
FAILED tests/test_sft_trainer.py::test_resume_training_from_checkpoint_with_flag_checkpoint_path_lora - assert 29959084032.0 == 22686858240.0
FAILED tests/test_sft_trainer.py::test_run_causallm_pt_and_inference - AssertionError: assert 3 == 5
FAILED tests/test_sft_trainer.py::test_run_causallm_lora_and_inference[default] - AssertionError: assert 3 == 5
FAILED tests/test_sft_trainer.py::test_run_causallm_lora_and_inference[custom_target_modules] - AssertionError: assert 3 == 5
FAILED tests/test_sft_trainer.py::test_run_causallm_lora_and_inference[all_linear_target_modules] - AssertionError: assert 3 == 5
FAILED tests/test_sft_trainer.py::test_run_causallm_ft_and_inference[/Users/anhuong/github.com/anhuong/fms-hf-tuning/tests/data/twitter_complaints_small.jsonl] - AssertionError: assert 3 == 5
FAILED tests/test_sft_trainer.py::test_run_causallm_ft_and_inference[/Users/anhuong/github.com/anhuong/fms-hf-tuning/tests/data/twitter_complaints_small.json] - AssertionError: assert 3 == 5
FAILED tests/test_sft_trainer.py::test_run_causallm_ft_pretokenized[/Users/anhuong/github.com/anhuong/fms-hf-tuning/tests/data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl] - AssertionError: assert 4 == 6
FAILED tests/test_sft_trainer.py::test_run_causallm_ft_pretokenized[/Users/anhuong/github.com/anhuong/fms-hf-tuning/tests/data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json] - AssertionError: assert 4 == 6
================== 12 failed, 37 passed, 2 skipped, 336 warnings in 121.46s (0:02:01) ===================

You can see with print statements the checkpoints:

{'loss': 17.8934, 'grad_norm': 22.092439651489258, 'learning_rate': 1e-05, 'epoch': 1.67}
{'loss': 17.9788, 'grad_norm': 22.23282241821289, 'learning_rate': 8.535533905932739e-06, 'epoch': 3.67}
{'train_runtime': 1.7212, 'train_samples_per_second': 29.049, 'train_steps_per_second': 2.905, 'train_tokens_per_second': 1661.61, 'train_loss': 22.376200675964355, 'epoch': 3.67}
['checkpoint-1', 'training_logs.jsonl', 'checkpoint-0', 'checkpoint-2']
The pip freeze for this run:
accelerate==1.0.1
aiohappyeyeballs==2.4.3
aiohttp==3.10.10
aiosignal==1.3.1
attrs==24.2.0
certifi==2024.8.30
charset-normalizer==3.4.0
datasets==2.21.0
dill==0.3.8
docstring_parser==0.16
filelock==3.16.1
fms-hf-tuning @ file:///Users/anhuong/github.com/anhuong/fms-hf-tuning/.tox/.tmp/package/426/fms-hf-tuning-2.1.0rc2.dev5%2Bg1f109fb.d20241101.tar.gz#sha256=38060f61732e4963b5cd6ab3d5d8709c1c916266b548397cefd77b8f55c62961
frozenlist==1.5.0
fsspec==2024.6.1
huggingface-hub==0.26.2
idna==3.10
iniconfig==2.0.0
Jinja2==3.1.4
markdown-it-py==3.0.0
MarkupSafe==3.0.2
mdurl==0.1.2
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
networkx==3.4.2
numpy==1.26.4
packaging==24.1
pandas==2.2.3
peft==0.13.2
pluggy==1.5.0
propcache==0.2.0
protobuf==5.28.3
psutil==6.1.0
pyarrow==18.0.0
Pygments==2.18.0
pytest==8.3.3
python-dateutil==2.9.0.post0
pytz==2024.2
PyYAML==6.0.2
regex==2024.9.11
requests==2.32.3
rich==13.9.3
safetensors==0.4.5
sentencepiece==0.2.0
shtab==1.7.1
simpleeval==0.9.13
six==1.16.0
sympy==1.13.3
tokenizers==0.20.1
torch==2.2.2
tqdm==4.66.6
transformers==4.46.1
trl==0.11.4
typing_extensions==4.12.2
tyro==0.8.14
tzdata==2024.2
urllib3==2.2.3
xxhash==3.5.0
yarl==1.17.1

Compared to a run with "transformers>4.41,<4.46" so the only diff between the two pip freezes is transformers, the only difference is in the checkpoint numbers but the number of checkpoints still fails for 9 tests.

======================================== short test summary info ========================================
FAILED tests/test_sft_trainer.py::test_resume_training_from_checkpoint - AssertionError: assert 3 == 5
FAILED tests/test_sft_trainer.py::test_run_causallm_pt_and_inference - AssertionError: assert 3 == 5
FAILED tests/test_sft_trainer.py::test_run_causallm_lora_and_inference[default] - AssertionError: assert 3 == 5
FAILED tests/test_sft_trainer.py::test_run_causallm_lora_and_inference[custom_target_modules] - AssertionError: assert 3 == 5
FAILED tests/test_sft_trainer.py::test_run_causallm_lora_and_inference[all_linear_target_modules] - AssertionError: assert 3 == 5
FAILED tests/test_sft_trainer.py::test_run_causallm_ft_and_inference[/Users/anhuong/github.com/anhuong/fms-hf-tuning/tests/data/twitter_complaints_small.jsonl] - AssertionError: assert 3 == 5
FAILED tests/test_sft_trainer.py::test_run_causallm_ft_and_inference[/Users/anhuong/github.com/anhuong/fms-hf-tuning/tests/data/twitter_complaints_small.json] - AssertionError: assert 3 == 5
FAILED tests/test_sft_trainer.py::test_run_causallm_ft_pretokenized[/Users/anhuong/github.com/anhuong/fms-hf-tuning/tests/data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.jsonl] - AssertionError: assert 4 == 6
FAILED tests/test_sft_trainer.py::test_run_causallm_ft_pretokenized[/Users/anhuong/github.com/anhuong/fms-hf-tuning/tests/data/twitter_complaints_tokenized_with_maykeye_tinyllama_v0.json] - AssertionError: assert 4 == 6
=================== 9 failed, 40 passed, 2 skipped, 282 warnings in 118.76s (0:01:58) ===================

Print checkpoints:

{'loss': 6.8061, 'grad_norm': 8.93274211883545, 'learning_rate': 1e-05, 'epoch': 1.0}
{'loss': 3.3769, 'grad_norm': 5.7982940673828125, 'learning_rate': 5e-06, 'epoch': 2.0}
{'loss': 3.3822, 'grad_norm': 4.714280605316162, 'learning_rate': 0.0, 'epoch': 3.0}
{'train_runtime': 1.8403, 'train_samples_per_second': 27.169, 'train_steps_per_second': 2.717, 'train_tokens_per_second': 1554.079, 'train_loss': 4.064878463745117, 'epoch': 3.0}
['checkpoint-1', 'training_logs.jsonl', 'checkpoint-5', 'checkpoint-3']

And so previously it only succeeded because checkpoint-5 exists whereas now it is checkpoint-3 but the number of checkpoints remains the same.

Before when I ran in the cluster, I was using a larger dataset with sample size 50 and did not see this behavior occur, 5 checkpoints were always saved.

My cluster runs:

$ python -m tuning.sft_trainer  --model_name_or_path Maykeye/TinyLLama-v0  --training_data_path /app/twitter_complaints_small.json  --output_dir /tmp/test-transformers-446-lora  --num_train_epochs 5  --per_device_train_batch_size 4  --gradient_accumulation_steps 4  --learning_rate 1e-5  --response_template "\n### Label:"  --dataset_text_field "output" --torch_dtype "float32" --logging_steps 1 --save_strategy "epoch" --use_flash_attn false --per_device_eval_batch_size 4 --weight_decay 0 --warmup_ratio 0.03 --lr_scheduler_type "cosine" --include_tokens_per_second true --packing false --peft_method lora

# GA=1, dataset_sample_size=10
/tmp/test-transformers-446-lora-ga-1:
added_tokens_info.json	checkpoint-15  checkpoint-6  training_logs.jsonl
checkpoint-12		checkpoint-3   checkpoint-9

# GA=2, dataset_sample_size=10
/tmp/test-transformers-446-lora-ga-2:
added_tokens_info.json	checkpoint-1  checkpoint-3  checkpoint-4  checkpoint-5	training_logs.jsonl

# GA=4, dataset_sample_size=10
/tmp/test-transformers-446-lora-no-logging:
added_tokens_info.json	checkpoint-0  checkpoint-1  checkpoint-2  training_logs.jsonl

# GA=4, dataset_sample_size=50
/tmp/test-transformers-446-lora-no-logging-50-samples:
added_tokens_info.json	checkpoint-15  checkpoint-6  training_logs.jsonl
checkpoint-13		checkpoint-3   checkpoint-9

In addition we see the training end early when GA=4 and dataset samples size=10

Logs
Currently training with a batch size of: 4
***** Running training *****
  Num examples = 10
  Num Epochs = 5
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 4
  Total optimization steps = 5
  Number of trainable parameters = 16,384
  0%|                                                                              | 0/5 [00:00<?, ?it/s]Saving model checkpoint to /tmp/test-transformers-446-lora-no-logging/checkpoint-0
DEBUG:connectionpool.py:[https://huggingface.co:443](https://huggingface.co/) "HEAD /Maykeye/TinyLLama-v0/resolve/main/config.json HTTP/11" 200 0
DEBUG:connectionpool.py:[https://huggingface.co:443](https://huggingface.co/) "HEAD /Maykeye/TinyLLama-v0/resolve/main/config.json HTTP/11" 200 0
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--Maykeye--TinyLLama-v0/snapshots/8c7ff07ec91bbe08ba42634a8611deb028a77896/config.json
Model config LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 4,
  "hidden_act": "silu",
  "hidden_size": 64,
  "initializer_range": 0.02,
  "intermediate_size": 256,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 16,
  "num_hidden_layers": 8,
  "num_key_value_heads": 16,
  "pad_token_id": 0,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.1",
  "use_cache": true,
  "vocab_size": 32000
}

/home/tuning/.local/lib/python3.11/site-packages/peft/utils/save_and_load.py:257: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.
  warnings.warn(
tokenizer config file saved in /tmp/test-transformers-446-lora-no-logging/checkpoint-0/tokenizer_config.json
Special tokens file saved in /tmp/test-transformers-446-lora-no-logging/checkpoint-0/special_tokens_map.json
{'loss': 20.4898, 'grad_norm': 20.289756774902344, 'learning_rate': 1e-05, 'epoch': 1.67}                
 20%|██████████████                                                        | 1/5 [00:10<00:41, 10.41s/it]Saving model checkpoint to /tmp/test-transformers-446-lora-no-logging/checkpoint-1
DEBUG:connectionpool.py:[https://huggingface.co:443](https://huggingface.co/) "HEAD /Maykeye/TinyLLama-v0/resolve/main/config.json HTTP/11" 200 0
DEBUG:connectionpool.py:[https://huggingface.co:443](https://huggingface.co/) "HEAD /Maykeye/TinyLLama-v0/resolve/main/config.json HTTP/11" 200 0
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--Maykeye--TinyLLama-v0/snapshots/8c7ff07ec91bbe08ba42634a8611deb028a77896/config.json
Model config LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 4,
  "hidden_act": "silu",
  "hidden_size": 64,
  "initializer_range": 0.02,
  "intermediate_size": 256,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 16,
  "num_hidden_layers": 8,
  "num_key_value_heads": 16,
  "pad_token_id": 0,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.1",
  "use_cache": true,
  "vocab_size": 32000
}

/home/tuning/.local/lib/python3.11/site-packages/peft/utils/save_and_load.py:257: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.
  warnings.warn(
tokenizer config file saved in /tmp/test-transformers-446-lora-no-logging/checkpoint-1/tokenizer_config.json
Special tokens file saved in /tmp/test-transformers-446-lora-no-logging/checkpoint-1/special_tokens_map.json
Saving model checkpoint to /tmp/test-transformers-446-lora-no-logging/checkpoint-1
DEBUG:connectionpool.py:[https://huggingface.co:443](https://huggingface.co/) "HEAD /Maykeye/TinyLLama-v0/resolve/main/config.json HTTP/11" 200 0
DEBUG:connectionpool.py:[https://huggingface.co:443](https://huggingface.co/) "HEAD /Maykeye/TinyLLama-v0/resolve/main/config.json HTTP/11" 200 0
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--Maykeye--TinyLLama-v0/snapshots/8c7ff07ec91bbe08ba42634a8611deb028a77896/config.json
Model config LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 4,
  "hidden_act": "silu",
  "hidden_size": 64,
  "initializer_range": 0.02,
  "intermediate_size": 256,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 16,
  "num_hidden_layers": 8,
  "num_key_value_heads": 16,
  "pad_token_id": 0,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.1",
  "use_cache": true,
  "vocab_size": 32000
}

/home/tuning/.local/lib/python3.11/site-packages/peft/utils/save_and_load.py:257: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.
  warnings.warn(
tokenizer config file saved in /tmp/test-transformers-446-lora-no-logging/checkpoint-1/tokenizer_config.json
Special tokens file saved in /tmp/test-transformers-446-lora-no-logging/checkpoint-1/special_tokens_map.json
{'loss': 20.9744, 'grad_norm': 19.00687599182129, 'learning_rate': 8.535533905932739e-06, 'epoch': 3.67} 
 40%|████████████████████████████                                          | 2/5 [00:10<00:13,  4.50s/it]Saving model checkpoint to /tmp/test-transformers-446-lora-no-logging/checkpoint-2
DEBUG:connectionpool.py:[https://huggingface.co:443](https://huggingface.co/) "HEAD /Maykeye/TinyLLama-v0/resolve/main/config.json HTTP/11" 200 0
DEBUG:connectionpool.py:[https://huggingface.co:443](https://huggingface.co/) "HEAD /Maykeye/TinyLLama-v0/resolve/main/config.json HTTP/11" 200 0
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--Maykeye--TinyLLama-v0/snapshots/8c7ff07ec91bbe08ba42634a8611deb028a77896/config.json
Model config LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 4,
  "hidden_act": "silu",
  "hidden_size": 64,
  "initializer_range": 0.02,
  "intermediate_size": 256,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 16,
  "num_hidden_layers": 8,
  "num_key_value_heads": 16,
  "pad_token_id": 0,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.1",
  "use_cache": true,
  "vocab_size": 32000
}

/home/tuning/.local/lib/python3.11/site-packages/peft/utils/save_and_load.py:257: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.
  warnings.warn(
tokenizer config file saved in /tmp/test-transformers-446-lora-no-logging/checkpoint-2/tokenizer_config.json
Special tokens file saved in /tmp/test-transformers-446-lora-no-logging/checkpoint-2/special_tokens_map.json
Saving model checkpoint to /tmp/test-transformers-446-lora-no-logging/checkpoint-2
DEBUG:connectionpool.py:[https://huggingface.co:443](https://huggingface.co/) "HEAD /Maykeye/TinyLLama-v0/resolve/main/config.json HTTP/11" 200 0
DEBUG:connectionpool.py:[https://huggingface.co:443](https://huggingface.co/) "HEAD /Maykeye/TinyLLama-v0/resolve/main/config.json HTTP/11" 200 0
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--Maykeye--TinyLLama-v0/snapshots/8c7ff07ec91bbe08ba42634a8611deb028a77896/config.json
Model config LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 4,
  "hidden_act": "silu",
  "hidden_size": 64,
  "initializer_range": 0.02,
  "intermediate_size": 256,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 16,
  "num_hidden_layers": 8,
  "num_key_value_heads": 16,
  "pad_token_id": 0,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.1",
  "use_cache": true,
  "vocab_size": 32000
}

/home/tuning/.local/lib/python3.11/site-packages/peft/utils/save_and_load.py:257: UserWarning: Setting `save_embedding_layers` to `True` as the embedding layer has been resized during finetuning.
  warnings.warn(
tokenizer config file saved in /tmp/test-transformers-446-lora-no-logging/checkpoint-2/tokenizer_config.json
Special tokens file saved in /tmp/test-transformers-446-lora-no-logging/checkpoint-2/special_tokens_map.json


Training completed. Do not forget to share your model on huggingface.co/models =)


{'train_runtime': 11.1269, 'train_samples_per_second': 4.494, 'train_steps_per_second': 0.449, 'train_tokens_per_second': 270.516, 'train_loss': 25.917766571044922, 'epoch': 3.67}
 40%|████████████████████████████                                          | 2/5 [00:11<00:16,  5.56s/it]
Pip freeze of 2.1.1-rc1 image running in cluster
accelerate==1.0.1
aiohappyeyeballs==2.4.3
aiohttp==3.10.10
aiosignal==1.3.1
attrs==24.2.0
bitsandbytes==0.43.3
certifi==2024.8.30
charset-normalizer==3.4.0
datasets==2.21.0
dill==0.3.8
docstring_parser==0.16
einops==0.8.0
filelock==3.16.1
flash-attn==2.6.3
fms-acceleration==0.4.0
fms-acceleration-aadp==0.1.1
fms-acceleration-foak==0.3.3
fms-acceleration-peft==0.3.4
fms-hf-tuning @ file:///tmp/fms_hf_tuning-0.1.dev86%2Bg22d2323.d20241031-py3-none-any.whl#sha256=52452b594a5392d16d1f5feef1c4d4eb73cbd9fb375b193451e23123f5ace627
frozenlist==1.5.0
fsspec==2024.6.1
huggingface-hub==0.26.2
idna==3.10
Jinja2==3.1.4
llvmlite==0.43.0
markdown-it-py==3.0.0
MarkupSafe==3.0.2
mdurl==0.1.2
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
networkx==3.4.2
numba==0.60.0
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.6.77
nvidia-nvtx-cu12==12.1.105
packaging==24.1
pandas==2.2.3
peft==0.13.2
propcache==0.2.0
protobuf==5.28.3
psutil==6.1.0
pyarrow==18.0.0
Pygments==2.18.0
pyproject_hooks==1.2.0
python-dateutil==2.9.0.post0
pytz==2024.2
PyYAML==6.0.2
regex==2024.9.11
requests==2.32.3
rich==13.9.3
safetensors==0.4.5
sentencepiece==0.2.0
shtab==1.7.1
simpleeval==0.9.13
six==1.16.0
sympy==1.13.3
threadpoolctl==3.5.0
tokenizers==0.20.1
torch==2.4.1
tqdm==4.66.6
transformers==4.46.1
triton==3.0.0
trl==0.11.4
typing_extensions==4.12.2
tyro==0.8.14
tzdata==2024.2
urllib3==2.2.3
xxhash==3.5.0
yarl==1.17.1

@Abhishek-TAMU
Copy link
Collaborator

Abhishek-TAMU commented Nov 1, 2024

@Abhishek-TAMU @anhuong I cannot reproduce the problem. I run the exact same command #383 (comment), and I have 5 checkpoints. In accordance to my reading of the [code] for transformers==4.52.2(https://github.com/huggingface/transformers/blob/v4.45.2/src/transformers/trainer.py#L2406)

10 samples, GA=4, so steps_in_epoch=2
for each epoch, on the second step (i.e. step=1, will set is_last_step_and_steps_less_than_grad_acc=True. Thus will go into the if block and optimizer will step. In other words, GA will only be for 2 steps
after the two steps of each epoch, it will coem to _maybe_log_save_evaluate and in there the checkpoint will be saved per epoch.

@fabianlim With logging statements I found out steps_in_epoch is 3 and not 2. (10 samples, GA=4, batch_size = 4).
But num_batches = 2 as subsequent variable values are remainder = 2, total_updates = 1 and hence only 2 batch_samples is there in this loop where the value of step would be 0 and 1 and hence is_last_step_and_steps_less_than_grad_acc is not becoming True for any step value in few epochs (3 out of 5 epochs) and that's why self._maybe_log_save_evaluate is not called for those epochs and hence not saving checkpoint for that epoch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants