From 1446b83c8b83780a96443ab69cbd8e1f8fc4640f Mon Sep 17 00:00:00 2001 From: Samuel Ainsworth Date: Mon, 6 Nov 2023 16:30:36 -0800 Subject: [PATCH] test `torch.flatten` --- tests/test_all_the_things.py | 4 ++++ torch2jax/__init__.py | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_all_the_things.py b/tests/test_all_the_things.py index 7ac3ef0..1a344d0 100644 --- a/tests/test_all_the_things.py +++ b/tests/test_all_the_things.py @@ -129,6 +129,10 @@ def test_oneliners(): t2j_function_test(lambda x, y: torch.cat((x, y)), [(2, 3), (5, 3)]) t2j_function_test(lambda x, y: torch.cat((x, y), dim=-1), [(2, 3), (2, 5)]) + t2j_function_test(torch.flatten, [(2, 3, 5)]) + t2j_function_test(lambda x: torch.flatten(x, start_dim=1), [(2, 3, 5)]) + t2j_function_test(lambda x: torch.flatten(x, start_dim=2), [(2, 3, 5, 7)]) + def test_detach(): t2j_function_test(lambda x: x.detach() ** 2, [()]) diff --git a/torch2jax/__init__.py b/torch2jax/__init__.py index d1474a4..16d44d7 100644 --- a/torch2jax/__init__.py +++ b/torch2jax/__init__.py @@ -178,7 +178,6 @@ def cat(tensors, dim=0): return jnp.concatenate([coerce(x) for x in tensors], axis=dim) -# TODO: test flatten @implements(torch.flatten) def flatten(input, start_dim=0, end_dim=-1): assert end_dim == -1, "TODO: implement end_dim"