Skip to content

Commit

Permalink
Merge pull request #8 from LaurentMazare/openxla-0.5.1
Browse files Browse the repository at this point in the history
Update to use a more recent version of xla.
  • Loading branch information
LaurentMazare committed Oct 2, 2023
2 parents 9870a56 + 35bf309 commit 3068dcc
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 25 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/rust-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
toolchain: ${{ matrix.rust }}
override: true
- run: |
wget -nv https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cpu.tar.gz
wget -nv https://github.com/elixir-nx/xla/releases/download/v0.5.1/xla_extension-x86_64-linux-gnu-cpu.tar.gz
tar -xzvf xla_extension-x86_64-linux-gnu-cpu.tar.gz
- uses: actions-rs/cargo@v1
with:
Expand All @@ -39,7 +39,7 @@ jobs:
toolchain: ${{ matrix.rust }}
override: true
- run: |
wget -nv https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cpu.tar.gz
wget -nv https://github.com/elixir-nx/xla/releases/download/v0.5.1/xla_extension-x86_64-linux-gnu-cpu.tar.gz
tar -xzvf xla_extension-x86_64-linux-gnu-cpu.tar.gz
- uses: actions-rs/cargo@v1
with:
Expand Down Expand Up @@ -73,7 +73,7 @@ jobs:
override: true
- run: rustup component add clippy
- run: |
wget -nv https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cpu.tar.gz
wget -nv https://github.com/elixir-nx/xla/releases/download/v0.5.1/xla_extension-x86_64-linux-gnu-cpu.tar.gz
tar -xzvf xla_extension-x86_64-linux-gnu-cpu.tar.gz
- uses: actions-rs/cargo@v1
with:
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "xla"
version = "0.1.5"
version = "0.1.6"
authors = ["laurent <[email protected]>"]
edition = "2021"
description = "Bindings for the XLA C++ library."
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
Experimentation using the xla compiler from rust

Pre-compiled binaries for the xla library can be downloaded from the
[elixir-nx/xla repo](https://github.com/elixir-nx/xla/releases/tag/v0.4.4).
[elixir-nx/xla repo](https://github.com/elixir-nx/xla/releases/tag/v0.5.1).
These should be extracted at the root of this repository, resulting
in a `xla_extension` subdirectory being created, the currently supported version
is 0.4.4.
is 0.5.1.

For a linux platform, this can be done via:
```bash
wget https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cpu.tar.gz
wget https://github.com/elixir-nx/xla/releases/download/v0.5.1/xla_extension-x86_64-linux-gnu-cpu.tar.gz
tar -xzvf xla_extension-x86_64-linux-gnu-cpu.tar.gz
```

Expand Down
1 change: 1 addition & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fn make_shared_lib<P: AsRef<Path>>(xla_dir: P) {
.flag("-std=c++17")
.flag("-Wno-deprecated-declarations")
.flag("-DLLVM_ON_UNIX=1")
.flag("-DLLVM_VERSION_STRING=")
.file("xla_rs/xla_rs.cc")
.compile("xla_rs");
}
Expand Down
4 changes: 2 additions & 2 deletions src/wrappers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub use shape::{ArrayShape, Shape};
pub use xla_builder::XlaBuilder;
pub use xla_op::XlaOp;

pub(self) unsafe fn c_ptr_to_string(ptr: *const std::ffi::c_char) -> String {
unsafe fn c_ptr_to_string(ptr: *const std::ffi::c_char) -> String {
let str = std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned();
libc::free(ptr as *mut libc::c_void);
str
Expand Down Expand Up @@ -291,7 +291,7 @@ element_type!(f64, F64, 8);
/// specialized to a given device through a compilation step.
pub struct XlaComputation(c_lib::xla_computation);

pub(self) fn handle_status(status: c_lib::status) -> Result<()> {
fn handle_status(status: c_lib::status) -> Result<()> {
if status.is_null() {
Ok(())
} else {
Expand Down
1 change: 1 addition & 0 deletions src/wrappers/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ impl TryFrom<&Shape> for ArrayShape {

macro_rules! extract_dims {
($cnt:tt, $dims:expr, $out_type:ty) => {
#[allow(clippy::redundant_closure_call)]
impl TryFrom<&ArrayShape> for $out_type {
type Error = Error;

Expand Down
1 change: 1 addition & 0 deletions src/wrappers/xla_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct XlaOp {

macro_rules! extract_dims {
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
#[allow(clippy::redundant_closure_call)]
pub fn $fn_name(&self) -> Result<$out_type> {
let dims = self.builder.get_dims(self)?;
if dims.len() != $cnt {
Expand Down
4 changes: 2 additions & 2 deletions xla_rs/xla_rs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ status pjrt_gpu_client_create(pjrt_client *output, double memory_fraction,
xla::GpuAllocatorConfig allocator = {.memory_fraction = memory_fraction,
.preallocate = preallocate};
ASSIGN_OR_RETURN_STATUS(
client, xla::GetStreamExecutorGpuClient(false, allocator, nullptr, 0));
client, xla::GetStreamExecutorGpuClient(false, allocator, 0, 0));
*output = new std::shared_ptr(std::move(client));
return nullptr;
}
Expand Down Expand Up @@ -1030,7 +1030,7 @@ char *xla_computation_name(xla_computation c) {
void xla_computation_free(xla_computation c) { delete c; }

char *status_error_message(status s) {
return strdup(s->error_message().c_str());
return strdup(tsl::NullTerminatedMessage(*s));
}

status hlo_module_proto_parse_and_return_unverified_module(
Expand Down
28 changes: 14 additions & 14 deletions xla_rs/xla_rs.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#pragma GCC diagnostic ignored "-Winvalid-offsetof"
#pragma GCC diagnostic ignored "-Wreturn-type"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.h"
#include "tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "xla/client/client_library.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/matrix.h"
#include "xla/client/xla_builder.h"
#include "xla/literal_util.h"
#include "xla/pjrt/gpu/gpu_helpers.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_stream_executor_client.h"
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "xla/pjrt/tpu_client.h"
#include "xla/service/hlo_parser.h"
#include "xla/shape_util.h"
#include "xla/statusor.h"
#pragma GCC diagnostic pop
using namespace xla;

Expand Down

0 comments on commit 3068dcc

Please sign in to comment.