Skip to content

Commit

Permalink
test torch.cat
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Ainsworth committed Nov 7, 2023
1 parent 13aa5d6 commit bd949dc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
3 changes: 3 additions & 0 deletions tests/test_all_the_things.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def test_oneliners():
t2j_function_test(lambda x: torch.transpose(x, 0, 2), [(2, 3, 5)])
t2j_function_test(lambda x: torch.transpose(x, 2, 1), [(2, 3, 5)])

t2j_function_test(lambda x, y: torch.cat((x, y)), [(2, 3), (5, 3)])
t2j_function_test(lambda x, y: torch.cat((x, y), dim=-1), [(2, 3), (2, 5)])


def test_detach():
t2j_function_test(lambda x: x.detach() ** 2, [()])
Expand Down
1 change: 0 additions & 1 deletion torch2jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def fn(*args, **kwargs):
connect(torch.Tensor.mul, jnp.multiply)


# TODO: test
@implements(torch.cat)
def cat(tensors, dim=0):
return jnp.concatenate([coerce(x) for x in tensors], axis=dim)
Expand Down

0 comments on commit bd949dc

Please sign in to comment.