Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Co-authored-by: Megan Hampton <[email protected]>
  • Loading branch information
hamptonm1 and MegoHam21 authored Dec 19, 2024
1 parent 9f9de36 commit b800036
Show file tree
Hide file tree
Showing 16 changed files with 43 additions and 87 deletions.
2 changes: 1 addition & 1 deletion docs/BuildOnLinuxOSX.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Firstly, install MLIR (as a part of LLVM-Project):
``` bash
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout 00128a20eec27246719d73ba427bf821883b00b4 && cd ..
cd llvm-project && git checkout 01d233ff403823389f8480897e41aea84ecbb3d3 && cd ..
```

[same-as-file]: <> (utils/build-mlir.sh)
Expand Down
2 changes: 1 addition & 1 deletion docs/BuildOnWindows.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Install MLIR (as a part of LLVM-Project):
```shell
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout 00128a20eec27246719d73ba427bf821883b00b4 && cd ..
cd llvm-project && git checkout 01d233ff403823389f8480897e41aea84ecbb3d3 && cd ..
```

[same-as-file]: <> (utils/build-mlir.cmd)
Expand Down
3 changes: 2 additions & 1 deletion docs/Dialects/krnl.md
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ _Indicate ONNX entry point_

The "krnl.entry_point" function indicates the main entry
point of ONNX model.

### `krnl.erf` (KrnlErfOp)

_Krnl erf scalar operation_
Expand Down Expand Up @@ -453,7 +454,7 @@ in the `value` dense element attribute.

Traits: `AlwaysSpeculatableImplTrait`, `MemRefsNormalizable`

Interfaces: `ConditionallySpeculatable`, `KrnlGlobalOpInterface`, `NoMemoryEffect (MemoryEffectOpInterface)`
Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`

Effects: `MemoryEffects::Effect{}`

Expand Down
3 changes: 0 additions & 3 deletions docs/Dialects/zhigh.md
Original file line number Diff line number Diff line change
Expand Up @@ -793,8 +793,6 @@ Effects: `MemoryEffects::Effect{}`
_ZHigh Stickified Constant operation_

This operator produces a constant tensor to store stickified data.
`value` attribute has original constant or stickified constant.
`stickified` attribute indicates the `value` is already stickified or not.
Stickified data is opaque and must be 4K-aligned. One who produces
the stickified data must make sure its size in bytes consistent with
the output tensor's size.
Expand All @@ -809,7 +807,6 @@ Effects: `MemoryEffects::Effect{}`

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>stickified</code></td><td>::mlir::BoolAttr</td><td>bool attribute</td></tr>
<tr><td><code>value</code></td><td>::mlir::Attribute</td><td>any attribute</td></tr>
<tr><td><code>alignment</code></td><td>::mlir::IntegerAttr</td><td>64-bit signless integer attribute</td></tr>
</table>
Expand Down
28 changes: 0 additions & 28 deletions docs/Dialects/zlow.md
Original file line number Diff line number Diff line change
Expand Up @@ -752,34 +752,6 @@ Interfaces: `MemoryEffectOpInterface`
| `X` | memref of 16-bit float or 32-bit float values
| `Out` | memref of dlfloat16 type values

### `zlow.stickifiedConstant` (::onnx_mlir::zlow::ZLowStickifiedConstantOp)

_ZLow Stickified Constant operation._


Traits: `MemRefsNormalizable`

Interfaces: `KrnlGlobalOpInterface`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>shape</code></td><td>::mlir::Attribute</td><td>any attribute</td></tr>
<tr><td><code>name</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
<tr><td><code>stickified</code></td><td>::mlir::BoolAttr</td><td>bool attribute</td></tr>
<tr><td><code>value</code></td><td>::mlir::Attribute</td><td>any attribute</td></tr>
<tr><td><code>layout</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
<tr><td><code>offset</code></td><td>::mlir::IntegerAttr</td><td>64-bit signless integer attribute</td></tr>
<tr><td><code>alignment</code></td><td>::mlir::IntegerAttr</td><td>64-bit signless integer attribute</td></tr>
</table>

#### Results:

| Result | Description |
| :----: | ----------- |
| `output` | memref of dlfloat16 type values

### `zlow.sub` (::onnx_mlir::zlow::ZLowSubOp)

_ZLow sub operation_
Expand Down
7 changes: 2 additions & 5 deletions src/Accelerators/NNPA/Transform/ZLow/ZLowStickExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,7 @@ class UnstickExpansionPattern : public OpRewritePattern<ZLowUnstickOp> {
// 64.
IndexExpr T = LitIE(2);
DimsExpr reallocTileDims = {T, lit64};
Value inputAsTx64 =
create.mem.reinterpretCast(input, litZero.getValue(), reallocTileDims);

Value inputAsTx64 = create.mem.reinterpretCast(input, reallocTileDims);
// Outer loop (E4, E3, E2, E1 iterates over tiles of 64 elements)
create.krnl.iterateIE(loopDefs, loopDefs, lbs, ubs,
[&](const KrnlBuilder &b, ValueRange loopInd) {
Expand Down Expand Up @@ -456,8 +454,7 @@ class StickExpansionPattern : public OpRewritePattern<ZLowStickOp> {
// 64 elements.
IndexExpr T = LitIE(2);
DimsExpr reallocTileDims = {T, lit64};
Value allocAsTx64 =
create.mem.reinterpretCast(alloc, litZero.getValue(), reallocTileDims);
Value allocAsTx64 = create.mem.reinterpretCast(alloc, reallocTileDims);

// Outer loop (E1 iterates over tiles of 64 elements).
create.krnl.iterateIE(loopDefs, loopDefs, lbs, ubs,
Expand Down
10 changes: 4 additions & 6 deletions src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -814,20 +814,18 @@ AffineTypeConverter::AffineTypeConverter() {
addConversion([](Type type) { return type; });

addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
ValueRange inputs, Location loc) -> Value {
if (inputs.size() != 1)
return std::nullopt;
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});

addTargetMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
ValueRange inputs, Location loc) -> Value {
if (inputs.size() != 1)
return std::nullopt;
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
Expand Down
10 changes: 4 additions & 6 deletions src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -547,20 +547,18 @@ KrnlTypeConverter::KrnlTypeConverter() {
});

addSourceMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
ValueRange inputs, Location loc) -> Value {
if (inputs.size() != 1)
return std::nullopt;
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
});

addTargetMaterialization([&](OpBuilder &builder, Type resultType,
ValueRange inputs,
Location loc) -> std::optional<Value> {
ValueRange inputs, Location loc) -> Value {
if (inputs.size() != 1)
return std::nullopt;
return Value();

return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
.getResult(0);
Expand Down
15 changes: 9 additions & 6 deletions src/Dialect/Mlir/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ Value MathBuilder::constant(Type type, double val) const {
b().create<arith::ConstantOp>(loc(), b().getF64FloatAttr(val));
})
.Case<IntegerType>([&](IntegerType elementType) {
assert(val == (int64_t)val && "value is ambiguous");
assert(val == static_cast<int64_t>(val) && "value is ambiguous");
unsigned width = elementType.getWidth();

if (width == 1)
Expand All @@ -628,11 +628,13 @@ Value MathBuilder::constant(Type type, double val) const {
if (elementType.isUnsignedInteger()) {
Type signlessTy = b().getIntegerType(width);
constant = b().create<arith::ConstantOp>(loc(),
b().getIntegerAttr(signlessTy, APInt(width, (int64_t)val)));
b().getIntegerAttr(signlessTy,
APInt(width, static_cast<int64_t>(val), false, true)));
constant = castToUnsigned(constant, width);
} else {
constant = b().create<arith::ConstantOp>(loc(),
b().getIntegerAttr(elementType, APInt(width, (int64_t)val)));
b().getIntegerAttr(elementType,
APInt(width, static_cast<int64_t>(val), false, true)));
}
}
})
Expand Down Expand Up @@ -695,7 +697,7 @@ TypedAttr MathBuilder::negativeInfAttr(Type type) const {
default:
llvm_unreachable("unsupported element type");
}
attr = b().getIntegerAttr(type, APInt(width, value));
attr = b().getIntegerAttr(type, APInt(width, value, false, true));
})
.Default([](Type) { llvm_unreachable("unsupported element type"); });
assert(attr != nullptr && "Expecting valid attribute");
Expand Down Expand Up @@ -740,7 +742,7 @@ TypedAttr MathBuilder::positiveInfAttr(Type type) const {
default:
llvm_unreachable("unsupported element type");
}
attr = b().getIntegerAttr(type, APInt(width, value));
attr = b().getIntegerAttr(type, APInt(width, value, false, true));
})
.Default([](Type) { llvm_unreachable("unsupported element type"); });
assert(attr != nullptr && "Expecting valid attribute");
Expand Down Expand Up @@ -2263,7 +2265,8 @@ Value LLVMBuilder::constant(Type type, int64_t val) const {
assert(type.isSignless() &&
"LLVM::ConstantOp requires a signless type.");
constant = b().create<LLVM::ConstantOp>(loc(), type,
b().getIntegerAttr(type, APInt(width, (int64_t)val)));
b().getIntegerAttr(
type, APInt(width, static_cast<int64_t>(val), false, true)));
}
})
.Case<IndexType>([&](Type) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ func.func @matmul_div(%arg0: tensor<?x12x?x64xf32>) -> tensor<?x?x?x?xf32> {
%r = "onnx.Div"(%m, %scalar) : (tensor<?x12x?x?xf32>, tensor<f32>) -> tensor<?x12x?x?xf32>
"onnx.Return"(%r) : (tensor<?x12x?x?xf32>) -> ()

// CHECK-LABEL: func.func @matmul_div
// CHECK-LABEL: func.func @matmul_div
// CHECK: memref.alloc
// CHECK: memref.alloc
// CHECK: [[ALLOC:%.+]] = memref.alloc({{.*}}) {{.*}}: memref<?x?x1x?x32x64xf16>
// CHECK-DAG: [[MATMUL_RES:%.+]] = memref.cast [[ALLOC]] : memref<?x?x1x?x32x64xf16> to memref<?x?x1x?x?x?xf16>
// CHECK: "zlow.matmul"({{.*}}, {{.*}}, {{.*}}, {{.*}}, [[MATMUL_RES]]) {is_bcast = 0 : si64, is_stacked = -1 : si64} : (memref<?x1x1x?x?x64xf16>, memref<?x?x1x2x32x?xf16>, memref<?x?x1x1x32x?xf16>, memref<4xi64>, memref<?x?x1x?x?x?xf16>) -> ()
// CHECK-NOT: "zlow.stick"
// CHECK-NOT: "zlow.unstick"
// CHECK-NOT: "zlow.stick"
// CHECK-NOT: "zlow.unstick"
// CHECK: "zlow.div"([[MATMUL_RES]], {{.*}}, {{.*}}, {{.*}}) {layout = "3DS"} : (memref<?x?x1x?x?x?xf16>, memref<?x?x1x?x?x?xf16>, memref<3xi64>, memref<?x?x1x?x?x?xf16>) -> ()
}
16 changes: 8 additions & 8 deletions test/mlir/accelerators/nnpa/driver/saturation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,31 @@
func.func @saturation(%arg0 : tensor<10x10xf32>) -> tensor<*xf32> {
%0 = "onnx.Relu"(%arg0) : (tensor<10x10xf32>) -> tensor<*xf32>
"func.return"(%0) : (tensor<*xf32>) -> ()
// ZHIGH_OFF-LABEL: func @saturation
// ZHIGH_OFF-LABEL: func @saturation
// ZHIGH_OFF: "zhigh.Stick"({{.*}}) {layout = "2D"} : {{.*}}

// ZHIGH_ON-LABEL: func @saturation
// ZHIGH_ON-LABEL: func @saturation
// ZHIGH_ON: "zhigh.Stick"({{.*}}) {layout = "2D", saturation = -1 : si64} : {{.*}}


// ZLOW_OFF-LABEL: func @saturation
// ZLOW_OFF-LABEL: func @saturation
// ZLOW_OFF: "zlow.stick"({{.*}}, {{.*}}) {layout = "2D"} : {{.*}}

// ZLOW_ON-LABEL: func @saturation
// ZLOW_ON-LABEL: func @saturation
// ZLOW_ON: "zlow.stick"({{.*}}, {{.*}}) {layout = "2D", saturation = -1 : si64} : {{.*}}

// DECOMPOSE_OFF-LABEL: func @saturation
// DECOMPOSE_OFF-LABEL: func @saturation
// DECOMPOSE_OFF: "zhigh.F32ToDLF16"(%arg0) : {{.*}}

// DECOMPOSE_ON-LABEL: func @saturation
// DECOMPOSE_ON-LABEL: func @saturation
// DECOMPOSE_ON: "zhigh.F32ToDLF16"(%arg0) {saturation = -1 : si64} : {{.*}}

// COMPILER_STICK_OFF-LABEL: func @saturation
// COMPILER_STICK_OFF-LABEL: func @saturation
// COMPILER_STICK_OFF-NOT: arith.minnumf
// COMPILER_STICK_OFF-NOT: arith.maxnumf
// COMPILER_STICK_OFF: zlow.relu

// COMPILER_STICK_ON-LABEL: func @saturation
// COMPILER_STICK_ON-LABEL: func @saturation
// COMPILER_STICK_ON: arith.minnumf
// COMPILER_STICK_ON: arith.maxnumf
// COMPILER_STICK_ON: zlow.relu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ module {
"onnx.EntryPoint"() {func = @main_graph} : () -> ()
}
// CHECK: {{.*}} opt {{.*}} -o {{.*}}.bc
// CHECK: {{.*}} llc {{.*}} {{.*}} {{.*}}.bc
// CHECK: {{.*}} llc {{.*}} {{.*}} {{.*}}.bc
// CHECK: {{.*}} {{clang|c|g}}++{{.*}} {{.*}}.o -o {{.*}}.so -shared -fPIC -L{{.*}}/lib -lRuntimeNNPA -lzdnn -lcruntime
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ func.func @test_stick_expansion_with_sat(%arg0: memref<16x8x128xf32>) -> memref<
// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[VAR_cst_:%.+]] = arith.constant dense<8.57315738E+9> : vector<4xf32>
// CHECK-DAG: [[VAR_cst_0_:%.+]] = arith.constant dense<-8.57315738E+9> : vector<4xf32>
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x8x128xf16, #map>
// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3
// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[RES_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [2, 64], strides: [64, 1] : memref<16x8x128xf16, #map> to memref<2x64xf16>
// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[RES_]] to offset: [0], sizes: [2, 64], strides: [64, 1] : memref<16x8x128xf16, #map> to memref<2x64xf16>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 16, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 8, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 2){
// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]#2)
Expand Down Expand Up @@ -124,10 +123,9 @@ func.func @test_stick_expansion_without_sat(%arg0: memref<16x8x128xf32>) -> memr
// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
// CHECK-DAG: [[CST_8_:%.+]] = arith.constant 8 : index
// CHECK-DAG: [[CST_4_:%.+]] = arith.constant 4 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x8x128xf16, #map>
// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3
// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[RES_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [2, 64], strides: [64, 1] : memref<16x8x128xf16, #map> to memref<2x64xf16>
// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[RES_]] to offset: [0], sizes: [2, 64], strides: [64, 1] : memref<16x8x128xf16, #map> to memref<2x64xf16>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 16, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 8, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 2){
// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]#2)
Expand Down Expand Up @@ -208,7 +206,7 @@ func.func @test_unstick_expansion(%arg0: memref<16x8x128xf16, #map>) -> memref<1
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x8x128xf32>
// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3
// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [2, 64], strides: [64, 1] : memref<16x8x128xf16, #map> to memref<2x64xf16>
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [2, 64], strides: [64, 1] : memref<16x8x128xf16, #map> to memref<2x64xf16>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 16, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 8, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 2){
// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]#2)
Expand Down Expand Up @@ -292,7 +290,7 @@ func.func @test_unstick_expansion_127(%arg0: memref<16x8x127xf16, #map>) -> memr
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<16x8x127xf32>
// CHECK-DAG: [[LOOP_0_:%.+]]:3 = krnl.define_loops 3
// CHECK: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: {{.}}[[CST_0_]]{{.}}, sizes: [2, 64], strides: [64, 1] : memref<16x8x127xf16, #map> to memref<2x64xf16>
// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [2, 64], strides: [64, 1] : memref<16x8x127xf16, #map> to memref<2x64xf16>
// CHECK: krnl.iterate([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) with ([[LOOP_0_]]#0 -> [[I_0_:%.+]] = 0 to 16, [[LOOP_0_]]#1 -> [[I_1_:%.+]] = 0 to 8, [[LOOP_0_]]#2 -> [[I_2_:%.+]] = 0 to 2){
// CHECK: [[VAR_1_:%.+]]:3 = krnl.get_induction_var_value([[LOOP_0_]]#0, [[LOOP_0_]]#1, [[LOOP_0_]]#2) : (!krnl.loop, !krnl.loop, !krnl.loop) -> (index, index, index)
// CHECK: [[VAR_2_:%.+]] = affine.apply [[MAP_1_]]([[VAR_1_]]#2)
Expand Down Expand Up @@ -360,5 +358,4 @@ func.func @test_unstick_expansion_127(%arg0: memref<16x8x127xf16, #map>) -> memr
// CHECK: }
// CHECK: return [[RES_]] : memref<16x8x127xf32>
// CHECK: }
}

}
9 changes: 1 addition & 8 deletions test/mlir/parallel/krnl_parallel_clause_to_omp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ func.func @omp_threads_affinity(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -
}
omp.yield
}
omp.terminator
}
omp.terminator
}
Expand All @@ -43,9 +42,7 @@ func.func @omp_threads_affinity(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -
// CHECK: [[CST_8_:%.+]] = arith.constant 8 : i32
// CHECK: omp.parallel num_threads([[CST_8_]] : i32) proc_bind(spread) {
}

// -----

func.func @omp_threads(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) {
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
Expand Down Expand Up @@ -76,7 +73,6 @@ func.func @omp_threads(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref
}
omp.yield
}
omp.terminator
}
omp.terminator
}
Expand Down Expand Up @@ -120,7 +116,6 @@ func.func @omp_affinity(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memre
}
omp.yield
}
omp.terminator
}
omp.terminator
}
Expand Down Expand Up @@ -162,7 +157,6 @@ func.func @omp_normal(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<
}
omp.yield
}
omp.terminator
}
omp.terminator
}
Expand All @@ -171,5 +165,4 @@ func.func @omp_normal(%arg0: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<
// CHECK-LABEL: func.func @omp_normal
// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<16x8x128xf32> {onnx.name = "x"}) -> (memref<16x8x128xf32> {onnx.name = "y"}) {
// CHECK: omp.parallel {
}

}
2 changes: 1 addition & 1 deletion third_party/stablehlo
Submodule stablehlo updated 54 files
+17 −0 BUILD.bazel
+4 −0 CMakeLists.txt
+7 −2 README.md
+2 −2 WORKSPACE.bazel
+1 −1 build_tools/llvm_version.txt
+124 −0 build_tools/update_version_h_cpp.sh
+2 −0 docs/_toc.yaml
+48 −0 docs/awesome.md
+1 −0 docs/generated/stablehlo_linalg_passes.md
+7 −0 docs/generated/stablehlo_passes.md
+1 −0 docs/generated/stablehlo_tosa_passes.md
+6 −2 docs/spec.md
+435 −430 docs/tutorials/jax-export.ipynb
+118 −99 docs/tutorials/pytorch-export.ipynb
+199 −0 rfcs/20241001-microscaling-formats.md
+19 −0 stablehlo/conversions/linalg/tests/miscellaneous.mlir
+48 −5 stablehlo/conversions/linalg/tests/pointwise.mlir
+47 −45 stablehlo/conversions/linalg/transforms/StablehloToLinalgPointwise.cpp
+9 −10 stablehlo/conversions/linalg/transforms/TypeConversion.cpp
+2 −19 stablehlo/dialect/Base.cpp
+3 −2 stablehlo/dialect/Base.td
+44 −4 stablehlo/dialect/StablehloOps.cpp
+5 −2 stablehlo/dialect/Version.cpp
+1 −1 stablehlo/dialect/Version.h
+49 −1 stablehlo/dialect/VhloBytecode.cpp
+1 −0 stablehlo/dialect/VhloDialect.td
+24 −0 stablehlo/dialect/VhloTypes.cpp
+12 −0 stablehlo/dialect/VhloTypes.td
+53 −77 stablehlo/reference/Tensor.cpp
+6 −4 stablehlo/reference/Types.cpp
+1 −1 stablehlo/testdata/igamma_float64_20_20_float64_20_20_chlo.mlir
+1 −1 stablehlo/testdata/igammac_float64_20_20_float64_20_20_chlo.mlir
+16 −0 stablehlo/tests/interpret/api_input_arguments.mlir
+32 −0 stablehlo/tests/interpret/constant.mlir
+40 −8 stablehlo/tests/ops_stablehlo.mlir
+53 −53 stablehlo/tests/ops_stablehlo_quantized.mlir
+4 −0 stablehlo/tests/ops_stablehlo_roundtrip.mlir
+220 −0 stablehlo/tests/transforms/stablehlo_aggressive_folder.mlir
+1,033 −485 stablehlo/tests/transforms/stablehlo_aggressive_simplification.mlir
+2,936 −0 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_8_0.mlir
+ stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.1_8_0.mlir.bc
+32 −0 stablehlo/tests/vhlo/stablehlo_legalize_to_vhlo.mlir
+35 −0 stablehlo/tests/vhlo/vhlo_to_version_downgrade_invalid.1_7_0.mlir
+15 −0 stablehlo/tests/vhlo/vhlo_to_version_downgrade_patch.mlir
+41 −2 stablehlo/tools/StablehloTranslateMain.cpp
+7 −2 stablehlo/transforms/CMakeLists.txt
+31 −2 stablehlo/transforms/PassUtils.cpp
+27 −12 stablehlo/transforms/PassUtils.h
+5 −0 stablehlo/transforms/Passes.h
+2 −1 stablehlo/transforms/Passes.td
+245 −7 stablehlo/transforms/StablehloAggressiveFolder.cpp
+873 −605 stablehlo/transforms/StablehloAggressiveSimplification.cpp
+417 −0 stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td
+7 −0 stablehlo/transforms/VhloToVersion.cpp
2 changes: 1 addition & 1 deletion utils/clone-mlir.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
git clone -n https://github.com/llvm/llvm-project.git
# Check out a specific branch that is known to work with ONNX-MLIR.
cd llvm-project && git checkout 00128a20eec27246719d73ba427bf821883b00b4 && cd ..
cd llvm-project && git checkout 01d233ff403823389f8480897e41aea84ecbb3d3 && cd ..

0 comments on commit b800036

Please sign in to comment.