Skip to content

Commit

Permalink
Add converter for harswish activation
Browse files Browse the repository at this point in the history
Had to add my owwn migraphx converter as the order of value return to clip would fail the test
  • Loading branch information
TedThemistokleous committed Oct 25, 2023
1 parent e237c6e commit e9fe731
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 0 deletions.
26 changes: 26 additions & 0 deletions py/torch_migraphx/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,32 @@ def acc_ops_hard_sigmoid(mgx_module, node, args, kwargs):
return mgx_module.add_instruction(migraphx.op('clip'), [add, zeros, ones])


@migraphx_converter(acc_ops.hardswish)
def acc_ops_hard_swish(mgx_module, node, args, kwargs):

inp = kwargs['input']
dtype = get_arg_dtype(inp)
shape = inp.shape().lens()

alpha = mgx_module.add_instruction(
migraphx.op('multibroadcast', out_lens=shape),
[mgx_module.add_literal(torch.tensor([1 / 6], dtype=dtype).numpy())])

beta = mgx_module.add_instruction(
migraphx.op('multibroadcast', out_lens=shape),
[mgx_module.add_literal(torch.tensor([1 / 2], dtype=dtype).numpy())])

zeros = mgx_module.add_instruction(
migraphx.op('multibroadcast', out_lens=shape),
[mgx_module.add_literal(torch.tensor([0], dtype=dtype).numpy())])

mul = mgx_module.add_instruction(migraphx.op('mul'), [alpha, inp])
add = mgx_module.add_instruction(migraphx.op('add'), [beta, mul])

mul2 = mgx_module.add_instruction(migraphx.op('mul'), [add, inp])
return mgx_module.add_instruction(migraphx.op('clip'), [zeros, inp, mul2])


@migraphx_converter(acc_ops.softmax)
def acc_ops_softmax(mgx_module, node, args, kwargs):

Expand Down
8 changes: 8 additions & 0 deletions py/torch_migraphx/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,14 @@ def aten_ops_leaky_relu(mgx_module, node, args, kwargs):
return acc_ops_converters.acc_ops_leaky_relu(mgx_module, node, (),
acc_kwargs)

@migraphx_converter(torch.ops.aten.hardswish.default)
def aten_ops_hardswish(mgx_module, node, args, kwargs):
assert len(args) == 1
acc_kwargs = {"input": args[0]}

return acc_ops_converters.acc_ops_hard_swish(mgx_module, node, (),
acc_kwargs)


@migraphx_converter(torch.ops.aten.hardsigmoid.default)
def aten_ops_hardsigmoid(mgx_module, node, args, kwargs):
Expand Down
5 changes: 5 additions & 0 deletions py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,11 @@ def hardsigmoid(*, input):
return nn.functional.hardsigmoid(input)


@register_acc_op
def hardswish(*, input):
return nn.functional.hardswish(input)


@register_acc_op_mapping(
op_and_target=("call_method", "softmax"),
arg_replacement_tuples=[
Expand Down
1 change: 1 addition & 0 deletions tests/dynamo/converters/test_activations_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_clamp(op_alias, inp_size):
torch.ops.aten.relu.default,
torch.ops.aten.tanh.default,
torch.ops.aten.hardsigmoid.default,
torch.ops.aten.hardswish.default,
torch.ops.aten.sigmoid.default,
torch.ops.aten.gelu.default,
torch.ops.aten.silu.default,
Expand Down

0 comments on commit e9fe731

Please sign in to comment.