Skip to content

Commit

Permalink
PR #16921: [PJRT:GPU] Treat GPU collective memory space as device mem…
Browse files Browse the repository at this point in the history
…ory space

Imported from GitHub PR #16921

This is a regression fix when using --xla_gpu_enable_nccl_user_buffers=true.
Return device memory space when collective memory space is used as an output on GPU.
Copybara import of the project:

--
1b73040 by Jane Liu <[email protected]>:

Treat collective memory space as device memory space when using as an output

Merging this change closes #16921

FUTURE_COPYBARA_INTEGRATE_REVIEW=#16921 from zhenying-liu:nccl-buffer-output 1b73040
PiperOrigin-RevId: 672618973
  • Loading branch information
zhenying-liu authored and Google-ML-Automation committed Sep 12, 2024
1 parent 873c080 commit 4d43e40
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 0 deletions.
65 changes: 65 additions & 0 deletions .github/workflows/bazel_dependency_violations.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2024 The OpenXLA Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
name: Bazel Dependency Violations
permissions:
contents: read
on:
pull_request:
push:
branches:
- main

env:
# Have `go install` place binaries in $PATH
GOBIN: "/usr/local/bin"

jobs:
dependency-violations:
strategy:
matrix:
tag: [gpu, no_rocm]
name: no-${{ matrix.tag }}-targets-in-cpu-build
runs-on: ubuntu-22.04
defaults:
run:
shell: bash
timeout-minutes: 3
continue-on-error: true
steps:
- name: "Checking out repository"
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1
- name: "Install bazelisk"
run: go install github.com/bazelbuild/bazelisk@24651ab # v1.20.0
- name: "Run bazel cquery ... //xla/..."
run: |
set -euo pipefail
OUTPUT=$(bazelisk cquery --aspects build_tools/dependencies/aspects.bzl%validate_${{ matrix.tag }}_tag //xla/... 2>&1)
if echo "$OUTPUT" | grep 'Violation' >/dev/null; then
echo "The following dependency violations were found:"
echo "$OUTPUT" | grep 'Violation' | sed -e 's/^.*\[Violation\]/ -/'
echo ""
echo ""
echo "There are a couple of potential solutions for this/these violation(s):"
echo ""
echo "1. Tag the dependent target with the same tag as the dependee."
echo ""
echo "2. If unavoidable make the dependency selective using the"
echo " 'if_{gpu|cuda|rocm}_is_configured' macro. This is discouraged"
echo " outside of stream_executor."
echo ""
exit 1
fi
echo "No dependency violations found for tag '${{ matrix.tag }}'."
19 changes: 19 additions & 0 deletions build_tools/dependencies/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright 2024 The OpenXLA Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

package(
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
licenses = ["notice"],
)
82 changes: 82 additions & 0 deletions build_tools/dependencies/aspects.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2024 The OpenXLA Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""A collection of Bazel aspects that can help detecting dependency violations
The dependency violation detection works by iterating through all targets in XLA
and comparing the applied tags of each target to the tags of all its dependencies.
If a target is tagged `gpu` it means it can only be used in an XLA build with the
GPU backend enabled. Hence all targets that are NOT tagged `gpu` may never depend
on a target that IS tagged `gpu` if we are building XLA with only the CPU backend
enabled.
The Bazel aspect runs after Bazel's analysis phase. That means all `select` expressions
(and its derivatives like the `if_gpu_is_configured` macro) have been evaluated and
the actual build configuration is taken into account.
The easiest way to run the aspect is during a build:
`bazel build --aspects build_tools/dependencies/aspects.bzl%validate_gpu_tag //xla/...`
But a cquery expression also works:
`bazel cquery --aspects build_tools/dependencies/aspects.bzl%validate_gpu_tag //xla/...`
The results are reported as debug prints and need to be fished out of stderr. There
are ways to make it less hacky but the complexity of the aspect would also increase
quite a bit.
"""

DependencyViolationInfo = provider(
"Internal provider needed by the dependency violation check",
fields = {
# We can't access the tags of a dependency through the context, so instead we
# "send" the tags to the dependee through this provider.
"tags": "Tags of the dependecy",
},
)

def _dependency_violation_aspect_impl(_, ctx, tag):
if not hasattr(ctx.rule.attr, "deps"):
return [DependencyViolationInfo(tags = ctx.rule.attr.tags)]

for dep in ctx.rule.attr.deps:
if DependencyViolationInfo not in dep:
continue
dep_tags = dep[DependencyViolationInfo].tags
if tag in dep_tags and tag not in ctx.rule.attr.tags:
print("[Violation] {} (not tagged {}) depends on {} (tagged {})".format(
ctx.label,
tag,
dep.label,
tag,
)) # buildifier: disable=print

return [DependencyViolationInfo(tags = ctx.rule.attr.tags)]

def _gpu_tag_violation_aspect_impl(target, ctx):
return _dependency_violation_aspect_impl(target, ctx, "gpu")

validate_gpu_tag = aspect(
implementation = _gpu_tag_violation_aspect_impl,
attr_aspects = ["deps"],
)

def _no_rocm_tag_violation_aspect_impl(target, ctx):
return _dependency_violation_aspect_impl(target, ctx, "no_rocm")

validate_no_rocm_tag = aspect(
implementation = _no_rocm_tag_violation_aspect_impl,
attr_aspects = ["deps"],
)
62 changes: 62 additions & 0 deletions xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,24 @@ constexpr char const* kD2HProgramTupleOutput = R"(
}
)";

constexpr char const* kCollectiveMemorySpaceOutput = R"(
HloModule jit__psum, entry_computation_layout={(s32[1,4]{1,0})->s32[4]{0}}
region_0.3 {
Arg_0.0 = s32[] parameter(0)
Arg_1.0 = s32[] parameter(1)
ROOT add.0 = s32[] add(Arg_0.0, Arg_1.0)
}
ENTRY main.10_spmd {
param = s32[1,4]{1,0} parameter(0)
reshape = s32[4]{0} reshape(param)
ROOT all-reduce = s32[4]{0} all-reduce(reshape), channel_id=1, to_apply=region_0.3
}
)";

} // namespace

TEST(StreamExecutorGpuClientTest, ExecutePinnedHostOutputTest) {
Expand Down Expand Up @@ -1197,6 +1215,50 @@ TEST(StreamExecutorGpuClientTest, ExecutablePinnedHostOutputMemoryKindTest) {
EXPECT_EQ(memory_kinds[0][0], "pinned_host");
}

// Verify the output device memory kind with collective memory space shape when
// NCCL user buffer is enabled.
TEST(StreamExecutorGpuClientTest,
ExecutableCollectiveMemoryOutputMemoryKindTest) {
TF_ASSERT_OK_AND_ASSIGN(auto client,
GetStreamExecutorGpuClient(GpuClientOptions()));
xla::CompileOptions options;
options.executable_build_options.mutable_debug_options()
->set_xla_gpu_enable_nccl_user_buffers(true);

TF_ASSERT_OK_AND_ASSIGN(
auto executable,
CompileExecutable(kCollectiveMemorySpaceOutput, *client, options));
std::vector<int32_t> data{1, 2, 3, 4};
// Build the input shape with the correct memory space set.
Shape shape = ShapeUtil::MakeShapeWithDenseLayout(S32, {1, 4},
/*major_to_minor=*/{1, 0});
shape.mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace);

auto device = client->addressable_devices()[0];
TF_EXPECT_OK(device->default_memory_space());
TF_ASSIGN_OR_RETURN(
auto input, client->BufferFromHostBuffer(
data.data(), shape.element_type(), shape.dimensions(),
/*byte_strides=*/std::nullopt,
PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
/*on_done_with_host_buffer=*/nullptr, device));
EXPECT_EQ(input->memory_space()->kind(), "device");

TF_ASSERT_OK_AND_ASSIGN(auto memory_kinds,
executable->GetOutputMemoryKinds());
EXPECT_EQ(memory_kinds.size(), 1);
EXPECT_EQ(memory_kinds[0].size(), 1);
EXPECT_EQ(memory_kinds[0][0], "device");

TF_ASSERT_OK_AND_ASSIGN(
auto result, executable->Execute({{input.get()}}, ExecuteOptions()));
std::vector<std::unique_ptr<xla::PjRtBuffer>>& result_buffers = result[0];
EXPECT_EQ(result_buffers[0]->memory_space()->kind(), "device");
Shape result_shape = result_buffers[0]->on_device_shape();
auto memory_space = result_shape.layout().memory_space();
EXPECT_EQ(memory_space, 1);
}

TEST(StreamExecutorGpuClientTest,
ExecutablePinnedHostTupleOutputMemoryKindTest) {
TF_ASSERT_OK_AND_ASSIGN(auto client,
Expand Down
2 changes: 2 additions & 0 deletions xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2286,6 +2286,7 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
device->default_memory_space().value_or(nullptr);
if (shape.has_layout()) {
switch (shape.layout().memory_space()) {
case Layout::kGenericFastMemorySpace:
case Layout::kDefaultMemorySpace:
// Nothing to do, we have already set the default memory space.
break;
Expand Down Expand Up @@ -3322,6 +3323,7 @@ absl::StatusOr<absl::string_view> MemoryKindFromSimpleShape(
switch (shape.layout().memory_space()) {
case Layout::kHostMemorySpace:
return PinnedHostMemorySpace::kKind;
case Layout::kGenericFastMemorySpace:
case Layout::kDefaultMemorySpace:
return default_memory_kind;
default:
Expand Down

0 comments on commit 4d43e40

Please sign in to comment.