Skip to content

Commit

Permalink
rename JAXishify_output -> Torchishify_output
Browse files Browse the repository at this point in the history
  • Loading branch information
Samuel Ainsworth committed Nov 7, 2023
1 parent 4ed99f5 commit d86f851
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torch2jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ def div_(self, other):
coerce = lambda x: Torchish(x).value


def implements(torch_function, JAXishify_output=True):
def implements(torch_function, Torchishify_output=True):
"""Register a torch function override for Torchish"""

def decorator(func):
func1 = (lambda *args, **kwargs: Torchish(func(*args, **kwargs))) if JAXishify_output else func
func1 = (lambda *args, **kwargs: Torchish(func(*args, **kwargs))) if Torchishify_output else func
functools.update_wrapper(func1, torch_function)
HANDLED_FUNCTIONS[torch_function] = func1
return func1
Expand Down Expand Up @@ -571,7 +571,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.

# NOTE: the "torch.Tensor" type annotations here are a lie, or at least an approximation: In reality, they can be
# anything coerce-able.
@implements(torch.nn.functional.multi_head_attention_forward, JAXishify_output=False)
@implements(torch.nn.functional.multi_head_attention_forward, Torchishify_output=False)
def multi_head_attention_forward(
query: torch.Tensor,
key: torch.Tensor,
Expand Down

0 comments on commit d86f851

Please sign in to comment.