Skip to content

Commit

Permalink
Fix slice upstream - Incompatible dimensions (microsoft#16818)
Browse files Browse the repository at this point in the history
### Fix slice upstream - (MatMul) [ShapeInferenceError] Incompatible
dimensions

```
     2023-07-22 14:58:16.918478478 [I:onnxruntime:Default, constant_sharing.cc:256 ApplyImpl] Total shared scalar initializer count: 10
        2023-07-22 14:58:16.919494252 [W:onnxruntime:Default, graph.cc:108 MergeShapeInfo] Error merging shape info for output. 'onnx::Cast_424' source:{-1,31,-1,-1} target:{-1,32,-1,-1}. Falling back to lenient merge.
        2023-07-22 14:58:16.921014114 [W:onnxruntime:Default, graph.cc:108 MergeShapeInfo] Error merging shape info for output. 'onnx::MatMul_425' source:{-1,31,-1,-1} target:{-1,32,-1,-1}. Falling back to lenient merge.

Traceback (most recent call last):
  File "examples/onnxruntime/training/language-modeling/run_clm.py", line 594, in <module>
    main()
  File "examples/onnxruntime/training/language-modeling/run_clm.py", line 542, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 454, in train
    return inner_training_loop(
  File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 755, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/transformers/trainer.py", line 2735, in training_step
    loss = self.compute_loss(model, inputs)
  File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 363, in compute_loss
    return model_with_loss(dict_inputs, return_outputs)
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1724, in forward
    loss = self.module(*inputs, **kwargs)
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_utils.py", line 384, in _forward
    return ortmodule._torch_module.forward(*inputs, **kwargs)
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_utils.py", line 364, in _forward
    return torch_module_ort._execution_manager(torch_module_ort.is_training()).forward(*inputs, **kwargs)
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 345, in forward
    self._fallback_manager.handle_exception(
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_fallback.py", line 157, in handle_exception
    raise exception
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 280, in forward
    self._build_graph(graph_transformer_config)
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 218, in wrapper
    result = func(graph_execution_manager, *args, **kwargs)
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 360, in _build_graph
    super()._build_graph(graph_transformer_config)
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 186, in _build_graph
    self._graph_builder.build(config)
RuntimeError: /bert_ort/pengwa/onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc:823 onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const onnxruntime::training::TrainingGraphTransformerConfiguration&)> [ONNXRuntimeError] : 1 : FAIL : Node (MatMul_403) Op (MatMul) [ShapeInferenceError] Incompatible dimensions

 
```

Missed using `axis` attribute for `Slice` op, so change to use `axes`
inputs instead.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored Jul 25, 2023
1 parent b0279b1 commit f2c0470
Show file tree
Hide file tree
Showing 4 changed files with 397 additions and 89 deletions.
4 changes: 3 additions & 1 deletion onnxruntime/core/optimizer/compute_optimizer/shared_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ NodeArg* CreateInitializerFromVector(Graph& graph,
total_count *= dim;
}

ORT_ENFORCE(total_count == static_cast<int64_t>(values.size()));
ORT_ENFORCE(total_count == static_cast<int64_t>(values.size()),
"The total count of dims does not match the size of values. ",
"total_count: ", total_count, " values.size(): ", values.size());

const_tensor.set_raw_data(values.data(), values.size() * sizeof(int64_t));
return &graph_utils::AddInitializer(graph, const_tensor);
Expand Down
83 changes: 67 additions & 16 deletions onnxruntime/core/optimizer/compute_optimizer/upstream_gather.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,65 @@ SliceInfo UpStreamGatherGraphTransformer::PropagateSlicingForInput(
std::to_string(!info.is_scalar_slice));

InlinedVector<NodeArg*> input_args;
input_args.reserve(slice_node.InputDefs().size());
input_args.resize(slice_node.InputDefs().size());

int axis_input_index = -1; // -1 means axis is passed in attribute.
if (std::holds_alternative<int>(info.axis_attr_name_or_input_index)) {
axis_input_index = std::get<int>(info.axis_attr_name_or_input_index);
}

auto create_axes_input = [&info, new_axis, &graph]() -> NodeArg* {
InlinedVector<int64_t> dims;
if (info.rank_of_axis_value == 1) {
dims.push_back(1);
}
return CreateInitializerFromVector(graph, dims, {new_axis}, graph.GenerateNodeArgName("axes"));
};

// The first slice op's data input should be current_node's current_node_input_index-th input.
// For some cases when rank changes, slice op's slice input should also be adapted.
input_args.push_back(current_node.MutableInputDefs()[current_node_input_index]);
for (size_t i = 1; i < slice_node.InputDefs().size(); ++i) {
input_args.push_back(slice_node.MutableInputDefs()[i]);
int i = 0;
for (; i < static_cast<int>(slice_node.InputDefs().size()); ++i) {
if (i == info.GetDataInputIndex()) {
input_args[i] = current_node.MutableInputDefs()[current_node_input_index];
} else if (axis_input_index != -1 && i == axis_input_index) {
if (info.non_negative_axis == new_axis) {
input_args[i] = slice_node.MutableInputDefs()[i];
} else {
input_args[i] = create_axes_input();
}
} else {
input_args[i] = slice_node.MutableInputDefs()[i];
}
}

// It is possible axes input is null.
if (axis_input_index != -1 && info.non_negative_axis != new_axis) {
for (; i <= axis_input_index; ++i) {
if (i == axis_input_index) {
input_args.push_back(create_axes_input());
} else {
NodeArg& empty_input = graph.GetOrCreateNodeArg("", nullptr);
input_args.push_back(&empty_input);
}
}
}

// Update the axis attribute if new_axis is not the same as the original slicing axis (which happens when data
// layout got changed by Transpose or Reshape ops)
onnxruntime::NodeAttributes attributes = slice_node.GetAttributes();
if (info.non_negative_axis != new_axis) {
attributes[info.axis_attr_name] =
ONNX_NAMESPACE::MakeAttribute(info.axis_attr_name, static_cast<int64_t>(new_axis));

if (axis_input_index == -1 && info.non_negative_axis != new_axis) {
std::string attr_name = std::get<std::string>(info.axis_attr_name_or_input_index);
if (info.rank_of_axis_value == 0) {
attributes[attr_name] =
ONNX_NAMESPACE::MakeAttribute(attr_name, static_cast<int64_t>(new_axis));
} else if (info.rank_of_axis_value == 1) {
attributes[attr_name] =
ONNX_NAMESPACE::MakeAttribute(attr_name, std::vector<int64_t>{static_cast<int64_t>(new_axis)});
} else {
ORT_THROW("Unexpected rank of axis attribute value: " + std::to_string(info.rank_of_axis_value));
}
}

InlinedVector<NodeArg*> output_args;
Expand Down Expand Up @@ -183,7 +228,8 @@ SliceInfo UpStreamGatherGraphTransformer::PropagateSlicingForInput(
auto new_slice_out_arg = new_slice_node->MutableOutputDefs()[new_slice_output_index_to_connect];
UpdateSliceOutputShape(*new_slice_out_arg, new_axis, info.output_dim_on_axis);

auto new_slice_info = SliceInfo(graph, new_slice_node, info.is_scalar_slice, info.axis_attr_name, new_axis);
auto new_slice_info = SliceInfo(graph, new_slice_node, info.is_scalar_slice, info.axis_attr_name_or_input_index,
new_axis, info.rank_of_axis_value);
new_slice_info.entry_node_name = info.entry_node_name;
new_slice_info.entry_slice_arg_name = info.entry_slice_arg_name;
return new_slice_info;
Expand Down Expand Up @@ -263,7 +309,8 @@ std::optional<SliceInfo> IsSupportedGatherND(Graph& graph, Node& node,
return std::nullopt;
}

return SliceInfo(graph, &node, false, "batch_dims", static_cast<int>(batch_dims), true);
return SliceInfo(graph, &node, false, "batch_dims", static_cast<int>(batch_dims),
0 /* rank of axis attribute value */, true);
}

std::optional<SliceInfo> IsSupportedGather(Graph& graph, Node& node,
Expand Down Expand Up @@ -304,7 +351,7 @@ std::optional<SliceInfo> IsSupportedGather(Graph& graph, Node& node,
}
}

return SliceInfo(graph, &node, dim_size == 0, "axis", axis, true);
return SliceInfo(graph, &node, dim_size == 0, "axis", axis, 0 /* rank of axis attribute value */, true);
}

std::optional<SliceInfo> IsSupportedShrunkenGather(Graph& graph, Node& node,
Expand Down Expand Up @@ -342,7 +389,7 @@ std::optional<SliceInfo> IsSupportedShrunkenGather(Graph& graph, Node& node,
return std::nullopt;
}

return SliceInfo(graph, &node, false /*is_slice_scalar*/, "axis", axis, true);
return SliceInfo(graph, &node, false /*is_slice_scalar*/, "axis", axis, 0 /* rank of axis attribute value */, true);
}

/**
Expand All @@ -366,42 +413,46 @@ std::optional<SliceInfo> IsSupportedSlice(Graph& graph, Node& node,
const NodeArg* axes_input = node.InputDefs().size() > 3 ? node.InputDefs()[3] : nullptr;

if (data_input->Shape() == nullptr || starts_input->Shape() == nullptr || ends_input->Shape() == nullptr ||
(axes_input && axes_input->Shape() == nullptr)) {
(axes_input && axes_input->Exists() && axes_input->Shape() == nullptr)) {
LOG_DEBUG_INFO(logger, "Skip Slice node " + node.Name() + " due to undefined shape.");
return std::nullopt;
}

// Make sure starts/ends/axes/steps are all 1D tensors, since we only support single-dimension slicing.
if (starts_input->Shape()->dim_size() != 1 || ends_input->Shape()->dim_size() != 1 ||
(axes_input && axes_input->Shape()->dim_size() != 1)) {
(axes_input && axes_input->Exists() && axes_input->Shape()->dim_size() != 1)) {
LOG_DEBUG_INFO(logger, "Skip Slice node " + node.Name() + " due to unsupported dim size: " +
std::to_string(starts_input->Shape()->dim_size()) + ", " +
std::to_string(ends_input->Shape()->dim_size()) + ", " +
std::to_string(axes_input ? axes_input->Shape()->dim_size() : 0));
std::to_string(axes_input && axes_input->Exists() ? axes_input->Shape()->dim_size() : 0));
return std::nullopt;
}

// Try to parse the 'axes' value.
int axis = 0;
if (axes_input) {
if (axes_input && axes_input->Exists()) {
InlinedVector<int64_t> axes_values;
if (!graph_utils::IsConstantInitializer(graph, axes_input->Name()) ||
!optimizer_utils::AppendTensorFromInitializer(graph, *axes_input, axes_values, true) ||
axes_values.size() != 1) {
LOG_DEBUG_INFO(logger, "Skip Slice node " + node.Name() + " due to unsupported axes value.");
return std::nullopt;
}
axis = static_cast<int>(axes_values[0]);
} else {
// If 'axes' is not specified, then it is [0, .., r-1], so we force data rank to be 1.
if (data_input->Shape()->dim_size() != 1) {
LOG_DEBUG_INFO(logger, "Skip Slice node " + node.Name() + " due to unsupported data rank: " +
std::to_string(data_input->Shape()->dim_size()));
return std::nullopt;
}
}

if (axis < 0)
axis += data_input->Shape()->dim_size();

return SliceInfo(graph, &node, false /*is_slice_scalar*/, "axis", axis, true);
return SliceInfo(graph, &node, false /*is_slice_scalar*/, 3 /* axis input index */, axis,
1 /* rank of axes value */, true);
}

} // namespace
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,21 @@ struct SliceInfo : public UpstreamOperatorInfoBase {
public:
SliceInfo(const Graph& graph, Node* slice_node,
bool is_slice_scalar,
const std::string& slice_axis_attr_name,
std::variant<std::string, int> axis_name_or_index,
int slice_axis,
int rank_of_axis,
bool is_entry_node_ptr = false)
: UpstreamOperatorInfoBase(slice_node, is_entry_node_ptr), is_scalar_slice(is_slice_scalar) {
axis_attr_name = slice_axis_attr_name;
axis_attr_name_or_input_index = axis_name_or_index;
rank_of_axis_value = rank_of_axis;

if (std::holds_alternative<int>(axis_name_or_index)) {
int axis_input_index = std::get<int>(axis_name_or_index);
ORT_ENFORCE(axis_input_index >= 0, "Axis input index is invalid");
}

ORT_ENFORCE(rank_of_axis_value == 0 || rank_of_axis_value == 1, "Rank of axis value is invalid: " +
std::to_string(rank_of_axis_value));

const NodeArg* input = node_ptr->InputDefs()[kSliceDataInputIndex_];
const NodeArg* output = node_ptr->OutputDefs()[kSliceOutputIndex_];
Expand Down Expand Up @@ -65,8 +75,16 @@ struct SliceInfo : public UpstreamOperatorInfoBase {
}

bool is_scalar_slice; // whether the slice is a scalar, if it is after Gather, the rank will be reduced by 1.
std::string axis_attr_name;

// The index of the input that contains the axis value. If it is a string, then axis will be treated as an attribute.
std::variant<std::string, int> axis_attr_name_or_input_index;

int non_negative_axis; // The axis to slice on

// The rank of value for axis attribute. For example, for Gather, its axis attribute is a scalar, so the rank is 0.
// For Slice, its axes attribute is a 1D tensor, so the rank is 1.
int rank_of_axis_value;

std::string entry_slice_arg_name;

int input_rank; // rank of the Gather data input tensor
Expand Down
Loading

0 comments on commit f2c0470

Please sign in to comment.