diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index dc7d73a..f6bb6b3 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -18,7 +18,7 @@ jobs: toolchain: ${{ matrix.rust }} override: true - run: | - wget -nv https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cpu.tar.gz + wget -nv https://github.com/elixir-nx/xla/releases/download/v0.5.1/xla_extension-x86_64-linux-gnu-cpu.tar.gz tar -xzvf xla_extension-x86_64-linux-gnu-cpu.tar.gz - uses: actions-rs/cargo@v1 with: @@ -39,7 +39,7 @@ jobs: toolchain: ${{ matrix.rust }} override: true - run: | - wget -nv https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cpu.tar.gz + wget -nv https://github.com/elixir-nx/xla/releases/download/v0.5.1/xla_extension-x86_64-linux-gnu-cpu.tar.gz tar -xzvf xla_extension-x86_64-linux-gnu-cpu.tar.gz - uses: actions-rs/cargo@v1 with: @@ -73,7 +73,7 @@ jobs: override: true - run: rustup component add clippy - run: | - wget -nv https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cpu.tar.gz + wget -nv https://github.com/elixir-nx/xla/releases/download/v0.5.1/xla_extension-x86_64-linux-gnu-cpu.tar.gz tar -xzvf xla_extension-x86_64-linux-gnu-cpu.tar.gz - uses: actions-rs/cargo@v1 with: diff --git a/Cargo.toml b/Cargo.toml index 98a5c44..778aa25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "xla" -version = "0.1.5" +version = "0.1.6" authors = ["laurent "] edition = "2021" description = "Bindings for the XLA C++ library." diff --git a/README.md b/README.md index 45a8977..ba6e0f6 100644 --- a/README.md +++ b/README.md @@ -2,14 +2,14 @@ Experimentation using the xla compiler from rust Pre-compiled binaries for the xla library can be downloaded from the -[elixir-nx/xla repo](https://github.com/elixir-nx/xla/releases/tag/v0.4.4). +[elixir-nx/xla repo](https://github.com/elixir-nx/xla/releases/tag/v0.5.1). These should be extracted at the root of this repository, resulting in a `xla_extension` subdirectory being created, the currently supported version -is 0.4.4. +is 0.5.1. For a linux platform, this can be done via: ```bash -wget https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cpu.tar.gz +wget https://github.com/elixir-nx/xla/releases/download/v0.5.1/xla_extension-x86_64-linux-gnu-cpu.tar.gz tar -xzvf xla_extension-x86_64-linux-gnu-cpu.tar.gz ``` diff --git a/build.rs b/build.rs index 922065a..3c512ea 100644 --- a/build.rs +++ b/build.rs @@ -17,6 +17,7 @@ fn make_shared_lib>(xla_dir: P) { .flag("-std=c++17") .flag("-Wno-deprecated-declarations") .flag("-DLLVM_ON_UNIX=1") + .flag("-DLLVM_VERSION_STRING=") .file("xla_rs/xla_rs.cc") .compile("xla_rs"); } diff --git a/src/wrappers/mod.rs b/src/wrappers/mod.rs index e530cf2..cb6da9f 100644 --- a/src/wrappers/mod.rs +++ b/src/wrappers/mod.rs @@ -21,7 +21,7 @@ pub use shape::{ArrayShape, Shape}; pub use xla_builder::XlaBuilder; pub use xla_op::XlaOp; -pub(self) unsafe fn c_ptr_to_string(ptr: *const std::ffi::c_char) -> String { +unsafe fn c_ptr_to_string(ptr: *const std::ffi::c_char) -> String { let str = std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned(); libc::free(ptr as *mut libc::c_void); str @@ -291,7 +291,7 @@ element_type!(f64, F64, 8); /// specialized to a given device through a compilation step. pub struct XlaComputation(c_lib::xla_computation); -pub(self) fn handle_status(status: c_lib::status) -> Result<()> { +fn handle_status(status: c_lib::status) -> Result<()> { if status.is_null() { Ok(()) } else { diff --git a/src/wrappers/shape.rs b/src/wrappers/shape.rs index 73d320e..dfcf235 100644 --- a/src/wrappers/shape.rs +++ b/src/wrappers/shape.rs @@ -133,6 +133,7 @@ impl TryFrom<&Shape> for ArrayShape { macro_rules! extract_dims { ($cnt:tt, $dims:expr, $out_type:ty) => { + #[allow(clippy::redundant_closure_call)] impl TryFrom<&ArrayShape> for $out_type { type Error = Error; diff --git a/src/wrappers/xla_op.rs b/src/wrappers/xla_op.rs index db90ad5..c5d5554 100644 --- a/src/wrappers/xla_op.rs +++ b/src/wrappers/xla_op.rs @@ -15,6 +15,7 @@ pub struct XlaOp { macro_rules! extract_dims { ($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => { + #[allow(clippy::redundant_closure_call)] pub fn $fn_name(&self) -> Result<$out_type> { let dims = self.builder.get_dims(self)?; if dims.len() != $cnt { diff --git a/xla_rs/xla_rs.cc b/xla_rs/xla_rs.cc index 1c779e7..220fc8b 100644 --- a/xla_rs/xla_rs.cc +++ b/xla_rs/xla_rs.cc @@ -43,7 +43,7 @@ status pjrt_gpu_client_create(pjrt_client *output, double memory_fraction, xla::GpuAllocatorConfig allocator = {.memory_fraction = memory_fraction, .preallocate = preallocate}; ASSIGN_OR_RETURN_STATUS( - client, xla::GetStreamExecutorGpuClient(false, allocator, nullptr, 0)); + client, xla::GetStreamExecutorGpuClient(false, allocator, 0, 0)); *output = new std::shared_ptr(std::move(client)); return nullptr; } @@ -1030,7 +1030,7 @@ char *xla_computation_name(xla_computation c) { void xla_computation_free(xla_computation c) { delete c; } char *status_error_message(status s) { - return strdup(s->error_message().c_str()); + return strdup(tsl::NullTerminatedMessage(*s)); } status hlo_module_proto_parse_and_return_unverified_module( diff --git a/xla_rs/xla_rs.h b/xla_rs/xla_rs.h index 149a77a..439b20a 100644 --- a/xla_rs/xla_rs.h +++ b/xla_rs/xla_rs.h @@ -7,20 +7,20 @@ #pragma GCC diagnostic ignored "-Wdeprecated-declarations" #pragma GCC diagnostic ignored "-Winvalid-offsetof" #pragma GCC diagnostic ignored "-Wreturn-type" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/client/lib/matrix.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.h" -#include "tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h" -#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" -#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" -#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h" -#include "tensorflow/compiler/xla/pjrt/tpu_client.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/statusor.h" +#include "xla/client/client_library.h" +#include "xla/client/lib/constants.h" +#include "xla/client/lib/matrix.h" +#include "xla/client/xla_builder.h" +#include "xla/literal_util.h" +#include "xla/pjrt/gpu/gpu_helpers.h" +#include "xla/pjrt/gpu/se_gpu_pjrt_client.h" +#include "xla/pjrt/pjrt_client.h" +#include "xla/pjrt/pjrt_stream_executor_client.h" +#include "xla/pjrt/tfrt_cpu_pjrt_client.h" +#include "xla/pjrt/tpu_client.h" +#include "xla/service/hlo_parser.h" +#include "xla/shape_util.h" +#include "xla/statusor.h" #pragma GCC diagnostic pop using namespace xla;