Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM issue with same batch size that was running ok on 0.0.80 #184

Open
salrowili opened this issue Jan 8, 2025 · 22 comments
Open

OOM issue with same batch size that was running ok on 0.0.80 #184

salrowili opened this issue Jan 8, 2025 · 22 comments

Comments

@salrowili
Copy link

Hi,

I've noticed that recent updates are causing the SFT trainer code to throw an OutOfMemory (OOM) error with the same batch size that previously ran without issue on version 0.0.80.

I attempted SFT tuning using bfloat16 (no LoRA) with LLaMA 3.1 8B, max_length=1024, and batch=8 on TPUv4-8, but encountered an OOM error. This fine-tuning setup was working ok with 0.0.80.

@erfanzar
Copy link
Owner

erfanzar commented Jan 8, 2025

Hello and thanks for reporting issue, can you share the code please?

@salrowili
Copy link
Author

Hi @erfanzar ,
Please see the code below :

import easydel as ed
from easydel.utils.analyze_memory import SMPMemoryMonitor  # Optional for memory analysis
import jax
from jax import numpy as jnp, sharding, lax, random as jrnd
from huggingface_hub import HfApi
import datasets
from flax.core import FrozenDict
from datasets import load_dataset
import jsonlines ## pip install jsonlines
import torch
import numpy as np
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
import wandb
wandb.init(project="test")
def train():
	pretrained_model_name_or_path="meta-llama/Llama-3.1-8B-Instruct"
	max_length=1024
	PartitionSpec, api = sharding.PartitionSpec, HfApi()
	json_entry=[]
	qa_data = load_dataset("Stanford/web_questions" ,split="train")
	instruction ="You are a helpful AI assistant."
	for item in qa_data:
		question=item["question"]
		answer=item["answers"][0]
		json_entry.append({"messages": [{"role": "system", "content": instruction}, {"role": "user", "content": item["question"]}, {"role": "assistant", "content": answer}]})

	with jsonlines.open('SFT_Train.jsonl', 'w') as writer:
    		writer.write_all(json_entry)

	train_dataset = load_dataset("json", data_files={"train" :"/home/big35manf/SFT_Train.jsonl"},split="train")
	sharding_axis_dims = (1, 1, -1, 1)
	new_repo_id = "Test"
	dtype = jnp.bfloat16

	model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
		pretrained_model_name_or_path,
		auto_shard_model=True,
		sharding_axis_dims=sharding_axis_dims,
		config_kwargs=ed.EasyDeLBaseConfigDict(
			use_scan_mlp=False,
			attn_dtype=jnp.bfloat16,
			freq_max_position_embeddings=max_length,
			mask_max_position_embeddings=max_length,
			attn_mechanism=ed.AttentionMechanisms.VANILLA,
		),
		param_dtype=dtype,
		torch_dtype=torch.bfloat16,
		dtype=dtype,
        	from_torch=True,
		precision=lax.Precision("fastest"),
	)
	config = model.config
	tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path,trust_remote_code=True)
	tokenizer.pad_token = (tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token)
	tokenizer.padding_side = "right"
	train_arguments = ed.SFTConfig(
		num_train_epochs=10,
		learning_rate=5e-5,
		learning_rate_end=0,
		warmup_steps=100,
		optimizer=ed.EasyDeLOptimizers.ADAMW,
		scheduler=ed.EasyDeLSchedulers.WARM_UP_COSINE,
		weight_decay=0.02,
		total_batch_size=8,
		max_sequence_length=max_length,
		gradient_accumulation_steps=1,
		do_last_save=True,
		model_name=new_repo_id,
		track_memory=False,
        	packing=True,
        	num_of_sequences=max_length,
		dataset_text_field=None,
		dataset_num_proc=32,
	)

	trainer = ed.SFTTrainer(
		processing_class=tokenizer,
		arguments=train_arguments,
		model=model,
		train_dataset=train_dataset,
		eval_dataset=None,
		formatting_func=lambda x: [tokenizer.apply_chat_template(x["messages"], tokenize=False)]
	)

	output = trainer.train()
	logger.info("Training Done")
	tokenizer.save_pretrained(output.last_save_file_name)
train()

@erfanzar
Copy link
Owner

erfanzar commented Jan 8, 2025

Can you rerun the code? There was an issue with the loss function, which wasn't using the fused version.

@erfanzar
Copy link
Owner

erfanzar commented Jan 8, 2025

And since the sharding mechanism you're using is tensor parallel you can except OOM but not on 1k sequence length

@erfanzar
Copy link
Owner

erfanzar commented Jan 8, 2025

In v0.0.80 trainer will automatically use gradient checkpointing (this behavior is removed in 0.1.0 and you should pass gradient_checkponiting to model_kwargs (ill take blame for not having good documentation))

@salrowili
Copy link
Author

salrowili commented Jan 8, 2025

You are right!. In 0.0.80, it was part of training arguments as we can see in this example :

gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,

However, it was removed in the recent updates
class TrainingArguments:

after updating the code with :

                config_kwargs=ed.EasyDeLBaseConfigDict(
                        use_scan_mlp=False,
                        attn_dtype=jnp.bfloat16,
                        freq_max_position_embeddings=max_length,
                        mask_max_position_embeddings=max_length,
                        attn_mechanism=ed.AttentionMechanisms.VANILLA,
                        gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
                ),

i was able to run SFT code with 8 batch size but i got couple of warning:

[easydel.trainers.base_trainer] Prevent Running Model Due to NaN Loss

FutureWarning: unhashable type: <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>. Attempting to hash a tracer will lead to an error in a future JAX release.

The issue of NaN is presented even with 0.0.80 every time i use meta-llama/Llama-3.1-8B-Instruct model with Packing=True. It would go in some runs and in some runs will appear. I have worked around this issue by re-running the script multiple time till i had no NaN, without changing any arguments. With Packing=False, the issue would disappear .This issue is not presented with other llama3.2 models.

Last note on sharding_axis_dims = (1, 1, -1, 1) choice. This setting give me 113 FLOPS (0.0.8) with TPUv4-8 against other sharding axis setting (98 FLOPS) . Hence, that why i chose it over other options.

@erfanzar
Copy link
Owner

erfanzar commented Jan 8, 2025

Look flops calculation method is changed in last version every thing was manually calculated but in this version it's calculated from jax analysis so except it to be wrong for example you might be running and getting 160flops but in some parts xla play a bit dumb and show 130Flops

Check this out

jax-ml/jax#17912

@salrowili
Copy link
Author

Thank you for the detailed reply.

I have tested the TFLOPs in term of runtime speed on both 0.0.80 and 0.1dev using the same example of WebQuestion with llama3.1 8B.

As you can see from the results, there is a problem with the speed with the recent updated even when with using different sharding strategies. I let the code run for a while till s/it metric become stable.

You can see that we have double the speed with 0.0.80. Also notice how (1,1,-1,1) is the best setting in term of speed for TPUv-8 as i stated earlier. Actually, the difference between (1,1,-1,1) and (1,1,1,-1) become worse (almost double) with the new update. I have also noticed that it will take a while (3-5 mins) before the script start running with 0.1.0.dev update, so we should add 3-5 mins to the runtime to have an accurate head-to-head comparison.

0.0.80

(1 ,-1 , 1, 1)

27%|▎| 71/260 [03:12<03:56, 1.25s/it, TFLOPs=101, accuracy=0.8993157, epoch=2, learning_rate=3.5015e-05, loss=0.389, max_grad_norm=1.3, mean_accuracy=0.8596358, mean_grad_norm=0.0757, mean_loss=0.78678674, perplexi

(1,1,-1,1)

46%|▍| 120/260 [03:26<02:21, 1.01s/it, TFLOPs=132, accuracy=0.93218476, epoch=4, learning_rate=4.8624297e-05, loss=0.248, max_grad_norm=2.33, mean_accuracy=0.88241583, mean_grad_norm=0.0815, mean_loss=0.5953737, pe

(1,1,1,-1)

22%|▏| 58/260 [03:06<04:23, 1.30s/it, TFLOPs=97.2, accuracy=0.8931932, epoch=2, learning_rate=2.85215e-05, loss=0.457, max_grad_norm=2.42, mean_accuracy=0.850847, mean_grad_norm=0.0981, mean_loss=0.87310237, perple

-------------------------------------------------------------------------------------------

0.1.0.dev

(1,-1,1,1)

15%|▏| 40/260 [02:48<15:06, 4.12s/it, TFLOPs=5.63e+13, accuracy=0.889, epoch=1, learning_rate=1.9530498e-05, loss=0.494, max_grad_norm=2.92, mean_accuracy=0.83462554, mean_grad_norm=0.109, mean_loss=1.0513492, perp

(1,1,-1,1)

42%|▍| 110/260 [03:53<05:15, 2.10s/it, TFLOPs=1.16e+14, accuracy=0.923, epoch=4, learning_rate=4.9610666e-05, loss=0.267, max_grad_norm=1.88, mean_accuracy=0.87859696, mean_grad_norm=0.0786, mean_loss=0.6243303, pe

(1,1,1,-1)

58%|▌| 150/260 [10:37<07:44, 4.22s/it, TFLOPs=5.58e+13, accuracy=0.94, epoch=5, learning_rate=3.9294693e-05, loss=0.208, max_grad_norm=1.25, mean_accuracy=0.8924935, mean_grad_norm=0.0669, mean_loss=0.52312225, per

-------------------------------------------------------------------------------------------------------

Script to run the code on 0.0.80

import easydel as ed
from easydel.utils.analyze_memory import SMPMemoryMonitor  # Optional for memory analysis
import jax
from transformers import AutoTokenizer
from jax import numpy as jnp, sharding, lax, random as jrnd
from huggingface_hub import HfApi
import datasets
from flax.core import FrozenDict
from datasets import load_dataset
import jsonlines
import torch
import sys
import wandb
wandb.init(project="test",mode='online')

json_entry=[]
qa_data = load_dataset("Stanford/web_questions" ,split="train")
instruction ="You are a helpful AI assistant."
for item in qa_data:
	question=item["question"]
	answer=item["answers"][0]
	json_entry.append({"messages": [{"role": "system", "content": instruction}, {"role": "user", "content": item["question"]}, {"role": "assistant", "content": answer}]})

with jsonlines.open('SFT_Train.jsonl', 'w') as writer:
	writer.write_all(json_entry)
train_dataset = load_dataset("json", data_files={"train" :"/home/big35manf/SFT_Train.jsonl"},split="train")
PartitionSpec, api = sharding.PartitionSpec, HfApi()
sharding_axis_dims = (1, -1, 1, 1)
max_length = 1024
input_shape = (len(jax.devices()), max_length)
pretrained_model_name_or_path = "meta-llama/Llama-3.1-8B-Instruct"
pretrained_model_name_or_path_tokenizer = pretrained_model_name_or_path
new_repo_id = "Test"
dtype = jnp.bfloat16

model, params = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
	pretrained_model_name_or_path,
	input_shape=input_shape,
	auto_shard_params=True,
	sharding_axis_dims=sharding_axis_dims,
	config_kwargs=ed.EasyDeLBaseConfigDict(
		use_scan_mlp=False,
		freq_max_position_embeddings=max_length,
		mask_max_position_embeddings=max_length,
		attn_dtype=jnp.bfloat16,
		attn_mechanism=ed.AttentionMechanisms.VANILLA,
	),
	param_dtype=dtype,
	dtype=dtype,
	torch_dtype=torch.bfloat16,
        from_torch=True,
	precision=lax.Precision("fastest"),
)

config = model.config
model_use_tie_word_embedding = config.tie_word_embeddings
model_parameters = FrozenDict({"params": params})

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path_tokenizer, trust_remote_code=True)
tokenizer.pad_token = (tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token)
tokenizer.padding_side = "right"

train_arguments = ed.TrainingArguments(
	num_train_epochs=10,
	learning_rate=5e-5,
	learning_rate_end=0,
	warmup_steps=100,
	optimizer=ed.EasyDeLOptimizers.ADAMW,
	scheduler=ed.EasyDeLSchedulers.WARM_UP_COSINE,
	weight_decay=0.02,
	total_batch_size=8,
	max_sequence_length=max_length,
	gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
	sharding_array=sharding_axis_dims,
	gradient_accumulation_steps=1,
	init_input_shape=input_shape,
	dtype=dtype,
	do_last_save=False,
	param_dtype=dtype,
	model_name=new_repo_id,
	training_time="70H",
	track_memory=False,
)

trainer = ed.SFTTrainer(
	arguments=train_arguments,
	model=model,
	train_dataset=train_dataset,
	eval_dataset=None,
	tokenizer=tokenizer,
	dataset_text_field=None,
	formatting_func=lambda x: [tokenizer.apply_chat_template(x["messages"], tokenize=False)],
	packing=True,
	num_of_sequences=max_length,
	dataset_num_proc=32,
)

output = trainer.train(model_parameters=model_parameters, state=None)
print("Saving the PyTorch Model")
trainer.save_pretrained(output.state, to_torch=True)

Script to run the code on 0.1.0.dev

import easydel as ed
from easydel.utils.analyze_memory import SMPMemoryMonitor  # Optional for memory analysis
import jax
from jax import numpy as jnp, sharding, lax, random as jrnd
from huggingface_hub import HfApi
import datasets
from flax.core import FrozenDict
from datasets import load_dataset
import jsonlines ## pip install jsonlines
import torch
import numpy as np
import pandas as pd
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
import wandb
wandb.init(project="test")
def train():
	pretrained_model_name_or_path="meta-llama/Llama-3.1-8B-Instruct"
	max_length=1024
	PartitionSpec, api = sharding.PartitionSpec, HfApi()
	json_entry=[]
	qa_data = load_dataset("Stanford/web_questions" ,split="train")
	instruction ="You are a helpful AI assistant."
	for item in qa_data:
		question=item["question"]
		answer=item["answers"][0]
		json_entry.append({"messages": [{"role": "system", "content": instruction}, {"role": "user", "content": item["question"]}, {"role": "assistant", "content": answer}]})

	with jsonlines.open('SFT_Train.jsonl', 'w') as writer:
    		writer.write_all(json_entry)

	train_dataset = load_dataset("json", data_files={"train" :"/home/big35manf/SFT_Train.jsonl"},split="train")
	sharding_axis_dims = (1, -1, 1, 1)
	new_repo_id = "Test"
	dtype = jnp.bfloat16

	model = ed.AutoEasyDeLModelForCausalLM.from_pretrained(
		pretrained_model_name_or_path,
		auto_shard_model=True,
		sharding_axis_dims=sharding_axis_dims,
		config_kwargs=ed.EasyDeLBaseConfigDict(
			use_scan_mlp=False,
			attn_dtype=jnp.bfloat16,
			freq_max_position_embeddings=max_length,
			mask_max_position_embeddings=max_length,
			attn_mechanism=ed.AttentionMechanisms.VANILLA,
			gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
		),
		param_dtype=dtype,
		torch_dtype=torch.bfloat16,
		dtype=dtype,
        	from_torch=True,
		precision=lax.Precision("fastest"),
	)
	config = model.config
	tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path,trust_remote_code=True)
	tokenizer.pad_token = (tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token)
	tokenizer.padding_side = "right"
	train_arguments = ed.SFTConfig(
		num_train_epochs=10,
		learning_rate=5e-5,
		learning_rate_end=0,
		warmup_steps=100,
		optimizer=ed.EasyDeLOptimizers.ADAMW,
		scheduler=ed.EasyDeLSchedulers.WARM_UP_COSINE,
		weight_decay=0.02,
		total_batch_size=8,
		max_sequence_length=max_length,
		gradient_accumulation_steps=1,
		do_last_save=True,
		model_name=new_repo_id,
		track_memory=False,
        	packing=True,
        	num_of_sequences=max_length,
		dataset_text_field=None,
		dataset_num_proc=32,
	)

	trainer = ed.SFTTrainer(
		processing_class=tokenizer,
		arguments=train_arguments,
		model=model,
		train_dataset=train_dataset,
		eval_dataset=None,
		formatting_func=lambda x: [tokenizer.apply_chat_template(x["messages"], tokenize=False)]
	)

	output = trainer.train()
	logger.info("Training Done")
	tokenizer.save_pretrained(output.last_save_file_name)
train()

@erfanzar
Copy link
Owner

Thank you @salrowili for bringing up these issues and for your detailed feedback!
I want to clarify that several of the speed-related issues stem from our flax/NNX integration. I'm currently working on fixing these with @cgarciae's help. Going forward, any easydel issues that are specifically related to flax/NNX will be redirected to the flax/NNX repository with additional context and documentation.

@erfanzar
Copy link
Owner

@salrowili #185

@salrowili
Copy link
Author

salrowili commented Jan 14, 2025

Great!. Thank you @erfanzar for opening the topic. I have one question. I am planning to start sharing my code with the topic you just opened #185, but i am still struggling to run my codes on the new EasyDEL 0.1dev release. Its very slow compared to 0.0.80 and you have told me that it due to flax/NNX integration. The inference with the new 0.1dev is fast, but the problem is with SFT code. Do you have any estimation when the issue will be fixed? Because if it will be soon, i will wait and till it fixed and share my codes with 0.1dev release.

@erfanzar
Copy link
Owner

Hi @salrowili,

Many performance issues related to the new arguments and the updated base trainer have been resolved. These include fixes for duplicated if statements, redundant code checks, and incorrect caching mechanisms.

You can rerun your benchmark to see if there are any remaining performance issues (avoid using ahead-of-time compilation).

With Qwen-2 7B, batch size 8, and full sequence parallelism, I was able to achieve 6 seconds per iteration. Let me know how it goes!

@salrowili
Copy link
Author

Hi @erfanzar . That's a great news!. Can you share the code you have used to achieve this performance?

@erfanzar
Copy link
Owner

@salrowili im using tests/trainer_test.py

@erfanzar
Copy link
Owner

Hi @salrowili,

I hope the performance issue is running smoothly now. If you’re still encountering any other problems, please let me know. I’m currently working on improving the speed of the training process and would be happy to assist you further.

For quicker communication and to resolve any issues more efficiently, feel free to connect with me on Discord. My user ID is citifer. Looking forward to helping you out!

@salrowili
Copy link
Author

Hi @erfanzar . The issue is still there and you can verify it by using the the SFT code that i posted early. Currently, i am doing the SFT training with 0.0.80 version. How i can disable AOT compilation?

I thought i should slow down in reporting issues since you may seems busy with other EasyDel issues, and this why i did not update you that the issue still exists (: .

I will try to DM you on Discord soon

@salrowili
Copy link
Author

0.0.80 is much faster especially when you run larger model. I ran llama3.1 8B and with 0.1dev it gave me 1.9s/it with ~2H against 1.24s/it with 1H 15m

@erfanzar
Copy link
Owner

Hey @salrowili the bugs are now fixed ill push fix commit today and now 0.1 is %5~10 faster

erfanzar added a commit that referenced this issue Jan 20, 2025
…various components, fixes issues related to trainers being slow in #184
@erfanzar
Copy link
Owner

Hi @salrowili,

Could you please test and confirm whether the execution time and step_times in easydel-v0.1.0-dev are approximately 10–20% faster compared to easydel-0.80.0?

Based on my tests using V3-8:

  • Version 0.80.0 runs at 2 seconds per iteration with your provided code.
  • Version 0.1.0-dev runs at 1.6 seconds per iteration with the same config.

Let me know your findings!

@salrowili
Copy link
Author

Hi @erfanzar . I think i have figured out the root of the speed issue. I think you are testing the code on TPUv3-8 which has 8 chips each has single core. In contrast, TPUv4-8 has 4 chips each has 2 cores. Thus, maybe you are only using single core from each chip on TPUv4-8??. This also may impact how you calculate FLOPs.

I having tested 0.1 on both TPUv3-8 and TPUv4-8 and TPUv3-8 is much faster (almost double). Also the trainer output is messed up as it print output vertically not horizontally as it used to.

Here is Trainer output for TPUv3-8 and TPUv4-8. The code for both is identical including the sharding method. Observe the step time.

TPUv3-8

training process - { 'mlperf/execution_time': 0.37421172,                                                                                                                      
  'mlperf/flops': 8245941370880.0,                                                                                                                                             
  'mlperf/flops_per_sequence': 4173047252.4696355,                                                                                                                             
  'mlperf/flops_per_token': 4075241.4574898784,                                                                                                                                
  'mlperf/flops_pre_second': 21753546211775.117,                                                                                                                               
  'mlperf/flops_sequence_pre_second': 4173047252.4696355,                                                                                                                      
  'mlperf/flops_token_pre_second': 4075241.4574898784,                                                                                                                         
  'mlperf/step_time': 0.3790619373321533,                                                                                                                                      
  'mlperf/throughput': 21611.24395040949,                                                                                                                                      
  'mlperf/total_time': 265.9115765094757,                                                                                                                                      
  'train/accuracy': 0.9125122427940369,                                                                                                                                        
  'train/epoch': 9,                                                                                                                                                            
  'train/grad_norm/lm_head.kernel': 0.35546875,                                                                                                                                
  'train/grad_norm/model.norm.kernel': 0.0031585693359375,                                                                                                                     
  'train/learning_rate': 9.386166652802785e-07,                                                                                                                                
  'train/loss': 0.3316769003868103,                                                                                                                                            
  'train/mean_accuracy': 0.8708844780921936,                                                                                                                                   
  'train/mean_loss': 0.6994993090629578,                                                                                                                                       
  'train/perplexity': 1.3933013677597046,                                                                                                                                      
  'train/step': 246,                                                                                                                                                           
  'train/train/max_grad_norm': 1.3671875,                                                                                                                                      
  'train/train/mean_grad_norm': 0.10693359375,                                                                                                                                 
  'train/visited_tokens': 2023424,                                                                                                                                             
  'train/z_loss': 0.0}  

TPUv4-8

training process - { 'mlperf/execution_time': 0.772636662,                                                                                                                     
  'mlperf/flops': 115442874056704.0,                                                                                                                                           
  'mlperf/flops_per_sequence': 53052791386.35294,                                                                                                                              
  'mlperf/flops_per_token': 51809366.5882353,                                                                                                                                  
  'mlperf/flops_pre_second': 148604274400124.53,                                                                                                                               
  'mlperf/flops_sequence_pre_second': 53052791386.35294,                                                                                                                       
  'mlperf/flops_token_pre_second': 51809366.5882353,                                                                                                                           
  'mlperf/step_time': 0.7768476009368896,                                                                                                                                      
  'mlperf/throughput': 10545.182851978081,                                                                                                                                     
  'mlperf/total_time': 479.89153575897217,                                                                                                                                     
  'train/accuracy': 0.8673020601272583,                                                                                                                                        
  'train/epoch': 0,                                                                                                                                                            
  'train/grad_norm/lm_head.kernel': 0.74609375,                                                                                                                                
  'train/grad_norm/model.norm.kernel': 0.00506591796875,                                                                                                                       
  'train/learning_rate': 4.973118120688014e-05,                                                                                                                                
  'train/loss': 0.5999301671981812,                                                                                                                                            
  'train/mean_accuracy': 0.880715012550354,                                                                                                                                    
  'train/mean_loss': 0.5317749381065369,                                                                                                                                       
  'train/perplexity': 1.8219926357269287,                                                                                                                                      
  'train/step': 271,                                                                                                                                                           
  'train/train/max_grad_norm': 1.609375,                                                                                                                                       
  'train/train/mean_grad_norm': 0.09228515625,                                                                                                                                 
  'train/visited_tokens': 2228224,                                                                                                                                             
  'train/z_loss': 0.0} 

@erfanzar
Copy link
Owner

@salrowili
FLOPs are calculated automatically by JAX.

For logging during training, you have the flexibility to customize how you log or add logging hooks. By default, the logging method is set to json, but you can switch to tqdm by specifying it in the TrainingArguments under progress_bar_type.

If the current logging format doesn't suit your preferences, reverting to tqdm as the default might be a better option, as it provides a more intuitive and user-friendly progress display.

@salrowili
Copy link
Author

salrowili commented Jan 21, 2025

You are right. However, i think it would be better to have the default progress bar type to tqdm because json will flood the terminal with loggings especially when you set the log step to be smaller.

I have re-run the code and fix an issue that the dataset in TPUv3-8 uses the cached dataset, not the updated one. With total runtime we can see that TPUv3-8 is much faster than TPUv4-8 so the issue is not related to how we calculate metrics (e.g. step time).

TPUv3-8

training process: 97%|▉| 252/260 [02:48<00:05, 1.55it/s, mlperf/execution_time=0.378, mlperf/flops=8.25e+12, mlperf/flops_per_sequence=4.07e+9, mlperf/flops_per_token=3.

TPUv4-8

training process: 98%|▉| 256/260 [04:37<00:04, 1.06s/it, mlperf/execution_time=0.773, mlperf/flops=1.15e+14, mlperf/flops_per_sequence=5.61e+10, mlperf/flops_per_token=5.4

Also one thing related to jax jaxlib version. We need to fix the requirement of the jax jaxlib version in the mean time to 0.4.35.

EasyDeL/pyproject.toml

Lines 36 to 37 in d13ecbb

jax = ">=0.4.34"
jaxlib = ">=0.4.34"

This is because the function core.new_main was depreciated after jax >0.4.35 . This function is used by fjformer repo :

https://github.com/erfanzar/FJFormer/blob/f9940b097ef34ba5627ce6a259bfe629372899d8/fjformer/core/implicit_array.py#L452

However, removing the new_main function and updating the code would a better solution since the new jax 0.5.0 has been released.
https://github.com/jax-ml/jax/releases

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants