From bd5a3d1d2d43f7bb8a3e9795e82b5fdd933845df Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 2 Oct 2023 18:29:19 +0100 Subject: [PATCH] Update to use a more recent version of xla. --- Cargo.toml | 2 +- README.md | 6 +++--- build.rs | 1 + xla_rs/xla_rs.cc | 4 ++-- xla_rs/xla_rs.h | 28 ++++++++++++++-------------- 5 files changed, 21 insertions(+), 20 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 98a5c44..778aa25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "xla" -version = "0.1.5" +version = "0.1.6" authors = ["laurent "] edition = "2021" description = "Bindings for the XLA C++ library." diff --git a/README.md b/README.md index 45a8977..ba6e0f6 100644 --- a/README.md +++ b/README.md @@ -2,14 +2,14 @@ Experimentation using the xla compiler from rust Pre-compiled binaries for the xla library can be downloaded from the -[elixir-nx/xla repo](https://github.com/elixir-nx/xla/releases/tag/v0.4.4). +[elixir-nx/xla repo](https://github.com/elixir-nx/xla/releases/tag/v0.5.1). These should be extracted at the root of this repository, resulting in a `xla_extension` subdirectory being created, the currently supported version -is 0.4.4. +is 0.5.1. For a linux platform, this can be done via: ```bash -wget https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cpu.tar.gz +wget https://github.com/elixir-nx/xla/releases/download/v0.5.1/xla_extension-x86_64-linux-gnu-cpu.tar.gz tar -xzvf xla_extension-x86_64-linux-gnu-cpu.tar.gz ``` diff --git a/build.rs b/build.rs index 922065a..3c512ea 100644 --- a/build.rs +++ b/build.rs @@ -17,6 +17,7 @@ fn make_shared_lib>(xla_dir: P) { .flag("-std=c++17") .flag("-Wno-deprecated-declarations") .flag("-DLLVM_ON_UNIX=1") + .flag("-DLLVM_VERSION_STRING=") .file("xla_rs/xla_rs.cc") .compile("xla_rs"); } diff --git a/xla_rs/xla_rs.cc b/xla_rs/xla_rs.cc index 1c779e7..220fc8b 100644 --- a/xla_rs/xla_rs.cc +++ b/xla_rs/xla_rs.cc @@ -43,7 +43,7 @@ status pjrt_gpu_client_create(pjrt_client *output, double memory_fraction, xla::GpuAllocatorConfig allocator = {.memory_fraction = memory_fraction, .preallocate = preallocate}; ASSIGN_OR_RETURN_STATUS( - client, xla::GetStreamExecutorGpuClient(false, allocator, nullptr, 0)); + client, xla::GetStreamExecutorGpuClient(false, allocator, 0, 0)); *output = new std::shared_ptr(std::move(client)); return nullptr; } @@ -1030,7 +1030,7 @@ char *xla_computation_name(xla_computation c) { void xla_computation_free(xla_computation c) { delete c; } char *status_error_message(status s) { - return strdup(s->error_message().c_str()); + return strdup(tsl::NullTerminatedMessage(*s)); } status hlo_module_proto_parse_and_return_unverified_module( diff --git a/xla_rs/xla_rs.h b/xla_rs/xla_rs.h index 149a77a..439b20a 100644 --- a/xla_rs/xla_rs.h +++ b/xla_rs/xla_rs.h @@ -7,20 +7,20 @@ #pragma GCC diagnostic ignored "-Wdeprecated-declarations" #pragma GCC diagnostic ignored "-Winvalid-offsetof" #pragma GCC diagnostic ignored "-Wreturn-type" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/matrix.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.h" -#include "tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h" -#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" -#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" -#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h" -#include "tensorflow/compiler/xla/pjrt/tpu_client.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/statusor.h" +#include "xla/client/client_library.h" +#include "xla/client/lib/constants.h" +#include "xla/client/lib/matrix.h" +#include "xla/client/xla_builder.h" +#include "xla/literal_util.h" +#include "xla/pjrt/gpu/gpu_helpers.h" +#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_stream_executor_client.h" +#include "xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "xla/pjrt/tpu_client.h" +#include "xla/service/hlo_parser.h" +#include "xla/shape_util.h" +#include "xla/statusor.h" #pragma GCC diagnostic pop using namespace xla;