Skip to content

Commit

Permalink
Update isScalarTensor and krnl.load and krnl.store to hanlde both ten…
Browse files Browse the repository at this point in the history
…sor<dtype> and tensor<1xdtype> (#2887)

* Scalar tensor can be tensor<dtype> or tensor<1xdtype>

Signed-off-by: Tung D. Le <[email protected]>

* No fusion for scalar tensors

Signed-off-by: Tung D. Le <[email protected]>

---------

Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld authored Jul 26, 2024
1 parent 603a8a2 commit 2044d52
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 220 deletions.
8 changes: 8 additions & 0 deletions src/Compiler/CompilerUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ void showCompilePhase(std::string msg) {

llvm::outs() << "[" << CURRENT_COMPILE_PHASE++ << "/" << TOTAL_COMPILE_PHASE
<< "] " << currentTime << " " << msg << "\n";

// Reset current phase.
if (CURRENT_COMPILE_PHASE > TOTAL_COMPILE_PHASE)
CURRENT_COMPILE_PHASE = 1;
}

} // namespace onnx_mlir
Expand Down Expand Up @@ -923,6 +927,10 @@ int compileModule(mlir::OwningOpRef<ModuleOp> &module,
mlir::MLIRContext &context, std::string outputNameNoExt,
EmissionTargetType emissionTarget) {
std::string msg = "Compiling and Optimizing MLIR Module";
// There is no importing phase (e.g. the model is .mlir, not .onnx), adjust to
// correctly reflect the current phase.
if (CURRENT_COMPILE_PHASE == 1)
CURRENT_COMPILE_PHASE++;
showCompilePhase(msg);
auto compileModuleTiming = rootTimingScope.nest("[onnx-mlir] " + msg);

Expand Down
5 changes: 5 additions & 0 deletions src/Conversion/ONNXToKrnl/Math/Elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1814,6 +1814,11 @@ bool OpFusionHelper::isControlFlowValidForFusion(
// function by fold function.
bool OpFusionHelper::areInputsValidForFusion(
Operation *useOp, Operation *defOp, DimAnalysis *dimAnalysis) {
// Do not fuse ops with scalar tensors.
if (llvm::all_of(
useOp->getOperands(), [](Value v) { return isScalarTensor(v); }))
return false;

// Elementwise unary operation is always fusible
if (useOp->getOperands().size() == 1)
return true;
Expand Down
3 changes: 2 additions & 1 deletion src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ Value OnnxToKrnlBuilder::transpose(const Value input,
bool isScalarValue(Value value) {
ShapedType stype = mlir::dyn_cast<ShapedType>(value.getType());
assert(stype && "expected shaped type");
return stype.getRank() == 0;
return (stype.getRank() == 0) ||
(stype.getRank() == 1 && stype.getShape()[0] == 1);
}

/// Check if all operands are scalar values at compile time.
Expand Down
42 changes: 42 additions & 0 deletions src/Dialect/Krnl/DialectBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ static StringRef getFormat(const Type &inputType) {
//====---------------- Support for Krnl Builder ----------------------===//

Value KrnlBuilder::load(Value memref, ValueRange indices) const {
if (indices.size() == 0) {
// case memref<1xdtype>
MemRefType type = dyn_cast_or_null<MemRefType>(memref.getType());
assert(type && "Not MemRefType");
if (type.getRank() == 1 && type.getShape()[0] == 1) {
MultiDialectBuilder<MathBuilder> create(*this);
Value iZero = create.math.constantIndex(0);
return b().create<KrnlLoadOp>(loc(), memref, ValueRange({iZero}));
}
}
return b().create<KrnlLoadOp>(loc(), memref, indices);
}

Expand All @@ -68,12 +78,33 @@ mlir::Value KrnlBuilder::load(mlir::Value memref, mlir::ValueRange indices,
}

Value KrnlBuilder::loadIE(Value memref, ArrayRef<IndexExpr> indices) const {
if (indices.size() == 0) {
// case memref<1xdtype>
MemRefType type = dyn_cast_or_null<MemRefType>(memref.getType());
assert(type && "Not MemRefType");
if (type.getRank() == 1 && type.getShape()[0] == 1) {
MultiDialectBuilder<MathBuilder> create(*this);
Value iZero = create.math.constantIndex(0);
return b().create<KrnlLoadOp>(loc(), memref, ValueRange({iZero}));
}
}
SmallVector<Value, 4> indexValues;
IndexExpr::getValues(indices, indexValues);
return b().create<KrnlLoadOp>(loc(), memref, indexValues);
}

void KrnlBuilder::store(Value val, Value memref, ValueRange indices) const {
if (indices.size() == 0) {
// case memref<1xdtype>
MemRefType type = dyn_cast_or_null<MemRefType>(memref.getType());
assert(type && "Not MemRefType");
if (type.getRank() == 1 && type.getShape()[0] == 1) {
MultiDialectBuilder<MathBuilder> create(*this);
Value iZero = create.math.constantIndex(0);
b().create<KrnlStoreOp>(loc(), val, memref, ValueRange({iZero}));
return;
}
}
b().create<KrnlStoreOp>(loc(), val, memref, indices);
}

Expand All @@ -87,6 +118,17 @@ void KrnlBuilder::store(mlir::Value val, mlir::Value memref,

void KrnlBuilder::storeIE(
Value val, Value memref, ArrayRef<IndexExpr> indices) const {
if (indices.size() == 0) {
// case memref<1xdtype>
MemRefType type = dyn_cast_or_null<MemRefType>(memref.getType());
assert(type && "Not MemRefType");
if (type.getRank() == 1 && type.getShape()[0] == 1) {
MultiDialectBuilder<MathBuilder> create(*this);
Value iZero = create.math.constantIndex(0);
b().create<KrnlStoreOp>(loc(), val, memref, ValueRange({iZero}));
return;
}
}
SmallVector<Value, 4> indexValues;
IndexExpr::getValues(indices, indexValues);
b().create<KrnlStoreOp>(loc(), val, memref, indexValues);
Expand Down
39 changes: 17 additions & 22 deletions test/mlir/conversion/onnx_to_krnl/ControlFlow/Loop.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,27 @@ func.func private @test_loop_simple_main_graph(%arg0: tensor<i64>, %arg1: tensor
// CHECK-DAG: [[CST_1_2_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[CST_1_3_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[RES_3_:%.+]] = memref.alloc() {{.*}}: memref<1xi64>
// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1
// CHECK-DAG: [[CST_0_2_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[CST_1_4_:%.+]] = arith.constant 1 : index
// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 1){
// CHECK-DAG: [[VAR_14_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index
// CHECK-DAG: [[CST_1_5_:%.+]] = arith.constant 1 : index
// CHECK-DAG: [[CST_0_3_:%.+]] = arith.constant 0 : index
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_0_3_]]{{.}} : memref<1xi64>
// CHECK-DAG: [[LOAD_RES_2_MEM_:%.+]] = krnl.load [[RES_2_]][] : memref<i64>
// CHECK: [[VAR_17_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[LOAD_RES_2_MEM_]] : i64
// CHECK: krnl.store [[VAR_17_]], [[RES_3_]]{{.}}[[VAR_14_]]{{.}} : memref<1xi64>
// CHECK: }
// CHECK-DAG: [[VAR_9_:%.+]] = builtin.unrealized_conversion_cast [[RES_3_]] : memref<1xi64> to tensor<1xi64>
// CHECK-DAG: [[VAR_10_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : memref<i1> to memref<i1>
// CHECK-DAG: [[LOAD_RES_MEM_:%.+]] = krnl.load [[RES_]]{{.}}[[CST_0_2_]]{{.}} : memref<1xi64>
// CHECK-DAG: [[LOAD_RES_2_MEM_:%.+]] = krnl.load [[RES_2_]][] : memref<i64>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_11_:%.+]] = builtin.unrealized_conversion_cast [[VAR_9_]] : tensor<1xi64> to memref<1xi64>
// CHECK-DAG: [[LOAD_VAR_10_MEM_:%.+]] = krnl.load [[VAR_10_]][] : memref<i1>
// CHECK: krnl.store [[LOAD_VAR_10_MEM_]], [[RES_1_]][] : memref<i1>
// CHECK-DAG: [[LOOP_3_:%.+]] = krnl.define_loops 1
// CHECK-DAG: [[VAR_10_:%.+]] = arith.addi [[LOAD_RES_MEM_]], [[LOAD_RES_2_MEM_]] : i64
// CHECK-DAG: [[CST_0_3_:%.+]] = arith.constant 0 : index
// CHECK: krnl.store [[VAR_10_]], [[RES_3_]]{{.}}[[CST_0_3_]]{{.}} : memref<1xi64>
// CHECK-DAG: [[VAR_11_:%.+]] = builtin.unrealized_conversion_cast [[RES_3_]] : memref<1xi64> to tensor<1xi64>
// CHECK-DAG: [[VAR_12_:%.+]] = builtin.unrealized_conversion_cast [[PARAM_1_]] : memref<i1> to memref<i1>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_13_:%.+]] = builtin.unrealized_conversion_cast [[VAR_11_]] : tensor<1xi64> to memref<1xi64>
// CHECK-DAG: [[LOAD_VAR_12_MEM_:%.+]] = krnl.load [[VAR_12_]][] : memref<i1>
// CHECK: krnl.store [[LOAD_VAR_12_MEM_]], [[RES_1_]][] : memref<i1>
// CHECK-DAG: [[LOOP_2_:%.+]] = krnl.define_loops 1
// CHECK-DAG: [[CST_0_4_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[CST_1_6_:%.+]] = arith.constant 1 : index
// CHECK: krnl.iterate([[LOOP_3_]]) with ([[LOOP_3_]] -> [[I_3_:%.+]] = 0 to 1){
// CHECK: [[VAR_14_1_:%.+]] = krnl.get_induction_var_value([[LOOP_3_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_RES_MEM_1_:%.+]] = krnl.load [[VAR_11_]]{{.}}[[VAR_14_1_]]{{.}} : memref<1xi64>
// CHECK: krnl.store [[LOAD_RES_MEM_1_]], [[RES_]]{{.}}[[VAR_14_1_]]{{.}} : memref<1xi64>
// CHECK-DAG: [[CST_1_4_:%.+]] = arith.constant 1 : index
// CHECK: krnl.iterate([[LOOP_2_]]) with ([[LOOP_2_]] -> [[I_2_:%.+]] = 0 to 1){
// CHECK: [[VAR_16_:%.+]] = krnl.get_induction_var_value([[LOOP_2_]]) : (!krnl.loop) -> index
// CHECK: [[LOAD_VAR_13_MEM_:%.+]] = krnl.load [[VAR_13_]]{{.}}[[VAR_16_]]{{.}} : memref<1xi64>
// CHECK: krnl.store [[LOAD_VAR_13_MEM_]], [[RES_]]{{.}}[[VAR_16_]]{{.}} : memref<1xi64>
// CHECK: }
// CHECK: }) : () -> ()
// CHECK: }
Expand Down
Loading

0 comments on commit 2044d52

Please sign in to comment.