Skip to content

Commit

Permalink
[Linalg] Refactor compare scalar op (llvm#3294)
Browse files Browse the repository at this point in the history
  • Loading branch information
penguin-wwy authored May 9, 2024
1 parent c4b28e8 commit 0f0f57c
Showing 1 changed file with 50 additions and 130 deletions.
180 changes: 50 additions & 130 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,50 @@ static Value createCompareTensorOp(OpBuilder &b, Location loc, OpTy op,
llvm_unreachable("unimplemented: op type not supported");
}

template <typename OpTy>
static Value createCompareScalarOp(OpBuilder &b, Location loc, OpTy op,
Value lhs, Value rhs) {
static_assert(std::is_same<OpTy, AtenLtScalarOp>() ||
std::is_same<OpTy, AtenLeScalarOp>() ||
std::is_same<OpTy, AtenEqScalarOp>() ||
std::is_same<OpTy, AtenNeScalarOp>() ||
std::is_same<OpTy, AtenGtScalarOp>() ||
std::is_same<OpTy, AtenGeScalarOp>(),
"unimplemented: op type not supported");

Type lhsDtype = lhs.getType();
Type rhsDtype = rhs.getType();
Type elementalType = cast<BaseTensorType>(op.getSelf().getType()).getDtype();
Value otherPromoted = convertScalarToDtype(b, loc, rhs, lhsDtype);

if (isa<mlir::IntegerType>(elementalType) &&
!isa<mlir::IntegerType>(rhsDtype)) {
// TODO: Promote tensor args from integer to float.
op.emitError("unimplemented: type promotion from tensor to scalar.");
return nullptr;
}

if constexpr (std::is_same<OpTy, AtenLtScalarOp>()) {
return createLessThan(b, loc, elementalType, lhs, otherPromoted);
}
if constexpr (std::is_same<OpTy, AtenLeScalarOp>()) {
return createLessThanOrEqual(b, loc, elementalType, lhs, otherPromoted);
}
if constexpr (std::is_same<OpTy, AtenGtScalarOp>()) {
return createGreaterThan(b, loc, elementalType, lhs, otherPromoted);
}
if constexpr (std::is_same<OpTy, AtenGeScalarOp>()) {
return createGreaterThanOrEqual(b, loc, elementalType, lhs, otherPromoted);
}
if constexpr (std::is_same<OpTy, AtenEqScalarOp>()) {
return createEqual(b, loc, elementalType, lhs, otherPromoted);
}
if constexpr (std::is_same<OpTy, AtenNeScalarOp>()) {
return createNotEqual(b, loc, elementalType, lhs, otherPromoted);
}
llvm_unreachable("unimplemented: op type not supported");
}

template <arith::CmpIPredicate predicate>
static LogicalResult
createTriangularMatrix(OpBuilder &b, Location loc, ValueRange payloadArgs,
Expand Down Expand Up @@ -959,151 +1003,27 @@ static Value createLinalgPayloadCalculationForElementwiseOp(
}

if (auto gtScalar = dyn_cast<AtenGtScalarOp>(op)) {
Type dtype = cast<BaseTensorType>(gtScalar.getSelf().getType()).getDtype();

// TODO: `gtTensor` and `gtScalar` share similar code and can be called from
// one static function.
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());

if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGT,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor args from integer to float.
gtScalar.emitError(
"unimplemented: type promotion from tensor to scalar.");
return nullptr;
}

if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ugt,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt,
payloadArgs[0], otherPromoted);
}
gtScalar.emitError("unimplemented: dtype isn't supported.");
return nullptr;
return createCompareScalarOp(b, loc, gtScalar, payloadArgs[0], operands[1]);
}

if (auto geScalar = dyn_cast<AtenGeScalarOp>(op)) {
Type dtype = cast<BaseTensorType>(geScalar.getSelf().getType()).getDtype();

// TODO: The `AtenGeScalarOp` and `AtenGtScalarOp` share a lot of code that
// can be refactored.
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());

if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UGE,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor args from integer to float.
geScalar.emitError(
"unimplemented: type promotion from tensor to scalar.");
return nullptr;
}

if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::uge,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
payloadArgs[0], otherPromoted);
}
geScalar.emitError("unimplemented: dtype isn't supported.");
return nullptr;
return createCompareScalarOp(b, loc, geScalar, payloadArgs[0], operands[1]);
}

if (auto eqScalar = dyn_cast<AtenEqScalarOp>(op)) {
Type dtype = cast<BaseTensorType>(eqScalar.getSelf().getType()).getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());

if (isa<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
eqScalar.emitError(
"unimplemented: type promotion from tensor to scalar");
return nullptr;
}
}
return createEqual(b, loc, dtype, payloadArgs[0], otherPromoted);
return createCompareScalarOp(b, loc, eqScalar, payloadArgs[0], operands[1]);
}

if (auto neScalar = dyn_cast<AtenNeScalarOp>(op)) {
Type dtype = cast<BaseTensorType>(neScalar.getSelf().getType()).getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());

if (isa<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
neScalar.emitError(
"unimplemented: type promotion from tensor to scalar");
return nullptr;
}
}
return createNotEqual(b, loc, dtype, payloadArgs[0], otherPromoted);
return createCompareScalarOp(b, loc, neScalar, payloadArgs[0], operands[1]);
}

if (auto ltScalar = dyn_cast<AtenLtScalarOp>(op)) {
Type dtype = cast<BaseTensorType>(ltScalar.getSelf().getType()).getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());

// TODO: Both tensor and scalar variants of `aten.gt` and `aten.lt` share
// a lot of code that can be refactored.
if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULT,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
ltScalar.emitError(
"unimplemented: type promotion from tensor to scalar");
return nullptr;
}
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
payloadArgs[0], otherPromoted);
}
ltScalar.emitError("unimplemented: dtype isn't supported.");
return nullptr;
return createCompareScalarOp(b, loc, ltScalar, payloadArgs[0], operands[1]);
}

if (auto leScalar = dyn_cast<AtenLeScalarOp>(op)) {
Type dtype = cast<BaseTensorType>(leScalar.getSelf().getType()).getDtype();
Value otherPromoted =
convertScalarToDtype(b, loc, operands[1], payloadArgs[0].getType());

// TODO: The `AtenLeScalarOp` and `AtenLtScalarOp` share a lot of code
// that can be refactored.
if (isa<mlir::FloatType>(dtype))
return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::ULE,
payloadArgs[0], otherPromoted);
if (IntegerType intType = dyn_cast<mlir::IntegerType>(dtype)) {
if (!operands[1].getType().isa<mlir::IntegerType>()) {
// TODO: Promote tensor operand from integer to float.
leScalar.emitError(
"unimplemented: type promotion from tensor to scalar");
return nullptr;
}
if (intType.isUnsigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ule,
payloadArgs[0], otherPromoted);
if (intType.isSigned())
return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sle,
payloadArgs[0], otherPromoted);
}
leScalar.emitError("unimplemented: dtype isn't supported.");
return nullptr;
return createCompareScalarOp(b, loc, leScalar, payloadArgs[0], operands[1]);
}

if (auto whereSelf = dyn_cast<AtenWhereSelfOp>(op)) {
Expand Down

0 comments on commit 0f0f57c

Please sign in to comment.