Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#sdy Handle redundant shard maps during import. #16943

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions xla/service/spmd/shardy/round_trip_common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ cc_library(
"//xla/mlir_hlo",
"//xla/mlir_hlo:mhlo_passes",
"//xla/service/spmd/shardy:constants",
"//xla/service/spmd/shardy:utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
Expand All @@ -81,10 +82,12 @@ cc_library(
"@llvm-project//mlir:CallOpInterfaces",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InliningUtils",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@tsl//tsl/platform:errors",
],
)

Expand Down
114 changes: 100 additions & 14 deletions xla/service/spmd/shardy/round_trip_common/shard_map_import.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,16 @@ limitations under the License.
#include "mlir/Support/LLVM.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/InliningUtils.h"
#include "shardy/dialect/sdy/ir/constants.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "shardy/dialect/sdy/ir/utils.h"
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
#include "xla/mlir_hlo/mhlo/transforms/passes.h"
#include "xla/service/spmd/shardy/constants.h"
#include "xla/service/spmd/shardy/utils.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"

namespace xla {
namespace sdy {
Expand Down Expand Up @@ -86,7 +89,10 @@ using sdy::SdyDialect;
using sdy::TensorShardingAttr;
using sdy::TensorShardingPerValueAttr;

// A pair of custom calls. `shardingOp` has target name "Sharding".
// A sharding op and corresponding shape transform op.
// If `shardingOp` is a CustomCallOp, it has target name "Sharding". Else it is
// expected to be an sdy::ShardingConstraintOp.
//
// `shapeTransformOp` has target name "SPMDFullToShardShape" or
// "SPMDShardToFullShape". `shardingOp` has exactly one user, which is
// shapeTransformOp. `shapeTransformOp` has exactly one operand, which is
Expand All @@ -95,7 +101,7 @@ using sdy::TensorShardingPerValueAttr;
// Both `shardingOp` and `shapeTransformOp` will be nullptr for unused results
// of the shard map `func.call`.
struct ShardMapCustomCallPair {
CustomCallOp shardingOp;
Operation* shardingOp;
CustomCallOp shapeTransformOp;
};

Expand Down Expand Up @@ -188,19 +194,36 @@ absl::StatusOr<ShardMapArgumentsResults> getJaxShardMapPatternOps(CallOp op) {
return absl::NotFoundError(
"expecting SPMDFullToShardShape custom call as operand");
}
auto shardingCustomCall =
fullToShardCustomCall->getOperand(0).getDefiningOp<CustomCallOp>();
if (!shardingCustomCall) {
if (mlir::isa<mlir::BlockArgument>(fullToShardCustomCall->getOperand(0))) {
return absl::NotFoundError(
"expecting CustomCallOp as operand of SPMDFullToShardShape");
}
if (!isShardMapCustomCall(shardingCustomCall,
kShardingCustomCallTargetName)) {
return absl::NotFoundError(
"expecting Sharding CustomCallOp as operand of SPMDFullToShardShape");
}
argumentOps.push_back(
ShardMapCustomCallPair{shardingCustomCall, fullToShardCustomCall});
TF_RETURN_IF_ERROR(
mlir::TypeSwitch<Operation*, absl::Status>(
fullToShardCustomCall->getOperand(0).getDefiningOp())
.Case<CustomCallOp>([&](CustomCallOp customCallOp) {
if (!isShardMapCustomCall(customCallOp,
kShardingCustomCallTargetName)) {
return absl::NotFoundError(
"expecting Sharding CustomCallOp as operand of "
"SPMDFullToShardShape");
}
argumentOps.push_back(
ShardMapCustomCallPair{customCallOp, fullToShardCustomCall});
return absl::OkStatus();
})
.Case<sdy::ShardingConstraintOp>(
[&](sdy::ShardingConstraintOp shardingConstraintOp) {
argumentOps.push_back(ShardMapCustomCallPair{
shardingConstraintOp, fullToShardCustomCall});
return absl::OkStatus();
})
.Default([](Operation*) {
return absl::NotFoundError(
"expecting CustomCallOp or ShardingConstraintOp as operand "
"of "
"SPMDFullToShardShape");
}));
}

SmallVector<ShardMapCustomCallPair> resultOps;
Expand Down Expand Up @@ -251,6 +274,55 @@ absl::StatusOr<ShardMapArgumentsResults> getJaxShardMapPatternOps(CallOp op) {
return ShardMapArgumentsResults{argumentOps, resultOps};
}

// Inlines shard maps with no operand/result custom calls.
//
// When the shard map's manual axes are all of size 1, JAX will not create the
// shmap pattern with custom calls. Instead, the call op by itself will exist.
// If that is the case, just inline the op.
absl::Status inlineRedundantShardMap(FuncOp funcOp,
mlir::DenseSet<FuncOp>& toDeleteFuncOps) {
bool errorEmitted = false;
bool inlined = false;
mlir::SymbolTableCollection symbolTableCollection;
CallOp toDeleteCallOp;
funcOp->walk([&](CallOp op) {
bool anyOperandCustomCall = llvm::any_of(
op.getOperands(),
[](Value operand) { return operand.getDefiningOp<CustomCallOp>(); });
bool anyResultCustomCall = llvm::any_of(op.getResults(), [](Value result) {
return !result.use_empty() &&
mlir::dyn_cast<CustomCallOp>(*result.user_begin());
});
if (!anyOperandCustomCall && !anyResultCustomCall) {
mlir::InlinerInterface inliner(op.getContext());
auto calleeOp = mlir::cast<FuncOp>(
mlir::cast<mlir::CallOpInterface>(*op).resolveCallable(
&symbolTableCollection));
if (mlir::failed(inlineCall(
inliner, mlir::cast<mlir::CallOpInterface>(op.getOperation()),
mlir::cast<mlir::CallableOpInterface>(calleeOp.getOperation()),
calleeOp.getCallableRegion()))) {
op.emitOpError() << "failed to inline.\n";
errorEmitted = true;
return mlir::WalkResult::interrupt();
}
inlined = true;
toDeleteFuncOps.insert(calleeOp);
toDeleteCallOp = op;
}
return mlir::WalkResult::advance();
});
if (errorEmitted) {
return absl::InternalError("Failed to inline redundant shard maps.");
}

if (inlined) {
toDeleteCallOp.erase();
return inlineRedundantShardMap(funcOp, toDeleteFuncOps);
}
return absl::OkStatus();
}

// When calling `jax.shard_map`, we have the following pattern in the MHLO.
// ```
// %shard_arg0_0 = custom_call @Sharding(%0)
Expand Down Expand Up @@ -308,6 +380,20 @@ class ShardMapImportPass
// Subsequent CallOps with that symbol will clone the mapped region.
llvm::SmallDenseMap<StringRef, mlir::Region*> shardMapNameToMovedRegion;
bool success = true;
mlir::DenseSet<FuncOp> toDeleteFuncOps;
module->walk([&](FuncOp funcOp) {
if (isShmapBody(funcOp)) {
return mlir::WalkResult::advance();
}
if (!inlineRedundantShardMap(funcOp, toDeleteFuncOps).ok()) {
signalPassFailure();
return mlir::WalkResult::interrupt();
}
return mlir::WalkResult::advance();
});
for (FuncOp funcOp : llvm::make_early_inc_range(toDeleteFuncOps)) {
symbolTable.erase(funcOp);
}
module->walk([&](CallOp op) {
if (!op.getCallee().contains("shmap_body")) {
return mlir::WalkResult::advance();
Expand Down Expand Up @@ -472,7 +558,7 @@ class ShardMapImportPass
CHECK(shapeTransformOp && shapeTransformOp->use_empty());
shapeTransformOp.erase();
CHECK(shardingOp->use_empty());
shardingOp.erase();
shardingOp->erase();
}

// Erase the call op itself.
Expand All @@ -497,7 +583,7 @@ class ShardMapImportPass
shapeTransformOp.erase();
}
if (shardingOp->use_empty()) {
shardingOp.erase();
shardingOp->erase();
}
}
}
Expand Down
22 changes: 18 additions & 4 deletions xla/service/spmd/shardy/test/shard_map_import.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
sdy.mesh @mesh_0 = <["a"=4]>
sdy.mesh @mesh_1 = <["a"=4, "b"=2]>
sdy.mesh @mesh_2 = <["a"=4, "b"=2, "c"=3]>
sdy.mesh @mesh_3 = <["i"=1, "j"=2, "k"=2]>

// CHECK-LABEL: func.func public @call_op_with_no_operands_or_results()
func.func public @call_op_with_no_operands_or_results() {
// CHECK: sdy.manual_computation() in_shardings=[] out_shardings=[] manual_axes={} () {
// CHECK: sdy.return
// CHECK: } : () -> ()
// CHECK-NOT: sdy.manual_computation
// CHECK: return
call @shmap_body_empty() : () -> ()
call @shmap_body_empty() {mhlo.frontend_attributes = {xla.sdy.manual_axes = "[]"}} : () -> ()
return
}
// CHECK-NOT: func.func private @shmap_body_empty
Expand Down Expand Up @@ -380,3 +379,18 @@ func.func private @shmap_body_11(%arg0: tensor<4x32xf32>, %arg1: tensor<4x32xf32
%0 = mhlo.add %arg0, %arg1 : tensor<4x32xf32>
return %0 : tensor<4x32xf32>
}

// Even with a mesh with multiple axes size > 1, if the shmap is operating on no
// axes > 1, then JAX will create the call but not have the pattern of the
// custom calls. So just inline.
// CHECK-LABEL: no_custom_call_pattern
func.func public @no_custom_call_pattern(%arg0: tensor<4x8xf32> { sdy.sharding = #sdy.sharding<@mesh_3, [{"i"}, {"j"}]>}) -> tensor<4x8xf32> {
// CHECK: %[[MULT:.*]] = stablehlo.multiply %arg0, %arg0 : tensor<4x8xf32>
// CHECK-NEXT: return %[[MULT]] : tensor<4x8xf32>
%0 = call @shmap_body_14(%arg0) {xla.sdy.manual_axes = "[\\\22i\\\22]"} : (tensor<4x8xf32>) -> tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
func.func private @shmap_body_14(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = stablehlo.multiply %arg0, %arg0 : tensor<4x8xf32>
return %0 : tensor<4x8xf32>
}
5 changes: 5 additions & 0 deletions xla/service/spmd/shardy/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <string>

#include "absl/strings/escaping.h"
#include "absl/strings/match.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
Expand Down Expand Up @@ -161,5 +162,9 @@ void loadAllRequiredDialects(mlir::MLIRContext* context) {
context->loadAllAvailableDialects();
}

bool isShmapBody(mlir::func::FuncOp funcOp) {
return absl::StrContains(funcOp.getSymName(), "shmap_body");
}

} // namespace sdy
} // namespace xla
3 changes: 3 additions & 0 deletions xla/service/spmd/shardy/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ void removeFrontendAttribute(mlir::func::FuncOp funcOp,

void loadAllRequiredDialects(mlir::MLIRContext* context);

// Returns whether this function is a body of a JAX shmap.
bool isShmapBody(mlir::func::FuncOp funcOp);

} // namespace sdy
} // namespace xla

Expand Down
Loading