Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[DISCUSSION] fix float8 all-gather in FSDP2 + TP: DTensor(WeightWithDynamicFloat8CastTensor) #326

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Jul 24, 2024

draft this PR for discussion, before having something landable

we see 2 problems in float8 all-gather FSDP2 + TP

  • FSDP2 all-gather is in bf16, but expect float8
  • TP all-reduce amax for weight, but expect all-reduce only for input

crux is how we dispatch torch.chunk, which is called from distribute_tensor for TP init

  • without this PR, torch.chunk returns Tensor. FSDP2 happens after TP, thus only see Float8Linear(weight=DTensor(_local_tensor=Tensor))
  • with this PR, torch.chunk returns WeightWithDynamicFloat8CastTensor

profiler trace without this PR: AR (all-reduce) for input -> AG (all-gather) -> 4 ARs for wq,k,v,o -> 1 AR for input. 4 ARs for wq,k,v,o should not happen if we precompute amax/scales for model.parameters() after opt.step()

Screenshot 2024-07-24 at 12 55 56 AM

weifengpy and others added 9 commits July 17, 2024 15:11
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 24, 2024
self.assertTrue(
isinstance(colwise_param, DTensor)
and isinstance(
colwise_param._local_tensor, WeightWithDynamicFloat8CastTensor
Copy link
Contributor Author

@weifengpy weifengpy Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

editted: without this PR, torch.chunk returns bf16 tensor. FSDP2 happens after TP, thus only see Float8Linear(weight=DTensor(_local_tensor=Tensor))
with this PR, torch.chunk returns WeightWithDynamicFloat8CastTensor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain where the bf16 came from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct my word to be accurate: without this PR, torch.chunk returns plain Tensor (can be fp32 or bf16) instead of WeightWithDynamicFloat8CastTensor

@@ -81,6 +81,8 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
torch.ops.aten.as_strided.default,
torch.ops.aten._to_copy.default,
torch.ops.aten._pin_memory.default,
torch.ops.aten.split.Tensor,
Copy link
Contributor Author

@weifengpy weifengpy Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aten.split is from torch.chunk, when calling from distribute_tensor during TP init

editted: @awgu curious if you still remember the reason to return Tensor from torch.chunk instead of WeightWithDynamicFloat8CastTensor. Is it for padding? any concerns if I prefer torch.chunk to returning WeightWithDynamicFloat8CastTensor ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awgu curious if you still remember the reason to return bf16 from torch.chunk.

I thought that dtype and whether is WeightWithDynamicFloat8CastTensor are orthogonal. Do you mean the latter (whether is WeightWithDynamicFloat8CastTensor or not?

I think originally I only added the ops that I saw I needed. Adding aten.split and aten.clone seems okay to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whether is WeightWithDynamicFloat8CastTensor or not

exactly, WeightWithDynamicFloat8CastTensor or not is the key. I edited my previous comments to say right now torch.chunk returns Tensor

I think originally I only added the ops that I saw I needed

changing torch.chunk affects both TP and FSDP2. will double check FSDP2 after the change

elif isinstance(out, DTensor) and isinstance(
out._local_tensor, Float8Tensor
):
out._local_tensor._scale = scale
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about this change yet. just want to have someting sketchy to discuss first

@weifengpy weifengpy marked this pull request as draft July 24, 2024 08:11
@weifengpy weifengpy changed the title fix float8 all-gather in FSDP2 + TP: DTensor(WeightWithDynamicFloat8CastTensor) [DISCUSSION] fix float8 all-gather in FSDP2 + TP: DTensor(WeightWithDynamicFloat8CastTensor) Jul 24, 2024
@weifengpy weifengpy requested review from awgu, wanchaol and yifuwang July 24, 2024 08:13
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants