Skip to content

Developers Guide

shivadbhavsar edited this page Jul 12, 2024 · 4 revisions

Torch-MIGraphX is a python library that aims to integrate MIGraphX into PyTorch workflows as seamlessly as possible. This means being able to consume models that have been built and trained using torch using our APIs and creating a MIGraphX compiled program. 

The high level workflow for doing this is as follows:

torch_migraphx flow

Below we will use a simple example to explore this full process. The best way to follow is to setup an environment and run the provided code.

Environment

Generally a good starting point for working with torch_migraphx is to use a base docker image from rocm/pytorch or rocm/pytorch-nightly.

For development, there is also a dev.Dockerfile in the docker directory of the torch_migraphx repo for convenience. Follow the steps under: Development to setup a container with torch_migraphx in develop mode with all other prerequisites (including MIGraphX) already installed.

Torch-MIGraphX Usage and Entrypoints

From a user perspective, there are two main APIs that allow them to convert a Torch Model to a Torch-MIGraphX model in a single call.

FX Tracing

# 'model' is a torch.nn.Module object, and 'sample_inputs' is a list of input tensors in the expected shape of real inputs.
mgx_model = lower_to_mgx(model, sample_inputs)

For full usage examples refer to: FX Examples

Dyanmo

mgx_model = torch.compile(model, backend="migraphx")
# Note that the compilation actually happens when the model is first executed and will be recompiled anytime it's called with different input sizes
result = mgx_model(*sample_inputs)

Technical Walkthrough

It's good to know how the library is intended to be used but that doesn't tell us much about how it works and how to develop on it. For that let's break down and understand each step in the above workflow diagram using real code.

Torch Model

Before diving into the core torch_migraphx codebase, take some time to understand what a torch model is, how it's created and used. Below is a definition of a simple custom module that we will use to understand the full workflow. Refer to Official PyTorch Docs for a more detailed explanation on the fundamental data structures used in torch.

import torch  

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param1 = torch.nn.Parameter(torch.rand(3, 4))
        self.param2 = torch.nn.Parameter(torch.rand(3, 5))
        self.linear = torch.nn.Linear(4, 5)
    
    def forward(self, x):
        x = x + self.param1
        x = self.linear(x)
        x = x.mul(self.param2)
        return torch.nn.functional.relu(x)

Key things to note here are:

  • Generally speaking a "PyTorch Model" is something that is of type (or inherits from) torch.nn.Module. A model class can consists of many submodules/layers that are also of torch.nn.Module type. In our example we have a model definition of our model MyModel which consists of such layers, specifically torch.nn.Linear  and torch.nn.ReLU. These are both also derived from the parent class torch.nn.Module 
  • torch.nn.Module requires that a forward method must be defined that determines which operations are performed in what order. In our example we have an input x to which we for add a constant param1, the result of that is passed through a Linear layer, the result of that is elementwise multiplied by another constant param2 and then finally it applies the relu function which is the output of our model.
  • This torch.nn.Module is callable, meaning that if you want to run this model in eager mode, you can simply pass an input x that will call the forward method.

Create a model using this class definition and run it in eager mode using some random input. "Eager mode" in this context means that the lines in the forward method will be executed in sequence as defined when the model is called with an input x. Note that in reality, a model created like this would first have to be trained, but that is not in the scope of MIGraphX, so we will assume that this is a pretrained model.

# Its good practice to set torch models in eval mode for our purposes since some layers behave differently in eval mode vs training mode
mod = MyModule().eval()
in_x = torch.randn(3, 4)
out = mod(in_x)

Torch Graphs

The model we created in the previous section is nothing more than a Python object. There is no meaningful way a graph optimizer can consume the model in this form, so we need to first transform this into a graph format. For this we use APIs provided within the torch library. There are two methods provided by torch to transform a model into a graphical representation and the data structure used to implement this representation is called torch.fx.GraphModule.

FX Tracing

The first method for generating a graph is to use the Tracer provided by the FX Toolkit. Lets start by tracing our custom module with the base symbolic tracer provided by this toolkit.

from torch.fx import symbolic_trace
symbolic_traced = symbolic_trace(mod)
print(symbolic_traced.graph) 
# symbolic_traced.graph.print_tabular() # Feel free to use this style of print if you find it easier to read
graph():
    %x : [num_users=1] = placeholder[target=x]
    %param1 : [num_users=1] = get_attr[target=param1]
    %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param1), kwargs = {})
    %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %param2 : [num_users=1] = get_attr[target=param2]
    %mul : [num_users=1] = call_method[target=mul](args = (%linear, %param2), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%mul,), kwargs = {inplace: False})
    return relu

Take a moment to relate this graph back to the original module and convince yourself that it is indeed performing the same set of operations on input x just written in graph format. There are a few key things to understand here:

  • There are 6 types of nodes here (referred to as opcode)
    • placeholder: These are model inputs
    • get_attr: These are generally constants (torch parameters are constant in eval mode) or any other model attributes with the name defined by target 
    • call_function: Calls the target function with args and kwargs 
    • call_method: Calls the target method with args and kwargs 
    • call_module: Calls the target torch.nn.Module (ie. the forward method of that module) with args and kwargs 
    • output: This defines the model output(s). In the about format this just shows up as a return statement, but if you use the tabular print function, you will see the opcode listed as output 

We will see how we deal with each of these types of nodes when translating to MIGraphX IR, but before that lets consider the model below.

class MyModule2(torch.nn.Module):
    def __init__(self, w, b):
        super().__init__()
        self.param1 = torch.nn.Parameter(torch.rand(3, 4))
        self.param2 = torch.nn.Parameter(torch.rand(3, 5))
        self.w = torch.nn.Parameter(w)
        self.b = torch.nn.Parameter(b)
        self.relu = torch.nn.ReLU()
    
    def forward(self, x):
        x = torch.add(x, self.param1)
        x = torch.nn.functional.linear(x, self.w, self.b)
        x = x*self.param2
        return self.relu(x)

If you look carefully, this is model is identical in functionally to our original module. Convince yourself further by feeing both models the same input and printing the respective outputs.

mod = MyModule().eval()
# Feed in the same constants to make sure the outputs are comparable
mod2 = MyModule2(mod.linear.weight, mod.linear.bias, mod.param1, mod.param2).eval()
in_x = torch.randn(3, 4)
print(mod(in_x))
print(mod2(in_x))

Lets look at the graph for this second model.

symbolic_traced2 = symbolic_trace(mod2)
print(symbolic_traced2.graph)
# symbolic_traced2.graph.print_tabular()
graph():
    %x : [num_users=1] = placeholder[target=x]
    %param1 : [num_users=1] = get_attr[target=param1]
    %add : [num_users=1] = call_function[target=torch.add](args = (%x, %param1), kwargs = {})
    %w : [num_users=1] = get_attr[target=w]
    %b : [num_users=1] = get_attr[target=b]
    %linear : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%add, %w, %b), kwargs = {})
    %param2 : [num_users=1] = get_attr[target=param2]
    %mul : [num_users=1] = call_function[target=operator.mul](args = (%linear, %param2), kwargs = {})
    %relu : [num_users=1] = call_module[target=relu](args = (%mul,), kwargs = {})
    return relu

Notice some key differences:

  • The target for the add node changed from operator.add to torch.add 
  • The linear layer is now a call_function rather than a call_module 
  • Similarly the mul and relu nodes are also different opcodes

Yet the model, mathematically, is the exact same. This can present a lot of duplication when implementing our translation layer, so to deal with these types of variations in torch models, we implement our own derived tracer that can normalize these variations. We call this the acc_tracer and the relevant code can be found here. Lets trace both these models using our derived tracer.

import torch_migraphx.fx.tracer.acc_tracer.acc_tracer as acc_tracer
acc_traced = acc_tracer.trace(mod, [in_x])
print(acc_traced.graph)
# acc_traced.graph.print_tabular()
graph():
    %x : [num_users=1] = placeholder[target=x]
    %param1 : [num_users=1] = get_attr[target=param1]
    %add_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.add](args = (), kwargs = {input: %x, other: %param1})
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %linear_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.linear](args = (), kwargs = {input: %add_1, weight: %linear_weight, bias: %linear_bias})
    %param2 : [num_users=1] = get_attr[target=param2]
    %mul_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.mul](args = (), kwargs = {input: %linear_1, other: %param2})
    %relu_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %mul_1, inplace: False})
    return relu_1

As an exercise, try and use this tracer on the second model and see if the graph is the same. 

Key things to note here about the normalization we do:

  • There are no more call_module and call_method nodes
    • call_module nodes are replaced by their corresponding functional call because we configure our tracer to trace inside of nested modules. Ie. acc_tracer traces into the torch.nn.Linear module, and finds the torch.nn.functional.linear function call within it, in its forward method.
    • call_method nodes are mapped to their corresponding functional call
  • All of the call_function targets now point to functions in the acc_ops namespace
  • All of the args are empty and all function parameters are defined as keyword arguments in kwargs

Take some time to examine few of the function mappings defined in acc_ops.py.

  • Mappings are defined using decorators, specifically register_acc_op_mapping. This tells our tracer that any node that targets a particular torch method or function should instead target our acc_ops function instead. Note that these functions are merely wrappers because we do not want to modify its core functionality, we only want our graph to use these wrappers as the call_function targets.
  • In some cases, instead of defining a new acc_op, we want to simple replace an op by writing it in terms of another op (which is usually a more generalized version of the same op). For example, examine the transpose_mapper: instead of creating a new acc_op called transpose, we simply remap this node to a permute node that performs this same operation. Now we only have to worry about implementing a translator for the permute operation. Such remappings are defined by used the decorator register_custom_acc_mapper_fn.

At this point the torch graph is ready to be translated to MIGraphX instructions. Before we explore the translation process, we will look at the other method provided for generating functional graphs in torch.

Dynamo (torch.compile)

This is a feature available with the release of PyTorch 2.0 and is a feature that is heavily worked on currently. This feature is actually intended to support model compilation natively using torch, but we can use the features it provides to instead use MIGraphX to perform the compilation. The main advantage of using this method over FX Tracing is that this toolkit allows compiling models that have data-dependent control flow (ie. we need to execute different sets of ops depending on the values of tensors at runtime). This is not allowed in FX Tracing, as the tracer will fail when it encounters such models. Read more about this approch in the official docs. We will use the same model from the previous section to explore this approach. Let's start by defining a custom backend that we can use to understand how the torch.compile API works.

import torch._dynamo as dynamo
from torch._functorch.aot_autograd import aot_export_joint_simple

@dynamo.register_backend(name="my_backend")
def test_backend(gm, example_inputs, **kwargs):
    TracingContext.get().fake_mode.allow_non_fake_inputs = True
 	print(example_inputs)
    print(gm.graph)
    aten_gm = aot_export_joint_simple(gm, example_inputs, trace_joint=False)
    print(aten_gm.graph)
    return aten_gm

There is a lot to understand in terms of all the underlying mechanisms that are employed by a torch.compile call, but for us, we can focus on understanding how a backend is defined. The dynamo.register_backend decorator is what tells dynamo where to look for the definition of "my_backend" in the call below.

mod = MyModule().eval()
in_x = torch.randn(3, 4)
mod_dynamo = torch.compile(mod, backend="my_backend")
mod_dynamo(in_x) # This line is when the test_backend function is invoked
[tensor([[ 1.0750, -0.4972,  0.7909,  0.1489],
        [-0.6334, -0.4037, -0.3144,  1.3126],
        [ 0.6391, -1.0924,  0.6623,  0.5520]])]
graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %l__self___param1 : [num_users=1] = get_attr[target=L__self___param1]
    %add : [num_users=1] = call_function[target=operator.add](args = (%l_x_, %l__self___param1), kwargs = {})
    %l__self___linear : [num_users=1] = call_module[target=L__self___linear](args = (%add,), kwargs = {})
    %l__self___param2 : [num_users=1] = get_attr[target=L__self___param2]
    %mul : [num_users=1] = call_method[target=mul](args = (%l__self___linear, %l__self___param2), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%mul,), kwargs = {})
    return (relu,)
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %_param_constant0), kwargs = {})
    %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
    %t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant1,), kwargs = {})
    %_param_constant2 : [num_users=1] = get_attr[target=_param_constant2]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant2, %add, %t), kwargs = {})
    %_param_constant3 : [num_users=1] = get_attr[target=_param_constant3]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, %_param_constant3), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%mul,), kwargs = {})
    return (relu,)

Let's highlight some key observations:

  • example_inputs is a list of tensors (even though our model call was mod_dynamo(in_x), the tensor in_x is added to the example input list
    • If our model had multiple inputs (eg. mod_dynamo(in_x, in_y, in_z)), the example_inputs tensor would be a list of 3 tensors corresponding to the 3 inputs
  • The dynamo machinery passes in a GraphModule object to our backend (which we reference as gm). 
    • This object consists of nodes that point to internal methods and modules and is not really in a format this can be translated
  • We use the aot_export_joint_simple API provided in the functorch toolkit to export a graph the is reduced to function calls that are in the torch.ops.aten namespace
    • This is similar to the acc_ops normalization that we did in FX Tracing where all operations are written as call_function nodes with targets in a single namespace.
  • In this toy example, our backend just does this export, prints the graph and then returns the exported torch GraphModule. In an actual backend implementation this return is expected to be a different Python callable that takes inputs that are identical to example_inputs in terms of shape and datatypes.

It's highly recommended going through this specific section of the official torch tutorial as it shows what happens when there is data dependent control flow in the model, and how this backend function can be invoked multiple times for each subgraph that the control flow can feed into.

To return something meaningful (ie. a compiled program) from this backend, we need to actually translate these graphs to MIGraphX programs.

Conversion to MIGraphX IR

So far, we've merely used torch APIs to get to a point where we have graphical representations of torch models in a consistent format. Now we can dive into the actual conversion mechanism for generating MIGraphX programs from these torch graphs. At this point it is more difficult to explore with toy code and so we will examine relevant code in our real codebase. 

The way we translate the torch graph is relatively straight forward but requires a good understanding of the torch.fx.Interpreter class. This class is designed to traverse a GraphModule node by node and perform any required "transformations". In our case these "transformations" will just mean adding corresponding instructions to our migraphx program that implement the equivalent functionality to that of the torch node. This will become clear as we walkthrough the conversion of our example model below.

Our FX Interpreter is defined in fx2mgx.py. Keep in mind that this class will be used to iterate through torch graphs node by node in order. Note some key things about what's happening in this class:

  • On initialization, we create an empty migraphx program (using migraphx's python API). We will be adding instructions to this program as we traverse the nodes
  • Calling the run method initiates the node traversal
  • The methods placeholder, call_module, call_function, call_method, get_attr, and output defines what happens when each of these types of nodes encountered during the traversal

The placeholder, get_attr , and output methods are straightforward

  • placeholder adds an input to the migraphx program (using the add_parameter migraphx function)
  • get_attr adds literals to the migraphx program (using add_literal migraphx function)
  • output adds outputs to the migraphx program (using add_return migraphx function)

The call_module, call_function, and call_method methods perform math/tensor operations that need to be translated to operations defined in MIGraphX. To maintain the converters we have implemented, we have a CONVERTERS dictionary.

  • The keys of this dictionary are functions (specifically acc_ops and aten functions that we saw as our targets in call_function nodes in the previous section).
  • The values in this dictionary are also functions. These are the converter functions that define how a acc_ops or aten function should be translated to migraphx instructions.

Converter functions are defined in acc_ops_converters.py and aten_ops_converters.py. Note that there are some cases where functions/modules are not normalized to acc_ops and so there are some additional converters defined in the converters directory. For now we will look at the FX traced graph and focus on acc_ops_converters.py. The CONVERTERS dictionary is populated by the decorators that are applied to each of the converter functions in these files.

FX Translation Walkthrough

Here we will pass the acc_traced module from the previous section through the interpreter and understand the migraphx program that is generated.

mod = MyModule().eval()
in_x = torch.randn(3, 4)

import torch_migraphx.fx.tracer.acc_tracer.acc_tracer as acc_tracer
acc_traced = acc_tracer.trace(mod, [in_x])
print(acc_traced.graph)
# acc_traced.graph.print_tabular()

from torch_migraphx.fx.fx2mgx import MGXInterpreter
interp = MGXInterpreter(acc_traced, [in_x])
interp.run()
print(interp.program)
graph():
- %x : [num_users=1] = placeholder[target=x]
%param1 : [num_users=1] = get_attr[target=param1]
! %add_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.add](args = (), kwargs = {input: %x, other: %param1})
%linear_weight : [num_users=1] = get_attr[target=linear.weight]
%linear_bias : [num_users=1] = get_attr[target=linear.bias]
+ %linear_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.linear](args = (), kwargs = {input: %add_1, weight: %linear_weight, bias: %linear_bias})
%param2 : [num_users=1] = get_attr[target=param2]
# %mul_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.mul](args = (), kwargs = {input: %linear_1, other: %param2})
@@ %relu_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %mul_1, inplace: False}) @@
return relu_1

module: "main"
@0 = @literal{ ... } -> float_type, {3, 5}, {5, 1}
@1 = @literal{-0.0148876, -0.224566, 0.488997, -0.449717, 0.486762} -> float_type, {5}, {1}
@2 = @literal{ ... } -> float_type, {5, 4}, {4, 1}
@3 = @literal{ ... } -> float_type, {3, 4}, {4, 1}
- x = @param:x -> float_type, {3, 4}, {4, 1}
! @5 = add(x,@3) -> float_type, {3, 4}, {4, 1}
+ @6 = transpose[permutation={1, 0}](@2) -> float_type, {4, 5}, {1, 4}
+ @7 = multibroadcast[out_lens={4, 5},out_dyn_dims={}](@6) -> float_type, {4, 5}, {1, 4}
+ @8 = dot(@5,@7) -> float_type, {3, 5}, {5, 1}
+ @9 = multibroadcast[out_lens={3, 5},out_dyn_dims={}](@1) -> float_type, {3, 5}, {0, 1}
+ @10 = add(@8,@9) -> float_type, {3, 5}, {5, 1}
# @11 = mul(@10,@0) -> float_type, {3, 5}, {5, 1}
@@ @12 = relu(@11) -> float_type, {3, 5}, {5, 1} @@
@13 = @return(@12)

Take some time to understand which converters are called and how each of the instructions is added in the generated migraphx program. The output above is manually color coded so make sure you fully understand how each torch node is translated to an migraphx instruction, or a set of migraphx instructions. Understanding this is key to implementing converters and contributing to this codebase.

Studying the implemented converters, you can note a few key things:

  • Converters can be trivial where it is simply a one-to-one mapping to an migraphx op (eg. relu)
  • Some are one-to-one mappings but need to account for the fact that torch allows implicit broadcasting (eg, add, mul)
    • In our particular example none of the operands for add and mul actually need broadcasting but in general this is allowed by torch and so it must be handled
  • Some converters can require a series of migraphx ops to implement the equivalent functionality (eg. linear)

Dynamo Translation

Here is a modified backend definition that adds the interpreter into the mix.

import torch._dynamo as dynamo
from torch._functorch.aot_autograd import aot_export_joint_simple
from torch._guards import TracingContext
from torch_migraphx.fx.fx2mgx import MGXInterpreter

@dynamo.register_backend(name="my_backend")
def test_backend(gm, example_inputs, **kwargs):
    TracingContext.get().fake_mode.allow_non_fake_inputs = True
    aten_gm = aot_export_joint_simple(gm, example_inputs, trace_joint=False)
    print(aten_gm.graph)
    interp = MGXInterpreter(aten_gm, example_inputs)
    interp.run()
    print(interp.program)
    return aten_gm

mod_dynamo = torch.compile(mod, backend="my_backend")
mod_dynamo(in_x)
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0]
    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %_param_constant0), kwargs = {})
    %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
    %t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant1,), kwargs = {})
    %_param_constant2 : [num_users=1] = get_attr[target=_param_constant2]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant2, %add, %t), kwargs = {})
    %_param_constant3 : [num_users=1] = get_attr[target=_param_constant3]
    %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, %_param_constant3), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%mul,), kwargs = {})
    return (relu,)

module: "main"
@0 = @literal{ ... } -> float_type, {3, 5}, {5, 1}
@1 = @literal{-0.298655, -0.145471, -0.311266, -0.279463, -0.389797} -> float_type, {5}, {1}
@2 = @literal{ ... } -> float_type, {5, 4}, {4, 1}
@3 = @literal{ ... } -> float_type, {3, 4}, {4, 1}
arg0_1 = @param:arg0_1 -> float_type, {3, 4}, {4, 1}
@5 = add(arg0_1,@3) -> float_type, {3, 4}, {4, 1}
@6 = transpose[permutation={1, 0}](@2) -> float_type, {4, 5}, {1, 4}
@7 = multibroadcast[out_lens={3, 4},out_dyn_dims={}](@5) -> float_type, {3, 4}, {4, 1}
@8 = multibroadcast[out_lens={4, 5},out_dyn_dims={}](@6) -> float_type, {4, 5}, {1, 4}
@9 = dot(@7,@8) -> float_type, {3, 5}, {5, 1}
@10 = multibroadcast[out_lens={3, 5},out_dyn_dims={}](@1) -> float_type, {3, 5}, {0, 1}
@11 = multibroadcast[out_lens={3, 5},out_dyn_dims={}](@9) -> float_type, {3, 5}, {5, 1}
@12 = add(@10,@11) -> float_type, {3, 5}, {5, 1}
@13 = mul(@12,@0) -> float_type, {3, 5}, {5, 1}
@14 = relu(@13) -> float_type, {3, 5}, {5, 1}
@15 = @return(@14)

Here, the nodes have not been color coded so it is a very good exercise to go through the torch graph and identify which migraphx instructions correspond to each node by examining the associated converters. Some notable things to highlight about aten converters are:

  • They ALWAYS point to an acc_ops converter
    • This is because in PyTorch, aten ops are low-level implementations of the high-level torch API functions. For example, the high-level function torch.add will actually call the aten.add function "under the hood". This means that if we have support (ie. a converter exists) for a high-level op, we can just point these low-level ops to existing converters with the right arguments.
  • In general aten ops are implemented in C++ in torch, and don't usually support keyword arguments in the same way as the high-level ops. So in aten converters we have to rely on args and so the order of inputs arguments is important

Executing the MIGraphX Program

We're almost at the finish line. We have a migraphx program, now all that's left to do is let MIGraphX do its magic on this parsed program and generate a compiled version of this program so we can execute it. Before diving into the implementation, lets complete the workflow for MyModule and see how to execute our custom model using migraphx.

mod = MyModule().eval()
in_x = torch.randn(3, 4)
out = mod(in_x)

import torch_migraphx.fx.tracer.acc_tracer.acc_tracer as acc_tracer
acc_traced = acc_tracer.trace(mod, [in_x])
# print(acc_traced.graph)
# acc_traced.graph.print_tabular()

from torch_migraphx.fx.fx2mgx import MGXInterpreter
interp = MGXInterpreter(acc_traced, [in_x])
interp.run()
# print(interp.program)

from torch_migraphx.fx.mgx_module import MGXModule
mgx_mod = MGXModule(interp.program, interp.get_input_names())
mgx_out = mgx_mod(in_x)

print(out)
print(mgx_out)

Run this and verify that the two outputs are the same. Also, at this point circle back to the usage section and examine the 2 entrypoints listed there. This above code block is a simplified version of what happens when those entrypoints (lower_to_mgx  and torch.compile) are used. There are a number of details that we ignore in this walkthrough for the sake of simplicity, but now you understand the core components of the pipeline that is invoked by those calls.

Let's examine the implementation of this final piece of the workflow.  The MGXModule class is implemented in mgx_module.py. Here are the core features of this class that you should look for in the code:

  • MGXModule inherits from torch.nn.Module which allows objects of this class to be executed in the same manner as normal torch models (notice how we invoke mgx_mod in the same way as mod)
  • When a MGXModule object is initialized with a program (from the interpreter):
    • Compile the program using the program.compile call provided by the migraphx python API
    • Output buffers are also allocated so that they can be passed as parameters when we run the program (this allows us the keep the output tensor on the gpu) 
  • The code for actually running this model resides in the forward method as for all torch.nn.Module objects. Important details to note: 
    • The run_async call is used to avoid unnecessary syncronizations
    • The stream used for this async call is the default PyTorch stream. This is an important detail that prevents race conditions from happening when inputs to the MGXModule are outputs from other torch models, or vice versa. Eg. if a user is running a workflow where there is a series of models mod1 → mod2 → mod3 where mod1 and mod3 are regular torch models, but mod2 is a MGXModule, then an async call on different streams may not wait for outputs from mod1, and similarly mod3 will not wait for migraphx to write to the output buffers defined in mod2. 
  • There are some additional functions implemented to allow MGXModule objects to be saved in the same manner as normal torch models