Fix Nx.slice crash on scalar tensor#1712
Conversation
bin_slice/7 called hd([]) on empty strides list for rank-0 tensors. Added scalar guard clause that returns data unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
nx/test/nx/scalar_slice_test.exs
Outdated
| result = Nx.slice(t, [], []) | ||
| assert_in_delta Nx.to_number(result), 3.14, 1.0e-10 | ||
| end | ||
| end |
There was a problem hiding this comment.
Instead of a separate module, let's add a function to nx_test.exs. I would also add a doctest that says:
Sliding a one-dimensional tensor is a no-op:
iex> Nx.slice(42, [], [])
WDYT?
There was a problem hiding this comment.
sounds good! I'll go through and apply that idea to the other PRs too, just adding em at the end of nx_test.exs
polvalente
left a comment
There was a problem hiding this comment.
I'm not sure we should support slicing scalars.
Furthermore, if we're adding support for this, we should also check how EXLA and Torchx behave.
I think the correct PR here is to fail on scalar slicing
|
Slicing a scalar could always be a no-op if you pass no dimensions (and raise if you pass any) so there is nothing to be implemented on the other backends because there is no operation. |
|
FWIW: import numpy as np
x = np.array(5)
x[()] # returns 5But the current implementation requires indeed each backend to implement it (but the doctest should test them aleady) |
|
gotcha, so if the check for this case is to be applied, it shouldn't just live in binary backend, it should live in Nx.slice, can do! also I changed a test value from 3.15 to 3.14 in that commit, i think a typo from earlier |
a478ea3 to
b1f27e9
Compare
|
ci failure: https://github.com/elixir-nx/nx/actions/runs/23414733273/job/68108024086?pr=1712 internally its a divide by 0 looks like, should be fixed in elixir-nx/complex#29 |
Slicing a scalar tensor is a valid no-op — return the tensor
unchanged when shape is {} and start_indices/lengths are empty.
The check is done in Nx.slice itself (not in BinaryBackend) so all
backends get the fix without needing separate implementations.
NumPy does the same: np.array(5)[()] returns 5.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
b1f27e9 to
52c6c78
Compare
If you run
Nx.sliceon a scalar tensor [tensor with no dimensions, just a single number wrapped in a tensor struct] (aka rank 0), it throws an errorSlicing a scalar is a valid no-op.
Nx.Shape.slicedoes it fine. But hereBinaryBackend.bin_slice/7callshd(strides)on the empty strides list[], which crashes.The fix is to add a scalar guard clause that matches when all lists are empty and returns the data unchanged.
running scalar_slice_test.exs on main
originally part of the closed fuzz test edge case PR #1707