Skip to content

Commit

Permalink
Rename GemmFusion to TritonFusionRewriter.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 672524825
  • Loading branch information
derdrdirk authored and Google-ML-Automation committed Sep 12, 2024
1 parent 3d31124 commit 557c24b
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 164 deletions.
5 changes: 3 additions & 2 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ xla_cc_test(
deps = [
":triton_fusion_analysis",
"//xla/hlo/ir:hlo",
"//xla/service/gpu/transforms:gemm_fusion",
"//xla/service/gpu/transforms:triton_fusion_rewriter",
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
Expand Down Expand Up @@ -1434,7 +1434,6 @@ cc_library(
"//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter",
"//xla/service/gpu/transforms:fusion_wrapper",
"//xla/service/gpu/transforms:gemm_broadcast_folding_rewriter",
"//xla/service/gpu/transforms:gemm_fusion",
"//xla/service/gpu/transforms:gemm_rewriter",
"//xla/service/gpu/transforms:gemv_rewriter",
"//xla/service/gpu/transforms:layout_assignment",
Expand All @@ -1457,6 +1456,7 @@ cc_library(
"//xla/service/gpu/transforms:transpose_dimension_grouper",
"//xla/service/gpu/transforms:tree_reduction_rewriter",
"//xla/service/gpu/transforms:triton_fusion_numerics_verifier",
"//xla/service/gpu/transforms:triton_fusion_rewriter",
"//xla/service/gpu/transforms:windowed_einsum_handler",
"//xla/service/llvm_ir:llvm_util",
"//xla/service/spmd:collective_permute_motion",
Expand Down Expand Up @@ -1627,6 +1627,7 @@ xla_test(
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:literal_test_util",
"//xla/tests:verified_hlo_module",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log",
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/autotuning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ xla_test(
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:matmul_utils",
"//xla/service/gpu/transforms:gemm_fusion",
"//xla/service/gpu/transforms:gemm_rewriter",
"//xla/service/gpu/transforms:triton_fusion_rewriter",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_description_proto_cc",
"//xla/stream_executor:semantic_version",
Expand Down
18 changes: 9 additions & 9 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ limitations under the License.
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/transforms/gemm_fusion.h"
#include "xla/service/gpu/transforms/gemm_rewriter.h"
#include "xla/service/gpu/transforms/triton_fusion_rewriter.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_pass_pipeline.h"
#include "xla/service/pattern_matcher.h"
Expand Down Expand Up @@ -198,10 +198,10 @@ class GemmFusionAutotunerTest : public StatelessAutotunerTest {
void CheckTritonAutotuning(absl::string_view hlo,
absl::string_view expected) {
HloPassPipeline pipeline("gemm_rewrite");
pipeline.AddPass<GemmFusion>(backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability());
pipeline.AddPass<TritonFusionRewriter>(backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability());
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
tsl::port::MaxParallelism());
DebugOptions opts;
Expand Down Expand Up @@ -782,10 +782,10 @@ ENTRY e {
)";

HloPassPipeline pipeline("gemm_rewrite_deviceless");
pipeline.AddPass<GemmFusion>(backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability());
pipeline.AddPass<TritonFusionRewriter>(backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability());
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
tsl::port::MaxParallelism());
DebugOptions opts;
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ limitations under the License.
#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h"
#include "xla/service/gpu/transforms/fusion_wrapper.h"
#include "xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h"
#include "xla/service/gpu/transforms/gemm_fusion.h"
#include "xla/service/gpu/transforms/gemm_rewriter.h"
#include "xla/service/gpu/transforms/gemv_rewriter.h"
#include "xla/service/gpu/transforms/layout_assignment.h"
Expand All @@ -183,6 +182,7 @@ limitations under the License.
#include "xla/service/gpu/transforms/transpose_dimension_grouper.h"
#include "xla/service/gpu/transforms/tree_reduction_rewriter.h"
#include "xla/service/gpu/transforms/triton_fusion_numerics_verifier.h"
#include "xla/service/gpu/transforms/triton_fusion_rewriter.h"
#include "xla/service/gpu/transforms/windowed_einsum_handler.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_computation_deduplicator.h"
Expand Down Expand Up @@ -1457,7 +1457,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) ||
rocm_cc != nullptr)) {
pipeline.AddPass<GemvRewriter>();
pipeline.AddPass<GemmFusion>(gpu_version);
pipeline.AddPass<TritonFusionRewriter>(gpu_version);
} else if (cuda_cc != nullptr &&
cuda_cc->major == se::CudaComputeCapability::VOLTA) {
// Greedy pattern matching for custom kernel fusions.
Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ limitations under the License.
#include "xla/tests/filecheck.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tests/verified_hlo_module.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/casts.h"
Expand Down Expand Up @@ -1000,7 +1001,7 @@ ENTRY main {
bool triton_gemm_rewriter_has_run = false;
for (const HloPassMetadata& pass_metadata : module_metadata.pass_metadata()) {
triton_gemm_rewriter_has_run |=
pass_metadata.pass_name() == "triton-gemm-rewriter";
pass_metadata.pass_name() == "triton-fusion-rewriter";
}

EXPECT_EQ(triton_gemm_rewriter_has_run, expect_triton_gemm_rewriter_has_run);
Expand Down
14 changes: 7 additions & 7 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ cc_library(
srcs = ["cublas_pad_for_gemms.cc"],
hdrs = ["cublas_pad_for_gemms.h"],
deps = [
":gemm_fusion",
":triton_fusion_rewriter",
"//xla:literal_util",
"//xla:shape_util",
"//xla:util",
Expand Down Expand Up @@ -1530,9 +1530,9 @@ xla_test(
)

cc_library(
name = "gemm_fusion",
srcs = ["gemm_fusion.cc"],
hdrs = ["gemm_fusion.h"],
name = "triton_fusion_rewriter",
srcs = ["triton_fusion_rewriter.cc"],
hdrs = ["triton_fusion_rewriter.h"],
deps = [
"//xla:shape_util",
"//xla:util",
Expand Down Expand Up @@ -1563,10 +1563,10 @@ cc_library(
)

xla_cc_test(
name = "gemm_fusion_test",
srcs = ["gemm_fusion_test.cc"],
name = "triton_fusion_rewriter_test",
srcs = ["triton_fusion_rewriter_test.cc"],
deps = [
":gemm_fusion",
":triton_fusion_rewriter",
"//xla:autotuning_proto_cc",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/transforms/cublas_pad_for_gemms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ limitations under the License.
#include "xla/literal_util.h"
#include "xla/service/gpu/fusions/triton/triton_support_legacy.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/transforms/gemm_fusion.h"
#include "xla/service/gpu/transforms/triton_fusion_rewriter.h"
#include "xla/shape.h"
#include "xla/stream_executor/device_description.h"
#include "xla/util.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/transforms/gemm_fusion.h"
#include "xla/service/gpu/transforms/triton_fusion_rewriter.h"

#include <array>
#include <cstddef>
Expand Down Expand Up @@ -717,9 +717,10 @@ absl::StatusOr<FusionDecision> CreateDotFusion(

// Extracts into fused computations parts of HLO graph including dot()
// operations that can target the triton GEMM emitter.
class GemmFusionVisitor : public DfsHloRewriteVisitor {
class TritonFusionRewriterVisitor : public DfsHloRewriteVisitor {
public:
explicit GemmFusionVisitor(const se::GpuComputeCapability& gpu_version)
explicit TritonFusionRewriterVisitor(
const se::GpuComputeCapability& gpu_version)
: gpu_version_(gpu_version) {}
// Checks that a dot() should be targeting the triton GEMM emitter;
// if so - fuses all its compatible inputs and outputs as a new computation
Expand Down Expand Up @@ -801,7 +802,7 @@ class GemmFusionVisitor : public DfsHloRewriteVisitor {

absl::StatusOr<bool> RunOnComputation(
HloComputation* computation, const se::GpuComputeCapability& gpu_version) {
GemmFusionVisitor visitor(gpu_version);
TritonFusionRewriterVisitor visitor(gpu_version);
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
return visitor.changed();
}
Expand All @@ -817,7 +818,7 @@ bool ShouldTritonHandleGEMM(HloDotInstruction& dot,
->CanFuse();
}

absl::StatusOr<bool> GemmFusion::Run(
absl::StatusOr<bool> TritonFusionRewriter::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
TF_RETURN_IF_ERROR(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_H_
#define XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_H_
#ifndef XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_REWRITER_H_
#define XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_REWRITER_H_

// This file contains the code for fusing dots and other operations into Triton
// GEMM fusions.
Expand All @@ -24,7 +24,6 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/hlo_pass_interface.h"
#include "xla/service/instruction_fusion.h"
#include "xla/stream_executor/device_description.h"

namespace xla {
Expand All @@ -36,11 +35,12 @@ bool ShouldTritonHandleGEMM(HloDotInstruction&,

// Rewrite compatible dot() calls into custom calls with fused computations
// that target Triton-based matmul emitter.
class GemmFusion : public HloModulePass {
class TritonFusionRewriter : public HloModulePass {
public:
explicit GemmFusion(const se::GpuComputeCapability& compute_capability)
explicit TritonFusionRewriter(
const se::GpuComputeCapability& compute_capability)
: compute_capability_(compute_capability) {}
absl::string_view name() const override { return "triton-gemm-rewriter"; }
absl::string_view name() const override { return "triton-fusion-rewriter"; }

using HloPassInterface::Run;
absl::StatusOr<bool> Run(
Expand All @@ -54,4 +54,4 @@ class GemmFusion : public HloModulePass {
} // namespace gpu
} // namespace xla

#endif // XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_H_
#endif // XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_REWRITER_H_
Loading

0 comments on commit 557c24b

Please sign in to comment.