Skip to content

Commit

Permalink
test torch.flatten
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Ainsworth committed Nov 7, 2023
1 parent bd949dc commit 1446b83
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 1 deletion.
4 changes: 4 additions & 0 deletions tests/test_all_the_things.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def test_oneliners():
t2j_function_test(lambda x, y: torch.cat((x, y)), [(2, 3), (5, 3)])
t2j_function_test(lambda x, y: torch.cat((x, y), dim=-1), [(2, 3), (2, 5)])

t2j_function_test(torch.flatten, [(2, 3, 5)])
t2j_function_test(lambda x: torch.flatten(x, start_dim=1), [(2, 3, 5)])
t2j_function_test(lambda x: torch.flatten(x, start_dim=2), [(2, 3, 5, 7)])


def test_detach():
t2j_function_test(lambda x: x.detach() ** 2, [()])
Expand Down
1 change: 0 additions & 1 deletion torch2jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def cat(tensors, dim=0):
return jnp.concatenate([coerce(x) for x in tensors], axis=dim)


# TODO: test flatten
@implements(torch.flatten)
def flatten(input, start_dim=0, end_dim=-1):
assert end_dim == -1, "TODO: implement end_dim"
Expand Down

0 comments on commit 1446b83

Please sign in to comment.