Skip to content

Commit

Permalink
Merge pull request #5 from LaurentMazare/shape-refactoring
Browse files Browse the repository at this point in the history
Refactor the Shape type and module
  • Loading branch information
LaurentMazare committed May 8, 2023
2 parents 7fb08c1 + e66e357 commit cfc953b
Show file tree
Hide file tree
Showing 21 changed files with 397 additions and 270 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "xla"
version = "0.1.4"
version = "0.1.5"
authors = ["laurent <[email protected]>"]
edition = "2021"
description = "Bindings for the XLA C++ library."
Expand Down
2 changes: 1 addition & 1 deletion examples/basics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ fn main() -> Result<()> {
result.to_vec::<f32>(),
result.get_first_element::<f32>()?,
);
let param = xla_builder.parameter_with_shape(0, &xla::Shape::new::<f32>(vec![]), "p")?;
let param = xla_builder.parameter_with_shape(0, &xla::ArrayShape::new::<f32>(vec![]), "p")?;
let sum = param.add_(&param)?;
let sum = sum.sqrt()?.build()?;
let result = client.compile(&sum)?;
Expand Down
15 changes: 7 additions & 8 deletions examples/llama/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use clap::Parser;
use rand::prelude::*;

extern crate xla;
use xla::{PrimitiveType, XlaBuilder, XlaOp};
use xla::{ElementType, PrimitiveType, XlaBuilder, XlaOp};

mod sentencepiece;
use sentencepiece::Tokenizer;
Expand Down Expand Up @@ -195,9 +195,8 @@ impl Mlp {
}

fn masked_fill<T: xla::NativeType>(on_false: &XlaOp, mask: &XlaOp, on_true: T) -> Result<XlaOp> {
let shape = mask.shape()?;
let on_true =
mask.builder().c0(on_true)?.convert(on_false.ty()?)?.broadcast(shape.dimensions())?;
let shape = mask.array_shape()?;
let on_true = mask.builder().c0(on_true)?.convert(on_false.ty()?)?.broadcast(shape.dims())?;
let m = mask.select(&on_true, on_false)?;
Ok(m)
}
Expand Down Expand Up @@ -255,15 +254,15 @@ impl CausalSelfAttention {
let v = v.reshape(&target_dim)?.swap_dims(1, 2)?;
let q = self.apply_rotary_emb(&q, &freqs_cis)?;
let k = self.apply_rotary_emb(&k, &freqs_cis)?;
let k_shape = k.shape()?;
let k_shape = k.array_shape()?;
let att = (q.matmul(&k.swap_dims(-2, -1)?)?
* builder.c0(1f32 / (k_shape.last_dim().unwrap() as f32).sqrt())?.convert(ty)?)?;
let mask = builder
.one(PrimitiveType::S32)?
.one(ElementType::S32)?
.broadcast(&[t, t])?
.lower_triangle()?
.reshape(&[1, 1, t, t])?;
let zero = builder.zero(PrimitiveType::S32)?.broadcast(&[b, self.n_head as i64, t, t])?;
let zero = builder.zero(ElementType::S32)?.broadcast(&[b, self.n_head as i64, t, t])?;
let att = masked_fill(&att, &mask.eq(&zero)?, f32::NEG_INFINITY)?;
let y = att.softmax(-1)?.matmul(&v)?;
let y = y.swap_dims(1, 2)?.reshape(&[b, t, c])?;
Expand Down Expand Up @@ -351,7 +350,7 @@ fn llama_computation(args: &Args, bsize: i64) -> Result<(xla::XlaComputation, Va
let config = Config::config_7b();
let freqs_cis = precompute_freqs_cis(&config, &b)?;
let llama = Llama::new(vb.clone(), &config)?;
let input = vb.arg("tokens", PrimitiveType::U32, &[bsize as usize, CONTEXT_SIZE])?;
let input = vb.arg("tokens", ElementType::U32, &[bsize as usize, CONTEXT_SIZE])?;
let logits = llama.forward(&input, &freqs_cis)?.convert(PrimitiveType::F32)?;
let prs = (logits / b.c0(args.temperature)?)?.softmax(-1)?;
Ok((prs.build()?, vb.into_store()))
Expand Down
22 changes: 10 additions & 12 deletions examples/llama/var_store.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use xla::{ArrayElement, FromRawBytes, PjRtBuffer, PjRtClient, PrimitiveType, Result, XlaOp};
use xla::{ArrayElement, ElementType, FromRawBytes, PjRtBuffer, PjRtClient, Result, XlaOp};

#[allow(dead_code)]
#[derive(Clone)]
struct NamedVar {
path: String,
ty: PrimitiveType,
ty: ElementType,
dims: Vec<usize>,
is_arg: bool,
}
Expand All @@ -14,8 +14,8 @@ pub struct VarBuilder {
path: Vec<String>,
vars: std::rc::Rc<std::cell::RefCell<Vec<NamedVar>>>,
builder: xla::XlaBuilder,
default_buffer_type_for_var: PrimitiveType,
default_op_type_for_var: PrimitiveType,
default_buffer_type_for_var: ElementType,
default_op_type_for_var: ElementType,
}

#[allow(dead_code)]
Expand All @@ -30,8 +30,8 @@ impl VarBuilder {
builder: builder.clone(),
path: vec![],
vars,
default_buffer_type_for_var: B::PRIMITIVE_TYPE,
default_op_type_for_var: O::PRIMITIVE_TYPE,
default_buffer_type_for_var: B::TY,
default_op_type_for_var: O::TY,
}
}

Expand All @@ -42,7 +42,7 @@ impl VarBuilder {
pub fn var_(
&mut self,
s: &str,
ty: PrimitiveType,
ty: ElementType,
dims: &[usize],
is_arg: bool,
) -> Result<XlaOp> {
Expand All @@ -57,10 +57,10 @@ impl VarBuilder {

pub fn var(&mut self, s: &str, dims: &[usize]) -> Result<XlaOp> {
let v = self.var_(s, self.default_buffer_type_for_var, dims, false)?;
v.convert(self.default_op_type_for_var)
v.convert(self.default_op_type_for_var.primitive_type())
}

pub fn arg(&mut self, s: &str, ty: PrimitiveType, dims: &[usize]) -> Result<XlaOp> {
pub fn arg(&mut self, s: &str, ty: ElementType, dims: &[usize]) -> Result<XlaOp> {
self.var_(s, ty, dims, true)
}

Expand Down Expand Up @@ -119,9 +119,7 @@ impl VarStore {
let buffer = if var.is_arg {
let ty = var.ty;
let element_count: usize = var.dims.iter().product();
let element_size_in_bytes = ty
.element_size_in_bytes()
.ok_or(xla::Error::UnsupportedElementType { ty, op: "buffer_from_bytes" })?;
let element_size_in_bytes = ty.element_size_in_bytes();
let data = vec![0u8; element_count * element_size_in_bytes];
c.buffer_from_host_raw_bytes(ty, &data, &var.dims, None)?
} else {
Expand Down
10 changes: 5 additions & 5 deletions examples/loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ fn main() -> Result<()> {
let client = xla::PjRtClient::cpu()?;
loop {
let builder = xla::XlaBuilder::new("test");
let x = builder.parameter(0, f32::PRIMITIVE_TYPE, &[2], "x")?;
let x = builder.parameter(0, f32::TY, &[2], "x")?;
let sum = x.reduce_sum(&[], false)?.build()?.compile(&client)?;
let input = xla::Literal::vec1(&[4.2f32, 1.337f32]);
let result = sum.execute::<xla::Literal>(&[input])?;
Expand All @@ -16,7 +16,7 @@ fn main() -> Result<()> {
assert_eq!(result.to_vec::<f32>()?, [4.2, 1.337]);

let builder = xla::XlaBuilder::new("test");
let x = builder.parameter(0, f32::PRIMITIVE_TYPE, &[-2], "x")?;
let x = builder.parameter(0, f32::TY, &[-2], "x")?;
let sum = x.reduce_sum(&[0], false)?.build()?.compile(&client)?;
let input = xla::Literal::vec1(&[4.2f32, 1.337f32]);
let result = sum.execute::<xla::Literal>(&[input])?;
Expand All @@ -25,10 +25,10 @@ fn main() -> Result<()> {
drop(sum);
assert_eq!(result.to_vec::<f32>()?, [5.5369997]);
// Dimensions got reduced.
assert_eq!(result.shape()?.dimensions(), []);
assert_eq!(result.array_shape()?.dims(), []);

let builder = xla::XlaBuilder::new("test");
let x = builder.parameter(0, f32::PRIMITIVE_TYPE, &[-2], "x")?;
let x = builder.parameter(0, f32::TY, &[-2], "x")?;
let sum = x.reduce_sum(&[0], true)?.build()?.compile(&client)?;
let input = xla::Literal::vec1(&[4.2f32, 1.337f32]);
let result = sum.execute::<xla::Literal>(&[input])?;
Expand All @@ -37,7 +37,7 @@ fn main() -> Result<()> {
drop(sum);
assert_eq!(result.to_vec::<f32>()?, [5.5369997]);
// keep_dims = true in this case.
assert_eq!(result.shape()?.dimensions(), [1]);
assert_eq!(result.array_shape()?.dims(), [1]);
println!("Done!");
}
}
16 changes: 8 additions & 8 deletions examples/nanogpt/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@ use anyhow::Result;
use rand::prelude::*;

extern crate xla;
use xla::{Literal, PjRtLoadedExecutable, PrimitiveType, XlaBuilder, XlaOp};
use xla::{ElementType, Literal, PjRtLoadedExecutable, XlaBuilder, XlaOp};

mod tokenizer;
mod var_store;
use tokenizer::Tokenizer;
use var_store::VarStore;

const TY: PrimitiveType = PrimitiveType::F32;
const TY: ElementType = ElementType::F32;
const TEMPERATURE: f32 = 0.8f32;
const USE_CPU: bool = false;
const NUM_SAMPLES: usize = 10;
Expand Down Expand Up @@ -104,8 +104,8 @@ impl Linear {
}

fn masked_fill<T: xla::NativeType>(on_false: &XlaOp, mask: &XlaOp, on_true: T) -> Result<XlaOp> {
let shape = mask.shape()?;
let on_true = mask.builder().c0(on_true)?.broadcast(shape.dimensions())?;
let shape = mask.array_shape()?;
let on_true = mask.builder().c0(on_true)?.broadcast(shape.dims())?;
let m = mask.select(&on_true, on_false)?;
Ok(m)
}
Expand Down Expand Up @@ -137,15 +137,15 @@ impl CausalSelfAttention {
let k = k.reshape(&target_dim)?.swap_dims(1, 2)?;
let q = q.reshape(&target_dim)?.swap_dims(1, 2)?;
let v = v.reshape(&target_dim)?.swap_dims(1, 2)?;
let k_shape = k.shape()?;
let k_shape = k.array_shape()?;
let att = (q.matmul(&k.swap_dims(-2, -1)?)?
* builder.c0(1f32 / (k_shape.last_dim().unwrap() as f32).sqrt()))?;
let mask = builder
.one(PrimitiveType::S32)?
.one(ElementType::S32)?
.broadcast(&[t, t])?
.lower_triangle()?
.reshape(&[1, 1, t, t])?;
let zero = builder.zero(PrimitiveType::S32)?.broadcast(&[b, self.n_head as i64, t, t])?;
let zero = builder.zero(ElementType::S32)?.broadcast(&[b, self.n_head as i64, t, t])?;
let att = masked_fill(&att, &mask.eq(&zero)?, f32::NEG_INFINITY)?;
let y = att.softmax(-1)?.matmul(&v)?;
let y = y.swap_dims(1, 2)?.reshape(&[b, t, c])?;
Expand Down Expand Up @@ -253,7 +253,7 @@ fn gpt_computation(vs: VarStore, bsize: i64) -> Result<xla::XlaComputation> {
let b = XlaBuilder::new("gpt");
let config = GptConfig::default();
let gpt = Gpt::new(vs, &config)?;
let input = b.parameter(0, PrimitiveType::S32, &[bsize, config.block_size as i64], "tokens")?;
let input = b.parameter(0, ElementType::S32, &[bsize, config.block_size as i64], "tokens")?;
let logits = gpt.forward(&input)?;
let prs = (logits / b.c0(TEMPERATURE))?.softmax(-1)?;
Ok(prs.build()?)
Expand Down
10 changes: 5 additions & 5 deletions examples/nanogpt/var_store.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use anyhow::{Context, Result};
use std::collections::HashMap;

use xla::{FromRawBytes, Literal, PrimitiveType};
use xla::{ElementType, FromRawBytes, Literal};

#[derive(Clone)]
pub struct VarStore {
Expand All @@ -24,7 +24,7 @@ impl VarStore {
pub fn take(
&mut self,
s: &str,
expected_type: PrimitiveType,
expected_type: ElementType,
expected_dims: &[usize],
) -> Result<Literal> {
let path = format!("{}.{s}", self.path.join("."));
Expand All @@ -33,9 +33,9 @@ impl VarStore {
.borrow_mut()
.remove(&path)
.with_context(|| format!("cannot find {path} in VarStore"))?;
let shape = literal.shape()?;
let element_type = shape.element_type();
let dims = shape.dimensions();
let shape = literal.array_shape()?;
let element_type = shape.ty();
let dims = shape.dims();
if element_type != expected_type {
anyhow::bail!(
"unexpected element type for {}, got {:?} expected {:?}",
Expand Down
10 changes: 8 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,25 @@ pub enum Error {
#[error("unexpected number of dimensions, expected: {expected}, got: {got} ({dims:?})")]
UnexpectedNumberOfDims { expected: usize, got: usize, dims: Vec<i64> },

#[error("not an element type, got: {got:?}")]
NotAnElementType { got: crate::PrimitiveType },

#[error("not an array, expected: {expected:?}, got: {got:?}")]
NotAnArray { expected: Option<usize>, got: crate::Shape },

#[error("unexpected number of tuple elements, expected: {expected}, got: {got}")]
UnexpectedNumberOfElemsInTuple { expected: usize, got: usize },

#[error("element type mismatch, on-device: {on_device:?}, on-host: {on_host:?}")]
ElementTypeMismatch { on_device: crate::PrimitiveType, on_host: crate::PrimitiveType },
ElementTypeMismatch { on_device: crate::ElementType, on_host: crate::ElementType },

#[error("unsupported element type for {op}: {ty:?}")]
UnsupportedElementType { ty: crate::PrimitiveType, op: &'static str },

#[error(
"target buffer is too large, offset {offset}, shape {shape:?}, buffer_len: {buffer_len}"
)]
TargetBufferIsTooLarge { offset: usize, shape: crate::Shape, buffer_len: usize },
TargetBufferIsTooLarge { offset: usize, shape: crate::ArrayShape, buffer_len: usize },

#[error("binary buffer is too large, element count {element_count}, buffer_len: {buffer_len}")]
BinaryBufferIsTooLarge { element_count: usize, buffer_len: usize },
Expand Down
Loading

0 comments on commit cfc953b

Please sign in to comment.