Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename GemmFusion to TritonFusionRewriter. #16939

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ xla_cc_test(
deps = [
":triton_fusion_analysis",
"//xla/hlo/ir:hlo",
"//xla/service/gpu/transforms:gemm_fusion",
"//xla/service/gpu/transforms:triton_fusion_rewriter",
"//xla/stream_executor:device_description",
"//xla/tests:hlo_test_base",
"//xla/tests:verified_hlo_module",
Expand Down Expand Up @@ -1434,7 +1434,6 @@ cc_library(
"//xla/service/gpu/transforms:dynamic_slice_fusion_rewriter",
"//xla/service/gpu/transforms:fusion_wrapper",
"//xla/service/gpu/transforms:gemm_broadcast_folding_rewriter",
"//xla/service/gpu/transforms:gemm_fusion",
"//xla/service/gpu/transforms:gemm_rewriter",
"//xla/service/gpu/transforms:gemv_rewriter",
"//xla/service/gpu/transforms:layout_assignment",
Expand All @@ -1457,6 +1456,7 @@ cc_library(
"//xla/service/gpu/transforms:transpose_dimension_grouper",
"//xla/service/gpu/transforms:tree_reduction_rewriter",
"//xla/service/gpu/transforms:triton_fusion_numerics_verifier",
"//xla/service/gpu/transforms:triton_fusion_rewriter",
"//xla/service/gpu/transforms:windowed_einsum_handler",
"//xla/service/llvm_ir:llvm_util",
"//xla/service/spmd:collective_permute_motion",
Expand Down Expand Up @@ -1627,6 +1627,7 @@ xla_test(
"//xla/tests:filecheck",
"//xla/tests:hlo_test_base",
"//xla/tests:literal_test_util",
"//xla/tests:verified_hlo_module",
"//xla/tests:xla_internal_test_main",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/log",
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/autotuning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ xla_test(
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:ir_emission_utils",
"//xla/service/gpu:matmul_utils",
"//xla/service/gpu/transforms:gemm_fusion",
"//xla/service/gpu/transforms:gemm_rewriter",
"//xla/service/gpu/transforms:triton_fusion_rewriter",
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_description_proto_cc",
"//xla/stream_executor:semantic_version",
Expand Down
18 changes: 9 additions & 9 deletions xla/service/gpu/autotuning/gemm_fusion_autotuner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ limitations under the License.
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/matmul_utils.h"
#include "xla/service/gpu/transforms/gemm_fusion.h"
#include "xla/service/gpu/transforms/gemm_rewriter.h"
#include "xla/service/gpu/transforms/triton_fusion_rewriter.h"
#include "xla/service/hlo_module_config.h"
#include "xla/service/hlo_pass_pipeline.h"
#include "xla/service/pattern_matcher.h"
Expand Down Expand Up @@ -198,10 +198,10 @@ class GemmFusionAutotunerTest : public StatelessAutotunerTest {
void CheckTritonAutotuning(absl::string_view hlo,
absl::string_view expected) {
HloPassPipeline pipeline("gemm_rewrite");
pipeline.AddPass<GemmFusion>(backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability());
pipeline.AddPass<TritonFusionRewriter>(backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability());
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
tsl::port::MaxParallelism());
DebugOptions opts;
Expand Down Expand Up @@ -782,10 +782,10 @@ ENTRY e {
)";

HloPassPipeline pipeline("gemm_rewrite_deviceless");
pipeline.AddPass<GemmFusion>(backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability());
pipeline.AddPass<TritonFusionRewriter>(backend()
.default_stream_executor()
->GetDeviceDescription()
.cuda_compute_capability());
tsl::thread::ThreadPool thread_pool(tsl::Env::Default(), "",
tsl::port::MaxParallelism());
DebugOptions opts;
Expand Down
4 changes: 2 additions & 2 deletions xla/service/gpu/gpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,6 @@ limitations under the License.
#include "xla/service/gpu/transforms/dynamic_slice_fusion_rewriter.h"
#include "xla/service/gpu/transforms/fusion_wrapper.h"
#include "xla/service/gpu/transforms/gemm_broadcast_folding_rewriter.h"
#include "xla/service/gpu/transforms/gemm_fusion.h"
#include "xla/service/gpu/transforms/gemm_rewriter.h"
#include "xla/service/gpu/transforms/gemv_rewriter.h"
#include "xla/service/gpu/transforms/layout_assignment.h"
Expand All @@ -183,6 +182,7 @@ limitations under the License.
#include "xla/service/gpu/transforms/transpose_dimension_grouper.h"
#include "xla/service/gpu/transforms/tree_reduction_rewriter.h"
#include "xla/service/gpu/transforms/triton_fusion_numerics_verifier.h"
#include "xla/service/gpu/transforms/triton_fusion_rewriter.h"
#include "xla/service/gpu/transforms/windowed_einsum_handler.h"
#include "xla/service/hlo.pb.h"
#include "xla/service/hlo_computation_deduplicator.h"
Expand Down Expand Up @@ -1457,7 +1457,7 @@ absl::Status GpuCompiler::OptimizeHloPostLayoutAssignment(
cuda_cc->IsAtLeast(se::CudaComputeCapability::AMPERE)) ||
rocm_cc != nullptr)) {
pipeline.AddPass<GemvRewriter>();
pipeline.AddPass<GemmFusion>(gpu_version);
pipeline.AddPass<TritonFusionRewriter>(gpu_version);
} else if (cuda_cc != nullptr &&
cuda_cc->major == se::CudaComputeCapability::VOLTA) {
// Greedy pattern matching for custom kernel fusions.
Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/gpu_compiler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ limitations under the License.
#include "xla/tests/filecheck.h"
#include "xla/tests/hlo_test_base.h"
#include "xla/tests/literal_test_util.h"
#include "xla/tests/verified_hlo_module.h"
#include "xla/tsl/lib/core/status_test_util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/casts.h"
Expand Down Expand Up @@ -1000,7 +1001,7 @@ ENTRY main {
bool triton_gemm_rewriter_has_run = false;
for (const HloPassMetadata& pass_metadata : module_metadata.pass_metadata()) {
triton_gemm_rewriter_has_run |=
pass_metadata.pass_name() == "triton-gemm-rewriter";
pass_metadata.pass_name() == "triton-fusion-rewriter";
}

EXPECT_EQ(triton_gemm_rewriter_has_run, expect_triton_gemm_rewriter_has_run);
Expand Down
14 changes: 7 additions & 7 deletions xla/service/gpu/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ cc_library(
srcs = ["cublas_pad_for_gemms.cc"],
hdrs = ["cublas_pad_for_gemms.h"],
deps = [
":gemm_fusion",
":triton_fusion_rewriter",
"//xla:literal_util",
"//xla:shape_util",
"//xla:util",
Expand Down Expand Up @@ -1530,9 +1530,9 @@ xla_test(
)

cc_library(
name = "gemm_fusion",
srcs = ["gemm_fusion.cc"],
hdrs = ["gemm_fusion.h"],
name = "triton_fusion_rewriter",
srcs = ["triton_fusion_rewriter.cc"],
hdrs = ["triton_fusion_rewriter.h"],
deps = [
"//xla:shape_util",
"//xla:util",
Expand Down Expand Up @@ -1563,10 +1563,10 @@ cc_library(
)

xla_cc_test(
name = "gemm_fusion_test",
srcs = ["gemm_fusion_test.cc"],
name = "triton_fusion_rewriter_test",
srcs = ["triton_fusion_rewriter_test.cc"],
deps = [
":gemm_fusion",
":triton_fusion_rewriter",
"//xla:autotuning_proto_cc",
"//xla:xla_data_proto_cc",
"//xla:xla_proto_cc",
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/transforms/cublas_pad_for_gemms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ limitations under the License.
#include "xla/literal_util.h"
#include "xla/service/gpu/fusions/triton/triton_support_legacy.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/service/gpu/transforms/gemm_fusion.h"
#include "xla/service/gpu/transforms/triton_fusion_rewriter.h"
#include "xla/shape.h"
#include "xla/stream_executor/device_description.h"
#include "xla/util.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/gpu/transforms/gemm_fusion.h"
#include "xla/service/gpu/transforms/triton_fusion_rewriter.h"

#include <array>
#include <cstddef>
Expand Down Expand Up @@ -717,9 +717,10 @@ absl::StatusOr<FusionDecision> CreateDotFusion(

// Extracts into fused computations parts of HLO graph including dot()
// operations that can target the triton GEMM emitter.
class GemmFusionVisitor : public DfsHloRewriteVisitor {
class TritonFusionRewriterVisitor : public DfsHloRewriteVisitor {
public:
explicit GemmFusionVisitor(const se::GpuComputeCapability& gpu_version)
explicit TritonFusionRewriterVisitor(
const se::GpuComputeCapability& gpu_version)
: gpu_version_(gpu_version) {}
// Checks that a dot() should be targeting the triton GEMM emitter;
// if so - fuses all its compatible inputs and outputs as a new computation
Expand Down Expand Up @@ -801,7 +802,7 @@ class GemmFusionVisitor : public DfsHloRewriteVisitor {

absl::StatusOr<bool> RunOnComputation(
HloComputation* computation, const se::GpuComputeCapability& gpu_version) {
GemmFusionVisitor visitor(gpu_version);
TritonFusionRewriterVisitor visitor(gpu_version);
TF_RETURN_IF_ERROR(computation->Accept(&visitor));
return visitor.changed();
}
Expand All @@ -817,7 +818,7 @@ bool ShouldTritonHandleGEMM(HloDotInstruction& dot,
->CanFuse();
}

absl::StatusOr<bool> GemmFusion::Run(
absl::StatusOr<bool> TritonFusionRewriter::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
TF_RETURN_IF_ERROR(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_H_
#define XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_H_
#ifndef XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_REWRITER_H_
#define XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_REWRITER_H_

// This file contains the code for fusing dots and other operations into Triton
// GEMM fusions.
Expand All @@ -24,7 +24,6 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/service/hlo_pass_interface.h"
#include "xla/service/instruction_fusion.h"
#include "xla/stream_executor/device_description.h"

namespace xla {
Expand All @@ -36,11 +35,12 @@ bool ShouldTritonHandleGEMM(HloDotInstruction&,

// Rewrite compatible dot() calls into custom calls with fused computations
// that target Triton-based matmul emitter.
class GemmFusion : public HloModulePass {
class TritonFusionRewriter : public HloModulePass {
public:
explicit GemmFusion(const se::GpuComputeCapability& compute_capability)
explicit TritonFusionRewriter(
const se::GpuComputeCapability& compute_capability)
: compute_capability_(compute_capability) {}
absl::string_view name() const override { return "triton-gemm-rewriter"; }
absl::string_view name() const override { return "triton-fusion-rewriter"; }

using HloPassInterface::Run;
absl::StatusOr<bool> Run(
Expand All @@ -54,4 +54,4 @@ class GemmFusion : public HloModulePass {
} // namespace gpu
} // namespace xla

#endif // XLA_SERVICE_GPU_TRANSFORMS_GEMM_FUSION_H_
#endif // XLA_SERVICE_GPU_TRANSFORMS_TRITON_FUSION_REWRITER_H_
Loading
Loading