Skip to content

Commit

Permalink
Rename ElementType in preparation of the Shape refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed May 7, 2023
1 parent bf3bf7b commit 7fb08c1
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 20 deletions.
4 changes: 2 additions & 2 deletions examples/llama/var_store.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use xla::{ElementType, FromRawBytes, PjRtBuffer, PjRtClient, PrimitiveType, Result, XlaOp};
use xla::{ArrayElement, FromRawBytes, PjRtBuffer, PjRtClient, PrimitiveType, Result, XlaOp};

#[allow(dead_code)]
#[derive(Clone)]
Expand All @@ -24,7 +24,7 @@ pub struct VarStore {
}

impl VarBuilder {
pub fn new<B: ElementType, O: ElementType>(builder: &xla::XlaBuilder) -> Self {
pub fn new<B: ArrayElement, O: ArrayElement>(builder: &xla::XlaBuilder) -> Self {
let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![]));
Self {
builder: builder.clone(),
Expand Down
2 changes: 1 addition & 1 deletion examples/loop.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use anyhow::Result;
extern crate xla;
use xla::ElementType;
use xla::ArrayElement;

fn main() -> Result<()> {
let client = xla::PjRtClient::cpu()?;
Expand Down
10 changes: 5 additions & 5 deletions src/wrappers/literal.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{ElementType, FromPrimitive, NativeType, PrimitiveType, Shape};
use super::{ArrayElement, FromPrimitive, NativeType, PrimitiveType, Shape};
use crate::{c_lib, Error, Result};

/// A literal represent a value, typically a multi-dimensional array, stored on the host device.
Expand Down Expand Up @@ -49,7 +49,7 @@ impl Literal {

/// Get the first element from a literal. This returns an error if type `T` is not the
/// primitive type that the literal uses.
pub fn get_first_element<T: NativeType + ElementType>(&self) -> Result<T> {
pub fn get_first_element<T: NativeType + ArrayElement>(&self) -> Result<T> {
let ty = self.ty()?;
if ty != T::PRIMITIVE_TYPE {
Err(Error::ElementTypeMismatch { on_device: ty, on_host: T::PRIMITIVE_TYPE })?
Expand Down Expand Up @@ -104,7 +104,7 @@ impl Literal {

/// Copy the literal data to a slice. This returns an error if the primitive type used by the
/// literal is not `T` or if the number of elements in the slice and literal are different.
pub fn copy_raw_to<T: ElementType>(&self, dst: &mut [T]) -> Result<()> {
pub fn copy_raw_to<T: ArrayElement>(&self, dst: &mut [T]) -> Result<()> {
let ty = self.ty()?;
let element_count = self.element_count();
if ty != T::PRIMITIVE_TYPE {
Expand All @@ -126,7 +126,7 @@ impl Literal {
/// Copy data from a slice to the literal. This returns an error if the primitive type used
/// by the literal is not `T` or if number of elements in the slice and the literal are
/// different.
pub fn copy_raw_from<T: ElementType>(&mut self, src: &[T]) -> Result<()> {
pub fn copy_raw_from<T: ArrayElement>(&mut self, src: &[T]) -> Result<()> {
let ty = self.ty()?;
let element_count = self.element_count();
if ty != T::PRIMITIVE_TYPE {
Expand All @@ -147,7 +147,7 @@ impl Literal {

/// Copy the values stored in the literal in a newly created vector. The data is flattened out
/// for literals with more than one dimension.
pub fn to_vec<T: ElementType>(&self) -> Result<Vec<T>> {
pub fn to_vec<T: ArrayElement>(&self) -> Result<Vec<T>> {
let element_count = self.element_count();
// Maybe we should use an uninitialized vec instead?
let mut data = vec![T::ZERO; element_count];
Expand Down
27 changes: 23 additions & 4 deletions src/wrappers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,26 @@ impl PrimitiveType {
}
}

pub trait ElementType: Copy {
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum ElementType {
Pred,
S8,
S16,
S32,
S64,
U8,
U16,
U32,
U64,
F16,
F32,
Bf16,
F64,
C64,
C128,
}

pub trait ArrayElement: Copy {
const PRIMITIVE_TYPE: PrimitiveType;
const ELEMENT_SIZE_IN_BYTES: usize;
const ZERO: Self;
Expand Down Expand Up @@ -188,7 +207,7 @@ native_type!(

macro_rules! element_type {
($ty:ty, $v:ident, $sz:tt) => {
impl ElementType for $ty {
impl ArrayElement for $ty {
const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::$v;
const ELEMENT_SIZE_IN_BYTES: usize = $sz;
const ZERO: Self = 0 as Self;
Expand All @@ -200,7 +219,7 @@ macro_rules! element_type {
#[derive(Copy, Clone, Debug)]
pub struct F16;

impl ElementType for F16 {
impl ArrayElement for F16 {
const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::F16;
const ELEMENT_SIZE_IN_BYTES: usize = 2;
const ZERO: Self = Self;
Expand All @@ -210,7 +229,7 @@ impl ElementType for F16 {
#[derive(Copy, Clone, Debug)]
pub struct Bf16;

impl ElementType for Bf16 {
impl ArrayElement for Bf16 {
const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::Bf16;
const ELEMENT_SIZE_IN_BYTES: usize = 2;
const ZERO: Self = Self;
Expand Down
4 changes: 2 additions & 2 deletions src/wrappers/pjrt_buffer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! A view on a memory slice hosted on a device.
use super::{ElementType, FromPrimitive, Literal, PjRtDevice, Shape};
use super::{ArrayElement, FromPrimitive, Literal, PjRtDevice, Shape};
use crate::{c_lib, Error, Result};

/// A buffer represents a view on a memory slice hosted on a device.
Expand Down Expand Up @@ -47,7 +47,7 @@ impl PjRtBuffer {
}

/// Copy the data stored in a buffer to host memory in a blocking way.
pub fn copy_raw_to_host_sync<T: ElementType>(
pub fn copy_raw_to_host_sync<T: ArrayElement>(
&self,
dst: &mut [T],
offset: usize,
Expand Down
4 changes: 2 additions & 2 deletions src/wrappers/pjrt_client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//! A device (CPUs, GPUs, TPUs) where computations can be run.
use super::{ElementType, Literal, PjRtBuffer, PjRtDevice, PjRtLoadedExecutable, XlaComputation};
use super::{ArrayElement, Literal, PjRtBuffer, PjRtDevice, PjRtLoadedExecutable, XlaComputation};
use crate::{c_lib, Error, Result};
use std::marker::PhantomData;
use std::rc::Rc;
Expand Down Expand Up @@ -99,7 +99,7 @@ impl PjRtClient {
/// The source data is passed as a slice of the specified primitive type, as well as the
/// dimensions. The dimensions have to match the number of elements in the source data,
/// otherwise an error is returned.
pub fn buffer_from_host_buffer<T: ElementType>(
pub fn buffer_from_host_buffer<T: ArrayElement>(
&self,
data: &[T],
dims: &[usize],
Expand Down
4 changes: 2 additions & 2 deletions src/wrappers/shape.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{ElementType, PrimitiveType};
use super::{ArrayElement, PrimitiveType};
use crate::Error;

/// A shape specifies a primitive type as well as some array dimensions.
Expand All @@ -11,7 +11,7 @@ pub struct Shape {

impl Shape {
/// Create a new shape.
pub fn new<E: ElementType>(dimensions: Vec<i64>) -> Shape {
pub fn new<E: ArrayElement>(dimensions: Vec<i64>) -> Shape {
Shape { ty: E::PRIMITIVE_TYPE, dimensions, tuple_shapes_size: 0 }
}

Expand Down
2 changes: 1 addition & 1 deletion tests/basic_tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use xla::{ElementType, Result};
use xla::{ArrayElement, Result};

#[test]
fn add_op() -> Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion tests/control_flow_tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use xla::{ElementType, Result};
use xla::{ArrayElement, Result};

#[test]
fn while_op() -> Result<()> {
Expand Down

0 comments on commit 7fb08c1

Please sign in to comment.