Skip to content

Commit

Permalink
rename connect -> auto_implements
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Ainsworth committed Nov 7, 2023
1 parent 593b7e1 commit bd7bd9c
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions torch2jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def decorator(func):
return decorator


def connect(torch_function, jax_function, dont_coerce_argnums=()):
def auto_implements(torch_function, jax_function, dont_coerce_argnums=()):
@implements(torch_function)
def fn(*args, **kwargs):
# NOTE: we don't coerce kwargs! So far this has not been problematic.
Expand All @@ -157,18 +157,18 @@ def fn(*args, **kwargs):
)


connect(torch.add, jnp.add)
connect(torch.exp, jnp.exp)
connect(torch.nn.functional.gelu, jax.nn.gelu)
connect(torch.mean, jnp.mean)
connect(torch.mul, jnp.multiply)
connect(torch.permute, jnp.transpose, dont_coerce_argnums=(1, 2))
connect(torch.pow, jnp.power)
connect(torch.sigmoid, jax.nn.sigmoid)
connect(torch.sqrt, jnp.sqrt)
connect(torch.sum, jnp.sum)
connect(torch.tanh, jnp.tanh)
connect(torch.transpose, jnp.swapaxes)
auto_implements(torch.add, jnp.add)
auto_implements(torch.exp, jnp.exp)
auto_implements(torch.nn.functional.gelu, jax.nn.gelu)
auto_implements(torch.mean, jnp.mean)
auto_implements(torch.mul, jnp.multiply)
auto_implements(torch.permute, jnp.transpose, dont_coerce_argnums=(1, 2))
auto_implements(torch.pow, jnp.power)
auto_implements(torch.sigmoid, jax.nn.sigmoid)
auto_implements(torch.sqrt, jnp.sqrt)
auto_implements(torch.sum, jnp.sum)
auto_implements(torch.tanh, jnp.tanh)
auto_implements(torch.transpose, jnp.swapaxes)


@implements(torch.cat)
Expand Down Expand Up @@ -541,7 +541,7 @@ def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode

@implements(torch.nn.functional.relu)
def relu(x, inplace=False):
# Can't use `connect` since jax.nn.relu does not have an `inplace` option.
# Can't use `auto_implements` since jax.nn.relu does not have an `inplace` option.
if inplace:
assert isinstance(x, Torchish)
x.value = jax.nn.relu(x.value)
Expand Down

0 comments on commit bd7bd9c

Please sign in to comment.