Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to use a more recent version of xla. #8

Merged
merged 3 commits into from
Oct 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/rust-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
toolchain: ${{ matrix.rust }}
override: true
- run: |
wget -nv https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cpu.tar.gz
wget -nv https://github.com/elixir-nx/xla/releases/download/v0.5.1/xla_extension-x86_64-linux-gnu-cpu.tar.gz
tar -xzvf xla_extension-x86_64-linux-gnu-cpu.tar.gz
- uses: actions-rs/cargo@v1
with:
Expand All @@ -39,7 +39,7 @@ jobs:
toolchain: ${{ matrix.rust }}
override: true
- run: |
wget -nv https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cpu.tar.gz
wget -nv https://github.com/elixir-nx/xla/releases/download/v0.5.1/xla_extension-x86_64-linux-gnu-cpu.tar.gz
tar -xzvf xla_extension-x86_64-linux-gnu-cpu.tar.gz
- uses: actions-rs/cargo@v1
with:
Expand Down Expand Up @@ -73,7 +73,7 @@ jobs:
override: true
- run: rustup component add clippy
- run: |
wget -nv https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cpu.tar.gz
wget -nv https://github.com/elixir-nx/xla/releases/download/v0.5.1/xla_extension-x86_64-linux-gnu-cpu.tar.gz
tar -xzvf xla_extension-x86_64-linux-gnu-cpu.tar.gz
- uses: actions-rs/cargo@v1
with:
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "xla"
version = "0.1.5"
version = "0.1.6"
authors = ["laurent <[email protected]>"]
edition = "2021"
description = "Bindings for the XLA C++ library."
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
Experimentation using the xla compiler from rust

Pre-compiled binaries for the xla library can be downloaded from the
[elixir-nx/xla repo](https://github.com/elixir-nx/xla/releases/tag/v0.4.4).
[elixir-nx/xla repo](https://github.com/elixir-nx/xla/releases/tag/v0.5.1).
These should be extracted at the root of this repository, resulting
in a `xla_extension` subdirectory being created, the currently supported version
is 0.4.4.
is 0.5.1.

For a linux platform, this can be done via:
```bash
wget https://github.com/elixir-nx/xla/releases/download/v0.4.4/xla_extension-x86_64-linux-gnu-cpu.tar.gz
wget https://github.com/elixir-nx/xla/releases/download/v0.5.1/xla_extension-x86_64-linux-gnu-cpu.tar.gz
tar -xzvf xla_extension-x86_64-linux-gnu-cpu.tar.gz
```

Expand Down
1 change: 1 addition & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fn make_shared_lib<P: AsRef<Path>>(xla_dir: P) {
.flag("-std=c++17")
.flag("-Wno-deprecated-declarations")
.flag("-DLLVM_ON_UNIX=1")
.flag("-DLLVM_VERSION_STRING=")
.file("xla_rs/xla_rs.cc")
.compile("xla_rs");
}
Expand Down
4 changes: 2 additions & 2 deletions src/wrappers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub use shape::{ArrayShape, Shape};
pub use xla_builder::XlaBuilder;
pub use xla_op::XlaOp;

pub(self) unsafe fn c_ptr_to_string(ptr: *const std::ffi::c_char) -> String {
unsafe fn c_ptr_to_string(ptr: *const std::ffi::c_char) -> String {
let str = std::ffi::CStr::from_ptr(ptr).to_string_lossy().into_owned();
libc::free(ptr as *mut libc::c_void);
str
Expand Down Expand Up @@ -291,7 +291,7 @@ element_type!(f64, F64, 8);
/// specialized to a given device through a compilation step.
pub struct XlaComputation(c_lib::xla_computation);

pub(self) fn handle_status(status: c_lib::status) -> Result<()> {
fn handle_status(status: c_lib::status) -> Result<()> {
if status.is_null() {
Ok(())
} else {
Expand Down
1 change: 1 addition & 0 deletions src/wrappers/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ impl TryFrom<&Shape> for ArrayShape {

macro_rules! extract_dims {
($cnt:tt, $dims:expr, $out_type:ty) => {
#[allow(clippy::redundant_closure_call)]
impl TryFrom<&ArrayShape> for $out_type {
type Error = Error;

Expand Down
1 change: 1 addition & 0 deletions src/wrappers/xla_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub struct XlaOp {

macro_rules! extract_dims {
($fn_name:ident, $cnt:tt, $dims:expr, $out_type:ty) => {
#[allow(clippy::redundant_closure_call)]
pub fn $fn_name(&self) -> Result<$out_type> {
let dims = self.builder.get_dims(self)?;
if dims.len() != $cnt {
Expand Down
4 changes: 2 additions & 2 deletions xla_rs/xla_rs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ status pjrt_gpu_client_create(pjrt_client *output, double memory_fraction,
xla::GpuAllocatorConfig allocator = {.memory_fraction = memory_fraction,
.preallocate = preallocate};
ASSIGN_OR_RETURN_STATUS(
client, xla::GetStreamExecutorGpuClient(false, allocator, nullptr, 0));
client, xla::GetStreamExecutorGpuClient(false, allocator, 0, 0));
*output = new std::shared_ptr(std::move(client));
return nullptr;
}
Expand Down Expand Up @@ -1030,7 +1030,7 @@ char *xla_computation_name(xla_computation c) {
void xla_computation_free(xla_computation c) { delete c; }

char *status_error_message(status s) {
return strdup(s->error_message().c_str());
return strdup(tsl::NullTerminatedMessage(*s));
}

status hlo_module_proto_parse_and_return_unverified_module(
Expand Down
28 changes: 14 additions & 14 deletions xla_rs/xla_rs.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#pragma GCC diagnostic ignored "-Winvalid-offsetof"
#pragma GCC diagnostic ignored "-Wreturn-type"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/pjrt/gpu/gpu_helpers.h"
#include "tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "xla/client/client_library.h"
#include "xla/client/lib/constants.h"
#include "xla/client/lib/matrix.h"
#include "xla/client/xla_builder.h"
#include "xla/literal_util.h"
#include "xla/pjrt/gpu/gpu_helpers.h"
#include "xla/pjrt/gpu/se_gpu_pjrt_client.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_stream_executor_client.h"
#include "xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "xla/pjrt/tpu_client.h"
#include "xla/service/hlo_parser.h"
#include "xla/shape_util.h"
#include "xla/statusor.h"
#pragma GCC diagnostic pop
using namespace xla;

Expand Down
Loading