Skip to content

[mlir][vector] Remove bit-width logic tests #143007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 0 additions & 58 deletions mlir/test/Dialect/Vector/linearize-subject-to-bitwidth.mlir

This file was deleted.

122 changes: 0 additions & 122 deletions mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,126 +837,6 @@ struct TestVectorEmulateMaskedLoadStore final
}
};

/// Get the set of operand/result types to check for sufficiently
/// small inner-most dimension size.
static SmallVector<std::pair<Type, unsigned>>
getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) {

if (auto insertOp = dyn_cast<vector::InsertOp>(op)) {
unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max()
? targetBitWidth + 1
: targetBitWidth;
return {{insertOp.getValueToStoreType(), w}};
}

auto resultTypes = op->getResultTypes();
SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth;
resultsWithBitWidth.reserve(resultTypes.size());
for (Type type : resultTypes) {
resultsWithBitWidth.push_back({type, targetBitWidth});
}
return resultsWithBitWidth;
}

/// If `type` is VectorType with trailing dimension of (bit) size greater than
/// or equal to `targetBitWidth`, its defining op is considered legal.
static bool
isNotLinearizableBecauseLargeInnerDimension(Type type,
unsigned targetBitWidth) {

VectorType vecType = dyn_cast<VectorType>(type);

// Not linearizable for reasons other than what this function checks.
if (!vecType || vecType.getRank() == 0)
return false;

// The width of the type 'index' is unbounded (and therefore potentially above
// the target width).
if (vecType.getElementType().isIndex())
return true;

unsigned finalDimSize = vecType.getShape().back();
unsigned nbBitsPerElm = vecType.getElementTypeBitWidth();
unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm;
return trailingVecDimBitWidth >= targetBitWidth;
}

static bool
isNotLinearizableBecauseLargeInnerDimension(Operation *op,
unsigned targetBitWidth) {
// Check on bitwidths.
SmallVector<std::pair<Type, unsigned>> toCheck =
getTypeBitWidthBoundPairs(op, targetBitWidth);
return llvm::any_of(toCheck, [&](std::pair<Type, unsigned> typeWidth) {
return isNotLinearizableBecauseLargeInnerDimension(typeWidth.first,
typeWidth.second);
});
}

void populateWithBitWidthConstraints(TypeConverter &typeConverter,
ConversionTarget &target,
unsigned targetBitWidth) {

// The general purpose definition of what ops are legal must come first.
populateForVectorLinearize(typeConverter, target);

// Extend the set of legal ops to include those with large inner-most
// dimensions on selected operands/results.
target.markUnknownOpDynamicallyLegal(
[=](Operation *op) -> std::optional<bool> {
if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) {
return true;
}
return {};
});
}

struct TestVectorBitWidthLinearize final
: public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize)

TestVectorBitWidthLinearize() = default;
TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass)
: PassWrapper(pass) {}

StringRef getArgument() const override {
return "test-bit-width-constrained-vector-linearize";
}
StringRef getDescription() const override {
return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints "
"in inner-most dimension's bit width.";
}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<vector::VectorDialect>();
}

Option<unsigned> targetVectorBitwidth{
*this, "target-vector-bitwidth",
llvm::cl::desc(
"Minimum vector bitwidth to enable the flattening transformation"),
llvm::cl::init(std::numeric_limits<unsigned>::max())};
void runOnOperation() override {
auto *context = &getContext();

TypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);

populateWithBitWidthConstraints(typeConverter, target,
targetVectorBitwidth);

vector::populateVectorLinearizeBasePatterns(typeConverter, target,
patterns);

vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target,
patterns);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
}
};

struct TestVectorLinearize final
: public PassWrapper<TestVectorLinearize, OperationPass<>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
Expand Down Expand Up @@ -1064,8 +944,6 @@ void registerTestVectorLowerings() {

PassRegistration<TestVectorLinearize>();

PassRegistration<TestVectorBitWidthLinearize>();

PassRegistration<TestEliminateVectorMasks>();
}
} // namespace test
Expand Down
Loading