Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Control flow support #124

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/client/executable_build_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ ExecutionOptions CreateExecutionOptions(
*execution_options.mutable_shape_with_output_layout() =
result_shape.ToProto();
}
execution_options.set_seed(build_options.seed());
execution_options.set_num_replicas(build_options.num_replicas());
execution_options.set_num_partitions(build_options.num_partitions());
execution_options.set_use_spmd_partitioning(
Expand Down
6 changes: 6 additions & 0 deletions tensorflow/compiler/xla/client/executable_build_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class ExecutableBuildOptions {
// debugging.
std::string ToString() const;

// The random seed for compilation.
int seed() const { return seed_; };
void set_seed(int seed) { seed_ = seed; }

// The number of replicas of this computation that are to be executed.
// Defaults to 1.
int num_replicas() const { return num_replicas_; }
Expand Down Expand Up @@ -189,6 +193,8 @@ class ExecutableBuildOptions {
bool run_backend_only_ = false;
bool allow_spmd_sharding_propagation_to_output_ = false;
tensorflow::thread::ThreadPool* compile_thread_pool_ = nullptr;
// Added by Alpa
int seed_ = 42;
};

// Creates an ExecutionOptions based on a given ExecutableBuildOptions and
Expand Down
37 changes: 29 additions & 8 deletions tensorflow/compiler/xla/client/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2935,7 +2935,8 @@ XlaOp XlaBuilder::CrossReplicaSum(
XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
const std::optional<Shape>& shape_with_layout) {
const std::optional<Shape>& shape_with_layout,
const std::optional<bool> use_global_device_ids) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
Expand Down Expand Up @@ -2992,6 +2993,9 @@ XlaOp XlaBuilder::AllReduce(XlaOp operand, const XlaComputation& computation,
if (channel_id.has_value()) {
instr.set_channel_id(channel_id->handle());
}
if (use_global_device_ids.has_value()) {
instr.set_use_global_device_ids(use_global_device_ids.value());
}

AddCalledComputation(computation, &instr);

Expand Down Expand Up @@ -3071,20 +3075,24 @@ XlaOp XlaBuilder::ReduceScatter(
XlaOp XlaBuilder::AllToAll(XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<Layout>& layout) {
const std::optional<ChannelHandle>& channel_id,
const std::optional<Layout>& layout,
const std::optional<bool> use_global_device_ids) {
// Array all_to_all may need to violate layout constraint to be legal so use
// the tuple version.
if (layout.has_value()) {
return AllToAllTuple(operand, split_dimension, concat_dimension,
split_count, replica_groups, layout);
}
return AllToAllArray(operand, split_dimension, concat_dimension, split_count,
replica_groups);
replica_groups, channel_id, use_global_device_ids);
}

XlaOp XlaBuilder::AllToAllArray(XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups) {
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
const std::optional<bool> use_global_device_ids) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(
Expand All @@ -3103,6 +3111,14 @@ XlaOp XlaBuilder::AllToAllArray(XlaOp operand, int64_t split_dimension,
*instr.add_replica_groups() = group;
}
}

if (channel_id.has_value()) {
instr.set_channel_id(channel_id->handle());
}
if (use_global_device_ids.has_value()) {
instr.set_use_global_device_ids(use_global_device_ids.value());
}

instr.add_dimensions(split_dimension);
TF_ASSIGN_OR_RETURN(
XlaOp all_to_all,
Expand Down Expand Up @@ -4663,9 +4679,11 @@ XlaOp CrossReplicaSum(const XlaOp operand,
XlaOp AllReduce(const XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
const std::optional<Shape>& shape_with_layout) {
const std::optional<Shape>& shape_with_layout,
const std::optional<bool> use_global_device_ids) {
return operand.builder()->AllReduce(operand, computation, replica_groups,
channel_id, shape_with_layout);
channel_id, shape_with_layout,
use_global_device_ids);
}

XlaOp ReduceScatter(const XlaOp operand, const XlaComputation& computation,
Expand All @@ -4682,9 +4700,12 @@ XlaOp ReduceScatter(const XlaOp operand, const XlaComputation& computation,
XlaOp AllToAll(const XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<Layout>& layout) {
const std::optional<ChannelHandle>& channel_id,
const std::optional<Layout>& layout,
const std::optional<bool> use_global_device_ids) {
return operand.builder()->AllToAll(operand, split_dimension, concat_dimension,
split_count, replica_groups, layout);
split_count, replica_groups, channel_id, layout,
use_global_device_ids);
}

XlaOp AllToAllTuple(absl::Span<const XlaOp> operands,
Expand Down
25 changes: 18 additions & 7 deletions tensorflow/compiler/xla/client/xla_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,8 @@ class XlaBuilder {
XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups = {},
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const std::optional<Shape>& shape_with_layout = std::nullopt);
const std::optional<Shape>& shape_with_layout = std::nullopt,
const std::optional<bool> use_global_device_ids = std::nullopt);

XlaOp ReduceScatter(
XlaOp operand, const XlaComputation& computation,
Expand All @@ -762,7 +763,9 @@ class XlaBuilder {
XlaOp AllToAll(XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<Layout>& layout = std::nullopt);
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
const std::optional<bool> use_global_device_ids = std::nullopt);

XlaOp AllToAllTuple(absl::Span<const XlaOp> operands,
absl::Span<const ReplicaGroup> replica_groups,
Expand Down Expand Up @@ -1362,7 +1365,8 @@ class XlaBuilder {
friend XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id,
const std::optional<Shape>& shape_with_layout);
const std::optional<Shape>& shape_with_layout,
const std::optional<bool> use_global_device_ids);
friend XlaOp ReduceScatter(XlaOp operand, const XlaComputation& computation,
int64_t scatter_dimension, int64_t shard_count,
absl::Span<const ReplicaGroup> replica_groups,
Expand All @@ -1373,7 +1377,9 @@ class XlaBuilder {
friend XlaOp AllToAll(XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<Layout>& layout);
const std::optional<ChannelHandle>& channel_id,
const std::optional<Layout>& layout,
const std::optional<bool> use_global_device_ids);
friend XlaOp AllToAllTuple(absl::Span<const XlaOp> operands,
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<Layout>& layout);
Expand Down Expand Up @@ -1517,7 +1523,9 @@ class XlaBuilder {

XlaOp AllToAllArray(XlaOp operand, int64_t split_dimension,
int64_t concat_dimension, int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups);
absl::Span<const ReplicaGroup> replica_groups,
const std::optional<ChannelHandle>& channel_id=std::nullopt,
const std::optional<bool> use_global_device_ids=std::nullopt);

// Creates an op with the given opcode and the output shape.
virtual StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
Expand Down Expand Up @@ -2343,7 +2351,8 @@ XlaOp AllGather(XlaOp operand, int64_t all_gather_dimension,
XlaOp AllReduce(XlaOp operand, const XlaComputation& computation,
absl::Span<const ReplicaGroup> replica_groups = {},
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const std::optional<Shape>& shape_with_layout = std::nullopt);
const std::optional<Shape>& shape_with_layout = std::nullopt,
const std::optional<bool> use_global_device_ids = std::nullopt);

XlaOp ReduceScatter(
XlaOp operand, const XlaComputation& computation, int64_t scatter_dimension,
Expand All @@ -2359,7 +2368,9 @@ XlaOp ReduceScatter(
XlaOp AllToAll(XlaOp operand, int64_t split_dimension, int64_t concat_dimension,
int64_t split_count,
absl::Span<const ReplicaGroup> replica_groups = {},
const std::optional<Layout>& layout = std::nullopt);
const std::optional<ChannelHandle>& channel_id = std::nullopt,
const std::optional<Layout>& layout = std::nullopt,
const std::optional<bool> use_global_device_ids = std::nullopt);

XlaOp AllToAllTuple(absl::Span<const XlaOp> operand,
absl::Span<const ReplicaGroup> replica_groups = {},
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/pjrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
"//tensorflow/compiler/xla/service:pass_context",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options",
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/compiler/xla/pjrt/local_device_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) {
usage_stream_pool_.push(std::move(stream));
}

void LocalDeviceState::SetPrngSeed(int seed) {
absl::MutexLock lock(&mu_);
prng_seed_generator_.seed(seed);
}

int LocalDeviceState::GetNewPrngSeed() {
absl::MutexLock lock(&mu_);
int x = 0;
Expand Down
4 changes: 3 additions & 1 deletion tensorflow/compiler/xla/pjrt/local_device_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,14 @@ class LocalDeviceState {

Semaphore& compute_semaphore() { return compute_semaphore_; }

void SetPrngSeed(int seed);

// Returns a fresh, PRNG-generated random seed for an XLA computation.
int GetNewPrngSeed();

private:
Status SynchronizeAllActivity();

private:
AllocationModel allocation_model_;

EventPool event_pool_;
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/stream.h"

// Added by Alpa
#include "tensorflow/compiler/xla/service/pass_context.h"

namespace xla {

PjRtPlatformId PjRtStreamExecutorDevice::platform_id() const {
Expand Down Expand Up @@ -2388,6 +2391,11 @@ PjRtStreamExecutorExecutable::GetHloModules() const {
StatusOr<PjRtStreamExecutorClient::ExecutableExtras>
PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) {
ExecutableExtras extras;

if (pass_context::GetBool("build_option::bypass_device_assignment_check", false)) {
return extras;
}

std::shared_ptr<DeviceAssignment>& device_assignment =
extras.device_assignment;
std::vector<PjRtStreamExecutorExecutable::LogicalDeviceIds>&
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/compiler/xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,11 @@ cc_library(
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@pybind11",
# Added by Alpa
"//tensorflow/compiler/xla/service:pass_context",
"//tensorflow/compiler/xla/service/gpu:gpu_cost_model",
"//tensorflow/compiler/xla/service/spmd:alpa_compiler",
"//tensorflow/compiler/xla/service/spmd:grad_acc_rewrite",
],
)

Expand Down Expand Up @@ -746,6 +751,8 @@ pybind_extension(
"//tensorflow/core:lib_internal_impl", # buildcleaner: keep
"//tensorflow/core/distributed_runtime/preemption:preemption_sync_manager",
"//tensorflow/python:bfloat16_lib",
# Added by Alpa
"//tensorflow/compiler/xla/service/gpu:alpa_nccl_wrapper",
] + select({
":gpu_enabled": [
"//tensorflow/compiler/xla/pjrt:gpu_device",
Expand Down
5 changes: 5 additions & 0 deletions tensorflow/compiler/xla/python/dlpack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ StatusOr<std::vector<int64_t>> StridesToLayout(
if (strides[a] > strides[b]) {
return false;
}
// FIXME(yonghao): This is only a walk-around.
// Should support isConsistent([1,1]{1,0}, [1,1]{0,1}) in type check
if (dims[a] == dims[b]) {
return a > b;
}
return dims[a] == 1 && dims[b] != 1;
});
int64_t stride = 1;
Expand Down
18 changes: 14 additions & 4 deletions tensorflow/compiler/xla/python/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

// Added by Alpa for TorchIndexSelect
#include "tensorflow/compiler/xla/client/lib/slicing.h"

namespace xla {

namespace py = pybind11;
Expand Down Expand Up @@ -80,12 +83,13 @@ void BuildOpsSubmodule(py::module* m) {
"AllReduce",
static_cast<XlaOp (*)(
XlaOp, const XlaComputation&, absl::Span<const ReplicaGroup>,
const std::optional<ChannelHandle>&, const std::optional<Shape>&)>(
&AllReduce),
const std::optional<ChannelHandle>&, const std::optional<Shape>&,
const std::optional<bool>)>(&AllReduce),
py::arg("operand"), py::arg("computation"),
py::arg("replica_groups") = py::list(),
py::arg("channel_id") = std::nullopt,
py::arg("shape_with_layout") = std::nullopt);
py::arg("shape_with_layout") = std::nullopt,
py::arg("use_global_device_ids") = std::nullopt);
ops.def("ReduceScatter", &ReduceScatter, py::arg("operand"),
py::arg("computation"), py::arg("scatter_dimension"),
py::arg("shard_count"), py::arg("replica_groups") = py::list(),
Expand All @@ -95,7 +99,9 @@ void BuildOpsSubmodule(py::module* m) {
ops.def("AllToAll", &AllToAll, py::arg("operand"), py::arg("split_dimension"),
py::arg("concat_dimension"), py::arg("split_count"),
py::arg("replica_groups") = py::list(),
py::arg("layout") = std::nullopt);
py::arg("channel_id") = std::nullopt,
py::arg("layout") = std::nullopt,
py::arg("use_global_device_ids") = std::nullopt);
ops.def("ApproxTopK", &ApproxTopK, py::arg("builder"), py::arg("operands"),
py::arg("init_values"), py::arg("top_k"), py::arg("reduction_dim"),
py::arg("comparator"), py::arg("recall_target") = 0.9,
Expand Down Expand Up @@ -429,6 +435,10 @@ void BuildOpsSubmodule(py::module* m) {
py::arg("b"), py::arg("x"));
ops.def("Zeta", &Zeta, py::arg("x"), py::arg("q"));

// Added by Alpa
ops.def("IndexSelect", &TorchIndexSelect, py::arg("input"), py::arg("index"),
py::arg("dim"), py::arg("batch_dims") = 0);

#define BINARY_OP(op) \
ops.def( \
#op, \
Expand Down
Loading