Skip to content

Commit

Permalink
Update aten operators (#46)
Browse files Browse the repository at this point in the history
* Update aten.hardtanh support

Add converter for hardtanh similar to that of clamp

* Add converter for harswish activation

Had to add my owwn migraphx converter as the order of value return to clip would fail the test

* fixup! Add converter for harswish activation

* Fix names for selu & softsign acc_ops converters

Was breaking when trying to use elu converter for aten elu converter op

* Add elu aten converter op

- Test for both non parameterize and parametertized for alpha/default values

* Add aten.max for max operator

Handle getting max value from input tensor.

* fixup! Add aten.max for max operator

* fixup! fixup! Add aten.max for max operator

* Add support and aten op for torch.min

Similar to the onnx reduce_min operator, implimentatin is similar to that of mean, max and map to similar reduce ops in MIGraphX.

This one is a freebie when doing max

* Add changes for stack op

* Add fx/dynamo changes for argmin op

Required if we want to support min() operator down the road.

Added converter for fx and dynamo in similar vein as the argmax function.

Also added unit tests

* Fix acc op and converters for max/min ops

Updated unit tests as well

* Fix aten op for max/min

added changes to unit tests and fix operators to handle multi input args correctly.

* Changes based on review comments

- remove print in stack
- merge argmin argmax in testing
- rename test_leak_relu -> test_single_param_activation_funcs
- Add default for min dim=None

---------

Co-authored-by: Ted Themistokleous <[email protected]>
  • Loading branch information
TedThemistokleous and TedThemistokleous authored Nov 15, 2023
1 parent d2cc58f commit ff8ac13
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 14 deletions.
72 changes: 70 additions & 2 deletions py/torch_migraphx/fx/converters/acc_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def acc_ops_elu(mgx_module, node, args, kwargs):


@migraphx_converter(acc_ops.selu)
def acc_ops_elu(mgx_module, node, args, kwargs):
def acc_ops_selu(mgx_module, node, args, kwargs):

inp = kwargs['input']
dtype = get_arg_dtype(inp)
Expand Down Expand Up @@ -442,7 +442,7 @@ def acc_ops_elu(mgx_module, node, args, kwargs):


@migraphx_converter(acc_ops.softsign)
def acc_ops_elu(mgx_module, node, args, kwargs):
def acc_ops_softsign(mgx_module, node, args, kwargs):

inp = kwargs['input']
dtype = get_arg_dtype(inp)
Expand Down Expand Up @@ -787,6 +787,26 @@ def acc_ops_argmax(mgx_module, node, args, kwargs):
return out


@migraphx_converter(acc_ops.argmin)
def acc_ops_argmin(mgx_module, node, args, kwargs):
inp = kwargs['input']
dim = kwargs["dim"]
keepdim = kwargs["keepdim"]

if dim is None:
assert not keepdim, "keepdim cannot be true when dim is None"
inp = acc_ops_flatten(mgx_module, node, (), {"input": inp})
dim = 0

out = mgx_module.add_instruction(migraphx.op('argmin', axis=dim), [inp])

if not keepdim:
out = mgx_module.add_instruction(migraphx.op('squeeze', axes=[dim]),
[out])

return out


@migraphx_converter(acc_ops.embedding)
def acc_ops_embedding(mgx_module, node, args, kwargs):
inp = kwargs['input']
Expand Down Expand Up @@ -980,6 +1000,54 @@ def acc_ops_maximum(mgx_module, node, args, kwargs):
return mgx_module.add_instruction(migraphx.op('max'), [inp, other])


@migraphx_converter(acc_ops.max)
def acc_ops_max(mgx_module, node, args, kwargs):
inp = kwargs['input']
in_shape = inp.shape().lens()

if 'dim' not in kwargs:
dims = list(range(len(in_shape)))
max_ = mgx_module.add_instruction(
migraphx.op('reduce_max', axes=dims), [inp])
return mgx_module.add_instruction(migraphx.op('squeeze', axes=dims), [max_])
else:
dims = kwargs['dim']
indicies = acc_ops_argmax(mgx_module, node, args, kwargs)
max_ = mgx_module.add_instruction(
migraphx.op('reduce_max', axes=[dims]), [inp])

if 'keepdim' in kwargs and kwargs['keepdim']:
return [max_, indicies]

max_ = mgx_module.add_instruction(
migraphx.op('reduce_max', axes=[dims]), [inp])
return [mgx_module.add_instruction(migraphx.op('squeeze', axes=[dims]), [max_]), indicies]


@migraphx_converter(acc_ops.min)
def acc_ops_min(mgx_module, node, args, kwargs):
inp = kwargs['input']
in_shape = inp.shape().lens()

if 'dim' not in kwargs:
dims = list(range(len(in_shape)))
min_ = mgx_module.add_instruction(
migraphx.op('reduce_min', axes=dims), [inp])
return mgx_module.add_instruction(migraphx.op('squeeze', axes=dims), [min_])
else:
dims = kwargs['dim']
indicies = acc_ops_argmin(mgx_module, node, args, kwargs)
min_ = mgx_module.add_instruction(
migraphx.op('reduce_min', axes=[dims]), [inp])

if 'keepdim' in kwargs and kwargs['keepdim']:
return [min_, indicies]

min_ = mgx_module.add_instruction(
migraphx.op('reduce_min', axes=[dims]), [inp])
return [mgx_module.add_instruction(migraphx.op('squeeze', axes=[dims]), [min_]), indicies]


@migraphx_converter(acc_ops.mean)
def acc_ops_mean(mgx_module, node, args, kwargs):

Expand Down
109 changes: 105 additions & 4 deletions py/torch_migraphx/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import migraphx
import torch
from typing import cast, Iterable, List, Sequence
from ..converter_registry import migraphx_converter
from torch_migraphx.fx.converters import acc_ops_converters
from ..utils import torch_dtype_to_mgx_enum
Expand Down Expand Up @@ -248,15 +249,19 @@ def aten_ops_split(mgx_module, node, args, kwargs):
return slice_nodes


@migraphx_converter(torch.ops.aten.hardtanh.default)
@migraphx_converter(torch.ops.aten.clamp.default)
def aten_ops_clamp(mgx_module, node, args, kwargs):
assert len(args) >= 1
min_, max_ = None, None
if node.target == torch.ops.aten.hardtanh.default:
min_, max_ = -1, 1

acc_kwargs = {
"input": args[0],
"min": args[1] if len(args) >= 2 else None,
"max": args[2] if len(args) == 3 else None
"min": args[1] if len(args) >= 2 else min_,
"max": args[2] if len(args) == 3 else max_
}

return acc_ops_converters.acc_ops_clamp(mgx_module, node, (), acc_kwargs)


Expand All @@ -276,6 +281,16 @@ def aten_ops_tanh(mgx_module, node, args, kwargs):
return acc_ops_converters.acc_ops_tanh(mgx_module, node, (), acc_kwargs)


@migraphx_converter(torch.ops.aten.elu.default)
def aten_ops_elu(mgx_module, node, args, kwargs):
assert len(args) >= 1
inp = args[0]
alpha = 1.0 if len(args) < 2 else args[1]

acc_kwargs = {'input': inp, 'alpha': alpha}
return acc_ops_converters.acc_ops_elu(mgx_module, node, (), acc_kwargs)


@migraphx_converter(torch.ops.aten.leaky_relu.default)
def aten_ops_leaky_relu(mgx_module, node, args, kwargs):
assert len(args) >= 1
Expand All @@ -287,6 +302,17 @@ def aten_ops_leaky_relu(mgx_module, node, args, kwargs):
acc_kwargs)


@migraphx_converter(torch.ops.aten.hardswish.default)
def aten_ops_hardswish(mgx_module, node, args, kwargs):
assert len(args) == 1
acc_kwargs = {"input": args[0]}

hard_sig = acc_ops_converters.acc_ops_hard_sigmoid(mgx_module, node, (), acc_kwargs)

mul_kwargs = {"input": args[0], "other": hard_sig}
return acc_ops_converters.acc_ops_mul(mgx_module, node, (), mul_kwargs)


@migraphx_converter(torch.ops.aten.hardsigmoid.default)
def aten_ops_hardsigmoid(mgx_module, node, args, kwargs):
assert len(args) == 1
Expand Down Expand Up @@ -718,7 +744,6 @@ def aten_ops_embedding(mgx_module, node, args, kwargs):
return acc_ops_converters.acc_ops_embedding(mgx_module, node, (),
acc_kwargs)


@migraphx_converter(torch.ops.aten.argmax.default)
def aten_ops_argmax(mgx_module, node, args, kwargs):
assert len(args) >= 1
Expand All @@ -731,7 +756,83 @@ def aten_ops_argmax(mgx_module, node, args, kwargs):

return acc_ops_converters.acc_ops_argmax(mgx_module, node, (), acc_kwargs)

@migraphx_converter(torch.ops.aten.argmin.default)
def aten_ops_argmin(mgx_module, node, args, kwargs):
assert len(args) >= 1

acc_kwargs = {
"input": args[0],
"dim": args[1] if len(args) >= 2 else None,
"keepdim": args[2] if len(args) >= 3 else False
}

return acc_ops_converters.acc_ops_argmin(mgx_module, node, (), acc_kwargs)

@migraphx_converter(torch.ops.aten.max.default)
@migraphx_converter(torch.ops.aten.max.dim)
def aten_ops_max(mgx_module, node, args, kwargs):
assert len(args) >= 1

acc_kwargs = {
"input": args[0],
"keepdim": False,
}

if len(args) >= 2:
acc_kwargs["dim"] = args[1]

if len(args) >= 3:
acc_kwargs["keepdim"] = args[2]

return acc_ops_converters.acc_ops_max(mgx_module, node, (), acc_kwargs)

@migraphx_converter(torch.ops.aten.min.default)
@migraphx_converter(torch.ops.aten.min.dim)
def aten_ops_min(mgx_module, node, args, kwargs):
assert len(args) >= 1

acc_kwargs = {
"input": args[0],
"keepdim": False
}

if len(args) >= 2:
acc_kwargs["dim"] = args[1]

if len(args) >= 3:
acc_kwargs["keepdim"] = args[2]

return acc_ops_converters.acc_ops_min(mgx_module, node, (), acc_kwargs)

@migraphx_converter(torch.ops.aten.stack.default)
def aten_ops_stack(mgx_module, node, args, kwargs):
assert len(args) >= 1

"""
Map aten.stack to unsqueeze + cat acc ops.
"""
inputs = args[0]
assert isinstance(inputs, Sequence)

dims = args[1] if len(args) > 1 else 0

unsqueeze_kwargs={
"dim": dims
}
cat_kwargs={
"dim": dims
}

unsqueeze_nodes = []
for i, t in enumerate(inputs):
unsqueeze_kwargs["input"] = t
unsq_res = acc_ops_converters.acc_ops_unsqueeze(mgx_module, node, (), unsqueeze_kwargs)
unsqueeze_nodes.append(unsq_res)

cat_kwargs["tensors"] = unsqueeze_nodes
return acc_ops_converters.acc_ops_cat(mgx_module, node, (), cat_kwargs)


@migraphx_converter(torch.ops.aten.as_strided.default)
def aten_ops_as_strided(mgx_module, node, args, kwargs):
assert len(args) >= 3
Expand Down
56 changes: 56 additions & 0 deletions py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,54 @@ def maximum(*, input, other):
return torch.maximum(input=input, other=other)


@register_acc_op_mapping(
op_and_target=("call_method", "max"),
arg_replacement_tuples=[
("input", "input"),
("dim", "dim", this_arg_is_optional),
("keepdim", "keepdim", this_arg_is_optional),
],
)
@register_acc_op_mapping(
op_and_target=("call_function", torch.max),
arg_replacement_tuples=[
("input", "input"),
("dim", "dim", this_arg_is_optional),
("keepdim", "keepdim", this_arg_is_optional),
],
)
@register_acc_op
def max(*, input, dim=None, keepdim=False):
if dim is not None:
return torch.max(input, dim=dim, keepdim=keepdim)
else:
return torch.max(input)


@register_acc_op_mapping(
op_and_target=("call_method", "min"),
arg_replacement_tuples=[
("input", "input"),
("dim", "dim", this_arg_is_optional),
("keepdim", "keepdim", this_arg_is_optional),
],
)
@register_acc_op_mapping(
op_and_target=("call_function", torch.min),
arg_replacement_tuples=[
("input", "input"),
("dim", "dim", this_arg_is_optional),
("keepdim", "keepdim", this_arg_is_optional),
],
)
@register_acc_op
def min(*, input, dim=None, keepdim=False):
if dim is not None:
return torch.min(input, dim=dim, keepdim=keepdim)
else:
return torch.min(input)


@register_acc_op_mapping(op_and_target=("call_function", operator.getitem))
@register_acc_op
def getitem(*, input, idx):
Expand Down Expand Up @@ -914,6 +962,14 @@ def argmax(*, input, dim, keepdim):
return torch.argmax(input=input, dim=dim, keepdim=keepdim)


@register_acc_op_properties(AccOpProperty.unary)
@register_acc_op_mapping(op_and_target=("call_function", torch.argmin))
@register_acc_op_mapping(op_and_target=("call_method", "argmin"))
@register_acc_op
def argmin(*, input, dim, keepdim):
return torch.argmin(input=input, dim=dim, keepdim=keepdim)


@register_acc_op_mapping(op_and_target=("call_function",
nn.functional.embedding))
@register_acc_op
Expand Down
12 changes: 9 additions & 3 deletions tests/dynamo/converters/test_activations_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
pytest.skip(allow_module_level=True)


@pytest.mark.parametrize('op_alias', [torch.ops.aten.clamp.default])
@pytest.mark.parametrize('op_alias', [torch.ops.aten.clamp.default,
torch.ops.aten.hardtanh.default])
@pytest.mark.parametrize('inp_size', [(4, 2, 7), (128, 2048),
(1, 3, 6, 128, 128)])
def test_clamp(op_alias, inp_size):
Expand All @@ -20,8 +21,10 @@ def test_clamp(op_alias, inp_size):

@pytest.mark.parametrize('op_alias', [
torch.ops.aten.relu.default,
torch.ops.aten.elu.default,
torch.ops.aten.tanh.default,
torch.ops.aten.hardsigmoid.default,
torch.ops.aten.hardswish.default,
torch.ops.aten.sigmoid.default,
torch.ops.aten.gelu.default,
torch.ops.aten.silu.default,
Expand All @@ -33,13 +36,16 @@ def test_noparam_activation_funcs(op_alias):
verify_outputs(mod, mgx_mod, inp)


@pytest.mark.parametrize('op_alias', [torch.ops.aten.leaky_relu.default])
@pytest.mark.parametrize('op_alias', [
torch.ops.aten.elu.default,
torch.ops.aten.leaky_relu.default,
])
@pytest.mark.parametrize('inp_size, alpha', [
((11, 3, 9), 0.1),
((6, 12, 32, 6), 0.05),
((2, ), 0),
])
def test_leaky_relu(op_alias, inp_size, alpha):
def test_single_param_activation_funcs(op_alias, inp_size, alpha):
inp = torch.randn(inp_size).cuda()
mod = FuncModule(op_alias, alpha).cuda()
mgx_mod = convert_to_mgx(mod, [inp])
Expand Down
9 changes: 6 additions & 3 deletions tests/dynamo/converters/test_maxmin_dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
pytest.skip(allow_module_level=True)


@pytest.mark.parametrize('op_alias', [torch.ops.aten.argmax.default])
@pytest.mark.parametrize('op_alias', [
torch.ops.aten.argmax.default,
torch.ops.aten.argmax.default,
])
@pytest.mark.parametrize('dim, keepdim', [
(2, True),
(-1, False),
(0, False),
])
def test_argmax(op_alias, dim, keepdim):
def test_argmax_argmin(op_alias, dim, keepdim):
inp = torch.randn(10, 2, 12, 8, 14).cuda()
mod = FuncModule(torch.argmax, dim, keepdim)
mod = FuncModule(op_alias, dim, keepdim)
mgx_mod = convert_to_mgx(mod, [inp])
verify_outputs(mod, mgx_mod, inp)
Loading

0 comments on commit ff8ac13

Please sign in to comment.