Skip to content

Decompose aten.channel_shuffle op (#4243) #4259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
ce672a9
Decompose aten.channel_shuffle op (#4243)
ivangarcia44 Jul 8, 2025
c9f48e2
Filtering channel shuffle tests on stablehlo.
ivangarcia44 Jul 8, 2025
463bf52
Added description field in wrong TableGen op.
ivangarcia44 Jul 8, 2025
b9217cb
Removing extra space
ivangarcia44 Jul 8, 2025
17401a3
Removing passing tests from stablehlo failed list.
ivangarcia44 Jul 8, 2025
4434a3c
Trying build fix for torch dynamo dependency test failures.
ivangarcia44 Jul 9, 2025
2b92bae
Attempt 2: Trying build fix for torch dynamo dependency test failures.
ivangarcia44 Jul 9, 2025
7fb42c1
Backtrack torch dynamo runtime failure fixes. Will create a separate …
ivangarcia44 Jul 9, 2025
27c6998
Undoing torch dynamo filtering.
ivangarcia44 Jul 9, 2025
841df07
Adding CI dependency on OMP for e2e test failures.
ivangarcia44 Jul 9, 2025
013d592
Backtracking torch dynamo fix attempt.
ivangarcia44 Jul 9, 2025
17b891b
Adding channel shuffle to a couple of illegal ops where the pixel shu…
ivangarcia44 Jul 10, 2025
e91b835
Another attempt to fix torch dynamo blood bath with OMP library depen…
ivangarcia44 Jul 11, 2025
2fce1bf
Backtracking failed attempt to fix blood bath issue (not related to m…
ivangarcia44 Jul 11, 2025
553a934
Adding comments in tests.
ivangarcia44 Jul 14, 2025
e383dd5
Running ./build_tools/update_torch_ods.sh on TableGen file.
ivangarcia44 Jul 14, 2025
99a13a4
Adding the removed channel shuffle back to GeneratedTorchOps.td
ivangarcia44 Jul 14, 2025
1eb5699
Using torch operation addition process defined in https://github.com/…
ivangarcia44 Jul 14, 2025
b9eeea0
Addressing Alaa's feedback.
ivangarcia44 Jul 15, 2025
254af35
Addressing feedback from Sayan and Alaa.
ivangarcia44 Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8668,6 +8668,30 @@ def Torch_AtenPixelShuffleOp : Torch_Op<"aten.pixel_shuffle", [
}];
}

def Torch_AtenChannelShuffleOp : Torch_Op<"aten.channel_shuffle", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::channel_shuffle : (Tensor, int) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$groups
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenChannelShuffleOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenChannelShuffleOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [
AllowsTypeRefinement,
ReadOnly
Expand Down
18 changes: 18 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7613,6 +7613,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %15 = torch.aten.append.t %6, %14 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" return %6 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.channel_shuffle\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: input must be at least rank-3 in channel_shuffle\"\n"
" %int3 = torch.constant.int 3\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" return %arg0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.permute\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.permute(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -12271,6 +12285,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.channel_shuffle\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.bool, %arg5: !torch.bool) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
Expand Down
203 changes: 172 additions & 31 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3537,6 +3537,30 @@ class DecomposeAten_LinalgDetOp : public OpRewritePattern<Aten_LinalgDetOp> {
};
} // namespace

namespace { // Start of rearrangement ops utility functions
// Extracts shape as vector of int64_t from vector of Value
SmallVector<int64_t> getIntShapeFromValues(ArrayRef<Value> vals) {
SmallVector<int64_t> shape;
shape.reserve(vals.size());
for (Value v : vals) {
int64_t cst_val;
if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
shape.push_back(cst_val);
} else {
shape.push_back(kUnknownSize);
}
}
return shape;
}

// Converts a vector of Value (shape dimensions) into a ValueTensorType
ValueTensorType getTypeFromShape(ArrayRef<Value> vals, Type inOptionalDType) {
SmallVector<int64_t> intShape = getIntShapeFromValues(vals);
return ValueTensorType::get(vals[0].getContext(), llvm::ArrayRef(intShape),
inOptionalDType);
}
} // namespace

// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and
// prims.collapse operations.
//
Expand All @@ -3562,7 +3586,6 @@ class DecomposeAtenPixelShuffleOp
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenPixelShuffleOp op,
PatternRewriter &rewriter) const override {

Location loc = op.getLoc();
Value inValue = op.getSelf();
auto inType = cast<BaseTensorType>(inValue.getType());
Expand All @@ -3585,27 +3608,6 @@ class DecomposeAtenPixelShuffleOp

const auto inOptionalDType = inType.getOptionalDtype();

auto getTypeFromShape = [inOptionalDType](auto &&vals) {
// Get a vector of integers from a vector of Values.
auto getIntShape = [](auto &&vals) {
SmallVector<int64_t> shape;
shape.reserve(vals.size());
for (auto v : vals) {
int64_t cst_val;
if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
shape.push_back(cst_val);
} else {
shape.push_back(kUnknownSize);
}
}
return shape;
};

const auto intShape = getIntShape(vals);
return ValueTensorType::get(vals[0].getContext(),
llvm::ArrayRef(intShape), inOptionalDType);
};

auto nLeadingDims = inRank - 3;

// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
Expand Down Expand Up @@ -3677,24 +3679,24 @@ class DecomposeAtenPixelShuffleOp
auto partiallyExpanded =
rewriter
.create<PrimsSplitDimOp>(
loc, getTypeFromShape(partiallyExpandedShape), inValue,
dimensionConstants[nLeadingDims], outC)
loc, getTypeFromShape(partiallyExpandedShape, inOptionalDType),
inValue, dimensionConstants[nLeadingDims], outC)
.getResult();

// Split new dimension factorSquared -> (factor, factor)
auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
loc, getTypeFromShape(prePermuteShape), partiallyExpanded,
dimensionConstants[nLeadingDims + 1], factor);
loc, getTypeFromShape(prePermuteShape, inOptionalDType),
partiallyExpanded, dimensionConstants[nLeadingDims + 1], factor);

// Perform the permutation
auto permuted =
rewriter.create<AtenPermuteOp>(loc, getTypeFromShape(postPermuteShape),
fullyExpanded, permuteDimsOrder);
auto permuted = rewriter.create<AtenPermuteOp>(
loc, getTypeFromShape(postPermuteShape, inOptionalDType), fullyExpanded,
permuteDimsOrder);

// Collapse final 2 dimension
auto partiallyCollapsed = rewriter.create<PrimsCollapseOp>(
loc, getTypeFromShape(partiallyCollapsedShape), permuted,
dimensionConstants[nLeadingDims + 3],
loc, getTypeFromShape(partiallyCollapsedShape, inOptionalDType),
permuted, dimensionConstants[nLeadingDims + 3],
dimensionConstants[nLeadingDims + 4]);

// Collapse back to original rank
Expand All @@ -3708,6 +3710,144 @@ class DecomposeAtenPixelShuffleOp
};
} // namespace

// Decompose aten.channel_shuffle into: prims.split_dim, aten.permute, and
// prims.collapse operations.
//
// If input is a tensor of shape
// (N, g*C, H, W),
//
// then
// X = channel_shuffle(input, groups)
//
// gets replaced with
// X = input.split_dim(...) # shape (N, g, C, *)
// X = X.permute(0, 2, 1, ...) # shape (N, C, g, *)
// X = X.collapse(...) # shape (N, C*g, *)
//
// 'g' above is referred to as the number of 'groups'. N is the batch
// dimension, and can't be omitted. In PyTorch's ChannelShuffle operator
// if the batch dimension is ommitted, the first spatial dimenion is seen
// as the channel. PyTorch errors out for the code below indicating that
// 4 is not divisible by 3:
// input_tensor = torch.arange(1, 37, dtype=torch.float32).view(3, 4, 3)
// channel_shuffle_layer = nn.ChannelShuffle(groups=3)
// output_tensor = channel_shuffle_layer(input_tensor)
//
// The decomposition is based on this specification:
// https://pytorch.org/docs/stable/generated/torch.nn.ChannelShuffle.html
// and PyTorch implementation: aten/src/ATen/native/ChanelShuffle.cpp
// (yes, the filename is misspelled "Chanel" in upstream PyTorch)
//
namespace {
class DecomposeAtenChannelShuffleOp
: public OpRewritePattern<AtenChannelShuffleOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenChannelShuffleOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value inValue = op.getSelf();
auto inType = cast<BaseTensorType>(inValue.getType());
auto maybeSizes = inType.getOptionalSizes();
if (!maybeSizes) {
return rewriter.notifyMatchFailure(
op, "Expected input tensor to have known rank.");
}
auto inShape = maybeSizes.value();
auto inRank = inShape.size();

// The input tensor must have at least 3 dimensions: batch size,
// channel size, and at least one spatial dimension.
if (inRank < 3)
return rewriter.notifyMatchFailure(
op, "Expected input tensor to have rank greater than or equal to 3.");

auto numOfSpatialDims = inRank - 2;

// Get the size of the dimension 'i'. Note the use of 'createOrFold'
// instead of 'create': if the dimension size is known, then the
// AtenSizeIntOp is folded to a ConstantOp.
auto getDimSize = [&rewriter, &inValue, loc](uint64_t i) -> Value {
Value dim =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
};

auto inC = getDimSize(1);
SmallVector<Value> inSpatialDims;
inSpatialDims.reserve(numOfSpatialDims);
for (unsigned i = 2; i < (2 + numOfSpatialDims); ++i) {
inSpatialDims.push_back(getDimSize(i));
}

auto groups = op.getGroups();

// Temporary channel dimension size = inC / groups
// Assumes input has been validated: `inC % groups == 0`
// This is enforced by PyTorch's runtime and is required for correctness.
Value tempC = rewriter.createOrFold<AtenFloordivIntOp>(loc, inC, groups);

// Create constants for split/permute/collapse operations. Note that we
// need an extra constant for the channel dimension split.
SmallVector<Value> dimensionConstants;
dimensionConstants.reserve(inRank + 1);
for (unsigned i = 0; i < inRank + 1; ++i) {
dimensionConstants.push_back(
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
}

Value batchDimSize = rewriter.createOrFold<AtenSizeIntOp>(
loc, inValue, dimensionConstants[0]);

SmallVector<Value> splitShape;
splitShape.reserve(inRank + 1);
splitShape.append({batchDimSize, groups, tempC});
splitShape.append(inSpatialDims); // Appends all spatial dimensions

SmallVector<Value> permuteShape;
permuteShape.reserve(inRank + 1);
permuteShape.append({batchDimSize, tempC, groups});
permuteShape.append(inSpatialDims); // Appends all spatial dimensions

// Permute (N, groups, tempC, *) -> (N, tempC, groups, *)
SmallVector<Value> permutation{dimensionConstants[0], // batch dimension
dimensionConstants[2], // tempC
dimensionConstants[1]}; // groups
for (unsigned i = 3; i < inRank + 1; ++i) {
permutation.push_back(dimensionConstants[i]);
}

Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
permutation);

const auto inOptionalDType = inType.getOptionalDtype();

Value dimC = dimensionConstants[1];
Value dimG = dimensionConstants[2];

// Split input channel inC -> (groups, inC/groups)
auto expandedTensor =
rewriter
.create<PrimsSplitDimOp>(
loc, getTypeFromShape(splitShape, inOptionalDType), inValue,
dimC, tempC)
.getResult();

// Perform the permutation
auto permuted = rewriter.create<AtenPermuteOp>(
loc, getTypeFromShape(permuteShape, inOptionalDType), expandedTensor,
permuteDimsOrder);

// Collapse (C, groups) back into a single channel dimension
rewriter.replaceOpWithNewOp<PrimsCollapseOp>(op, op.getType(), permuted,
dimC, dimG);

return success();
}
};
} // namespace

// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
Value input) {
Expand Down Expand Up @@ -12444,6 +12584,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenChannelShuffleOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
patterns);
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<Aten_LinalgDetOp>();
target.addIllegalOp<AtenLinalgSlogdetOp>();
target.addIllegalOp<AtenPixelShuffleOp>();
target.addIllegalOp<AtenChannelShuffleOp>();
target.addIllegalOp<AtenTOp>();
target.addIllegalOp<Aten_LogSoftmaxBackwardDataOp>();
target.addDynamicallyLegalOp<AtenMatmulOp>([](AtenMatmulOp op) {
Expand Down
22 changes: 11 additions & 11 deletions lib/Dialect/Torch/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,17 +317,17 @@ bool Torch::isViewLikeOp(Operation *op) {
// correct. We could potentially be more precise and identify the cases
// that it does not return a view and treat those as having value
// semantics.
return isa<AtenAsStridedOp, AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp,
AtenExpandAsOp, AtenExpandOp, AtenFlattenUsingIntsOp,
AtenUnflattenIntOp, AtenPermuteOp, AtenReshapeOp,
Aten_ReshapeAliasOp, AtenSelectIntOp, AtenSliceTensorOp,
AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp, AtenToDtypeOp,
AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp,
AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp,
AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp,
PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp,
AtenPixelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op);
return isa<
AtenAsStridedOp, AtenBroadcastToOp, AtenContiguousOp, AtenDetachOp,
AtenExpandAsOp, AtenExpandOp, AtenFlattenUsingIntsOp, AtenUnflattenIntOp,
AtenPermuteOp, AtenReshapeOp, Aten_ReshapeAliasOp, AtenSelectIntOp,
AtenSliceTensorOp, AtenSqueezeDimOp, AtenSqueezeOp, AtenTOp,
AtenToDtypeOp, AtenTransposeIntOp, AtenUnsqueezeOp, AtenViewOp,
TensorStaticInfoCastOp, AtenToDtypeLayoutOp, AtenNumpyTOp, AtenNarrowOp,
AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp,
PrimsViewOfOp, AtenRealOp, AtenImagOp, PrimsSplitDimOp,
AtenViewAsComplexOp, AtenViewAsRealOp, AtenPixelShuffleOp,
AtenChannelShuffleOp, AtenDiagonalOp, AtenUnfoldOp>(op);
}

Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter,
Expand Down
Loading
Loading