-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[bug] FP8+PP+Recompute+GA>1, loss = nan #539
Comments
hotfix_pp_recompute_nan.diff.txt pp (1f1b or interleave) will run few warmup forward passes. |
@jingjie01ai Thank you for the bug report and the proposed fix. @ksivaman could you take a look? |
I have also met this problem when using bfloat16. |
@jingjie01ai @codecaution Thank you for reporting the problem. Could you please share a minimal, reproducible example so we could investigate further? |
Confirming we could reproduce the issue with the attached script, with changes from 493e9ef to print out the data_ptr. repro.py import torch
from transformer_engine.pytorch import Linear
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.utils import scaled_init_method_normal
from transformer_engine.pytorch.distributed import activation_recompute_forward
num_layers = 2
seq_len = 32
batch_size = 2
hidden_size = 64
num_attention_heads = 2
torch.manual_seed(0)
dtype = torch.float32
sigma = 0.023
output_layer_init_method = scaled_init_method_normal(sigma, num_layers)
block = (
Linear(
hidden_size, hidden_size, init_method=output_layer_init_method
)
.to(dtype=dtype)
.cuda()
)
te_inp = torch.randn(
seq_len, batch_size, hidden_size, dtype=dtype, requires_grad=True
).cuda()
use_fp8 = True
fp8_recipe = recipe.DelayedScaling(0, 1, recipe.Format.HYBRID)
with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
with torch.no_grad():
with activation_recompute_forward(activation_recompute=True, recompute_phase=False):
te_out1 = block(te_inp, is_first_microbatch=True)
te_out2 = block(te_inp, is_first_microbatch=False)
print("te_out1", te_out1)
print("te_out2", te_out2) output
TransformerEngine/transformer_engine/pytorch/module/linear.py Lines 168 to 173 in b5e13a1
|
Confirming that #646 fixes this issue with the following output from the repro script:
|
Describe the bug
FP8+PP+Recompute+GA>1, loss = nan
FP8+PP+GA>1, loss is normal
FP8+PP+Recompute+GA=1, loss is normal
FP8+TP+Recompute+GA>1, loss is normal
The text was updated successfully, but these errors were encountered: