-
Notifications
You must be signed in to change notification settings - Fork 74.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PR #16921: [PJRT:GPU] Treat GPU collective memory space as device mem…
…ory space Imported from GitHub PR openxla/xla#16921 This is a regression fix when using --xla_gpu_enable_nccl_user_buffers=true. Return device memory space when collective memory space is used as an output on GPU. Copybara import of the project: -- 1b730405577b926030c3fbde1132141717590089 by Jane Liu <[email protected]>: Treat collective memory space as device memory space when using as an output Merging this change closes #16921 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16921 from zhenying-liu:nccl-buffer-output 1b730405577b926030c3fbde1132141717590089 PiperOrigin-RevId: 672618973
- Loading branch information
1 parent
a31cd33
commit 537b30e
Showing
5 changed files
with
230 additions
and
0 deletions.
There are no files selected for viewing
65 changes: 65 additions & 0 deletions
65
third_party/xla/.github/workflows/bazel_dependency_violations.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright 2024 The OpenXLA Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
name: Bazel Dependency Violations | ||
permissions: | ||
contents: read | ||
on: | ||
pull_request: | ||
push: | ||
branches: | ||
- main | ||
|
||
env: | ||
# Have `go install` place binaries in $PATH | ||
GOBIN: "/usr/local/bin" | ||
|
||
jobs: | ||
dependency-violations: | ||
strategy: | ||
matrix: | ||
tag: [gpu, no_rocm] | ||
name: no-${{ matrix.tag }}-targets-in-cpu-build | ||
runs-on: ubuntu-22.04 | ||
defaults: | ||
run: | ||
shell: bash | ||
timeout-minutes: 3 | ||
continue-on-error: true | ||
steps: | ||
- name: "Checking out repository" | ||
uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 | ||
- name: "Install bazelisk" | ||
run: go install github.com/bazelbuild/bazelisk@24651ab # v1.20.0 | ||
- name: "Run bazel cquery ... //xla/..." | ||
run: | | ||
set -euo pipefail | ||
OUTPUT=$(bazelisk cquery --aspects build_tools/dependencies/aspects.bzl%validate_${{ matrix.tag }}_tag //xla/... 2>&1) | ||
if echo "$OUTPUT" | grep 'Violation' >/dev/null; then | ||
echo "The following dependency violations were found:" | ||
echo "$OUTPUT" | grep 'Violation' | sed -e 's/^.*\[Violation\]/ -/' | ||
echo "" | ||
echo "" | ||
echo "There are a couple of potential solutions for this/these violation(s):" | ||
echo "" | ||
echo "1. Tag the dependent target with the same tag as the dependee." | ||
echo "" | ||
echo "2. If unavoidable make the dependency selective using the" | ||
echo " 'if_{gpu|cuda|rocm}_is_configured' macro. This is discouraged" | ||
echo " outside of stream_executor." | ||
echo "" | ||
exit 1 | ||
fi | ||
echo "No dependency violations found for tag '${{ matrix.tag }}'." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright 2024 The OpenXLA Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
|
||
package( | ||
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], | ||
licenses = ["notice"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright 2024 The OpenXLA Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================ | ||
|
||
"""A collection of Bazel aspects that can help detecting dependency violations | ||
The dependency violation detection works by iterating through all targets in XLA | ||
and comparing the applied tags of each target to the tags of all its dependencies. | ||
If a target is tagged `gpu` it means it can only be used in an XLA build with the | ||
GPU backend enabled. Hence all targets that are NOT tagged `gpu` may never depend | ||
on a target that IS tagged `gpu` if we are building XLA with only the CPU backend | ||
enabled. | ||
The Bazel aspect runs after Bazel's analysis phase. That means all `select` expressions | ||
(and its derivatives like the `if_gpu_is_configured` macro) have been evaluated and | ||
the actual build configuration is taken into account. | ||
The easiest way to run the aspect is during a build: | ||
`bazel build --aspects build_tools/dependencies/aspects.bzl%validate_gpu_tag //xla/...` | ||
But a cquery expression also works: | ||
`bazel cquery --aspects build_tools/dependencies/aspects.bzl%validate_gpu_tag //xla/...` | ||
The results are reported as debug prints and need to be fished out of stderr. There | ||
are ways to make it less hacky but the complexity of the aspect would also increase | ||
quite a bit. | ||
""" | ||
|
||
DependencyViolationInfo = provider( | ||
"Internal provider needed by the dependency violation check", | ||
fields = { | ||
# We can't access the tags of a dependency through the context, so instead we | ||
# "send" the tags to the dependee through this provider. | ||
"tags": "Tags of the dependecy", | ||
}, | ||
) | ||
|
||
def _dependency_violation_aspect_impl(_, ctx, tag): | ||
if not hasattr(ctx.rule.attr, "deps"): | ||
return [DependencyViolationInfo(tags = ctx.rule.attr.tags)] | ||
|
||
for dep in ctx.rule.attr.deps: | ||
if DependencyViolationInfo not in dep: | ||
continue | ||
dep_tags = dep[DependencyViolationInfo].tags | ||
if tag in dep_tags and tag not in ctx.rule.attr.tags: | ||
print("[Violation] {} (not tagged {}) depends on {} (tagged {})".format( | ||
ctx.label, | ||
tag, | ||
dep.label, | ||
tag, | ||
)) # buildifier: disable=print | ||
|
||
return [DependencyViolationInfo(tags = ctx.rule.attr.tags)] | ||
|
||
def _gpu_tag_violation_aspect_impl(target, ctx): | ||
return _dependency_violation_aspect_impl(target, ctx, "gpu") | ||
|
||
validate_gpu_tag = aspect( | ||
implementation = _gpu_tag_violation_aspect_impl, | ||
attr_aspects = ["deps"], | ||
) | ||
|
||
def _no_rocm_tag_violation_aspect_impl(target, ctx): | ||
return _dependency_violation_aspect_impl(target, ctx, "no_rocm") | ||
|
||
validate_no_rocm_tag = aspect( | ||
implementation = _no_rocm_tag_violation_aspect_impl, | ||
attr_aspects = ["deps"], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters