Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aten roi align #209

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open

Aten roi align #209

wants to merge 16 commits into from

Conversation

bpickrel
Copy link
Contributor

Add aten and acc converters for the RoiAlign op.

This change is dependent on Migraphx PR 3482. It can't be tested or merged until that branch is merged. (Test locally using MigraphX branch roialign_fix

As of date of PR creation, the aten converter test test_roialign_d is still incomplete.

@bpickrel bpickrel self-assigned this Oct 18, 2024
@bpickrel bpickrel marked this pull request as ready for review October 21, 2024 21:43
docker/dev.Dockerfile Outdated Show resolved Hide resolved
py/torch_migraphx/fx/tracer/acc_tracer/acc_ops.py Outdated Show resolved Hide resolved
py/torch_migraphx/fx/converters/acc_ops_converters.py Outdated Show resolved Hide resolved
Comment on lines 233 to 236
# batch_indices = batch_indices2
elif boxes_ref.shape().lens()[1] == 4:
# batch_indices3=range(boxes_ref.shape().lens()[0])
# boxes2 = boxes_ref
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean up comments please, this just looks like unused debug code?

elif boxes_ref.shape().lens()[1] == 4:
# batch_indices3=range(boxes_ref.shape().lens()[0])
# boxes2 = boxes_ref
# This isn't supported because torchvision roi_align.default() doesn't support it
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its not supported because the list inputs was giving issues? I'd suggest leaving a TODO comment here to try and support in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torchvision module threw errors when it was called with that shape of input in the aten test. The only way I see to add support is to avoid testing it against the torchvision base. Other than that, it should work fine with Line 235 commented in.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure what you mean, this is the case where the rois are a list of tensors right? Thats not even supported by the aten op. This is only for the acc op case where the rois are defined as a list

On a second look, that if statement does not work at all, if the rois are defined as a list, your boxes_ref variable wont be a migraphx instruction and the .shape call isnt valid anyways

py/torch_migraphx/fx/converters/aten_ops_converters.py Outdated Show resolved Hide resolved
py/torch_migraphx/fx/mgx_module.py Outdated Show resolved Hide resolved
@@ -21,4 +21,4 @@ cd AMDMIGraphX
rbuild build -d depend -DBUILD_TESTING=Off -DCMAKE_INSTALL_PREFIX=/opt/rocm/ --cxx=/opt/rocm/llvm/bin/clang++ -DGPU_TARGETS=$GPU_ARCH

cd build
make install
make -j88 install
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we don't want to use the multithread flag?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well usually the number you set depends on the system. Also, the rbuild command does the building part here (which i think is multithreaded?), the make install just copies the files over to opt/rocm and so multithreading that doesnt much of a difference

tests/fx/converters/test_pooling_fx.py Outdated Show resolved Hide resolved
py/torch_migraphx/fx/converters/aten_ops_converters.py Outdated Show resolved Hide resolved
@@ -100,7 +100,6 @@ def _initialize(self):
if not self.program.is_compiled():
if self.quantize_fp16:
migraphx.quantize_fp16(self.program)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert

Comment on lines +238 to +240
# batch_indices3=range(boxes_ref.shape().lens()[0])
# boxes2 = boxes_ref
# This isn't supported at this time because torchvision roi_align.default() doesn't support it
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean up debug comments. Also can we change the isnt supported comment to a TODO comment. Its not supported at this time because we are having trouble consuming a list input in the test cases atm. "torchvision roi_align.default() doesn't support it" is not a valid reason not to support it in the acc converter.

@@ -33,6 +33,7 @@

import torch
from typing import cast, Iterable, List, Sequence
import torchvision
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import torchvision
try:
import torchvision
except ImportError:
pass

This is not a mandatory prerequisite for torch_migraphx so wrap it in a try/except

@@ -373,6 +374,27 @@ def clamp(*, input, min=None, max=None):
return torch.clamp(input=input, min=min, max=max)


@register_acc_op_mapping(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@register_acc_op_mapping(
if 'torchvision' in sys.modules:
@register_acc_op_mapping(

wrap the roi_align function definition in this if so that it doesnt blow up if torchvision isnt installed

]),
[7, 6])]
)
def test_roialign(input, boxes, output_size, spatial_scale, sampling_ratio, aligned):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move this to test_torchvision_models_fx.py and follow the example there to skip if torchvision is not installed

[7, 6])
]
)
def test_roialign(op_alias, input, boxes, output_size, spatial_scale, sampling_ratio, aligned):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the fx test comment, this should be in a test_torchvision_models_dynamo.py file, and it should mark to skip in the same way if torchvision is not installed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants