Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1.0.0 rc Rebase #812

Merged
merged 15 commits into from
Mar 18, 2025
Merged

1.0.0 rc Rebase #812

merged 15 commits into from
Mar 18, 2025

Conversation

ktangsali
Copy link
Collaborator

@ktangsali ktangsali commented Mar 18, 2025

Modulus Pull Request

Description

Rebase of the 1.0.0-rc

CI run on internal test

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

coreyjadams and others added 15 commits March 17, 2025 17:52
* Stashing profiling work

* Torch profile works but is very slow.  line profiler not functional at this time

* Enablement of profiling tool with pytorch profiler, as a context manager.  Still several TBD Objects but this implementation will capture a torch profile.

* Moving profiling tools into a directory to make separate tools more clearly separated as well as enable easier extensions.

* Profiling tools work with torch profiler and line_profiler.  nsys has a crash that I haven't resolved yet.

* Fix line profiling construction

* Begin instrumenting figconvnet and adding tutorials on modulus profiling tools

* Remove annotations and force all annotations to conform to nvtx.  Simpler, for now, and the most (only?) useful annotation tool

* Updating profiling tutorial

* Minor updates to profiling interfaces

* only adding some profiling hooks to figconvnet

* Add profiling hooks to mesh graph net.

* Set TELayerNorm to default layer norm in MeshGraphNet

* Nearly finished profiling tutorial and tooling example.  Just need to add images.

* Final (first) draft of the profiling tutorial and clean up profiler code slightly.  Ready for draft PR

* Add tests to the profiler tools to check functionality.  Thanks Cursor!

Some minor updtes to the tools themselves to accomodate instance clearing and refreshing.

* Update changelog for profiling tools

* Update profiler files to (hopefully) pass CI checks

* Remove profiling parts from capture.py for later integration

* Update __init__.py

Remove nvtx wrapper

* Add extra line to make linting happy...

* When cuda is not available (mostly CI), emit a warning and switch to native layer norm.

* Make the default as LayerNorm so tests will pass.  Needs more care in the test, I think, about TELayerNorm

* Very minor fixes per review

* Resolve most comments from PR review.  One to go (profiler state to become a literal)

* Change profiler state tracker to a single state with an enum type.

* Two changes made here:
- the exit stack moves from a class variable to an instance variable
- The double-check locking mechanism in the registry becomes a single lock and check.

* Make sure the exit stack init is actually in __init__ and not initialize()
* Enable mesh-based parallelism as the configuration backend, even for simple DDP sharding

* Fix small typo in docstring

* Remove  unnecessary  functions with new interface

* Adding first implementation of ShardTensor prototype.  Still several pieces are WIP but this has basic functionality supported for creation and forward usage.

* Working implementation of ShardTensor, though still somewhate incomplete.

* Adding work-in-progress examples.  Be careful of sharp edges!

* A few more example pieces before natten will work out of the box.  Most of the ops have been validated, all that remains is to  wrap the na2d function call to ensure it will dispatch properly.

* Fix naming scheme

* Minor name change

* Add monkey patching for na2d operation with shard tensors

* Fix bug in shard tensor inference of globla size.  CHeck agains sharding in unbind op rules.

* Enable backwards gradients for halo sharding and natten patch

* Convolution 2d backwards works, though would be  better to catch torch.ops.aten.convolution.default.

* Fix missing import and ensure tensors are contiguous before allgather_v

* Clean up and remove unnecessary noise and printouts for debugging

* Unify (and correct!) the sharded convolution implementation.  There was also a minor bug in the backward
pass that got more pronounced with smaller data: grad inputs were failing to properly collect
haloed gradients and add them on the edges.  Now fixed.

* Remove noise from sharding utils.

* For smaller tensors, the alltoall step of halo reductions might be significant overhead.
I'm implementing here an option to switch to peer to peer message passing, since it might
benefit from stream utilization in layers like natten.na2d.

It's a developer choice currently, not a user choice.

* Remove shard_utils file, it is a subfolder.

* Add modulus ShardTensor api documentation

* Clean up doc strings, type annotations and mesh implementation.  No significant functionality changes in this commit.

* Add significant docstring / type annotation cleanup to ShardTensor.

Add `scatter_tensor` function to enable more easy transition to shard tensor.
This function allows users to maintain data pipelines (on one rank) and easily
scatter that data to a domain mesh.

* Remove neighborhood attention prototypes

* Remove the rest of these examples since they are outdated and unnecessary

* Mostly, this commit is adding type annotations and doc strings.

But also, this adjusts the shard tensor mechanism for tracking shard info to use
a dict instead of a list of tuples.

* Clean up and document conv patches.
No real code changes applied here.

* clean up and improve documentation and type hints for shard utils worker functions

* Adding basic tests for shard tensor initialization and redistribution.

There appears to be one corner case in redistribute to fix.  TBD.

Tests for grad propogation are coming.

* Add full working example of multilevel parallelism with pytorch
FSDP and modulus ShardTensor

* Add missing type annotations

* Ensure scatter_tensor is available to import from modulus.distributed

* Update changelog and ensure wrapt is a optional dependency

* Update fsdp_and_shard_tensor.rst

Update tutorial based on feedback from @pzharrington

* Update __init__.py

Remove wildcard import.

* Update shard_tensor.py

fix spacing

* This is an essential bug fix for a missing import

* Update branch to pass CI tests.

* This commit provides several pieces:

- First, the ability to transpose the sharding dimensions is supported.  For square submeshs, 2x2 for example,
the output sharding will match the input sharding if it's uneven.  This can only be supported if the number of
devices in the output mesh dimension is equal to the input dimension, hence the restriction on square submeshes.
Other scenarios will apply dtensor-like chunk syntax, but return a shard tensor tracking that split.  Comprehensive
tests on 1D and 2D meshes are included here.  No testing is done at this time on 3D sharding / meshes.

- Second, the issues with torch.mean are intercepted and fixed.  This uses a new dispatch intercept (below)
and applies a weight to the mean, and converts the Partial placement to a Partial(sum) with the weight applied.
This has a bug that appears to be present in DTensor too: reductions over non-sharded dimensions appear to falter.
To be fixed in a future release.

- Third, ShardTensor has a new class attribute to accomodate operator interceptions.  The only applied function
at this time are variants of aten.mean, however, it is expected to convert all monkey patching to this syntax.

* Update monkey patching to ensure patches get applied by modulus, and don't require
them to trigger elsewhere.  If ShardTensor is used, the patches get applied.

Also, minor updates to docs.

* Codify ShardTensor and FSDP in tutorials.

* Apparently, codify'ing in rst requires double ticks.

* This commit fixes gradient propagation for unevenly sharded tensors.  Tests are coming in the next commit immediately after.

* Add tests for shard tensor: initialization, resharding, and gradient sharding.

Further, fixed an annoying bug in other distributed tests where OS environs weren't cleared after testing, and tsome tests would fail but only if others ran first.

Now, all distributed tests use a context manager to change OS environment variables locally only.

* Two things done here:
- Enable dynamic (off by default) wrapping of layers by shard tensor.  they get turned on automatically when a shard tensor is created.
- Rename the utils to manage env variables.

Tests are failing with unusual CPU errors on ORD.  Moving to github runners ...

* Disable patched operations by default.
@ktangsali
Copy link
Collaborator Author

/blossom-ci

1 similar comment
@Alexey-Kamenev
Copy link
Collaborator

/blossom-ci

Copy link
Collaborator

@Alexey-Kamenev Alexey-Kamenev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@ktangsali ktangsali merged commit 4348562 into main Mar 18, 2025
1 check failed
@ktangsali ktangsali deleted the 1.0.0-rc-rebase-2 branch March 20, 2025 00:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants