Skip to content

Fix Nx.slice crash on scalar tensor#1712

Merged
polvalente merged 5 commits intoelixir-nx:mainfrom
blasphemetheus:fork/fix/scalar-slice
Mar 24, 2026
Merged

Fix Nx.slice crash on scalar tensor#1712
polvalente merged 5 commits intoelixir-nx:mainfrom
blasphemetheus:fork/fix/scalar-slice

Conversation

@blasphemetheus
Copy link
Copy Markdown
Contributor

@blasphemetheus blasphemetheus commented Mar 20, 2026

If you run Nx.slice on a scalar tensor [tensor with no dimensions, just a single number wrapped in a tensor struct] (aka rank 0), it throws an error

Nx.slice(Nx.tensor(42), [], [])
# ** (ArgError) 1st argument: not a nonempty list
#      :erlang.hd([])
...

Slicing a scalar is a valid no-op. Nx.Shape.slice does it fine. But here BinaryBackend.bin_slice/7 calls hd(strides) on the empty strides list [], which crashes.

The fix is to add a scalar guard clause that matches when all lists are empty and returns the data unchanged.

running scalar_slice_test.exs on main

     Nx.ScalarSliceTest [test/nx/scalar_slice_test.exs]                                 
       * test slice of scalar tensor returns scalar [L#4]                               
       * test slice of scalar tensor returns scalar (20.5ms) [L#4]                      
                                                                                        
       1) test slice of scalar tensor returns scalar (Nx.ScalarSliceTest)               
          test/nx/scalar_slice_test.exs:4                                               
          ** (ArgumentError) errors were found at the given arguments:                  
                                                                                        
            * 1st argument: not a nonempty list                                         
                                                                                        
          code: result = Nx.slice(t, [], [])                                            
          stacktrace:                                                                   
            :erlang.hd([])                                                              
            (nx 0.11.0) lib/nx/binary_backend.ex:1855: Nx.BinaryBackend.bin_slice/7     
            (nx 0.11.0) lib/nx/binary_backend.ex:1848: Nx.BinaryBackend.slice/5         
            (nx 0.11.0) lib/nx.ex:13738: Nx.slice/4                                     
            test/nx/scalar_slice_test.exs:6: (test)                                     
...

originally part of the closed fuzz test edge case PR #1707

bin_slice/7 called hd([]) on empty strides list for rank-0 tensors.
Added scalar guard clause that returns data unchanged.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
result = Nx.slice(t, [], [])
assert_in_delta Nx.to_number(result), 3.14, 1.0e-10
end
end
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of a separate module, let's add a function to nx_test.exs. I would also add a doctest that says:

Sliding a one-dimensional tensor is a no-op:

iex> Nx.slice(42, [], [])

WDYT?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good! I'll go through and apply that idea to the other PRs too, just adding em at the end of nx_test.exs

Copy link
Copy Markdown
Contributor

@polvalente polvalente left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should support slicing scalars.
Furthermore, if we're adding support for this, we should also check how EXLA and Torchx behave.

I think the correct PR here is to fail on scalar slicing

@josevalim
Copy link
Copy Markdown
Contributor

Slicing a scalar could always be a no-op if you pass no dimensions (and raise if you pass any) so there is nothing to be implemented on the other backends because there is no operation.

@josevalim
Copy link
Copy Markdown
Contributor

FWIW:

import numpy as np
x = np.array(5)
x[()]  # returns 5

But the current implementation requires indeed each backend to implement it (but the doctest should test them aleady)

@blasphemetheus
Copy link
Copy Markdown
Contributor Author

blasphemetheus commented Mar 22, 2026

gotcha, so if the check for this case is to be applied, it shouldn't just live in binary backend, it should live in Nx.slice, can do!

also I changed a test value from 3.15 to 3.14 in that commit, i think a typo from earlier

@blasphemetheus blasphemetheus force-pushed the fork/fix/scalar-slice branch from a478ea3 to b1f27e9 Compare March 22, 2026 23:07
@blasphemetheus
Copy link
Copy Markdown
Contributor Author

blasphemetheus commented Mar 22, 2026

ci failure:


  1) test pinv does not raise for 0 singular values (Nx.LinAlgTest)
Error:      test/nx/lin_alg_test.exs:1025
     ** (ArithmeticError) bad argument in arithmetic expression
     stacktrace:
       (complex 0.6.0) lib/complex.ex:579: Complex.divide/2
       (nx 0.11.0) lib/nx/binary_backend.ex:707: anonymous fn/11 in Nx.BinaryBackend.element_wise_bin_op/4
       (elixir 1.18.4) lib/enum.ex:4507: Enum.reduce/3
       (nx 0.11.0) lib/nx/binary_backend.ex:698: Nx.BinaryBackend.element_wise_bin_op/4
       (nx 0.11.0) lib/nx/defn/evaluator.ex:426: Nx.Defn.Evaluator.eval_apply/5
       (nx 0.11.0) lib/nx/defn/evaluator.ex:257: Nx.Defn.Evaluator.eval/3
       (elixir 1.18.4) lib/enum.ex:1840: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
       (elixir 1.18.4) lib/enum.ex:1840: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
       (nx 0.11.0) lib/nx/defn/evaluator.ex:404: Nx.Defn.Evaluator.eval_apply/5
       (nx 0.11.0) lib/nx/defn/evaluator.ex:257: Nx.Defn.Evaluator.eval/3
       (elixir 1.18.4) lib/enum.ex:1840: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
       (nx 0.11.0) lib/nx/defn/evaluator.ex:404: Nx.Defn.Evaluator.eval_apply/5
       (nx 0.11.0) lib/nx/defn/evaluator.ex:257: Nx.Defn.Evaluator.eval/3
       (elixir 1.18.4) lib/enum.ex:1840: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
       (elixir 1.18.4) lib/enum.ex:1840: Enum."-map_reduce/3-lists^mapfoldl/2-0-"/3
       (nx 0.11.0) lib/nx/defn/evaluator.ex:404: Nx.Defn.Evaluator.eval_apply/5
       (nx 0.11.0) lib/nx/defn/evaluator.ex:257: Nx.Defn.Evaluator.eval/3
       (nx 0.11.0) lib/nx/defn/tree.ex:235: Nx.Defn.Tree.apply_args/4
       (nx 0.11.0) lib/nx/defn/evaluator.ex:404: Nx.Defn.Evaluator.eval_apply/5
       (nx 0.11.0) lib/nx/defn/evaluator.ex:257: Nx.Defn.Evaluator.eval/3

https://github.com/elixir-nx/nx/actions/runs/23414733273/job/68108024086?pr=1712

internally its a divide by 0 looks like, should be fixed in elixir-nx/complex#29

Slicing a scalar tensor is a valid no-op — return the tensor
unchanged when shape is {} and start_indices/lengths are empty.

The check is done in Nx.slice itself (not in BinaryBackend) so all
backends get the fix without needing separate implementations.

NumPy does the same: np.array(5)[()] returns 5.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@blasphemetheus blasphemetheus force-pushed the fork/fix/scalar-slice branch from b1f27e9 to 52c6c78 Compare March 23, 2026 05:09
@polvalente polvalente merged commit 674d3bb into elixir-nx:main Mar 24, 2026
7 of 9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants