Skip to content

Commit

Permalink
Add fix and test for the unsqueeze issue in #7
Browse files Browse the repository at this point in the history
  • Loading branch information
samuela committed Nov 22, 2024
1 parent 2490400 commit 93ed706
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
5 changes: 5 additions & 0 deletions tests/test_all_the_things.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def poop(self):
return rng_key


def test_t2j_array():
# See https://github.com/samuela/torch2jax/issues/7
aac(t2j(torch.eye(3).unsqueeze(0)), jnp.eye(3)[jnp.newaxis, ...])


def t2j_function_test(f, input_shapes, rng=random.PRNGKey(123), num_tests=5, **assert_kwargs):
for test_rng in random.split(rng, num_tests):
inputs = [random.normal(rng, shape) for rng, shape in zip(random.split(test_rng, len(input_shapes)), input_shapes)]
Expand Down
12 changes: 11 additions & 1 deletion torch2jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,19 @@ def t2j_array(torch_array):
# `torch.func.functionalize` in `t2j_function`. For now, we're avoiding `torch.func.functionalize`, but something to
# be wary of in the future.

# RuntimeError: Can't export tensors that require gradient, use tensor.detach()
torch_array = torch_array.detach()

# See https://github.com/google/jax/issues/8082.
torch_array = torch_array.contiguous()
return jax.dlpack.from_dlpack(torch.utils.dlpack.to_dlpack(torch_array))

# At some point between 0.4.28 and 0.4.33 from_dlpack introduced a new
# deprecation notice:
#
# DeprecationWarning: Calling from_dlpack with a DLPack tensor is deprecated. The argument to from_dlpack should be an array from another framework that implements the __dlpack__ protocol.
#
# Very well, PyTorch arrays implement the __dlpack__ protocol, so no need to convert them to dlpack first.
return jax.dlpack.from_dlpack(torch_array)

# Alternative, but copying implementation:
# Note FunctionalTensor.numpy() returns incorrect results, preventing us from using torch.func.functionalize.
Expand Down

0 comments on commit 93ed706

Please sign in to comment.