Skip to content

Support vectorized gradients via boundary devectorization#1717

Closed
blasphemetheus wants to merge 5 commits intoelixir-nx:mainfrom
blasphemetheus:fix/1533-vectorized-grad-v2
Closed

Support vectorized gradients via boundary devectorization#1717
blasphemetheus wants to merge 5 commits intoelixir-nx:mainfrom
blasphemetheus:fix/1533-vectorized-grad-v2

Conversation

@blasphemetheus
Copy link
Copy Markdown
Contributor

Minimal approach: instead of per-op vectorization adjustments in the backward pass, handle vectorization at the boundary only:

  1. Devectorize the gradient seed (constant 1.0 uses devec output shape)
  2. Devectorize expression args in recur_to_grad (backward pass stays in devectorized space)
  3. Re-vectorize the final gradient in to_grad to match input's vectorized_axes

229/232 existing tests pass. 3 failures are the mixed-vectorization case (non-vectorized target in vectorized context) where unbroadcast sums over batch dims — needs further work.

@blasphemetheus blasphemetheus force-pushed the fix/1533-vectorized-grad-v2 branch from 2baec0c to 9e4a22b Compare March 26, 2026 12:59
Minimal approach: instead of per-op vectorization adjustments in the
backward pass, handle vectorization at the boundary only:

1. Devectorize the gradient seed (constant 1.0 uses devec output shape)
2. Devectorize expression args in recur_to_grad (backward pass stays
   in devectorized space)
3. Re-vectorize the final gradient in to_grad to match input's
   vectorized_axes

229/232 existing tests pass. 3 failures are the mixed-vectorization
case (non-vectorized target in vectorized context) where unbroadcast
sums over batch dims — needs further work.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@blasphemetheus blasphemetheus force-pushed the fix/1533-vectorized-grad-v2 branch from 9e4a22b to 8982220 Compare March 26, 2026 13:16
blasphemetheus and others added 2 commits March 26, 2026 08:36
Thread batch_count through to_grad_ids → update_grads → grad → unbroadcast
instead of using Process.put/get. batch_count is computed once from the
output's vectorized_axes and passed as a proper parameter through the
entire backward pass.

Changes: to_grad_ids gains a 3rd element, grad/4 becomes grad/5,
unbroadcast/3 becomes unbroadcast/4. All 76 grad clauses updated.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When different grad inputs have different vectorized axis names (e.g.,
x with :a, y with :b), raise a clear error instead of producing wrong
results. The cross-product broadcasting pattern creates intermediate
reshapes with mixed batch/broadcast dims that the boundary approach
cannot distinguish.

Same-axis vectorization ({x_vec_a, y_vec_a}) continues to work.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@blasphemetheus blasphemetheus force-pushed the fix/1533-vectorized-grad-v2 branch from a330990 to 6ea2798 Compare March 26, 2026 21:08
Bring over tests from the v1 approach:
- vectorized_grad_test.exs: 91 correctness tests (85 pass, 6 skipped)
- grad_test.exs: edge case tests with skip tags for unsupported patterns
- helpers.ex: fix assert_all_close for vectorized tensors

13 tests skipped total — all involve apply_vectorized boundary ops
(QR, cholesky, triangular_solve, cond, take_along_axis) or different
vectorized axis names on different inputs.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@blasphemetheus blasphemetheus force-pushed the fix/1533-vectorized-grad-v2 branch from 6ea2798 to d0abc23 Compare March 26, 2026 22:32
Per polvalente's review on elixir-nx#1697: use Nx.dot(a, [-2], b, [-2]) and
Nx.dot(a, [-1], b, [-2]) instead of Nx.dot(a, [0], b, [0]) and
Nx.dot(a, b) so the grad rules work for batched (3D+) tensors.

Cholesky: replace dot([0], [0]) with dot([-2], [-2]), use Nx.eye(l)
instead of Nx.eye(Nx.shape(l)), add batch_transpose for last-2-dim
transpose (can't use adjoint — would double-conjugate for complex).

QR: replace all Nx.dot(a, b) with Nx.dot(a, [-1], b, [-2]) inline.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@blasphemetheus blasphemetheus force-pushed the fix/1533-vectorized-grad-v2 branch from 72b0ca0 to 25f1a71 Compare March 26, 2026 22:43
@blasphemetheus
Copy link
Copy Markdown
Contributor Author

sorry meant to open that to my fork so i can iterate with ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant