-
Notifications
You must be signed in to change notification settings - Fork 582
compile optimizer #2623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/IvanKobzarev/1/base
Are you sure you want to change the base?
compile optimizer #2623
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2623
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit d8b9431 with merge base f3e4747 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Compiling optimizer helps perf of Llama4 Scout Model 3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations) peak memory is the same ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,` [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're now compiling several things independently, it might make sense logically to have a section of the recipe where we compile everything after instantiation.
torchtune/training/_compile.py
Outdated
@@ -86,3 +86,10 @@ def compile_loss(loss: nn.Module, verbose: bool = True) -> nn.Module: | |||
else: | |||
loss = torch.compile(loss, backend=backend) | |||
return loss | |||
|
|||
|
|||
def compile_optimizer_step(optimizer_step_fn, verbose: bool = True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I appreciate you wanting to keep this similar to how we were currently doing things; however, we only needed to this for the loss function b/c we were doing funky things with chunking.
We should just compile this directly in the recipe. Same goes for the other PR you have up.
recipes/full_finetune_distributed.py
Outdated
@@ -923,7 +923,16 @@ def train(self) -> None: | |||
# If sharded, collect the DTensor here | |||
if isinstance(grad_norm, DTensor): | |||
grad_norm = grad_norm.full_tensor() | |||
self._optimizer.step() | |||
optimizer_step_fn = self._optimizer.step |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment below, we can just compile the optimizer step in the recipe directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, agree, just copied the previous setup. Will move compile to the recipe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noob q: is there a reason we need to compile self._optimizer.step
every step? Why is it different than the model, which we compile one time upfront?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I tried this out I found issues with setting up the LR scheduler which fails when attempting to wrap the optimizer step fn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, will check the compile optimizer error with LR scheduler.
Ah sorry, did not mean to approve :) |
Compiling optimizer helps perf of Llama4 Scout Model 3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations) peak memory is the same ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,` [ghstack-poisoned]
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## gh/IvanKobzarev/1/base #2623 +/- ##
==========================================================
- Coverage 65.78% 65.63% -0.16%
==========================================================
Files 396 396
Lines 23764 23769 +5
==========================================================
- Hits 15634 15600 -34
- Misses 8130 8169 +39 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
recipes/full_finetune_distributed.py
Outdated
if self._compile: | ||
optimizer_step_fn = torch.compile( | ||
optimizer_step_fn, | ||
backend=self._compile_backend, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some optimizers might not work with this, if i remember it correctly, like torchao/bnb. May need some testing. The safest option might be to add a compile flag per area, e.g.:
compile:
loss: True
model: True
optimizer_step: False
```
Compiling optimizer helps perf of Llama4 Scout Model 3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations) peak memory is the same ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,` [ghstack-poisoned]
Changed to direct compilation of self.optimizer.step and it works :) Just FYI for testing: compilation at the moment needs workarounds for 2 different problems:
If to remove /tmp/torchinductor_${USER} before every run - then it does not fires (or disable pt2 cache) |
Compiling optimizer helps perf of Llama4 Scout Model 3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations) peak memory is the same ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,` [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one nit on naming, but this looks good!
@@ -71,6 +71,11 @@ enable_activation_offloading: False | |||
fsdp_cpu_offload: True | |||
compile: False # torch.compile, set to true for perf/memory improvement | |||
|
|||
compile_components: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: could we match the argument to just "compile"? Then valid arguments would be "True", "False", or the specific components. If "True", then we compile everything. If "False", we compile nothing. If the argument has a dictionary with each component, then we follow those instructions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. Agree with this logic, will update to it.
Compiling optimizer helps perf of Llama4 Scout Model 3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations) peak memory is the same ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,` [ghstack-poisoned]
Compiling optimizer helps perf of Llama4 Scout Model 3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations) peak memory is the same ``` tune run --nproc_per_node 8 \ full_finetune_distributed \ --config recipes/configs/llama4/scout_17B_16E_full.yaml ``` PS: Current repo compilation fails if to set `skip_rope_interval=4,`, have to test with `skip_rope_interval=None,` [ghstack-poisoned]
Stack from ghstack (oldest at bottom):
Compiling optimizer helps perf of Llama4 Scout Model
3.8 tokens_per_second -> 9 tokens_per_second (max value of tokens per second in the first ~10 iterations)
peak memory is the same
PS:
Current repo compilation fails if to set
skip_rope_interval=4,
, have to test withskip_rope_interval=None,