Skip to content

Commit

Permalink
Update to use a more recent version of xla.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Oct 2, 2023
1 parent 9870a56 commit bd5a3d1
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "xla"
version = "0.1.5"
version = "0.1.6"
authors = ["laurent <[email protected]>"]
edition = "2021"
description = "Bindings for the XLA C++ library."
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
Experimentation using the xla compiler from rust

Pre-compiled binaries for the xla library can be downloaded from the
[elixir-nx/xla repo](https://github.com/elixir-nx/xla/releases/tag/v0.4.4).
[elixir-nx/xla repo](https://github.com/elixir-nx/xla/releases/tag/v0.5.1).
These should be extracted at the root of this repository, resulting
in a `xla_extension` subdirectory being created, the currently supported version
is 0.4.4.
is 0.5.1.

For a linux platform, this can be done via:
```bash
wget https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cpu.tar.gz
wget https://github.com/elixir-nx/xla/releases/download/v0.5.1/xla_extension-x86_64-linux-gnu-cpu.tar.gz
tar -xzvf xla_extension-x86_64-linux-gnu-cpu.tar.gz
```

Expand Down
1 change: 1 addition & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fn make_shared_lib<P: AsRef<Path>>(xla_dir: P) {
.flag("-std=c++17")
.flag("-Wno-deprecated-declarations")
.flag("-DLLVM_ON_UNIX=1")
.flag("-DLLVM_VERSION_STRING=")
.file("xla_rs/xla_rs.cc")
.compile("xla_rs");
}
Expand Down
4 changes: 2 additions & 2 deletions xla_rs/xla_rs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ status pjrt_gpu_client_create(pjrt_client *output, double memory_fraction,
xla::GpuAllocatorConfig allocator = {.memory_fraction = memory_fraction,
.preallocate = preallocate};
ASSIGN_OR_RETURN_STATUS(
client, xla::GetStreamExecutorGpuClient(false, allocator, nullptr, 0));
client, xla::GetStreamExecutorGpuClient(false, allocator, 0, 0));
*output = new std::shared_ptr(std::move(client));
return nullptr;
}
Expand Down Expand Up @@ -1030,7 +1030,7 @@ char *xla_computation_name(xla_computation c) {
void xla_computation_free(xla_computation c) { delete c; }

char *status_error_message(status s) {
return strdup(s->error_message().c_str());
return strdup(tsl::NullTerminatedMessage(*s));
}

status hlo_module_proto_parse_and_return_unverified_module(
Expand Down
28 changes: 14 additions & 14 deletions xla_rs/xla_rs.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#pragma GCC diagnostic ignored "-Winvalid-offsetof"
#pragma GCC diagnostic ignored "-Wreturn-type"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.h"
#include "tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "xla/client/client_library.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/matrix.h"
#include "xla/client/xla_builder.h"
#include "xla/literal_util.h"
#include "xla/pjrt/gpu/gpu_helpers.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_stream_executor_client.h"
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "xla/pjrt/tpu_client.h"
#include "xla/service/hlo_parser.h"
#include "xla/shape_util.h"
#include "xla/statusor.h"
#pragma GCC diagnostic pop
using namespace xla;

Expand Down

0 comments on commit bd5a3d1

Please sign in to comment.