Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] FP8+PP+Recompute+GA>1, loss = nan #539

Closed
jingjie01ai opened this issue Nov 27, 2023 · 6 comments
Closed

[bug] FP8+PP+Recompute+GA>1, loss = nan #539

jingjie01ai opened this issue Nov 27, 2023 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@jingjie01ai
Copy link

jingjie01ai commented Nov 27, 2023

Describe the bug
FP8+PP+Recompute+GA>1, loss = nan
FP8+PP+GA>1, loss is normal
FP8+PP+Recompute+GA=1, loss is normal
FP8+TP+Recompute+GA>1, loss is normal

@jingjie01ai jingjie01ai changed the title when training with fp8, pipeline parallel and recompute enabled in megatron (GA > 1), the loss is nan [bug] FP8+PP+Recompute+GA>1, loss = nan Nov 28, 2023
@jingjie01ai
Copy link
Author

hotfix_pp_recompute_nan.diff.txt

pp (1f1b or interleave) will run few warmup forward passes.
At the time of first microbatch warmup forward (is_first_microbatch=true), layer is executed under no-grad context (recompute phase 1), and the self.weight_fp8 of linear module is not updated because cast_to_fp8 created a new tensor.
At the time of second microbatch warmup forward (is_first_microbatch=false), linear will use self.weight_fp8 to run gemm operation and get nan out, because self.weight_fp8 is a empty tensor placeholder.

@ptrendx
Copy link
Member

ptrendx commented Dec 1, 2023

@jingjie01ai Thank you for the bug report and the proposed fix. @ksivaman could you take a look?

@codecaution
Copy link

I have also met this problem when using bfloat16.

@jinzex
Copy link
Contributor

jinzex commented Jan 22, 2024

@jingjie01ai @codecaution Thank you for reporting the problem. Could you please share a minimal, reproducible example so we could investigate further?

@jinzex
Copy link
Contributor

jinzex commented Jan 30, 2024

Confirming we could reproduce the issue with the attached script, with changes from 493e9ef to print out the data_ptr.

repro.py

import torch
from transformer_engine.pytorch import Linear
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.utils import scaled_init_method_normal
from transformer_engine.pytorch.distributed import activation_recompute_forward

num_layers = 2
seq_len = 32
batch_size = 2
hidden_size = 64
num_attention_heads = 2

torch.manual_seed(0)
dtype = torch.float32
sigma = 0.023
output_layer_init_method = scaled_init_method_normal(sigma, num_layers)

block = (
    Linear(
        hidden_size, hidden_size, init_method=output_layer_init_method
    )
    .to(dtype=dtype)
    .cuda()
)

te_inp = torch.randn(
    seq_len, batch_size, hidden_size, dtype=dtype, requires_grad=True
).cuda()

use_fp8 = True
fp8_recipe = recipe.DelayedScaling(0, 1, recipe.Format.HYBRID)

with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe):
    with torch.no_grad():
        with activation_recompute_forward(activation_recompute=True, recompute_phase=False):
            te_out1 = block(te_inp, is_first_microbatch=True)
            te_out2 = block(te_inp, is_first_microbatch=False)
            print("te_out1", te_out1)
            print("te_out2", te_out2)

output

before data_ptr weight_fp8_bak: 22573943414272, weight_fp8: 22573943414272
after  data_ptr weight_fp8_bak: 22573943414272, weight_fp8: 22573943441920
te_out1 tensor([[[ 0.1578,  0.0800,  0.0902,  ...,  0.0439,  0.0244,  0.0424],
         [-0.0949, -0.0265, -0.0230,  ..., -0.0503,  0.1298,  0.0212]],
        ...,
        [[ 0.0183, -0.0587,  0.1722,  ...,  0.0945, -0.0697, -0.0029],
         [-0.0288, -0.0141, -0.0333,  ..., -0.0888,  0.0073,  0.0359]]],
       device='cuda:0')
te_out2 tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],
        ...,
        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]], device='cuda:0')

te_out2 is zero because cast_to_fp8 returns and assigns a new tensor to the local variable weight_fp8, and the original weight_fp8 tensor passed from the argument is not updated with the casted output, remaining empty. This will cause an issue later when recompute uses this empty weight_fp8 tensor for calculations.

weight_fp8._data = cast_to_fp8(
weight,
fp8_meta["scaling_fwd"],
tex.FP8FwdTensors.GEMM1_WEIGHT,
fp8_dtype_forward,
)

@ksivaman
Copy link
Member

Confirming that #646 fixes this issue with the following output from the repro script:

before data_ptr weight_fp8_bak: 140304877411840, weight_fp8: 140304877411840
after  data_ptr weight_fp8_bak: 140304877411840, weight_fp8: 140304877411840
te_out1 tensor([[[ 0.1578,  0.0800,  0.0902,  ...,  0.0439,  0.0244,  0.0424],
         [-0.0949, -0.0265, -0.0230,  ..., -0.0503,  0.1298,  0.0212]],

        [[-0.0033,  0.0299, -0.0311,  ..., -0.0723, -0.0873, -0.0293],
         [-0.0157, -0.1337, -0.0591,  ..., -0.0443, -0.1072, -0.0702]],

        [[ 0.1131, -0.0860, -0.1050,  ..., -0.0586, -0.0582, -0.0546],
         [ 0.0446, -0.0304,  0.0873,  ..., -0.0381,  0.0410, -0.0685]],

        ...,

        [[ 0.0462, -0.0806,  0.0091,  ...,  0.0105, -0.0785,  0.1738],
         [ 0.0816, -0.0390, -0.1187,  ...,  0.0513, -0.1007, -0.0127]],

        [[-0.0562,  0.1853, -0.0528,  ..., -0.0971, -0.0886, -0.0164],
         [ 0.1194, -0.0728,  0.0328,  ..., -0.1134, -0.0909, -0.1145]],

        [[ 0.0183, -0.0587,  0.1722,  ...,  0.0945, -0.0697, -0.0029],
         [-0.0288, -0.0141, -0.0333,  ..., -0.0888,  0.0073,  0.0359]]],
       device='cuda:0')
te_out2 tensor([[[ 0.1585,  0.0787,  0.0928,  ...,  0.0437,  0.0221,  0.0424],
         [-0.0944, -0.0302, -0.0220,  ..., -0.0452,  0.1324,  0.0208]],

        [[-0.0044,  0.0275, -0.0294,  ..., -0.0762, -0.0922, -0.0330],
         [-0.0151, -0.1364, -0.0602,  ..., -0.0442, -0.1054, -0.0684]],

        [[ 0.1180, -0.0868, -0.1061,  ..., -0.0548, -0.0579, -0.0546],
         [ 0.0389, -0.0377,  0.0840,  ..., -0.0370,  0.0360, -0.0705]],

        ...,

        [[ 0.0404, -0.0819,  0.0097,  ...,  0.0114, -0.0800,  0.1729],
         [ 0.0826, -0.0326, -0.1153,  ...,  0.0528, -0.1023, -0.0123]],

        [[-0.0574,  0.1777, -0.0486,  ..., -0.0939, -0.0868, -0.0146],
         [ 0.1259, -0.0797,  0.0349,  ..., -0.1093, -0.0883, -0.1099]],

        [[ 0.0140, -0.0621,  0.1714,  ...,  0.0969, -0.0645, -0.0081],
         [-0.0265, -0.0140, -0.0352,  ..., -0.0846,  0.0017,  0.0350]]],
       device='cuda:0')

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants