Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.aten.native_layer_norm to linalg #569

Closed
Tracked by #347
saienduri opened this issue Mar 28, 2024 · 17 comments
Closed
Tracked by #347

torch.aten.native_layer_norm to linalg #569

saienduri opened this issue Mar 28, 2024 · 17 comments
Assignees

Comments

@saienduri
Copy link
Contributor

saienduri commented Mar 28, 2024

We are having following problems with onnx lowering into tensor.expand_shape in these models (beit-base-patch16-224-pt22k-ft22k, deit-small-distilled-patch16-224, vit-base-patch16-224). Repro instructions can be found here::

beit-base-patch16-224-pt22k-ft22k.default.pytorch.torch.mlir:336:36: error: 'tensor.expand_shape' op invalid to have a single dimension (0) expanded into multiple dynamic dims (0,1)
    %result0, %result1, %result2 = torch.aten.native_layer_norm %293, %294, %10, %11, %float9.999990e-13 : !torch.vtensor<[?,?,768],f32>, !torch.list<int>, !torch.vtensor<[768],f32>, !torch.vtensor<[768],f32>, !torch.float -> !torch.vtensor<[?,?,768],f32>, !torch.vtensor<[?,?,1],f32>, !torch.vtensor<[?,?,1],f32>
                                   ^
beit-base-patch16-224-pt22k-ft22k.default.pytorch.torch.mlir:336:36: note: see current operation: %295 = "tensor.expand_shape"(%294) <{reassociation = [[0, 1]]}> : (tensor<?xf32>) -> tensor<?x?xf32>
deit-small-distilled-patch16-224.default.pytorch.torch.mlir:352:36: error: 'tensor.expand_shape' op expected dimension 0 of collapsed type to be dynamic since one or more of the corresponding dimensions in the expanded type is dynamic
    %result0, %result1, %result2 = torch.aten.native_layer_norm %297, %298, %11, %12, %float9.999990e-13 : !torch.vtensor<[?,198,384],f32>, !torch.list<int>, !torch.vtensor<[384],f32>, !torch.vtensor<[384],f32>, !torch.float -> !torch.vtensor<[?,198,384],f32>, !torch.vtensor<[?,198,1],f32>, !torch.vtensor<[?,198,1],f32>
                                   ^
deit-small-distilled-patch16-224.default.pytorch.torch.mlir:352:36: note: see current operation: %263 = "tensor.expand_shape"(%262) <{reassociation = [[0, 1]]}> : (tensor<198xf32>) -> tensor<?x198xf32>
vit-base-patch16-224.default.pytorch.torch.mlir:316:36: error: 'tensor.expand_shape' op expected dimension 0 of collapsed type to be dynamic since one or more of the corresponding dimensions in the expanded type is dynamic
    %result0, %result1, %result2 = torch.aten.native_layer_norm %272, %273, %10, %11, %float9.999990e-13 : !torch.vtensor<[?,197,768],f32>, !torch.list<int>, !torch.vtensor<[768],f32>, !torch.vtensor<[768],f32>, !torch.float -> !torch.vtensor<[?,197,768],f32>, !torch.vtensor<[?,197,1],f32>, !torch.vtensor<[?,197,1],f32>
                                   ^
vit-base-patch16-224.default.pytorch.torch.mlir:316:36: note: see current operation: %259 = "tensor.expand_shape"(%258) <{reassociation = [[0, 1]]}> : (tensor<197xf32>) -> tensor<?x197xf32>
@AmosLewis AmosLewis changed the title tensor.expand_shape torch.aten.native_layer_norm to linalg Apr 19, 2024
@renxida
Copy link
Contributor

renxida commented May 6, 2024

Reproing with

HF_TOKEN=... python run.py --torchmlirbuild /path/to/torch-mlir/build --ireebuild /path/to/iree-build --cachedir ~/.cache/huggingface --tests pytorch/models/deit-small-distilled-patch16-224 -r test-onnx --tolerance .001 .001 --mode onnx --report

@renxida
Copy link
Contributor

renxida commented May 8, 2024

My attempt at a minimal repro example:

func.func @native_layer_norm(%input: !torch.vtensor<[?,198,384],f32>, %weight: !torch.vtensor<[384],f32>, %bias: !torch.vtensor<[384],f32>, %eps: !torch.float) -> (!torch.vtensor<[?,198,384],f32>, !torch.vtensor<[?,198,1],f32>, !torch.vtensor<[?,198,1],f32>) {
  %int384 = torch.constant.int 384
  %normalized_shape = torch.prim.ListConstruct %int384 : (!torch.int) -> !torch.list<int>
  %result0, %result1, %result2 = torch.aten.native_layer_norm %input, %normalized_shape, %weight, %bias, %eps : !torch.vtensor<[?,198,384],f32>, !torch.list<int>, !torch.vtensor<[384],f32>, !torch.vtensor<[384],f32>, !torch.float -> !torch.vtensor<[?,198,384],f32>, !torch.vtensor<[?,198,1],f32>, !torch.vtensor<[?,198,1],f32>
  return %result0, %result1, %result2 : !torch.vtensor<[?,198,384],f32>, !torch.vtensor<[?,198,1],f32>, !torch.vtensor<[?,198,1],f32>
}

For some reason no compile error on this. Succeedes without an issue:
/home/azureuser/iree-build/tools/iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu minimal.mlir > minimal.vmfb

@renxida
Copy link
Contributor

renxida commented May 8, 2024

Repro log with:

  • original
  • stripped
  • minimal example
+ echo --- reproing with original model mlir ---
--- reproing with original model mlir ---
+ /home/azureuser/iree-build/tools/iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu deit-small-distilled-patch16-224.default.pytorch.torch.mlir -o deit-small-distilled-patch16-224.default.vmfb
/home/azureuser/iree-build/tools/iree-compile: /home/azureuser/miniconda/lib/libtinfo.so.6: no version information available (required by /home/azureuser/iree-build/lib/libIREECompiler.so)
deit-small-distilled-patch16-224.default.pytorch.torch.mlir:364:36: error: 'tensor.expand_shape' op expected dimension 0 of collapsed type to be dynamic since one or more of the corresponding dimensions in the expanded type is dynamic
    %result0, %result1, %result2 = torch.aten.native_layer_norm %309, %310, %11, %12, %float9.999990e-13 : !torch.vtensor<[?,198,384],f32>, !torch.list<int>, !torch.vtensor<[384],f32>, !torch.vtensor<[384],f32>, !torch.float -> !torch.vtensor<[?,198,384],f32>, !torch.vtensor<[?,198,1],f32>, !torch.vtensor<[?,198,1],f32>
                                   ^
deit-small-distilled-patch16-224.default.pytorch.torch.mlir:364:36: note: see current operation: %263 = "tensor.expand_shape"(%262) <{reassociation = [[0, 1]]}> : (tensor<198xf32>) -> tensor<?x198xf32>
+ echo Return code: 1
Return code: 1
+ echo --- reproing with stripped model mlir ---
--- reproing with stripped model mlir ---
+ /home/azureuser/iree-build/tools/iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu stripped/deit-small-distilled-patch16-224.default.pytorch.torch.stripped.mlir -o deit-small-distilled-patch16-224.default.stripped.vmfb
/home/azureuser/iree-build/tools/iree-compile: /home/azureuser/miniconda/lib/libtinfo.so.6: no version information available (required by /home/azureuser/iree-build/lib/libIREECompiler.so)
stripped/deit-small-distilled-patch16-224.default.pytorch.torch.stripped.mlir:364:36: error: 'tensor.expand_shape' op expected dimension 0 of collapsed type to be dynamic since one or more of the corresponding dimensions in the expanded type is dynamic
    %result0, %result1, %result2 = torch.aten.native_layer_norm %309, %310, %11, %12, %float9.999990e-13 : !torch.vtensor<[?,198,384],f32>, !torch.list<int>, !torch.vtensor<[384],f32>, !torch.vtensor<[384],f32>, !torch.float -> !torch.vtensor<[?,198,384],f32>, !torch.vtensor<[?,198,1],f32>, !torch.vtensor<[?,198,1],f32>
                                   ^
stripped/deit-small-distilled-patch16-224.default.pytorch.torch.stripped.mlir:364:36: note: see current operation: %73 = "tensor.expand_shape"(%72) <{reassociation = [[0, 1]]}> : (tensor<198xf32>) -> tensor<?x198xf32>
+ echo Return code: 1
Return code: 1
+ echo --- reproing with minimal example ---
--- reproing with minimal example ---
+ /home/azureuser/iree-build/tools/iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu stripped/minimal.mlir -o minimal.vmfb
/home/azureuser/iree-build/tools/iree-compile: /home/azureuser/miniconda/lib/libtinfo.so.6: no version information available (required by /home/azureuser/iree-build/lib/libIREECompiler.so)
/home/azureuser/iree-build/tools/iree-lld: /home/azureuser/miniconda/lib/libtinfo.so.6: no version information available (required by /home/azureuser/iree-build/lib/libIREECompiler.so)
+ echo Return code: 0
Return code: 0

@renxida
Copy link
Contributor

renxida commented May 8, 2024

IR dump after failure (with failing line highlighted)
https://gist.github.com/renxida/fdcec365e89ae317fdd8545fbd4ddaf8#file-gistfile0-txt-L117

Relevant section:

  %72 = "linalg.generic"(%69, %71) <{indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<reduction>], operandSegmentSizes = array<i32: 1, 1>}> ({
  ^bb0(%arg3: f32, %arg4: f32):
    %1595 = "arith.addf"(%arg3, %arg4) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
    "linalg.yield"(%1595) : (f32) -> ()
  }) : (tensor<198x384xf32>, tensor<198xf32>) -> tensor<198xf32>
  %73 = "tensor.expand_shape"(%72) <{reassociation = [[0, 1]]}> : (tensor<198xf32>) -> tensor<?x198xf32>
  %74 = "tensor.empty"(%53) : (index) -> tensor<?x198xf32>
  %75 = "linalg.generic"(%73, %74) <{indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = [#linalg.iterator_type<parallel>, #linalg.iterator_type<parallel>], operandSegmentSizes = array<i32: 1, 1>}> ({

@renxida
Copy link
Contributor

renxida commented May 8, 2024

TLDR: Looks like we generate the problematic tensor.expand_shape during a FoldUnitExtentDims pass

Print-ir-after-all produces a 15MB text file.

/home/azureuser/iree-build/tools/iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu  stripped/deit-small-distilled-patch16-224.default.pytorch.torch.stripped.mlir -o deit-small-distilled-patch16-224.default.stripped.vmfb   --mlir-pretty-debuginfo         --mlir-print-ir-after-all &> printirafterall.mlir

But we can grep

┌─[0]─[azureuser@xida-cpu-0]─[~/SHARK-TestSuite/e2eshark/test-onnx/pytorch/models/deit-small-distilled-patch16-224]
└──╼ $cat printirafterall.mlir | grep -E '\(tensor<198xf32>\)|IR Dump'

<... a bunch of ir dump after messages with no matches to the tensor<198xf32> pattern ...>

// -----// IR Dump After ConvertElementwiseToLinalgPass (convert-elementwise-to-linalg) //----- //
// -----// IR Dump After RaiseSpecialOps (iree-global-opt-raise-special-ops) //----- //
// -----// IR Dump After DecomposeConcat (iree-global-opt-decompose-concat) //----- //
// -----// IR Dump After GeneralizeLinalgNamedOps (iree-global-opt-generalize-linalg-named-ops) //----- //
// -----// IR Dump After FoldUnitExtentDims Failed (iree-flow-fold-unit-extent-dims) //----- //
  %73 = "tensor.expand_shape"(%72) <{reassociation = [[0, 1]]}> : (tensor<198xf32>) -> tensor<?x198xf32>
  %86 = "tensor.expand_shape"(%85) <{reassociation = [[0, 1]]}> : (tensor<198xf32>) -> tensor<?x198xf32>
stripped/deit-small-distilled-patch16-224.default.pytorch.torch.stripped.mlir:364:36: note: see current operation: %73 = "tensor.expand_shape"(%72) <{reassociation = [[0, 1]]}> : (tensor<198xf32>) -> tensor<?x198xf32>

@renxida
Copy link
Contributor

renxida commented May 8, 2024

With the minimal ir, it correctly preserves the <?x198xf32>. See output at: https://gist.github.com/renxida/e4347a1ef027e9bf7ed3487e7d87d577

Command:

/home/azureuser/iree-build/tools/iree-compile --iree-input-demote-i64-to-i32 --iree-hal-target-backends=llvm-cpu  stripped/minimal.mlir -o minimal.vmfb   --mlir-pretty-debuginfo --mlir-print-ir-before=iree-flow-fold-unit-extent-dims --mlir-print-ir-after=iree-flow-fold-unit-extent-dims |& gh gist create - -d "minimal ir before and after FoldUnitExtentDims"

@renxida
Copy link
Contributor

renxida commented May 8, 2024

A search of the error message indicates that the error is generated by

https://github.com/shark-infra/llvm-project/blob/b3291793f11924a3b62601aabebebdcfbb12a9a1/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp#L250-L257

This function LogicalResult mlir::reshapeLikeShapesAreCompatible is used for both the tensor.expand_shape and tensor.collapse_shape ops.

But this particular check

    if (dynamicShape) {
      if (!ShapedType::isDynamic(collapsedShape[map.index()])) {
        return emitError(
            "expected dimension " + Twine(map.index()) +
            " of collapsed type to be dynamic since one or more of the "
            "corresponding dimensions in the expanded type is dynamic");
      }
    }

may not make sense - - i can see a world where we allow collapsing unknown-shape dims.

@renxida
Copy link
Contributor

renxida commented May 8, 2024

@renxida
Copy link
Contributor

renxida commented May 8, 2024

previously i thought the problem is introduced by iree-flow-fold-unit-extent-dims because the expand_shape wasn't there before the pass and after the pass it was there.

but then after slogging through it a bunch with stanley i noticed a couple of weird things in the ir from right before expand_shape:

  1. a linalg generic that performed a reduce-sum over the last dim of a tensor, yet converted dim 0 from 1 to ?: tensor<1x198x384xf32>) outs(%27 : tensor<?x198x1xf32> (at: https://gist.github.com/renxida/cff78af3cc3dfde1e7755a8a6cdf24b9#file-before-expand-shape-mlir-L113) [^1]

  2. a cast that turned a tensor with known shapes to unknown shapes %cast = tensor.cast %8 : tensor<1x196x384xf32> to tensor<?x?x?xf32> (see: https://gist.github.com/renxida/cff78af3cc3dfde1e7755a8a6cdf24b9#file-before-expand-shape-mlir-L64)

@renxida
Copy link
Contributor

renxida commented May 9, 2024

looks like (1) from the last comment is introduced right after memref-expand:

// -----// IR Dump After ConvertTorchConversionToMLProgram (convert-torch-conversion-to-mlprogram) //----- //
// -----// IR Dump After ExpandOps (memref-expand) //----- //
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
  %100 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%97 : tensor<1x198x384xf32>) outs(%99 : tensor<?x198x1xf32>) {
  %111 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%108 : tensor<1x198x384xf32>) outs(%110 : tensor<?x198x1xf32>) {

Full logs at

https://gist.github.com/renxida/4e1fdcf2cd2b04a9001462b45096323d

@renxida
Copy link
Contributor

renxida commented May 13, 2024

Got sidetracked from (1) a little bit.

There are a lot of linalg generics that seems to turn known shapes into unknown shapes:

cat printirafterall.mlir | grep -E 'tensor<1x198x?x.*outs.*\?x198|IR Dump'

Result:

// -----// IR Dump After ConvertTorchToLinalg (convert-torch-to-linalg) //----- //
  %533 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_641, %cast_153 : tensor<1x198x384xf32>, tensor<?x198x384xf32>) outs(%532 : tensor<?x198x384xf32>) {
  %855 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cast_1398, %cast_906 : tensor<1x198x384xf32>, tensor<?x198x384xf32>) outs(%854 : tensor<?x198x384xf32>) {

there are things like this that come relatively early in the pipeline (right after convert-torch-to-linalg).

These make sense because they're broadcasting an elementwise op between a 1x198x384 and a ?x194x384 to produce a ?x198x384

What doesn't make sense is when we have a 1x198x384 converted to a ?x198x1 with a single-operand reduce sum.

Need to filter specifically for those.

cat printirafterall.mlir | grep -E 'ins(%\d+tensor<1x198x?x.outs.?x198|IR Dump'

@renxida
Copy link
Contributor

renxida commented May 13, 2024

Grepping for specifically one-input linalg.generic that convert known shapes into unknown shapes:

cat printirafterall-input.mlir | grep -E 'ins\(%[0-9]+ : tensor<1x198x384xf32>\) outs\(%[0-9]+ : tensor<\?x198x1xf32>\)|IR Dump'
// -----// IR Dump After ConvertTorchOnnxToTorch (convert-torch-onnx-to-torch) //----- //
// -----// IR Dump After SetStrictSymbolicShapesPass (torch-iree-set-strict-symbolic-shapes) //----- //
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
// -----// IR Dump After BitCastQuantTensorPass (torch-iree-bitcast-quant-tensor) //----- //
// -----// IR Dump After ReduceOpVariants (torch-reduce-op-variants) //----- //
// -----// IR Dump After ConvertCustomQuantOp (torch-convert-custom-quant-op) //----- //
// -----// IR Dump After DecomposeComplexOps (torch-decompose-complex-ops) //----- //
// -----// IR Dump After ConvertTorchToTMTensor (convert-torch-to-tmtensor) //----- //
// -----// IR Dump After ConvertTMTensorToLinalgExt (torch-iree-tm-tensor-to-linalg-ext) //----- //
// -----// IR Dump After ConvertTorchToTensor (convert-torch-to-tensor) //----- //
// -----// IR Dump After ConvertTorchToLinalg (convert-torch-to-linalg) //----- //
// -----// IR Dump After ConvertTorchToSCF (convert-torch-to-scf) //----- //
// -----// IR Dump After ConvertTorchToArith (convert-torch-to-arith) //----- //
// -----// IR Dump After ConvertTorchConversionToMLProgram (convert-torch-conversion-to-mlprogram) //----- //
// -----// IR Dump After ExpandOps (memref-expand) //----- //
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
  %100 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%97 : tensor<1x198x384xf32>) outs(%99 : tensor<?x198x1xf32>) {
  %111 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%108 : tensor<1x198x384xf32>) outs(%110 : tensor<?x198x1xf32>) {
// -----// IR Dump After ResolveShapedTypeResultDims (resolve-shaped-type-result-dims) //----- //
  %100 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%97 : tensor<1x198x384xf32>) outs(%99 : tensor<?x198x1xf32>) {
  %111 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%108 : tensor<1x198x384xf32>) outs(%110 : tensor<?x198x1xf32>) {
// -----// IR Dump After CSE (cse) //----- //
  %74 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%71 : tensor<1x198x384xf32>) outs(%73 : tensor<?x198x1xf32>) {
  %79 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%78 : tensor<1x198x384xf32>) outs(%73 : tensor<?x198x1xf32>) {
// -----// IR Dump After Canonicalizer (canonicalize) //----- //
  %74 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%71 : tensor<1x198x384xf32>) outs(%73 : tensor<?x198x1xf32>) {
  %79 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%78 : tensor<1x198x384xf32>) outs(%73 : tensor<?x198x1xf32>) {

@renxida
Copy link
Contributor

renxida commented May 14, 2024

this is an issue caused by an incomplete shape inference during canonicalization that only assigned 1 to dim 0 of the input but not the output, while operating on a linalg.generic op that really only reduce-sums the last dim.

After ExpandOps and before canonicalization, the IR section looks like this (full file here):

%338 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%cast_173 : tensor<?x198x384xf32>) outs(%337 : tensor<?x?x1xf32>) {
  ^bb0(%in: f32, %out: f32):
    %4260 = arith.addf %in, %out : f32
    linalg.yield %4260 : f32
  } -> tensor<?x?x1xf32>

the canonicalization pass after ExpandOps above fills in the input shapes as follows:

  %100 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%97 : tensor<1x198x384xf32>) outs(%99 : tensor<?x198x1xf32>) {
  ^bb0(%in: f32 loc("after-memref-expand.mlir":587:8), %out: f32 loc("after-memref-expand.mlir":587:18)):
    %1984 = arith.addf %in, %out : f32 loc("after-memref-expand.mlir":588:13)
    linalg.yield %1984 : f32 loc("after-memref-expand.mlir":589:5)
  } -> tensor<?x198x1xf32> loc("after-memref-expand.mlir":586:10)

the eventual tensor.expand_shape error is caused by ins(%97 : tensor<1x198x384xf32>) outs(%99 : tensor<?x198x1xf32>) having and ? for the shapes.

so now the question is, is this linalg generic not supposed to have a 1->? or is the UnitExtentDims pass supposed to be able to be able to handle the 1->?

@renxida
Copy link
Contributor

renxida commented May 15, 2024

Summary so far:

  1. original error: lowering native_layer_norm generated a tensor.expand_shape expanding a tensor<198xf32> to tensor<?x198xf32>
  2. this tensor.expand_shape causes an error because expand_shape expects only ? to correspond to ? dims, and not expanding a dim out to a ?. this happens during the pass that eliminates unit-extent dims.
  3. the tensor.expand_shape is materialized in between when the following pair of linalg ops are lowered:
// reduce sum over the last dim
// iterator types = parallel, parallel, reduction
%28 = linalg.generic ... ins(%25 : tensor<1x198x384xf32>) outs(%27 : tensor<?x198x1xf32>) ...

// elementwise division by a scalar
// iterator types = all parallel
%29 = linalg.generic ... ins(%28 : tensor<?x198x1xf32>) outs(%26 : tensor<?x198x1xf32>) ... 
  1. The root cause seems to be a linalg.generic that performs a reduce-sum with: ins(%25 : tensor<1x198x384xf32>) outs(%27 : tensor<?x198x1xf32>). Note that the reduce sum is performed on the 384 extent dim, but the first dim has input 1 but output ?
  2. This doesn't make sense because we don't really touch the 1st dim and if the input is a 1, the output would also have to be a 1.
  3. The reduce sum originally looked like this:
    %338 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, 0)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%cast_173 : tensor<?x198x384xf32>) outs(%337 : tensor<?x?x1xf32>)
    and the ? for the input is filled in by shape inference, but the ? for the out remained a ?.

@renxida
Copy link
Contributor

renxida commented May 28, 2024

Latest status: after updating to the latest torch-mlir and iree, encountering an Unsqueeze error. Either something fixed this, or an Unsqueeze error masked this.

@renxida
Copy link
Contributor

renxida commented May 28, 2024

Next steps to fixing this:

  1. get the newest torch_mlir and iree
  2. try running it through to see if it shows a Unsqueeze error or a native_layer_norm errror
  3. if it's with Unsqueeze, figure out who's currently working on unsqueeze and ask them to see if it's related to what they're currently working on
  4. otherwise, read through the above notes and try to fix the shape inference issue with the linalg.generic that performs a reduce sum

@AmosLewis
Copy link
Contributor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants