Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix slice upstream - Incompatible dimensions (microsoft#16818)
### Fix slice upstream - (MatMul) [ShapeInferenceError] Incompatible dimensions ``` 2023-07-22 14:58:16.918478478 [I:onnxruntime:Default, constant_sharing.cc:256 ApplyImpl] Total shared scalar initializer count: 10 2023-07-22 14:58:16.919494252 [W:onnxruntime:Default, graph.cc:108 MergeShapeInfo] Error merging shape info for output. 'onnx::Cast_424' source:{-1,31,-1,-1} target:{-1,32,-1,-1}. Falling back to lenient merge. 2023-07-22 14:58:16.921014114 [W:onnxruntime:Default, graph.cc:108 MergeShapeInfo] Error merging shape info for output. 'onnx::MatMul_425' source:{-1,31,-1,-1} target:{-1,32,-1,-1}. Falling back to lenient merge. Traceback (most recent call last): File "examples/onnxruntime/training/language-modeling/run_clm.py", line 594, in <module> main() File "examples/onnxruntime/training/language-modeling/run_clm.py", line 542, in main train_result = trainer.train(resume_from_checkpoint=checkpoint) File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 454, in train return inner_training_loop( File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 755, in _inner_training_loop tr_loss_step = self.training_step(model, inputs) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/transformers/trainer.py", line 2735, in training_step loss = self.compute_loss(model, inputs) File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 363, in compute_loss return model_with_loss(dict_inputs, return_outputs) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/deepspeed/utils/nvtx.py", line 15, in wrapped_fn ret_val = func(*args, **kwargs) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 1724, in forward loss = self.module(*inputs, **kwargs) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl return forward_call(*input, **kwargs) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_utils.py", line 384, in _forward return ortmodule._torch_module.forward(*inputs, **kwargs) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_utils.py", line 364, in _forward return torch_module_ort._execution_manager(torch_module_ort.is_training()).forward(*inputs, **kwargs) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 345, in forward self._fallback_manager.handle_exception( File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_fallback.py", line 157, in handle_exception raise exception File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 280, in forward self._build_graph(graph_transformer_config) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_logger.py", line 218, in wrapper result = func(graph_execution_manager, *args, **kwargs) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_training_manager.py", line 360, in _build_graph super()._build_graph(graph_transformer_config) File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_graph_execution_manager.py", line 186, in _build_graph self._graph_builder.build(config) RuntimeError: /bert_ort/pengwa/onnxruntime/orttraining/orttraining/python/orttraining_pybind_state.cc:823 onnxruntime::python::addObjectMethodsForTraining(pybind11::module&, onnxruntime::python::ExecutionProviderRegistrationFn)::<lambda(onnxruntime::training::OrtModuleGraphBuilder*, const onnxruntime::training::TrainingGraphTransformerConfiguration&)> [ONNXRuntimeError] : 1 : FAIL : Node (MatMul_403) Op (MatMul) [ShapeInferenceError] Incompatible dimensions ``` Missed using `axis` attribute for `Slice` op, so change to use `axes` inputs instead. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
- Loading branch information