diff --git a/Cargo.toml b/Cargo.toml index 49eb021..98a5c44 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "xla" -version = "0.1.4" +version = "0.1.5" authors = ["laurent "] edition = "2021" description = "Bindings for the XLA C++ library." diff --git a/examples/basics.rs b/examples/basics.rs index 06b52d2..578b3b4 100644 --- a/examples/basics.rs +++ b/examples/basics.rs @@ -29,7 +29,7 @@ fn main() -> Result<()> { result.to_vec::(), result.get_first_element::()?, ); - let param = xla_builder.parameter_with_shape(0, &xla::Shape::new::(vec![]), "p")?; + let param = xla_builder.parameter_with_shape(0, &xla::ArrayShape::new::(vec![]), "p")?; let sum = param.add_(¶m)?; let sum = sum.sqrt()?.build()?; let result = client.compile(&sum)?; diff --git a/examples/llama/main.rs b/examples/llama/main.rs index a27472d..7c35123 100644 --- a/examples/llama/main.rs +++ b/examples/llama/main.rs @@ -14,7 +14,7 @@ use clap::Parser; use rand::prelude::*; extern crate xla; -use xla::{PrimitiveType, XlaBuilder, XlaOp}; +use xla::{ElementType, PrimitiveType, XlaBuilder, XlaOp}; mod sentencepiece; use sentencepiece::Tokenizer; @@ -195,9 +195,8 @@ impl Mlp { } fn masked_fill(on_false: &XlaOp, mask: &XlaOp, on_true: T) -> Result { - let shape = mask.shape()?; - let on_true = - mask.builder().c0(on_true)?.convert(on_false.ty()?)?.broadcast(shape.dimensions())?; + let shape = mask.array_shape()?; + let on_true = mask.builder().c0(on_true)?.convert(on_false.ty()?)?.broadcast(shape.dims())?; let m = mask.select(&on_true, on_false)?; Ok(m) } @@ -255,15 +254,15 @@ impl CausalSelfAttention { let v = v.reshape(&target_dim)?.swap_dims(1, 2)?; let q = self.apply_rotary_emb(&q, &freqs_cis)?; let k = self.apply_rotary_emb(&k, &freqs_cis)?; - let k_shape = k.shape()?; + let k_shape = k.array_shape()?; let att = (q.matmul(&k.swap_dims(-2, -1)?)? * builder.c0(1f32 / (k_shape.last_dim().unwrap() as f32).sqrt())?.convert(ty)?)?; let mask = builder - .one(PrimitiveType::S32)? + .one(ElementType::S32)? .broadcast(&[t, t])? .lower_triangle()? .reshape(&[1, 1, t, t])?; - let zero = builder.zero(PrimitiveType::S32)?.broadcast(&[b, self.n_head as i64, t, t])?; + let zero = builder.zero(ElementType::S32)?.broadcast(&[b, self.n_head as i64, t, t])?; let att = masked_fill(&att, &mask.eq(&zero)?, f32::NEG_INFINITY)?; let y = att.softmax(-1)?.matmul(&v)?; let y = y.swap_dims(1, 2)?.reshape(&[b, t, c])?; @@ -351,7 +350,7 @@ fn llama_computation(args: &Args, bsize: i64) -> Result<(xla::XlaComputation, Va let config = Config::config_7b(); let freqs_cis = precompute_freqs_cis(&config, &b)?; let llama = Llama::new(vb.clone(), &config)?; - let input = vb.arg("tokens", PrimitiveType::U32, &[bsize as usize, CONTEXT_SIZE])?; + let input = vb.arg("tokens", ElementType::U32, &[bsize as usize, CONTEXT_SIZE])?; let logits = llama.forward(&input, &freqs_cis)?.convert(PrimitiveType::F32)?; let prs = (logits / b.c0(args.temperature)?)?.softmax(-1)?; Ok((prs.build()?, vb.into_store())) diff --git a/examples/llama/var_store.rs b/examples/llama/var_store.rs index fcd1403..b7713a5 100644 --- a/examples/llama/var_store.rs +++ b/examples/llama/var_store.rs @@ -1,10 +1,10 @@ -use xla::{ArrayElement, FromRawBytes, PjRtBuffer, PjRtClient, PrimitiveType, Result, XlaOp}; +use xla::{ArrayElement, ElementType, FromRawBytes, PjRtBuffer, PjRtClient, Result, XlaOp}; #[allow(dead_code)] #[derive(Clone)] struct NamedVar { path: String, - ty: PrimitiveType, + ty: ElementType, dims: Vec, is_arg: bool, } @@ -14,8 +14,8 @@ pub struct VarBuilder { path: Vec, vars: std::rc::Rc>>, builder: xla::XlaBuilder, - default_buffer_type_for_var: PrimitiveType, - default_op_type_for_var: PrimitiveType, + default_buffer_type_for_var: ElementType, + default_op_type_for_var: ElementType, } #[allow(dead_code)] @@ -30,8 +30,8 @@ impl VarBuilder { builder: builder.clone(), path: vec![], vars, - default_buffer_type_for_var: B::PRIMITIVE_TYPE, - default_op_type_for_var: O::PRIMITIVE_TYPE, + default_buffer_type_for_var: B::TY, + default_op_type_for_var: O::TY, } } @@ -42,7 +42,7 @@ impl VarBuilder { pub fn var_( &mut self, s: &str, - ty: PrimitiveType, + ty: ElementType, dims: &[usize], is_arg: bool, ) -> Result { @@ -57,10 +57,10 @@ impl VarBuilder { pub fn var(&mut self, s: &str, dims: &[usize]) -> Result { let v = self.var_(s, self.default_buffer_type_for_var, dims, false)?; - v.convert(self.default_op_type_for_var) + v.convert(self.default_op_type_for_var.primitive_type()) } - pub fn arg(&mut self, s: &str, ty: PrimitiveType, dims: &[usize]) -> Result { + pub fn arg(&mut self, s: &str, ty: ElementType, dims: &[usize]) -> Result { self.var_(s, ty, dims, true) } @@ -119,9 +119,7 @@ impl VarStore { let buffer = if var.is_arg { let ty = var.ty; let element_count: usize = var.dims.iter().product(); - let element_size_in_bytes = ty - .element_size_in_bytes() - .ok_or(xla::Error::UnsupportedElementType { ty, op: "buffer_from_bytes" })?; + let element_size_in_bytes = ty.element_size_in_bytes(); let data = vec![0u8; element_count * element_size_in_bytes]; c.buffer_from_host_raw_bytes(ty, &data, &var.dims, None)? } else { diff --git a/examples/loop.rs b/examples/loop.rs index 8251e90..39606f6 100644 --- a/examples/loop.rs +++ b/examples/loop.rs @@ -6,7 +6,7 @@ fn main() -> Result<()> { let client = xla::PjRtClient::cpu()?; loop { let builder = xla::XlaBuilder::new("test"); - let x = builder.parameter(0, f32::PRIMITIVE_TYPE, &[2], "x")?; + let x = builder.parameter(0, f32::TY, &[2], "x")?; let sum = x.reduce_sum(&[], false)?.build()?.compile(&client)?; let input = xla::Literal::vec1(&[4.2f32, 1.337f32]); let result = sum.execute::(&[input])?; @@ -16,7 +16,7 @@ fn main() -> Result<()> { assert_eq!(result.to_vec::()?, [4.2, 1.337]); let builder = xla::XlaBuilder::new("test"); - let x = builder.parameter(0, f32::PRIMITIVE_TYPE, &[-2], "x")?; + let x = builder.parameter(0, f32::TY, &[-2], "x")?; let sum = x.reduce_sum(&[0], false)?.build()?.compile(&client)?; let input = xla::Literal::vec1(&[4.2f32, 1.337f32]); let result = sum.execute::(&[input])?; @@ -25,10 +25,10 @@ fn main() -> Result<()> { drop(sum); assert_eq!(result.to_vec::()?, [5.5369997]); // Dimensions got reduced. - assert_eq!(result.shape()?.dimensions(), []); + assert_eq!(result.array_shape()?.dims(), []); let builder = xla::XlaBuilder::new("test"); - let x = builder.parameter(0, f32::PRIMITIVE_TYPE, &[-2], "x")?; + let x = builder.parameter(0, f32::TY, &[-2], "x")?; let sum = x.reduce_sum(&[0], true)?.build()?.compile(&client)?; let input = xla::Literal::vec1(&[4.2f32, 1.337f32]); let result = sum.execute::(&[input])?; @@ -37,7 +37,7 @@ fn main() -> Result<()> { drop(sum); assert_eq!(result.to_vec::()?, [5.5369997]); // keep_dims = true in this case. - assert_eq!(result.shape()?.dimensions(), [1]); + assert_eq!(result.array_shape()?.dims(), [1]); println!("Done!"); } } diff --git a/examples/nanogpt/main.rs b/examples/nanogpt/main.rs index dffb00b..6d8145f 100644 --- a/examples/nanogpt/main.rs +++ b/examples/nanogpt/main.rs @@ -9,14 +9,14 @@ use anyhow::Result; use rand::prelude::*; extern crate xla; -use xla::{Literal, PjRtLoadedExecutable, PrimitiveType, XlaBuilder, XlaOp}; +use xla::{ElementType, Literal, PjRtLoadedExecutable, XlaBuilder, XlaOp}; mod tokenizer; mod var_store; use tokenizer::Tokenizer; use var_store::VarStore; -const TY: PrimitiveType = PrimitiveType::F32; +const TY: ElementType = ElementType::F32; const TEMPERATURE: f32 = 0.8f32; const USE_CPU: bool = false; const NUM_SAMPLES: usize = 10; @@ -104,8 +104,8 @@ impl Linear { } fn masked_fill(on_false: &XlaOp, mask: &XlaOp, on_true: T) -> Result { - let shape = mask.shape()?; - let on_true = mask.builder().c0(on_true)?.broadcast(shape.dimensions())?; + let shape = mask.array_shape()?; + let on_true = mask.builder().c0(on_true)?.broadcast(shape.dims())?; let m = mask.select(&on_true, on_false)?; Ok(m) } @@ -137,15 +137,15 @@ impl CausalSelfAttention { let k = k.reshape(&target_dim)?.swap_dims(1, 2)?; let q = q.reshape(&target_dim)?.swap_dims(1, 2)?; let v = v.reshape(&target_dim)?.swap_dims(1, 2)?; - let k_shape = k.shape()?; + let k_shape = k.array_shape()?; let att = (q.matmul(&k.swap_dims(-2, -1)?)? * builder.c0(1f32 / (k_shape.last_dim().unwrap() as f32).sqrt()))?; let mask = builder - .one(PrimitiveType::S32)? + .one(ElementType::S32)? .broadcast(&[t, t])? .lower_triangle()? .reshape(&[1, 1, t, t])?; - let zero = builder.zero(PrimitiveType::S32)?.broadcast(&[b, self.n_head as i64, t, t])?; + let zero = builder.zero(ElementType::S32)?.broadcast(&[b, self.n_head as i64, t, t])?; let att = masked_fill(&att, &mask.eq(&zero)?, f32::NEG_INFINITY)?; let y = att.softmax(-1)?.matmul(&v)?; let y = y.swap_dims(1, 2)?.reshape(&[b, t, c])?; @@ -253,7 +253,7 @@ fn gpt_computation(vs: VarStore, bsize: i64) -> Result { let b = XlaBuilder::new("gpt"); let config = GptConfig::default(); let gpt = Gpt::new(vs, &config)?; - let input = b.parameter(0, PrimitiveType::S32, &[bsize, config.block_size as i64], "tokens")?; + let input = b.parameter(0, ElementType::S32, &[bsize, config.block_size as i64], "tokens")?; let logits = gpt.forward(&input)?; let prs = (logits / b.c0(TEMPERATURE))?.softmax(-1)?; Ok(prs.build()?) diff --git a/examples/nanogpt/var_store.rs b/examples/nanogpt/var_store.rs index eb84ca1..e5108c5 100644 --- a/examples/nanogpt/var_store.rs +++ b/examples/nanogpt/var_store.rs @@ -1,7 +1,7 @@ use anyhow::{Context, Result}; use std::collections::HashMap; -use xla::{FromRawBytes, Literal, PrimitiveType}; +use xla::{ElementType, FromRawBytes, Literal}; #[derive(Clone)] pub struct VarStore { @@ -24,7 +24,7 @@ impl VarStore { pub fn take( &mut self, s: &str, - expected_type: PrimitiveType, + expected_type: ElementType, expected_dims: &[usize], ) -> Result { let path = format!("{}.{s}", self.path.join(".")); @@ -33,9 +33,9 @@ impl VarStore { .borrow_mut() .remove(&path) .with_context(|| format!("cannot find {path} in VarStore"))?; - let shape = literal.shape()?; - let element_type = shape.element_type(); - let dims = shape.dimensions(); + let shape = literal.array_shape()?; + let element_type = shape.ty(); + let dims = shape.dims(); if element_type != expected_type { anyhow::bail!( "unexpected element type for {}, got {:?} expected {:?}", diff --git a/src/error.rs b/src/error.rs index 020349d..1afd0fc 100644 --- a/src/error.rs +++ b/src/error.rs @@ -15,11 +15,17 @@ pub enum Error { #[error("unexpected number of dimensions, expected: {expected}, got: {got} ({dims:?})")] UnexpectedNumberOfDims { expected: usize, got: usize, dims: Vec }, + #[error("not an element type, got: {got:?}")] + NotAnElementType { got: crate::PrimitiveType }, + + #[error("not an array, expected: {expected:?}, got: {got:?}")] + NotAnArray { expected: Option, got: crate::Shape }, + #[error("unexpected number of tuple elements, expected: {expected}, got: {got}")] UnexpectedNumberOfElemsInTuple { expected: usize, got: usize }, #[error("element type mismatch, on-device: {on_device:?}, on-host: {on_host:?}")] - ElementTypeMismatch { on_device: crate::PrimitiveType, on_host: crate::PrimitiveType }, + ElementTypeMismatch { on_device: crate::ElementType, on_host: crate::ElementType }, #[error("unsupported element type for {op}: {ty:?}")] UnsupportedElementType { ty: crate::PrimitiveType, op: &'static str }, @@ -27,7 +33,7 @@ pub enum Error { #[error( "target buffer is too large, offset {offset}, shape {shape:?}, buffer_len: {buffer_len}" )] - TargetBufferIsTooLarge { offset: usize, shape: crate::Shape, buffer_len: usize }, + TargetBufferIsTooLarge { offset: usize, shape: crate::ArrayShape, buffer_len: usize }, #[error("binary buffer is too large, element count {element_count}, buffer_len: {buffer_len}")] BinaryBufferIsTooLarge { element_count: usize, buffer_len: usize }, diff --git a/src/npy.rs b/src/npy.rs index d7d0a20..6b0725b 100644 --- a/src/npy.rs +++ b/src/npy.rs @@ -26,7 +26,7 @@ //! # Load multiple values from a npz file. //! values = np.loadz("test.npz") //! ``` -use crate::{Error, Literal, PrimitiveType, Result}; +use crate::{ElementType, Error, Literal, Result}; use std::collections::HashMap; use std::fs::File; use std::io::{BufReader, Read, Write}; @@ -58,7 +58,7 @@ fn read_header(reader: &mut R) -> Result { #[derive(Debug, PartialEq)] struct Header { - descr: PrimitiveType, + descr: ElementType, fortran_order: bool, shape: Vec, } @@ -68,14 +68,14 @@ impl Header { let fortran_order = if self.fortran_order { "True" } else { "False" }; let mut shape = self.shape.iter().map(|x| x.to_string()).collect::>().join(","); let descr = match self.descr { - PrimitiveType::F16 => "f2", - PrimitiveType::F32 => "f4", - PrimitiveType::F64 => "f8", - PrimitiveType::S32 => "i4", - PrimitiveType::S64 => "i8", - PrimitiveType::S16 => "i2", - PrimitiveType::S8 => "i1", - PrimitiveType::U8 => "u1", + ElementType::F16 => "f2", + ElementType::F32 => "f4", + ElementType::F64 => "f8", + ElementType::S32 => "i4", + ElementType::S64 => "i8", + ElementType::S16 => "i2", + ElementType::S8 => "i1", + ElementType::U8 => "u1", descr => return Err(Error::Npy(format!("unsupported kind {descr:?}"))), }; if !shape.is_empty() { @@ -146,17 +146,17 @@ impl Header { // int64, int32, int16, int8, // uint8, and bool. match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') { - "e" | "f2" => PrimitiveType::F16, - "f" | "f4" => PrimitiveType::F32, - "d" | "f8" => PrimitiveType::F64, - "i" | "i4" => PrimitiveType::S32, - "q" | "i8" => PrimitiveType::S64, - "h" | "i2" => PrimitiveType::S16, - "b" | "i1" => PrimitiveType::S8, - "B" | "u1" => PrimitiveType::U8, - "?" | "b1" => PrimitiveType::Pred, - "F" | "F4" => PrimitiveType::C64, - "D" | "F8" => PrimitiveType::C128, + "e" | "f2" => ElementType::F16, + "f" | "f4" => ElementType::F32, + "d" | "f8" => ElementType::F64, + "i" | "i4" => ElementType::S32, + "q" | "i8" => ElementType::S64, + "h" | "i2" => ElementType::S16, + "b" | "i1" => ElementType::S8, + "B" | "u1" => ElementType::U8, + "?" | "b1" => ElementType::Pred, + "F" | "F4" => ElementType::C64, + "D" | "F8" => ElementType::C128, descr => return Err(Error::Npy(format!("unrecognized descr {descr}"))), } } @@ -183,7 +183,7 @@ pub trait FromRawBytes: Sized { type Context; fn from_raw_bytes( h: &Self::Context, - ty: PrimitiveType, + ty: ElementType, dims: &[usize], bytes: &[u8], ) -> Result; @@ -261,7 +261,7 @@ impl FromRawBytes for crate::Literal { fn from_raw_bytes( _: &Self::Context, - ty: PrimitiveType, + ty: ElementType, dims: &[usize], bytes: &[u8], ) -> Result { @@ -274,7 +274,7 @@ impl FromRawBytes for crate::PjRtBuffer { fn from_raw_bytes( client: &Self::Context, - ty: PrimitiveType, + ty: ElementType, dims: &[usize], bytes: &[u8], ) -> Result { @@ -286,12 +286,9 @@ impl crate::Literal { fn write(&self, f: &mut T) -> Result<()> { f.write_all(NPY_MAGIC_STRING)?; f.write_all(&[1u8, 0u8])?; - let shape = self.shape()?; - let header = Header { - descr: shape.element_type(), - fortran_order: false, - shape: shape.dimensions().to_vec(), - }; + let shape = self.array_shape()?; + let header = + Header { descr: shape.ty(), fortran_order: false, shape: shape.dims().to_vec() }; let mut header = header.to_string()?; let pad = 16 - (NPY_MAGIC_STRING.len() + 5 + header.len()) % 16; for _ in 0..pad % 16 { @@ -302,9 +299,7 @@ impl crate::Literal { f.write_all(header.as_bytes())?; let numel = self.element_count(); let element_type = self.element_type()?; - let elt_size_in_bytes = element_type.element_size_in_bytes().ok_or_else(|| { - Error::Npy(format!("unsupported element type for npy {element_type:?}")) - })?; + let elt_size_in_bytes = element_type.element_size_in_bytes(); let mut content = vec![0u8; numel * elt_size_in_bytes]; self.copy_raw_to(&mut content)?; f.write_all(&content)?; @@ -343,14 +338,14 @@ mod tests { let h = "{'descr': ' Result { let dims64: Vec<_> = dims.iter().map(|x| *x as i64).collect(); + let ty = ty.primitive_type(); let v = unsafe { c_lib::literal_create_from_shape_and_data( ty as i32, @@ -51,8 +54,8 @@ impl Literal { /// primitive type that the literal uses. pub fn get_first_element(&self) -> Result { let ty = self.ty()?; - if ty != T::PRIMITIVE_TYPE { - Err(Error::ElementTypeMismatch { on_device: ty, on_host: T::PRIMITIVE_TYPE })? + if ty != T::TY { + Err(Error::ElementTypeMismatch { on_device: ty, on_host: T::TY })? } if self.element_count() == 0 { Err(Error::EmptyLiteral)? @@ -67,15 +70,21 @@ impl Literal { } /// The primitive type used by element stored in this literal. - pub fn element_type(&self) -> Result { + pub fn primitive_type(&self) -> Result { let ty = unsafe { c_lib::literal_element_type(self.0) }; match FromPrimitive::from_i32(ty) { None => Err(Error::UnexpectedElementType(ty)), Some(ty) => Ok(ty), } } - /// The primitive type used by element stored in this literal, shortcut for `element_type`. - pub fn ty(&self) -> Result { + + /// The element type used by element stored in this literal. + pub fn element_type(&self) -> Result { + self.primitive_type()?.element_type() + } + + /// The element type used by element stored in this literal, shortcut for `element_type`. + pub fn ty(&self) -> Result { self.element_type() } @@ -90,16 +99,11 @@ impl Literal { pub fn shape(&self) -> Result { let mut out: c_lib::shape = std::ptr::null_mut(); unsafe { c_lib::literal_shape(self.0, &mut out) }; - let rank = unsafe { c_lib::shape_dimensions_size(out) }; - let dimensions: Vec<_> = - (0..rank).map(|i| unsafe { c_lib::shape_dimensions(out, i) }).collect(); - let ty = unsafe { c_lib::shape_element_type(out) }; - let tuple_shapes_size = unsafe { c_lib::shape_tuple_shapes_size(out) }; - unsafe { c_lib::shape_free(out) }; - match FromPrimitive::from_i32(ty) { - None => Err(Error::UnexpectedElementType(ty)), - Some(ty) => Ok(Shape { ty, dimensions, tuple_shapes_size }), - } + unsafe { Shape::from_ptr(out) } + } + + pub fn array_shape(&self) -> Result { + ArrayShape::try_from(&self.shape()?) } /// Copy the literal data to a slice. This returns an error if the primitive type used by the @@ -107,8 +111,8 @@ impl Literal { pub fn copy_raw_to(&self, dst: &mut [T]) -> Result<()> { let ty = self.ty()?; let element_count = self.element_count(); - if ty != T::PRIMITIVE_TYPE { - Err(Error::ElementTypeMismatch { on_device: ty, on_host: T::PRIMITIVE_TYPE })? + if ty != T::TY { + Err(Error::ElementTypeMismatch { on_device: ty, on_host: T::TY })? } if dst.len() > element_count { Err(Error::BinaryBufferIsTooLarge { element_count, buffer_len: dst.len() })? @@ -129,8 +133,8 @@ impl Literal { pub fn copy_raw_from(&mut self, src: &[T]) -> Result<()> { let ty = self.ty()?; let element_count = self.element_count(); - if ty != T::PRIMITIVE_TYPE { - Err(Error::ElementTypeMismatch { on_device: ty, on_host: T::PRIMITIVE_TYPE })? + if ty != T::TY { + Err(Error::ElementTypeMismatch { on_device: ty, on_host: T::TY })? } if src.len() > element_count { Err(Error::BinaryBufferIsTooLarge { element_count, buffer_len: src.len() })? @@ -193,9 +197,10 @@ impl Literal { /// When the input is a tuple, return a vector of its elements. This replaces the original /// value by an empty tuple, no copy is performed. pub fn decompose_tuple(&mut self) -> Result> { - match self.shape()?.tuple_size() { - None => Ok(vec![]), - Some(tuple_len) => { + match self.shape()? { + Shape::Array(_) | Shape::Unsupported(_) => Ok(vec![]), + Shape::Tuple(shapes) => { + let tuple_len = shapes.len(); let mut outputs = vec![std::ptr::null_mut::(); tuple_len]; unsafe { c_lib::literal_decompose_tuple(self.0, outputs.as_mut_ptr(), tuple_len) }; Ok(outputs.into_iter().map(Literal).collect()) diff --git a/src/wrappers/mod.rs b/src/wrappers/mod.rs index 4cae538..e530cf2 100644 --- a/src/wrappers/mod.rs +++ b/src/wrappers/mod.rs @@ -17,7 +17,7 @@ pub use pjrt_buffer::PjRtBuffer; pub use pjrt_client::PjRtClient; pub use pjrt_device::PjRtDevice; pub use pjrt_loaded_executable::PjRtLoadedExecutable; -pub use shape::Shape; +pub use shape::{ArrayShape, Shape}; pub use xla_builder::XlaBuilder; pub use xla_op::XlaOp; @@ -53,28 +53,26 @@ pub enum PrimitiveType { } impl PrimitiveType { - /// The size for this element type in bytes if defined. - pub fn element_size_in_bytes(&self) -> Option { + fn element_type(self) -> Result { match self { - PrimitiveType::Invalid => None, - PrimitiveType::Pred => None, - PrimitiveType::S8 => Some(1), - PrimitiveType::S16 => Some(2), - PrimitiveType::S32 => Some(4), - PrimitiveType::S64 => Some(8), - PrimitiveType::U8 => Some(1), - PrimitiveType::U16 => Some(2), - PrimitiveType::U32 => Some(4), - PrimitiveType::U64 => Some(8), - PrimitiveType::F16 => Some(2), - PrimitiveType::F32 => Some(4), - PrimitiveType::Bf16 => Some(2), - PrimitiveType::F64 => Some(8), - PrimitiveType::C64 => Some(8), - PrimitiveType::C128 => Some(16), - PrimitiveType::Tuple => None, - PrimitiveType::OpaqueType => None, - PrimitiveType::Token => None, + Self::Pred => Ok(ElementType::Pred), + Self::S8 => Ok(ElementType::S8), + Self::S16 => Ok(ElementType::S16), + Self::S32 => Ok(ElementType::S32), + Self::S64 => Ok(ElementType::S64), + Self::U8 => Ok(ElementType::U8), + Self::U16 => Ok(ElementType::U16), + Self::U32 => Ok(ElementType::U32), + Self::U64 => Ok(ElementType::U64), + Self::F16 => Ok(ElementType::F16), + Self::F32 => Ok(ElementType::F32), + Self::Bf16 => Ok(ElementType::Bf16), + Self::F64 => Ok(ElementType::F64), + Self::C64 => Ok(ElementType::C64), + Self::C128 => Ok(ElementType::C128), + Self::Invalid | Self::Tuple | Self::OpaqueType | Self::Token => { + Err(Error::NotAnElementType { got: self }) + } } } } @@ -98,8 +96,51 @@ pub enum ElementType { C128, } +impl ElementType { + /// The size for this element type in bytes. + pub fn element_size_in_bytes(&self) -> usize { + match self { + Self::Pred => 1, + Self::S8 => 1, + Self::S16 => 2, + Self::S32 => 4, + Self::S64 => 8, + Self::U8 => 1, + Self::U16 => 2, + Self::U32 => 4, + Self::U64 => 8, + Self::F16 => 2, + Self::F32 => 4, + Self::Bf16 => 2, + Self::F64 => 8, + Self::C64 => 8, + Self::C128 => 16, + } + } + + pub fn primitive_type(&self) -> PrimitiveType { + match self { + Self::Pred => PrimitiveType::Pred, + Self::S8 => PrimitiveType::S8, + Self::S16 => PrimitiveType::S16, + Self::S32 => PrimitiveType::S32, + Self::S64 => PrimitiveType::S64, + Self::U8 => PrimitiveType::U8, + Self::U16 => PrimitiveType::U16, + Self::U32 => PrimitiveType::U32, + Self::U64 => PrimitiveType::U64, + Self::F16 => PrimitiveType::F16, + Self::F32 => PrimitiveType::F32, + Self::Bf16 => PrimitiveType::Bf16, + Self::F64 => PrimitiveType::F64, + Self::C64 => PrimitiveType::C64, + Self::C128 => PrimitiveType::C128, + } + } +} + pub trait ArrayElement: Copy { - const PRIMITIVE_TYPE: PrimitiveType; + const TY: ElementType; const ELEMENT_SIZE_IN_BYTES: usize; const ZERO: Self; } @@ -208,7 +249,7 @@ native_type!( macro_rules! element_type { ($ty:ty, $v:ident, $sz:tt) => { impl ArrayElement for $ty { - const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::$v; + const TY: ElementType = ElementType::$v; const ELEMENT_SIZE_IN_BYTES: usize = $sz; const ZERO: Self = 0 as Self; } @@ -220,7 +261,7 @@ macro_rules! element_type { pub struct F16; impl ArrayElement for F16 { - const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::F16; + const TY: ElementType = ElementType::F16; const ELEMENT_SIZE_IN_BYTES: usize = 2; const ZERO: Self = Self; } @@ -230,7 +271,7 @@ impl ArrayElement for F16 { pub struct Bf16; impl ArrayElement for Bf16 { - const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::Bf16; + const TY: ElementType = ElementType::Bf16; const ELEMENT_SIZE_IN_BYTES: usize = 2; const ZERO: Self = Self; } diff --git a/src/wrappers/pjrt_buffer.rs b/src/wrappers/pjrt_buffer.rs index 19ec344..b839189 100644 --- a/src/wrappers/pjrt_buffer.rs +++ b/src/wrappers/pjrt_buffer.rs @@ -1,5 +1,5 @@ //! A view on a memory slice hosted on a device. -use super::{ArrayElement, FromPrimitive, Literal, PjRtDevice, Shape}; +use super::{ArrayElement, ArrayShape, Literal, PjRtDevice, Shape}; use crate::{c_lib, Error, Result}; /// A buffer represents a view on a memory slice hosted on a device. @@ -34,16 +34,7 @@ impl PjRtBuffer { /// Retrieve the shape used by this buffer. pub fn on_device_shape(&self) -> Result { let shape = unsafe { c_lib::pjrt_buffer_on_device_shape(self.buffer) }; - let rank = unsafe { c_lib::shape_dimensions_size(shape) }; - let dimensions: Vec<_> = - (0..rank).map(|i| unsafe { c_lib::shape_dimensions(shape, i) }).collect(); - let ty = unsafe { c_lib::shape_element_type(shape) }; - let tuple_shapes_size = unsafe { c_lib::shape_tuple_shapes_size(shape) }; - unsafe { c_lib::shape_free(shape) }; - match FromPrimitive::from_i32(ty) { - None => Err(Error::UnexpectedElementType(ty)), - Some(ty) => Ok(Shape { ty, dimensions, tuple_shapes_size }), - } + unsafe { Shape::from_ptr(shape) } } /// Copy the data stored in a buffer to host memory in a blocking way. @@ -52,9 +43,11 @@ impl PjRtBuffer { dst: &mut [T], offset: usize, ) -> Result<()> { - let shape = self.on_device_shape()?; - if shape.ty != T::PRIMITIVE_TYPE { - Err(Error::ElementTypeMismatch { on_device: shape.ty, on_host: T::PRIMITIVE_TYPE })? + let shape = ArrayShape::try_from(&self.on_device_shape()?)?; + let on_host = T::TY; + let on_device = shape.primitive_type().element_type()?; + if on_device != on_host { + Err(Error::ElementTypeMismatch { on_device, on_host })? } if offset + dst.len() > shape.element_count() { Err(Error::TargetBufferIsTooLarge { offset, shape, buffer_len: dst.len() })? diff --git a/src/wrappers/pjrt_client.rs b/src/wrappers/pjrt_client.rs index 7bf456e..fe044de 100644 --- a/src/wrappers/pjrt_client.rs +++ b/src/wrappers/pjrt_client.rs @@ -117,7 +117,7 @@ impl PjRtClient { self.ptr(), device, data.as_ptr() as *const libc::c_void, - T::PRIMITIVE_TYPE as i32, + T::TY.primitive_type() as i32, dims.len() as i32, dims.as_ptr(), &mut buffer, @@ -134,16 +134,14 @@ impl PjRtClient { /// is returned. pub fn buffer_from_host_raw_bytes( &self, - ty: super::PrimitiveType, + ty: super::ElementType, data: &[u8], dims: &[usize], device: Option<&PjRtDevice>, ) -> Result { let mut buffer: c_lib::pjrt_buffer = std::ptr::null_mut(); let element_count: usize = dims.iter().product(); - let element_size_in_bytes = ty - .element_size_in_bytes() - .ok_or(Error::UnsupportedElementType { ty, op: "buffer_from_bytes" })?; + let element_size_in_bytes = ty.element_size_in_bytes(); if element_count * element_size_in_bytes != data.len() { Err(Error::WrongElementCount { dims: dims.to_vec(), element_count })? } diff --git a/src/wrappers/shape.rs b/src/wrappers/shape.rs index 1a8f3ff..eff8e74 100644 --- a/src/wrappers/shape.rs +++ b/src/wrappers/shape.rs @@ -1,86 +1,173 @@ -use super::{ArrayElement, PrimitiveType}; -use crate::Error; +use super::{ArrayElement, ElementType, PrimitiveType}; +use crate::{c_lib, Error}; -/// A shape specifies a primitive type as well as some array dimensions. #[derive(Clone, PartialEq, Eq, Debug)] -pub struct Shape { - pub(super) ty: PrimitiveType, - pub(super) dimensions: Vec, - pub(super) tuple_shapes_size: usize, +pub struct ArrayShape { + ty: ElementType, + dims: Vec, } -impl Shape { - /// Create a new shape. - pub fn new(dimensions: Vec) -> Shape { - Shape { ty: E::PRIMITIVE_TYPE, dimensions, tuple_shapes_size: 0 } +impl ArrayShape { + /// Create a new array shape. + pub fn new(dims: Vec) -> Self { + Self { ty: E::TY, dims } } - /// Create a new shape. - pub fn with_type(ty: PrimitiveType, dimensions: Vec) -> Shape { - Shape { ty, dimensions, tuple_shapes_size: 0 } + /// Create a new array shape. + pub fn new_with_type(ty: ElementType, dims: Vec) -> Self { + Self { ty, dims } } - /// Create a new tuple shape. - pub fn tuple(size: usize) -> Shape { - Shape { ty: PrimitiveType::Tuple, dimensions: vec![], tuple_shapes_size: size } + pub fn element_type(&self) -> ElementType { + self.ty } - /// The stored primitive type. - pub fn element_type(&self) -> PrimitiveType { + pub fn ty(&self) -> ElementType { self.ty } - /// The stored primitive type, shortcut for `element_type`. - pub fn ty(&self) -> PrimitiveType { - self.ty + /// The stored primitive type. + pub fn primitive_type(&self) -> PrimitiveType { + self.ty.primitive_type() } /// The number of elements stored in arrays that use this shape, this is the product of sizes /// across each dimension. pub fn element_count(&self) -> usize { - self.dimensions.iter().map(|d| *d as usize).product::() + self.dims.iter().map(|d| *d as usize).product::() } - pub fn dimensions(&self) -> &[i64] { - &self.dimensions + pub fn dims(&self) -> &[i64] { + &self.dims } pub fn first_dim(&self) -> Option { - self.dimensions.first().copied() + self.dims.first().copied() } pub fn last_dim(&self) -> Option { - self.dimensions.last().copied() + self.dims.last().copied() + } +} + +/// A shape specifies a primitive type as well as some array dimensions. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum Shape { + Tuple(Vec), + Array(ArrayShape), + Unsupported(PrimitiveType), +} + +impl Shape { + /// Create a new array shape. + pub fn array(dims: Vec) -> Self { + Self::Array(ArrayShape { ty: E::TY, dims }) + } + + /// Create a new array shape. + pub fn array_with_type(ty: ElementType, dims: Vec) -> Self { + Self::Array(ArrayShape { ty, dims }) + } + + /// Create a new tuple shape. + pub fn tuple(shapes: Vec) -> Self { + Self::Tuple(shapes) + } + + /// The stored primitive type. + pub fn primitive_type(&self) -> PrimitiveType { + match self { + Self::Tuple(_) => PrimitiveType::Tuple, + Self::Array(a) => a.ty.primitive_type(), + Self::Unsupported(ty) => *ty, + } } pub fn is_tuple(&self) -> bool { - self.ty == PrimitiveType::Tuple + match self { + Self::Tuple(_) => true, + Self::Array { .. } | Self::Unsupported(_) => false, + } } pub fn tuple_size(&self) -> Option { - if self.ty == PrimitiveType::Tuple { - Some(self.tuple_shapes_size) - } else { - None + match self { + Self::Tuple(shapes) => Some(shapes.len()), + Self::Array { .. } | Self::Unsupported(_) => None, + } + } + + pub(crate) unsafe fn from_ptr(ptr: c_lib::shape) -> crate::Result { + fn from_ptr_rec(ptr: c_lib::shape) -> crate::Result { + let ty = unsafe { c_lib::shape_element_type(ptr) }; + let ty = super::FromPrimitive::from_i32(ty) + .ok_or_else(|| Error::UnexpectedElementType(ty))?; + match ty { + PrimitiveType::Tuple => { + let elem_cnt = unsafe { c_lib::shape_tuple_shapes_size(ptr) }; + let shapes: crate::Result> = (0..elem_cnt) + .map(|i| from_ptr_rec(unsafe { c_lib::shape_tuple_shapes(ptr, i as i32) })) + .collect(); + Ok(Shape::Tuple(shapes?)) + } + ty => match ty.element_type() { + Ok(ty) => { + let rank = unsafe { c_lib::shape_dimensions_size(ptr) }; + let dims: Vec<_> = + (0..rank).map(|i| unsafe { c_lib::shape_dimensions(ptr, i) }).collect(); + Ok(Shape::Array(ArrayShape { ty, dims })) + } + Err(_) => Ok(Shape::Unsupported(ty)), + }, + } + } + + let shape = from_ptr_rec(ptr); + unsafe { c_lib::shape_free(ptr) }; + shape + } +} + +impl TryFrom<&Shape> for ArrayShape { + type Error = Error; + + fn try_from(value: &Shape) -> Result { + match value { + Shape::Tuple(_) | Shape::Unsupported(_) => { + Err(Error::NotAnArray { expected: None, got: value.clone() }) + } + Shape::Array(a) => Ok(a.clone()), } } } macro_rules! extract_dims { ($cnt:tt, $dims:expr, $out_type:ty) => { - impl TryFrom<&Shape> for $out_type { + impl TryFrom<&ArrayShape> for $out_type { type Error = Error; - fn try_from(value: &Shape) -> Result { - let dims = &value.dimensions; - if dims.len() != $cnt { + fn try_from(value: &ArrayShape) -> Result { + if value.dims.len() != $cnt { Err(Error::UnexpectedNumberOfDims { expected: $cnt, - got: dims.len(), - dims: dims.clone(), + got: value.dims.len(), + dims: value.dims.clone(), }) } else { - Ok($dims(dims)) + Ok($dims(&value.dims)) + } + } + } + + impl TryFrom<&Shape> for $out_type { + type Error = Error; + + fn try_from(value: &Shape) -> Result { + match value { + Shape::Tuple(_) | Shape::Unsupported(_) => { + Err(Error::NotAnArray { expected: Some($cnt), got: value.clone() }) + } + Shape::Array(a) => Self::try_from(a), } } } diff --git a/src/wrappers/xla_builder.rs b/src/wrappers/xla_builder.rs index 1b462b6..68cc50a 100644 --- a/src/wrappers/xla_builder.rs +++ b/src/wrappers/xla_builder.rs @@ -1,5 +1,6 @@ use super::{ - handle_status, FromPrimitive, Literal, NativeType, PrimitiveType, Shape, XlaComputation, XlaOp, + handle_status, ArrayShape, FromPrimitive, Literal, NativeType, PrimitiveType, Shape, + XlaComputation, XlaOp, }; use crate::{c_lib, Error, Result}; use std::rc::Rc; @@ -77,7 +78,7 @@ impl XlaBuilder { pub fn parameter( &self, parameter_number: i64, - ty: PrimitiveType, + ty: super::ElementType, dims: &[i64], name: &str, ) -> Result { @@ -86,7 +87,7 @@ impl XlaBuilder { c_lib::parameter( self.ptr(), parameter_number, - ty as i32, + ty.primitive_type() as i32, dims.len() as i32, dims.as_ptr(), name.as_ptr(), @@ -107,10 +108,11 @@ impl XlaBuilder { pub fn parameter_with_shape( &self, parameter_number: i64, - shape: &Shape, + shape: &ArrayShape, name: &str, ) -> Result { - self.parameter(parameter_number, shape.ty, &shape.dimensions, name) + let dims = shape.dims(); + self.parameter(parameter_number, shape.ty(), dims, name) } pub fn constant_r1c(&self, f: T, len: usize) -> Result { @@ -130,46 +132,47 @@ impl XlaBuilder { } /// A scalar node with the zero value for the associated type. - pub fn zero(&self, ty: super::PrimitiveType) -> Result { - let op = unsafe { c_lib::op_zero(self.ptr(), ty as i32) }; + pub fn zero(&self, ty: super::ElementType) -> Result { + let op = unsafe { c_lib::op_zero(self.ptr(), ty.primitive_type() as i32) }; self.wrap(op) } /// A scalar node with the one value for the associated type. - pub fn one(&self, ty: super::PrimitiveType) -> Result { - let op = unsafe { c_lib::op_one(self.ptr(), ty as i32) }; + pub fn one(&self, ty: super::ElementType) -> Result { + let op = unsafe { c_lib::op_one(self.ptr(), ty.primitive_type() as i32) }; self.wrap(op) } /// A scalar node with the minimum value for the associated type. - pub fn min_value(&self, ty: super::PrimitiveType) -> Result { - let op = unsafe { c_lib::op_min_value(self.ptr(), ty as i32) }; + pub fn min_value(&self, ty: super::ElementType) -> Result { + let op = unsafe { c_lib::op_min_value(self.ptr(), ty.primitive_type() as i32) }; self.wrap(op) } /// A scalar node with the maximum value for the associated type. - pub fn max_value(&self, ty: super::PrimitiveType) -> Result { - let op = unsafe { c_lib::op_max_value(self.ptr(), ty as i32) }; + pub fn max_value(&self, ty: super::ElementType) -> Result { + let op = unsafe { c_lib::op_max_value(self.ptr(), ty.primitive_type() as i32) }; self.wrap(op) } /// A constant node with the specified shape that holds increasing values starting from 0 along /// the iota dimension. - pub fn iota( - &self, - ty: super::PrimitiveType, - dims: &[i64], - iota_dimension: i64, - ) -> Result { + pub fn iota(&self, ty: super::ElementType, dims: &[i64], iota_dimension: i64) -> Result { let op = unsafe { - c_lib::op_iota(self.ptr(), ty as i32, dims.len(), dims.as_ptr(), iota_dimension) + c_lib::op_iota( + self.ptr(), + ty.primitive_type() as i32, + dims.len(), + dims.as_ptr(), + iota_dimension, + ) }; self.wrap(op) } /// A constant node for a unidimensional array of increasing values starting from 0. - pub fn iota1(&self, ty: super::PrimitiveType, size: usize) -> Result { - let op = unsafe { c_lib::op_iota1(self.ptr(), ty as i32, size) }; + pub fn iota1(&self, ty: super::ElementType, size: usize) -> Result { + let op = unsafe { c_lib::op_iota1(self.ptr(), ty.primitive_type() as i32, size) }; self.wrap(op) } @@ -208,16 +211,7 @@ impl XlaBuilder { let mut out: c_lib::shape = std::ptr::null_mut(); let status = unsafe { c_lib::get_shape(self.ptr(), op.op, &mut out) }; handle_status(status)?; - let rank = unsafe { c_lib::shape_dimensions_size(out) }; - let dimensions: Vec<_> = - (0..rank).map(|i| unsafe { c_lib::shape_dimensions(out, i) }).collect(); - let ty = unsafe { c_lib::shape_element_type(out) }; - let tuple_shapes_size = unsafe { c_lib::shape_tuple_shapes_size(out) }; - unsafe { c_lib::shape_free(out) }; - match FromPrimitive::from_i32(ty) { - None => Err(Error::UnexpectedElementType(ty)), - Some(ty) => Ok(Shape { ty, dimensions, tuple_shapes_size }), - } + unsafe { Shape::from_ptr(out) } } /// The dimension sizes associated with this op. @@ -230,7 +224,7 @@ impl XlaBuilder { } /// The element type associated with this op. - pub fn get_element_type(&self, op: &XlaOp) -> Result { + pub fn get_primitive_type(&self, op: &XlaOp) -> Result { let mut ty = 0i32; let status = unsafe { c_lib::get_element_type(self.ptr(), op.op, &mut ty) }; handle_status(status)?; diff --git a/src/wrappers/xla_op.rs b/src/wrappers/xla_op.rs index 938f39c..db90ad5 100644 --- a/src/wrappers/xla_op.rs +++ b/src/wrappers/xla_op.rs @@ -5,7 +5,7 @@ //! //! For details on the semantics, see //! [operation_semantics](https://www.tensorflow.org/xla/operation_semantics). -use super::{PrimitiveType, Shape, XlaBuilder, XlaComputation}; +use super::{ArrayShape, PrimitiveType, Shape, XlaBuilder, XlaComputation}; use crate::{c_lib, Error, Result}; pub struct XlaOp { @@ -268,28 +268,30 @@ impl XlaOp { } /// A node that when executed generates values using a random uniform distribution. - pub fn rng_uniform(min: &Self, max: &Self, shape: &Shape) -> Result { + pub fn rng_uniform(min: &Self, max: &Self, shape: &ArrayShape) -> Result { + let dims = shape.dims(); let op = unsafe { c_lib::op_rng_uniform( min.op, max.op, - shape.ty as i32, - shape.dimensions.len() as i32, - shape.dimensions.as_ptr(), + shape.primitive_type() as i32, + dims.len() as i32, + dims.as_ptr(), ) }; min.wrap(op) } /// A node that when executed generates values using a random normal distribution. - pub fn rng_normal(mu: &Self, sigma: &Self, shape: &Shape) -> Result { + pub fn rng_normal(mu: &Self, sigma: &Self, shape: &ArrayShape) -> Result { + let dims = shape.dims(); let op = unsafe { c_lib::op_rng_normal( mu.op, sigma.op, - shape.ty as i32, - shape.dimensions.len() as i32, - shape.dimensions.as_ptr(), + shape.primitive_type() as i32, + dims.len() as i32, + dims.as_ptr(), ) }; mu.wrap(op) @@ -395,13 +397,13 @@ impl XlaOp { } /// The kind of elements that are computed by this operand. - pub fn element_type(&self) -> Result { - self.builder.get_element_type(self) + pub fn primitive_type(&self) -> Result { + self.builder.get_primitive_type(self) } - /// The kind of elements that are computed by this operand, shortcut for `element_type`. + /// The kind of elements that are computed by this operand, shortcut for `primitive_type`. pub fn ty(&self) -> Result { - self.element_type() + self.primitive_type() } /// The number of dimensions for this node. @@ -413,6 +415,10 @@ impl XlaOp { self.builder.get_shape(self) } + pub fn array_shape(&self) -> Result { + ArrayShape::try_from(&self.builder.get_shape(self)?) + } + pub fn dims(&self) -> Result> { self.builder.get_dims(self) } @@ -486,10 +492,10 @@ impl XlaOp { pub fn take(&self, indices: &XlaOp, axis: i64) -> Result { let axis = self.normalize_index(axis)?; - let shape = self.shape()?; - let indices_shape = indices.shape()?; - let index_dims = indices_shape.dimensions(); - let dims = shape.dimensions(); + let shape = self.array_shape()?; + let indices_shape = indices.array_shape()?; + let index_dims = indices_shape.dims(); + let dims = shape.dims(); let offset_dims: Vec<_> = (0..((dims.len() + index_dims.len()) as i64 - 1)) .filter(|x| *x < axis || *x >= axis + index_dims.len() as i64) .collect(); @@ -503,14 +509,14 @@ impl XlaOp { self.gather(&indices, &offset_dims, &[axis], &[axis], index_vector_dim, &slice_sizes) } - fn maybe_keep_dims(&self, res: XlaOp, dims: &[i64], keep_dims: bool) -> Result { - if keep_dims && !dims.is_empty() { - let shape = self.shape()?; - let mut dimensions = shape.dimensions().to_vec(); - for d in dims.iter() { - dimensions[*d as usize] = 1; + fn maybe_keep_dims(&self, res: XlaOp, dims_to_keep: &[i64], keep_dims: bool) -> Result { + if keep_dims && !dims_to_keep.is_empty() { + let shape = self.array_shape()?; + let mut dims = shape.dims().to_vec(); + for d in dims_to_keep.iter() { + dims[*d as usize] = 1; } - res.reshape(&dimensions) + res.reshape(&dims) } else { Ok(res) } @@ -521,7 +527,7 @@ impl XlaOp { /// original node. pub fn reduce_sum(&self, dims: &[i64], keep_dims: bool) -> Result { let builder = XlaBuilder::new("Sum"); - let ty = self.element_type()?; + let ty = self.primitive_type()?.element_type()?; let x = builder.parameter(0, ty, &[], "x")?; let y = builder.parameter(1, ty, &[], "y")?; let sum = x.add_(&y)?.build()?; @@ -532,8 +538,8 @@ impl XlaOp { /// A node that computes the average value across the specified dimensions. pub fn reduce_mean(&self, dims: &[i64], keep_dims: bool) -> Result { let b = &self.builder(); - let ty = self.element_type()?; - let mut scale = b.one(PrimitiveType::S32)?; + let ty = self.primitive_type()?; + let mut scale = b.one(crate::ElementType::S32)?; for d in dims.iter() { scale = (scale * self.dimensions_size(*d)?)?; } @@ -544,7 +550,7 @@ impl XlaOp { /// A node that computes the maximum value across the specified dimensions. pub fn reduce_max(&self, dims: &[i64], keep_dims: bool) -> Result { let builder = XlaBuilder::new("Max"); - let ty = self.element_type()?; + let ty = self.primitive_type()?.element_type()?; let x = builder.parameter(0, ty, &[], "x")?; let y = builder.parameter(1, ty, &[], "y")?; let sum = x.max(&y)?.build()?; @@ -555,7 +561,7 @@ impl XlaOp { /// A node that computes the minimum value across the specified dimensions. pub fn reduce_min(&self, dims: &[i64], keep_dims: bool) -> Result { let builder = XlaBuilder::new("Min"); - let ty = self.element_type()?; + let ty = self.primitive_type()?.element_type()?; let x = builder.parameter(0, ty, &[], "x")?; let y = builder.parameter(1, ty, &[], "y")?; let sum = x.min(&y)?.build()?; @@ -573,7 +579,7 @@ impl XlaOp { /// Layer normalization, this normalizes values on the target dimension to be of zero mean and /// standard deviation one, and then scales the result by `scale` and adds `bias`. pub fn layer_norm(&self, dim: i64, scale: &XlaOp, bias: &XlaOp) -> Result { - let ty = self.element_type().unwrap_or(PrimitiveType::F32); + let ty = self.primitive_type().unwrap_or(PrimitiveType::F32); let eps = self.builder().c0(1e-5)?.convert(ty)?; let mean = self.reduce_mean(&[dim], true)?; let mean2 = (self * self)?.reduce_mean(&[dim], true)?; @@ -587,10 +593,10 @@ impl XlaOp { pub fn matmul(&self, rhs: &Self) -> Result { // Similar to the jax implementation but without the squeezing. // https://github.com/google/jax/blob/849e47f79ac64ccba1a762804217c00a9905025b/jax/_src/numpy/lax_numpy.py#L3028 - let lhs_shape = self.shape()?; - let rhs_shape = self.shape()?; - let lhs_dims = lhs_shape.dimensions(); - let rhs_dims = rhs_shape.dimensions(); + let lhs_shape = self.array_shape()?; + let rhs_shape = self.array_shape()?; + let lhs_dims = lhs_shape.dims(); + let rhs_dims = rhs_shape.dims(); let lhs_ndims = lhs_dims.len(); let rhs_ndims = rhs_dims.len(); if lhs_ndims < 1 || rhs_ndims < 1 { diff --git a/tests/basic_tests.rs b/tests/basic_tests.rs index e09a05d..12b54a0 100644 --- a/tests/basic_tests.rs +++ b/tests/basic_tests.rs @@ -12,7 +12,7 @@ fn add_op() -> Result<()> { let result = result.execute::(&[])?; let result = result[0][0].to_literal_sync()?; assert_eq!(result.element_count(), 2); - assert_eq!(result.shape()?, xla::Shape::new::(vec![2])); + assert_eq!(result.array_shape()?, xla::ArrayShape::new::(vec![2])); assert_eq!(result.get_first_element::()?, 85.); assert_eq!(result.to_vec::()?, [85., 85.]); Ok(()) @@ -22,7 +22,7 @@ fn add_op() -> Result<()> { fn sum_op() -> Result<()> { let client = xla::PjRtClient::cpu()?; let builder = xla::XlaBuilder::new("test"); - let x = builder.parameter(0, f32::PRIMITIVE_TYPE, &[2], "x")?; + let x = builder.parameter(0, f32::TY, &[2], "x")?; let sum = x.reduce_sum(&[], false)?.build()?.compile(&client)?; let input = xla::Literal::vec1(&[4.2f32, 1.337f32]); let result = sum.execute::(&[input])?; @@ -30,24 +30,24 @@ fn sum_op() -> Result<()> { assert_eq!(result.to_vec::()?, [4.2, 1.337]); let builder = xla::XlaBuilder::new("test"); - let x = builder.parameter(0, f32::PRIMITIVE_TYPE, &[-2], "x")?; + let x = builder.parameter(0, f32::TY, &[-2], "x")?; let sum = x.reduce_sum(&[0], false)?.build()?.compile(&client)?; let input = xla::Literal::vec1(&[4.2f32, 1.337f32]); let result = sum.execute::(&[input])?; let result = result[0][0].to_literal_sync()?; assert_eq!(result.to_vec::()?, [5.5369997]); // Dimensions got reduced. - assert_eq!(result.shape()?.dimensions(), []); + assert_eq!(result.array_shape()?.dims(), []); let builder = xla::XlaBuilder::new("test"); - let x = builder.parameter(0, f32::PRIMITIVE_TYPE, &[-2], "x")?; + let x = builder.parameter(0, f32::TY, &[-2], "x")?; let sum = x.reduce_sum(&[0], true)?.build()?.compile(&client)?; let input = xla::Literal::vec1(&[4.2f32, 1.337f32]); let result = sum.execute::(&[input])?; let result = result[0][0].to_literal_sync()?; assert_eq!(result.to_vec::()?, [5.5369997]); // keep_dims = true in this case. - assert_eq!(result.shape()?.dimensions(), [1]); + assert_eq!(result.array_shape()?.dims(), [1]); Ok(()) } @@ -55,14 +55,14 @@ fn sum_op() -> Result<()> { fn mean_op() -> Result<()> { let client = xla::PjRtClient::cpu()?; let builder = xla::XlaBuilder::new("test"); - let x = builder.parameter(0, f32::PRIMITIVE_TYPE, &[-2], "x")?; + let x = builder.parameter(0, f32::TY, &[-2], "x")?; let sum = x.reduce_mean(&[0], false)?.build()?.compile(&client)?; let input = xla::Literal::vec1(&[4.2f32, 1.337f32]); let result = sum.execute::(&[input])?; let result = result[0][0].to_literal_sync()?; assert_eq!(result.to_vec::()?, [2.7684999]); // Dimensions got reduced. - assert_eq!(result.shape()?.dimensions(), []); + assert_eq!(result.array_shape()?.dims(), []); Ok(()) } @@ -70,8 +70,8 @@ fn mean_op() -> Result<()> { fn tuple_op() -> Result<()> { let client = xla::PjRtClient::cpu()?; let builder = xla::XlaBuilder::new("test"); - let x = builder.parameter(0, f32::PRIMITIVE_TYPE, &[-1], "x")?; - let y = builder.parameter(1, f32::PRIMITIVE_TYPE, &[2], "x")?; + let x = builder.parameter(0, f32::TY, &[-1], "x")?; + let y = builder.parameter(1, f32::TY, &[2], "x")?; let tuple = builder.tuple(&[x, y])?.build()?.compile(&client)?; let x = xla::Literal::scalar(3.1f32); let y = xla::Literal::vec1(&[4.2f32, 1.337f32]); diff --git a/tests/control_flow_tests.rs b/tests/control_flow_tests.rs index 10fc99a..bd035e4 100644 --- a/tests/control_flow_tests.rs +++ b/tests/control_flow_tests.rs @@ -6,12 +6,12 @@ fn while_op() -> Result<()> { let builder = xla::XlaBuilder::new("test"); let cond = { let builder = xla::XlaBuilder::new("cond"); - let x = builder.parameter(0, i32::PRIMITIVE_TYPE, &[], "x")?; + let x = builder.parameter(0, i32::TY, &[], "x")?; x.le(&builder.constant_r0(10i32)?)?.build()? }; let body = { let builder = xla::XlaBuilder::new("cond"); - let x = builder.parameter(0, i32::PRIMITIVE_TYPE, &[], "x")?; + let x = builder.parameter(0, i32::TY, &[], "x")?; (x + builder.constant_r0(1i32)?)?.build()? }; let init = builder.constant_r0(0i32)?; @@ -21,7 +21,7 @@ fn while_op() -> Result<()> { let result = result.execute::(&[])?; let result = result[0][0].to_literal_sync()?; assert_eq!(result.element_count(), 1); - assert_eq!(result.shape()?, xla::Shape::new::(vec![])); + assert_eq!(result.shape()?, xla::Shape::array::(vec![])); assert_eq!(result.to_vec::()?, [11]); Ok(()) } diff --git a/tests/tuple_tests.rs b/tests/tuple_tests.rs index 7002d02..9db2901 100644 --- a/tests/tuple_tests.rs +++ b/tests/tuple_tests.rs @@ -10,12 +10,12 @@ fn tuple_op() -> Result<()> { let result = client.compile(&computation)?; let result = result.execute::(&[])?; let mut result = result[0][0].to_literal_sync()?; - assert_eq!(result.shape()?, xla::Shape::tuple(2)); + assert_eq!(result.shape()?.tuple_size(), Some(2)); let as_tuple = result.decompose_tuple()?; - assert_eq!(result.shape()?, xla::Shape::tuple(0)); + assert_eq!(result.shape()?.tuple_size(), Some(0)); assert_eq!(as_tuple.len(), 2); - assert_eq!(as_tuple[0].shape()?, xla::Shape::new::(vec![])); - assert_eq!(as_tuple[1].shape()?, xla::Shape::new::(vec![2])); + assert_eq!(as_tuple[0].array_shape()?, xla::ArrayShape::new::(vec![])); + assert_eq!(as_tuple[1].array_shape()?, xla::ArrayShape::new::(vec![2])); assert_eq!(as_tuple[1].to_vec::()?, vec![43f32, 43f32]); Ok(()) } diff --git a/xla_rs/xla_rs.cc b/xla_rs/xla_rs.cc index 5cd9ef2..e547650 100644 --- a/xla_rs/xla_rs.cc +++ b/xla_rs/xla_rs.cc @@ -799,6 +799,10 @@ void xla_op_free(xla_op o) { delete o; } size_t shape_tuple_shapes_size(const shape s) { return s->tuple_shapes_size(); } +shape shape_tuple_shapes(const shape s, int i) { + return (shape)&s->tuple_shapes(i); +} + int shape_dimensions_size(const shape s) { return s->dimensions_size(); } int shape_element_type(const shape s) { return s->element_type(); } diff --git a/xla_rs/xla_rs.h b/xla_rs/xla_rs.h index 66ebb55..413f6b1 100644 --- a/xla_rs/xla_rs.h +++ b/xla_rs/xla_rs.h @@ -184,6 +184,7 @@ void xla_op_free(xla_op); int shape_dimensions_size(const shape); size_t shape_tuple_shapes_size(const shape); +shape shape_tuple_shapes(const shape, int); int shape_element_type(const shape); int64_t shape_dimensions(const shape, int); void shape_free(shape);