Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Reduced CPU overhead in precompute_float8_dynamic_scale_for_fsdp (#331
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: #331

**Description**
For Llama3-8B on 8xH100 profiling with `with_stack=True` (which does add overhead), the `precompute_float8_dynamic_scale_for_fsdp` CPU time decreases from 24 ms to 15 ms.

Before:
<img width="600" alt="Screenshot 2024-07-25 at 10 16 38 AM" src="https://github.com/user-attachments/assets/5d2384a0-6864-4bdc-91db-90cae809c702">

After:
<img width="638" alt="Screenshot 2024-07-25 at 10 17 00 AM" src="https://github.com/user-attachments/assets/1dbf3b2e-a576-4cdf-ac4f-06ae96020c38">

**Test Plan**
```
(pytorch-3.10) [[email protected] /data/users/andgu/float8_experimental (precompute_float8)]$ pytest test/test_fsdp2/test_fsdp2.py
========================================================= test session starts =========================================================
platform linux -- Python 3.10.13, pytest-7.3.2, pluggy-1.3.0
rootdir: /data/users/andgu/float8_experimental
plugins: xdoctest-1.1.0, hypothesis-5.35.1, xdist-3.3.1, shard-0.1.2, rerunfailures-13.0, flakefinder-1.1.0, cpp-2.3.0
collected 8 items
Running 8 items in this shard

test/test_fsdp2/test_fsdp2.py ........                                                                                          [100%]

========================================================== warnings summary ===========================================================
test/test_fsdp2/test_fsdp2.py::TestFloat8MultiThread::test_fp32_fp8_multi_module_parity
test/test_fsdp2/test_fsdp2.py::TestFloat8MultiThread::test_fp32_fp8_single_module_parity
  /data/users/andgu/float8_experimental/float8_experimental/float8_linear_utils.py:272: FutureWarning: The combination of ranks + tag as process group identifier has been deprecated. Please switch to using ProcessGroup, DeviceMesh, or group name instead.
    all_reduced_amax_tensor = all_reduce(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================== 8 passed, 2 warnings in 121.90s (0:02:01) ==============================================
```

imported-using-ghimport

Test Plan: Imported from OSS

Reviewed By: weifengpy

Differential Revision: D60236258

Pulled By: awgu

fbshipit-source-id: 7b1e48d431dac25d534a77d64d1e5571ad3ad807
  • Loading branch information
awgu authored and facebook-github-bot committed Jul 25, 2024
1 parent a6cef5a commit 701647b
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,16 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:

# inf-norm is equivalent to max(abs(w))
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
amax_tensor = torch.vstack(max_weights) # Partial
amax_tensor = torch.stack(max_weights) # Partial
# clamp is dispatched through DTensor
# it will issue a single all-reduce
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
if amax_tensor.dtype is torch.float16:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
scales = torch.split(scale_tensor, 1) # Replicate
for scale, float8_linear in zip(scales, float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = (
scale._local_tensor.squeeze()
)
local_scale_tensor = scale_tensor.to_local()
for i, float8_linear in enumerate(float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]


# FSDP pads its local tensor on dim-0. The subclass should be preserved such
Expand Down

0 comments on commit 701647b

Please sign in to comment.