Skip to content

Commit

Permalink
Internal breakage.
Browse files Browse the repository at this point in the history
Reverts 5adafde

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#16921 from zhenying-liu:nccl-buffer-output b5e43d6455adc49f5ac99a9a9e95cf495eb46170
PiperOrigin-RevId: 674073765
  • Loading branch information
tensorflower-gardener committed Sep 13, 2024
1 parent d58d166 commit 1bd8688
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 4 deletions.
62 changes: 62 additions & 0 deletions third_party/xla/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,24 @@ constexpr char const* kD2HProgramTupleOutput = R"(
}
)";

constexpr char const* kCollectiveMemorySpaceOutput = R"(
HloModule jit__psum, entry_computation_layout={(s32[1,4]{1,0})->s32[4]{0}}
region_0.3 {
Arg_0.0 = s32[] parameter(0)
Arg_1.0 = s32[] parameter(1)
ROOT add.0 = s32[] add(Arg_0.0, Arg_1.0)
}
ENTRY main.10_spmd {
param = s32[1,4]{1,0} parameter(0)
reshape = s32[4]{0} reshape(param)
ROOT all-reduce = s32[4]{0} all-reduce(reshape), channel_id=1, to_apply=region_0.3
}
)";

} // namespace

TEST(StreamExecutorGpuClientTest, ExecutePinnedHostOutputTest) {
Expand Down Expand Up @@ -1197,6 +1215,50 @@ TEST(StreamExecutorGpuClientTest, ExecutablePinnedHostOutputMemoryKindTest) {
EXPECT_EQ(memory_kinds[0][0], "pinned_host");
}

// Verify the output device memory kind with collective memory space shape when
// NCCL user buffer is enabled.
TEST(StreamExecutorGpuClientTest,
ExecutableCollectiveMemoryOutputMemoryKindTest) {
TF_ASSERT_OK_AND_ASSIGN(auto client,
GetStreamExecutorGpuClient(GpuClientOptions()));
xla::CompileOptions options;
options.executable_build_options.mutable_debug_options()
->set_xla_gpu_enable_nccl_user_buffers(true);

TF_ASSERT_OK_AND_ASSIGN(
auto executable,
CompileExecutable(kCollectiveMemorySpaceOutput, *client, options));
std::vector<int32_t> data{1, 2, 3, 4};
// Build the input shape with the correct memory space set.
Shape shape = ShapeUtil::MakeShapeWithDenseLayout(S32, {1, 4},
/*major_to_minor=*/{1, 0});
shape.mutable_layout()->set_memory_space(Layout::kDefaultMemorySpace);

auto device = client->addressable_devices()[0];
TF_EXPECT_OK(device->default_memory_space());
TF_ASSERT_OK_AND_ASSIGN(
auto input, client->BufferFromHostBuffer(
data.data(), shape.element_type(), shape.dimensions(),
/*byte_strides=*/std::nullopt,
PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
/*on_done_with_host_buffer=*/nullptr, device));
EXPECT_EQ(input->memory_space()->kind(), "device");

TF_ASSERT_OK_AND_ASSIGN(auto memory_kinds,
executable->GetOutputMemoryKinds());
EXPECT_EQ(memory_kinds.size(), 1);
EXPECT_EQ(memory_kinds[0].size(), 1);
EXPECT_EQ(memory_kinds[0][0], "device");

TF_ASSERT_OK_AND_ASSIGN(
auto result, executable->Execute({{input.get()}}, ExecuteOptions()));
std::vector<std::unique_ptr<xla::PjRtBuffer>>& result_buffers = result[0];
EXPECT_EQ(result_buffers[0]->memory_space()->kind(), "device");
Shape result_shape = result_buffers[0]->on_device_shape();
auto memory_space = result_shape.layout().memory_space();
EXPECT_EQ(memory_space, 1);
}

TEST(StreamExecutorGpuClientTest,
ExecutablePinnedHostTupleOutputMemoryKindTest) {
TF_ASSERT_OK_AND_ASSIGN(auto client,
Expand Down
2 changes: 2 additions & 0 deletions third_party/xla/xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2286,6 +2286,7 @@ absl::StatusOr<std::unique_ptr<PjRtBuffer>> OutputBufferHelper(
device->default_memory_space().value_or(nullptr);
if (shape.has_layout()) {
switch (shape.layout().memory_space()) {
case Layout::kGenericFastMemorySpace:
case Layout::kDefaultMemorySpace:
// Nothing to do, we have already set the default memory space.
break;
Expand Down Expand Up @@ -3322,6 +3323,7 @@ absl::StatusOr<absl::string_view> MemoryKindFromSimpleShape(
switch (shape.layout().memory_space()) {
case Layout::kHostMemorySpace:
return PinnedHostMemorySpace::kKind;
case Layout::kGenericFastMemorySpace:
case Layout::kDefaultMemorySpace:
return default_memory_kind;
default:
Expand Down
12 changes: 9 additions & 3 deletions third_party/xla/xla/python/pytree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -595,9 +595,15 @@ nb::list PyTreeDef::FlattenUpTo(nb::handle xs) const {

case PyTreeKind::kNone:
if (!object.is_none()) {
throw std::invalid_argument(
absl::StrFormat("Expected None, got %s.",
nb::cast<std::string_view>(nb::repr(object))));
PythonDeprecationWarning(
/*stacklevel=*/3,
"In a future release of JAX, flatten-up-to will no longer "
"consider None to be a tree-prefix of non-None values, got: "
"%s.\n\n"
"To preserve the current behavior, you can usually write:\n"
" jax.tree.map(lambda x, y: None if x is None else f(x, y), a, "
"b, is_leaf=lambda x: x is None)",
nb::cast<std::string_view>(nb::repr(object)));
}
break;

Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@

# Just an internal arbitrary increasing number to help with backward-compatible
# changes. In JAX, reference this via jax._src.lib.xla_extension_version.
_version = 284
_version = 283

# Version number for MLIR:Python components.
mlir_api_version = 57
Expand Down

0 comments on commit 1bd8688

Please sign in to comment.