Conversation
…line in te amp cm
Te fp8 wrapper
…ant_mode argument name in python API
Merge to main
|
|
||
| TensorModelParallelArgs.tensor_model_parallel_size = 2 | ||
|
|
||
| # MixedPrecisionArgs.mixed_precision_dtype = "fp8" |
There was a problem hiding this comment.
Uncommenting this line should throw error at iteration 64 related to NaN loss
There was a problem hiding this comment.
reproduce command
PYTHONPATH=${PYTHONPATH}:$(realpath ../) torchrun --nproc_per_node 2 --master_addr localhost --master_port 6000 pretrain_gr_ranking.py --gin-config-file movielens_ranking_fp8.gin
| target_group_size=self._target_group_size, | ||
| ) | ||
|
|
||
| # TODO: Remove this once the attention kernel outputs consistent dtype |
There was a problem hiding this comment.
@JacoCheung could you double check why we need this?
There was a problem hiding this comment.
When training with fp8 enabled, the model weight could be bf16/fp16 (NetworkArgs.dtype_str) (Usually it's bf16). So as activation is.
But the hstu kernel output is fp16, so here we need a cast between fp16->bf16. @shijieliu
@esoba do you think if there's a need to move the cast into the kernel or not?
There was a problem hiding this comment.
I think casting fp16 to bf16 wouldn't result in error since dynamic range is larger, but would assume some additional quantization error pops up (ideally that gets learned by model anyways). I think for consistency it would probably be better to move it into the kernel but as a workaround casting outside should be fine.
|
I will help to review the code. |
|
|
||
| return jd | ||
|
|
||
| def _align_jagged_data_for_fp8( |
There was a problem hiding this comment.
Hi @esoba , Since you have padding here, there should be a discarding process at the postprocessor before loss compute. That's being said,
final_loss = drop_pad_values(final_loss)
final_loss.mean().backward()
Otherwise, the padded token will impact the backward both data gradient and weight gradient even if the padded value is initialized as 0.
See our loss calculation.
And our post-processor (if it's ranking) will take the padded token as normal token..
There was a problem hiding this comment.
I believe I was seeing the issue when I set this to truncate (cut off the last N elements to get nearest divisible by 16), let me double check this to see if there is any undefined behavior doing it this way as well. Thanks for the catch!
|
FBGEMM_GPU HSTU has been integrated into recsys-example #321 . FYI. |
Added MixedPrecisionArgs to allow users to configure FP8 usage in TE linear layers and HSTU attention for Native HSTU layer. Features include:
Minimal working example PYTHONPATH=${PYTHONPATH}:$(realpath ../) torchrun --nproc_per_node 2 --master_addr localhost --master_port 6000 pretrain_gr_ranking.py --gin-config-file movielens_ranking_fp8.gin
Setup currently has a bug when both TE linear layer and HSTU attn are fp8 enabled, seeing NaN loss at iteration 64. I have a debugging branch here that tracks the forward pass and associated fp8 metadata for easier debugging. I tried to repro the issue here with some dummy inputs, and it has run successfully - have a hunch that there are NaN gradients flowing back into embedding table.
Checklist