Skip to content

Commit

Permalink
PR #16921: [PJRT:GPU] Treat GPU collective memory space as device mem…
Browse files Browse the repository at this point in the history
…ory space

Imported from GitHub PR #16921

This is a regression fix when using --xla_gpu_enable_nccl_user_buffers=true.
Return device memory space when collective memory space is used as an output on GPU.
Copybara import of the project:

--
1b73040 by Jane Liu <[email protected]>:

Treat collective memory space as device memory space when using as an output

Merging this change closes #16921

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16921 from zhenying-liu:nccl-buffer-output 1b73040
PiperOrigin-RevId: 672618973
  • Loading branch information
zhenying-liu authored and Google-ML-Automation committed Sep 12, 2024
1 parent e46de06 commit d202f4f
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 0 deletions.
62 changes: 62 additions & 0 deletions xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,24 @@ constexpr char const* kD2HProgramTupleOutput = R"(
}
)";

constexpr char const* kCollectiveMemorySpaceOutput = R"(
HloModule jit__psum, entry_computation_layout={(s32[1,4]{1,0})->s32[4]{0}}
region_0.3 {
Arg_0.0 = s32[] parameter(0)
Arg_1.0 = s32[] parameter(1)
ROOT add.0 = s32[] add(Arg_0.0, Arg_1.0)
}
ENTRY main.10_spmd {
param = s32[1,4]{1,0} parameter(0)
reshape = s32[4]{0} reshape(param)
ROOT all-reduce = s32[4]{0} all-reduce(reshape), channel_id=1, to_apply=region_0.3
}
)";

} // namespace

TEST(StreamExecutorGpuClientTest, ExecutePinnedHostOutputTest) {
Expand Down Expand Up @@ -1197,6 +1215,50 @@ TEST(StreamExecutorGpuClientTest, ExecutablePinnedHostOutputMemoryKindTest) {
EXPECT_EQ(memory_kinds[0][0], "pinned_host");
}

// Verify the output device memory kind with collective memory space shape when
// NCCL user buffer is enabled.
TEST(StreamExecutorGpuClientTest,
ExecutableCollectiveMemoryOutputMemoryKindTest) {
TF_ASSERT_OK_AND_ASSIGN(auto client,
GetStreamExecutorGpuClient(GpuClientOptions()));
xla::CompileOptions options;
options.executable_build_options.mutable_debug_options()
->set_xla_gpu_enable_nccl_user_buffers(true);

TF_ASSERT_OK_AND_ASSIGN(
auto executable,
CompileExecutable(kCollectiveMemorySpaceOutput, *client, options));
std::vector<int32_t> data{1, 2, 3, 4};
// Build the input shape with the correct memory space set.
Shape shape = ShapeUtil::MakeShapeWithDenseLayout(S32, {1, 4},
/*major_to_minor=*/{1, 0});
shape.mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace);

auto device = client->addressable_devices()[0];
TF_EXPECT_OK(device->default_memory_space());
TF_ASSIGN_OR_RETURN(
auto input, client->BufferFromHostBuffer(
data.data(), shape.element_type(), shape.dimensions(),
/*byte_strides=*/std::nullopt,
PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
/*on_done_with_host_buffer=*/nullptr, device));
EXPECT_EQ(input->memory_space()->kind(), "device");

TF_ASSERT_OK_AND_ASSIGN(auto memory_kinds,
executable->GetOutputMemoryKinds());
EXPECT_EQ(memory_kinds.size(), 1);
EXPECT_EQ(memory_kinds[0].size(), 1);
EXPECT_EQ(memory_kinds[0][0], "device");

TF_ASSERT_OK_AND_ASSIGN(
auto result, executable->Execute({{input.get()}}, ExecuteOptions()));
std::vector<std::unique_ptr<xla::PjRtBuffer>>& result_buffers = result[0];
EXPECT_EQ(result_buffers[0]->memory_space()->kind(), "device");
Shape result_shape = result_buffers[0]->on_device_shape();
auto memory_space = result_shape.layout().memory_space();
EXPECT_EQ(memory_space, 1);
}

TEST(StreamExecutorGpuClientTest,
ExecutablePinnedHostTupleOutputMemoryKindTest) {
TF_ASSERT_OK_AND_ASSIGN(auto client,
Expand Down
2 changes: 2 additions & 0 deletions xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2286,6 +2286,7 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
device->default_memory_space().value_or(nullptr);
if (shape.has_layout()) {
switch (shape.layout().memory_space()) {
case Layout::kGenericFastMemorySpace:
case Layout::kDefaultMemorySpace:
// Nothing to do, we have already set the default memory space.
break;
Expand Down Expand Up @@ -3322,6 +3323,7 @@ absl::StatusOr<absl::string_view> MemoryKindFromSimpleShape(
switch (shape.layout().memory_space()) {
case Layout::kHostMemorySpace:
return PinnedHostMemorySpace::kKind;
case Layout::kGenericFastMemorySpace:
case Layout::kDefaultMemorySpace:
return default_memory_kind;
default:
Expand Down
66 changes: 66 additions & 0 deletions xla/service/gpu/fusions/cudnn_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,72 @@ ENTRY e {
ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
}

TEST_F(CuDnnFusionFileCheckTest, VectorTensorMultiplicationWorksCorrectly) {
const std::string kHloText = R"(
f {
p0 = bf16[64,1] parameter(0)
p1 = s8[64,128] parameter(1)
p1c = bf16[64,128] convert(p1)
ROOT out = bf16[1,128] dot(p0, p1c),
lhs_contracting_dims={0}, rhs_contracting_dims={0}
}
ENTRY e {
p0 = bf16[64,1] parameter(0)
p1 = s8[64,128] parameter(1)
ROOT r = bf16[1,128] fusion(p0, p1), kind=kCustom, calls=f,
backend_config={"fusion_backend_config":{"kind":"__cudnn$fusion"}}
})";

EXPECT_TRUE(*RunCuDnnFileCheck(kHloText, R"(
CHECK: "tensors"
CHECK: "out"
CHECK: "dim": [1,1,128]
CHECK: "stride": [1,128,1]
CHECK: "p0"
CHECK: "dim": [1,1,64]
CHECK: "stride": [1,64,1]
CHECK: "p1"
CHECK: "dim": [1,64,128]
CHECK: "stride": [1,128,1]
)"));

EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
}

TEST_F(CuDnnFusionFileCheckTest, TensorVectorMultiplicationWorksCorrectly) {
const std::string kHloText = R"(
f {
p0 = bf16[64,256] parameter(0)
p1 = s8[64,1] parameter(1)
p1c = bf16[64,1] convert(p1)
ROOT out = bf16[256,1] dot(p0, p1c),
lhs_contracting_dims={0}, rhs_contracting_dims={0}
}
ENTRY e {
p0 = bf16[64,256] parameter(0)
p1 = s8[64,1] parameter(1)
ROOT r = bf16[256,1] fusion(p0, p1), kind=kCustom, calls=f,
backend_config={"fusion_backend_config":{"kind":"__cudnn$fusion"}}
})";

EXPECT_TRUE(*RunCuDnnFileCheck(kHloText, R"(
CHECK: "tensors"
CHECK: "out"
CHECK: "dim": [1,256,1]
CHECK: "stride": [1,1,256]
CHECK: "p0"
CHECK: "dim": [1,256,64]
CHECK: "stride": [1,1,256]
CHECK: "p1"
CHECK: "dim": [1,64,1]
CHECK: "stride": [1,1,64]
)"));

EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3}));
}

TEST_F(CuDnnFusionExecutionTest, DotBF16WithCopyExecutesCorrectly) {
EXPECT_TRUE(RunAndCompare(R"(
fusion1 {
Expand Down
14 changes: 14 additions & 0 deletions xla/service/gpu/transforms/cudnn_fusion_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,20 @@ class GemmDimensionAdapter {
result.strides[kOutputLHSNonContractingDimensionIndex] *
result.sizes[kOutputLHSNonContractingDimensionIndex];
}

// 0 (kBatchDimensionIndex) is always the batch dimension;
// 1 and 2 are the non-batch ones. cuDNN relies on strides to determine
// layouts and gets confused when both strides of non-batch dimensions
// are equal to 1 - this is the case for tensors with 1-sized dimension
// like [A,1]. The stride of the 1-sized dimension does not matter for
// correctness because there is no iteration along this dimension, but
// setting it to A and representing the tensor as its equivalent [1,A]
// helps cuDNN.
if (result.strides[1] == 1 && result.strides[2] == 1) {
const int one_sized_dim_idx = (result.sizes[1] == 1) ? 1 : 2;
result.strides[one_sized_dim_idx] = result.sizes[1] * result.sizes[2];
}

if (!slicing_is_present) {
result.slices.reset();
}
Expand Down

0 comments on commit d202f4f

Please sign in to comment.