Skip to content

compile optimizer #2623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: gh/IvanKobzarev/1/base
Choose a base branch
from

Conversation

IvanKobzarev
Copy link

@IvanKobzarev IvanKobzarev commented Apr 22, 2025

Stack from ghstack (oldest at bottom):

Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same

tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml

PS:
Current repo compilation fails if to set skip_rope_interval=4,, have to test with skip_rope_interval=None,

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Apr 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2623

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Unrelated Failures

As of commit d8b9431 with merge base f3e4747 (image):

NEW FAILURE - The following job has failed:

  • GPU tests / gpu_test (3.9, stable) (gh)
    tests/recipes/test_full_finetune_distributed.py::TestFullFinetuneDistributedRecipe::test_loss_2d_parallel[llama3/8B_full-llama3-tune-4-1-True-2]

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 22, 2025
Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```


PS:
Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,`

[ghstack-poisoned]
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're now compiling several things independently, it might make sense logically to have a section of the recipe where we compile everything after instantiation.

@@ -86,3 +86,10 @@ def compile_loss(loss: nn.Module, verbose: bool = True) -> nn.Module:
else:
loss = torch.compile(loss, backend=backend)
return loss


def compile_optimizer_step(optimizer_step_fn, verbose: bool = True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appreciate you wanting to keep this similar to how we were currently doing things; however, we only needed to this for the loss function b/c we were doing funky things with chunking.

We should just compile this directly in the recipe. Same goes for the other PR you have up.

@@ -923,7 +923,16 @@ def train(self) -> None:
# If sharded, collect the DTensor here
if isinstance(grad_norm, DTensor):
grad_norm = grad_norm.full_tensor()
self._optimizer.step()
optimizer_step_fn = self._optimizer.step
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment below, we can just compile the optimizer step in the recipe directly.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, agree, just copied the previous setup. Will move compile to the recipe.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noob q: is there a reason we need to compile self._optimizer.step every step? Why is it different than the model, which we compile one time upfront?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I tried this out I found issues with setting up the LR scheduler which fails when attempting to wrap the optimizer step fn

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, will check the compile optimizer error with LR scheduler.

@joecummings
Copy link
Contributor

Ah sorry, did not mean to approve :)

Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```


PS:
Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,`

[ghstack-poisoned]
@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 0% with 6 lines in your changes missing coverage. Please review.

Project coverage is 65.63%. Comparing base (f3e4747) to head (e12e17f).

Files with missing lines Patch % Lines
recipes/full_finetune_distributed.py 0.00% 6 Missing ⚠️
Additional details and impacted files
@@                    Coverage Diff                     @@
##           gh/IvanKobzarev/1/base    #2623      +/-   ##
==========================================================
- Coverage                   65.78%   65.63%   -0.16%     
==========================================================
  Files                         396      396              
  Lines                       23764    23769       +5     
==========================================================
- Hits                        15634    15600      -34     
- Misses                       8130     8169      +39     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines 929 to 933
if self._compile:
optimizer_step_fn = torch.compile(
optimizer_step_fn,
backend=self._compile_backend,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some optimizers might not work with this, if i remember it correctly, like torchao/bnb. May need some testing. The safest option might be to add a compile flag per area, e.g.:

compile:
     loss: True
     model: True
     optimizer_step: False
     ```

Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```


PS:
Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,`

[ghstack-poisoned]
@IvanKobzarev
Copy link
Author

IvanKobzarev commented Apr 28, 2025

Changed to direct compilation of self.optimizer.step and it works :)
Updated the diff.

Just FYI for testing: compilation at the moment needs workarounds for 2 different problems:

  1. There is some problem with rng states preservation which can be workarounded
diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py
index 668353867ab..493883542f9 100644
--- a/torch/_dynamo/convert_frame.py
+++ b/torch/_dynamo/convert_frame.py
@@ -249,8 +249,8 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
             prior_dtype = torch.get_default_dtype()
             torch_rng_state = torch.random.get_rng_state()
             cuda_rng_state = None
-            if torch.cuda.is_available():
-                cuda_rng_state = torch.cuda.get_rng_state()
+            # if torch.cuda.is_available():
+            #     cuda_rng_state = torch.cuda.get_rng_state()
             allow_tf32 = torch._C._get_cublas_allow_tf32()
             prior_fwd_from_src = torch.fx.graph_module._forward_from_src
             torch.fx.graph_module._forward_from_src = fx_forward_from_src_skip_result
@@ -281,8 +281,8 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
                 )
                 if prior_mobile_allocator_state != curr_mobile_allocator_state:
                     torch._C._unset_default_mobile_cpu_allocator()
-                if cuda_rng_state is not None:
-                    torch.cuda.set_rng_state(cuda_rng_state)
+                # if cuda_rng_state is not None:
+                #     torch.cuda.set_rng_state(cuda_rng_state)
                 torch._C._set_cublas_allow_tf32(allow_tf32)
                 torch.fx.graph_module._forward_from_src = prior_fwd_from_src
                 assert guards.check(), (
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index b75b1d6c39f..7ca67523704 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -2110,15 +2110,15 @@ def preserve_rng_state():
     with disable_current_modes(), disable_functorch():
         rng_state = torch.clone(torch.random.get_rng_state())
         skip_frame_if_in_functorch_mode(rng_state)
-        if torch.cuda.is_available():
-            cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
+        # if torch.cuda.is_available():
+        #     cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
     try:
         yield
     finally:
         with torch.utils._python_dispatch._disable_current_modes():
             torch.random.set_rng_state(rng_state)
-            if torch.cuda.is_available():
-                torch.cuda.set_rng_state(cuda_rng_state)  # type: ignore[possibly-undefined]
+            # if torch.cuda.is_available():
+            #     torch.cuda.set_rng_state(cuda_rng_state)  # type: ignore[possibly-undefined]
 
 
 def is_jit_model(model0):
  1. There is illegal memory access in Chunked flex Attention x Caching.

If to remove /tmp/torchinductor_${USER} before every run - then it does not fires (or disable pt2 cache)

Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```


PS:
Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,`

[ghstack-poisoned]
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one nit on naming, but this looks good!

@@ -71,6 +71,11 @@ enable_activation_offloading: False
fsdp_cpu_offload: True
compile: False # torch.compile, set to true for perf/memory improvement

compile_components:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we match the argument to just "compile"? Then valid arguments would be "True", "False", or the specific components. If "True", then we compile everything. If "False", we compile nothing. If the argument has a dictionary with each component, then we follow those instructions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Agree with this logic, will update to it.

Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```


PS:
Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,`

[ghstack-poisoned]
Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same

```
tune run --nproc_per_node 8 \
  full_finetune_distributed \
  --config recipes/configs/llama4/scout_17B_16E_full.yaml
```


PS:
Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,`

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants