-
Notifications
You must be signed in to change notification settings - Fork 514
Fixes pytorch/xla#7398 #9047
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Fixes pytorch/xla#7398 #9047
Conversation
torchax/test/test_ops.py
Outdated
@@ -204,6 +203,10 @@ def test_reference_eager(self, device, dtype, op): | |||
|
|||
# print("[DEBUG] sample_input: ", sample_input) | |||
|
|||
if op.name == "cat": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should edit here: https://github.com/pytorch/xla/blob/master/torchax/torchax/ops/jaten.py#L1200 and filter out the empty tensor instead of skipping the test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. Thanks!
Before the fix:
ERROR: test_reference_eager_cat_cpu_float32 (main.TestOpInfoCPU) [torchax_eval]
Traceback (most recent call last):
File "/home/qianminj/qianminj123/xla/torchax/test/test_ops.py", line 140, in run_export_and_compare
res2 = func(input2, *args2, **kwargs2)
File "/home/qianminj/anaconda3/envs/torchax/lib/python3.10/site-packages/torch/testing/_internal/opinfo/core.py", line 1178, in call
return self.op(*args, **kwargs)
File "/home/qianminj/xla/torchax/torchax/tensor.py", line 266, in torch_function
return func(*args, **(kwargs or {}))
File "/home/qianminj/xla/torchax/torchax/tensor.py", line 289, in torch_dispatch
return self.env.dispatch(func, types, args, kwargs)
File "/home/qianminj/xla/torchax/torchax/tensor.py", line 526, in dispatch
res = op.func(*args, **kwargs)
File "/home/qianminj/xla/torchax/torchax/ops/jaten.py", line 1201, in _aten_cat
return jnp.concatenate(tensors, dims)
File "/home/qianminj/anaconda3/envs/torchax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4638, in concatenate
axis = _canonicalize_axis(axis, np.ndim(arrays[0]))
File "/home/qianminj/anaconda3/envs/torchax/lib/python3.10/site-packages/jax/_src/util.py", line 376, in canonicalize_axis
raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
ValueError: axis 1 is out of bounds for array of dimension 1
======================================================================
ERROR: test_reference_eager_cat_cpu_int64 (main.TestOpInfoCPU) [torchax_eval]
Traceback (most recent call last):
File "/home/qianminj/qianminj123/xla/torchax/test/test_ops.py", line 140, in run_export_and_compare
res2 = func(input2, *args2, **kwargs2)
File "/home/qianminj/anaconda3/envs/torchax/lib/python3.10/site-packages/torch/testing/_internal/opinfo/core.py", line 1178, in call
return self.op(*args, **kwargs)
File "/home/qianminj/xla/torchax/torchax/tensor.py", line 266, in torch_function
return func(*args, **(kwargs or {}))
File "/home/qianminj/xla/torchax/torchax/tensor.py", line 289, in torch_dispatch
return self.env.dispatch(func, types, args, kwargs)
File "/home/qianminj/xla/torchax/torchax/tensor.py", line 526, in dispatch
res = op.func(*args, **kwargs)
File "/home/qianminj/xla/torchax/torchax/ops/jaten.py", line 1201, in _aten_cat
return jnp.concatenate(tensors, dims)
File "/home/qianminj/anaconda3/envs/torchax/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 4638, in concatenate
axis = _canonicalize_axis(axis, np.ndim(arrays[0]))
File "/home/qianminj/anaconda3/envs/torchax/lib/python3.10/site-packages/jax/_src/util.py", line 376, in canonicalize_axis
raise ValueError(f"axis {axis} is out of bounds for array of dimension {num_dims}")
ValueError: axis 1 is out of bounds for array of dimension 1
Ran 2 tests in 4.274s
FAILED (errors=2)
After the fix:
Ran 2 tests in 2.877s
OK