Support vectorized gradients via boundary devectorization#1717
Closed
blasphemetheus wants to merge 5 commits intoelixir-nx:mainfrom
Closed
Support vectorized gradients via boundary devectorization#1717blasphemetheus wants to merge 5 commits intoelixir-nx:mainfrom
blasphemetheus wants to merge 5 commits intoelixir-nx:mainfrom
Conversation
2baec0c to
9e4a22b
Compare
Minimal approach: instead of per-op vectorization adjustments in the backward pass, handle vectorization at the boundary only: 1. Devectorize the gradient seed (constant 1.0 uses devec output shape) 2. Devectorize expression args in recur_to_grad (backward pass stays in devectorized space) 3. Re-vectorize the final gradient in to_grad to match input's vectorized_axes 229/232 existing tests pass. 3 failures are the mixed-vectorization case (non-vectorized target in vectorized context) where unbroadcast sums over batch dims — needs further work. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
9e4a22b to
8982220
Compare
Thread batch_count through to_grad_ids → update_grads → grad → unbroadcast instead of using Process.put/get. batch_count is computed once from the output's vectorized_axes and passed as a proper parameter through the entire backward pass. Changes: to_grad_ids gains a 3rd element, grad/4 becomes grad/5, unbroadcast/3 becomes unbroadcast/4. All 76 grad clauses updated. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When different grad inputs have different vectorized axis names (e.g.,
x with :a, y with :b), raise a clear error instead of producing wrong
results. The cross-product broadcasting pattern creates intermediate
reshapes with mixed batch/broadcast dims that the boundary approach
cannot distinguish.
Same-axis vectorization ({x_vec_a, y_vec_a}) continues to work.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
a330990 to
6ea2798
Compare
Bring over tests from the v1 approach: - vectorized_grad_test.exs: 91 correctness tests (85 pass, 6 skipped) - grad_test.exs: edge case tests with skip tags for unsupported patterns - helpers.ex: fix assert_all_close for vectorized tensors 13 tests skipped total — all involve apply_vectorized boundary ops (QR, cholesky, triangular_solve, cond, take_along_axis) or different vectorized axis names on different inputs. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
6ea2798 to
d0abc23
Compare
Per polvalente's review on elixir-nx#1697: use Nx.dot(a, [-2], b, [-2]) and Nx.dot(a, [-1], b, [-2]) instead of Nx.dot(a, [0], b, [0]) and Nx.dot(a, b) so the grad rules work for batched (3D+) tensors. Cholesky: replace dot([0], [0]) with dot([-2], [-2]), use Nx.eye(l) instead of Nx.eye(Nx.shape(l)), add batch_transpose for last-2-dim transpose (can't use adjoint — would double-conjugate for complex). QR: replace all Nx.dot(a, b) with Nx.dot(a, [-1], b, [-2]) inline. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
72b0ca0 to
25f1a71
Compare
Contributor
Author
|
sorry meant to open that to my fork so i can iterate with ci |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Minimal approach: instead of per-op vectorization adjustments in the backward pass, handle vectorization at the boundary only:
229/232 existing tests pass. 3 failures are the mixed-vectorization case (non-vectorized target in vectorized context) where unbroadcast sums over batch dims — needs further work.