This repository has been archived by the owner on Aug 7, 2024. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reduced CPU overhead in
precompute_float8_dynamic_scale_for_fsdp
(#331
) Summary: Pull Request resolved: #331 **Description** For Llama3-8B on 8xH100 profiling with `with_stack=True` (which does add overhead), the `precompute_float8_dynamic_scale_for_fsdp` CPU time decreases from 24 ms to 15 ms. Before: <img width="600" alt="Screenshot 2024-07-25 at 10 16 38 AM" src="https://github.com/user-attachments/assets/5d2384a0-6864-4bdc-91db-90cae809c702"> After: <img width="638" alt="Screenshot 2024-07-25 at 10 17 00 AM" src="https://github.com/user-attachments/assets/1dbf3b2e-a576-4cdf-ac4f-06ae96020c38"> **Test Plan** ``` (pytorch-3.10) [[email protected] /data/users/andgu/float8_experimental (precompute_float8)]$ pytest test/test_fsdp2/test_fsdp2.py ========================================================= test session starts ========================================================= platform linux -- Python 3.10.13, pytest-7.3.2, pluggy-1.3.0 rootdir: /data/users/andgu/float8_experimental plugins: xdoctest-1.1.0, hypothesis-5.35.1, xdist-3.3.1, shard-0.1.2, rerunfailures-13.0, flakefinder-1.1.0, cpp-2.3.0 collected 8 items Running 8 items in this shard test/test_fsdp2/test_fsdp2.py ........ [100%] ========================================================== warnings summary =========================================================== test/test_fsdp2/test_fsdp2.py::TestFloat8MultiThread::test_fp32_fp8_multi_module_parity test/test_fsdp2/test_fsdp2.py::TestFloat8MultiThread::test_fp32_fp8_single_module_parity /data/users/andgu/float8_experimental/float8_experimental/float8_linear_utils.py:272: FutureWarning: The combination of ranks + tag as process group identifier has been deprecated. Please switch to using ProcessGroup, DeviceMesh, or group name instead. all_reduced_amax_tensor = all_reduce( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ============================================== 8 passed, 2 warnings in 121.90s (0:02:01) ============================================== ``` imported-using-ghimport Test Plan: Imported from OSS Reviewed By: weifengpy Differential Revision: D60236258 Pulled By: awgu fbshipit-source-id: 7b1e48d431dac25d534a77d64d1e5571ad3ad807
- Loading branch information