From 7fb08c19406aeb2779bc72d8f058f8e9e166f138 Mon Sep 17 00:00:00 2001 From: laurent Date: Sun, 7 May 2023 20:20:50 +0100 Subject: [PATCH] Rename ElementType in preparation of the Shape refactoring. --- examples/llama/var_store.rs | 4 ++-- examples/loop.rs | 2 +- src/wrappers/literal.rs | 10 +++++----- src/wrappers/mod.rs | 27 +++++++++++++++++++++++---- src/wrappers/pjrt_buffer.rs | 4 ++-- src/wrappers/pjrt_client.rs | 4 ++-- src/wrappers/shape.rs | 4 ++-- tests/basic_tests.rs | 2 +- tests/control_flow_tests.rs | 2 +- 9 files changed, 39 insertions(+), 20 deletions(-) diff --git a/examples/llama/var_store.rs b/examples/llama/var_store.rs index d8f2888..fcd1403 100644 --- a/examples/llama/var_store.rs +++ b/examples/llama/var_store.rs @@ -1,4 +1,4 @@ -use xla::{ElementType, FromRawBytes, PjRtBuffer, PjRtClient, PrimitiveType, Result, XlaOp}; +use xla::{ArrayElement, FromRawBytes, PjRtBuffer, PjRtClient, PrimitiveType, Result, XlaOp}; #[allow(dead_code)] #[derive(Clone)] @@ -24,7 +24,7 @@ pub struct VarStore { } impl VarBuilder { - pub fn new(builder: &xla::XlaBuilder) -> Self { + pub fn new(builder: &xla::XlaBuilder) -> Self { let vars = std::rc::Rc::new(std::cell::RefCell::new(vec![])); Self { builder: builder.clone(), diff --git a/examples/loop.rs b/examples/loop.rs index 368cd2a..8251e90 100644 --- a/examples/loop.rs +++ b/examples/loop.rs @@ -1,6 +1,6 @@ use anyhow::Result; extern crate xla; -use xla::ElementType; +use xla::ArrayElement; fn main() -> Result<()> { let client = xla::PjRtClient::cpu()?; diff --git a/src/wrappers/literal.rs b/src/wrappers/literal.rs index 61005af..d303620 100644 --- a/src/wrappers/literal.rs +++ b/src/wrappers/literal.rs @@ -1,4 +1,4 @@ -use super::{ElementType, FromPrimitive, NativeType, PrimitiveType, Shape}; +use super::{ArrayElement, FromPrimitive, NativeType, PrimitiveType, Shape}; use crate::{c_lib, Error, Result}; /// A literal represent a value, typically a multi-dimensional array, stored on the host device. @@ -49,7 +49,7 @@ impl Literal { /// Get the first element from a literal. This returns an error if type `T` is not the /// primitive type that the literal uses. - pub fn get_first_element(&self) -> Result { + pub fn get_first_element(&self) -> Result { let ty = self.ty()?; if ty != T::PRIMITIVE_TYPE { Err(Error::ElementTypeMismatch { on_device: ty, on_host: T::PRIMITIVE_TYPE })? @@ -104,7 +104,7 @@ impl Literal { /// Copy the literal data to a slice. This returns an error if the primitive type used by the /// literal is not `T` or if the number of elements in the slice and literal are different. - pub fn copy_raw_to(&self, dst: &mut [T]) -> Result<()> { + pub fn copy_raw_to(&self, dst: &mut [T]) -> Result<()> { let ty = self.ty()?; let element_count = self.element_count(); if ty != T::PRIMITIVE_TYPE { @@ -126,7 +126,7 @@ impl Literal { /// Copy data from a slice to the literal. This returns an error if the primitive type used /// by the literal is not `T` or if number of elements in the slice and the literal are /// different. - pub fn copy_raw_from(&mut self, src: &[T]) -> Result<()> { + pub fn copy_raw_from(&mut self, src: &[T]) -> Result<()> { let ty = self.ty()?; let element_count = self.element_count(); if ty != T::PRIMITIVE_TYPE { @@ -147,7 +147,7 @@ impl Literal { /// Copy the values stored in the literal in a newly created vector. The data is flattened out /// for literals with more than one dimension. - pub fn to_vec(&self) -> Result> { + pub fn to_vec(&self) -> Result> { let element_count = self.element_count(); // Maybe we should use an uninitialized vec instead? let mut data = vec![T::ZERO; element_count]; diff --git a/src/wrappers/mod.rs b/src/wrappers/mod.rs index 12dd777..4cae538 100644 --- a/src/wrappers/mod.rs +++ b/src/wrappers/mod.rs @@ -79,7 +79,26 @@ impl PrimitiveType { } } -pub trait ElementType: Copy { +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum ElementType { + Pred, + S8, + S16, + S32, + S64, + U8, + U16, + U32, + U64, + F16, + F32, + Bf16, + F64, + C64, + C128, +} + +pub trait ArrayElement: Copy { const PRIMITIVE_TYPE: PrimitiveType; const ELEMENT_SIZE_IN_BYTES: usize; const ZERO: Self; @@ -188,7 +207,7 @@ native_type!( macro_rules! element_type { ($ty:ty, $v:ident, $sz:tt) => { - impl ElementType for $ty { + impl ArrayElement for $ty { const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::$v; const ELEMENT_SIZE_IN_BYTES: usize = $sz; const ZERO: Self = 0 as Self; @@ -200,7 +219,7 @@ macro_rules! element_type { #[derive(Copy, Clone, Debug)] pub struct F16; -impl ElementType for F16 { +impl ArrayElement for F16 { const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::F16; const ELEMENT_SIZE_IN_BYTES: usize = 2; const ZERO: Self = Self; @@ -210,7 +229,7 @@ impl ElementType for F16 { #[derive(Copy, Clone, Debug)] pub struct Bf16; -impl ElementType for Bf16 { +impl ArrayElement for Bf16 { const PRIMITIVE_TYPE: PrimitiveType = PrimitiveType::Bf16; const ELEMENT_SIZE_IN_BYTES: usize = 2; const ZERO: Self = Self; diff --git a/src/wrappers/pjrt_buffer.rs b/src/wrappers/pjrt_buffer.rs index 4f5a626..19ec344 100644 --- a/src/wrappers/pjrt_buffer.rs +++ b/src/wrappers/pjrt_buffer.rs @@ -1,5 +1,5 @@ //! A view on a memory slice hosted on a device. -use super::{ElementType, FromPrimitive, Literal, PjRtDevice, Shape}; +use super::{ArrayElement, FromPrimitive, Literal, PjRtDevice, Shape}; use crate::{c_lib, Error, Result}; /// A buffer represents a view on a memory slice hosted on a device. @@ -47,7 +47,7 @@ impl PjRtBuffer { } /// Copy the data stored in a buffer to host memory in a blocking way. - pub fn copy_raw_to_host_sync( + pub fn copy_raw_to_host_sync( &self, dst: &mut [T], offset: usize, diff --git a/src/wrappers/pjrt_client.rs b/src/wrappers/pjrt_client.rs index 2185032..7bf456e 100644 --- a/src/wrappers/pjrt_client.rs +++ b/src/wrappers/pjrt_client.rs @@ -1,5 +1,5 @@ //! A device (CPUs, GPUs, TPUs) where computations can be run. -use super::{ElementType, Literal, PjRtBuffer, PjRtDevice, PjRtLoadedExecutable, XlaComputation}; +use super::{ArrayElement, Literal, PjRtBuffer, PjRtDevice, PjRtLoadedExecutable, XlaComputation}; use crate::{c_lib, Error, Result}; use std::marker::PhantomData; use std::rc::Rc; @@ -99,7 +99,7 @@ impl PjRtClient { /// The source data is passed as a slice of the specified primitive type, as well as the /// dimensions. The dimensions have to match the number of elements in the source data, /// otherwise an error is returned. - pub fn buffer_from_host_buffer( + pub fn buffer_from_host_buffer( &self, data: &[T], dims: &[usize], diff --git a/src/wrappers/shape.rs b/src/wrappers/shape.rs index 015bf79..1a8f3ff 100644 --- a/src/wrappers/shape.rs +++ b/src/wrappers/shape.rs @@ -1,4 +1,4 @@ -use super::{ElementType, PrimitiveType}; +use super::{ArrayElement, PrimitiveType}; use crate::Error; /// A shape specifies a primitive type as well as some array dimensions. @@ -11,7 +11,7 @@ pub struct Shape { impl Shape { /// Create a new shape. - pub fn new(dimensions: Vec) -> Shape { + pub fn new(dimensions: Vec) -> Shape { Shape { ty: E::PRIMITIVE_TYPE, dimensions, tuple_shapes_size: 0 } } diff --git a/tests/basic_tests.rs b/tests/basic_tests.rs index 71ae3fc..e09a05d 100644 --- a/tests/basic_tests.rs +++ b/tests/basic_tests.rs @@ -1,4 +1,4 @@ -use xla::{ElementType, Result}; +use xla::{ArrayElement, Result}; #[test] fn add_op() -> Result<()> { diff --git a/tests/control_flow_tests.rs b/tests/control_flow_tests.rs index 7d2c1e6..10fc99a 100644 --- a/tests/control_flow_tests.rs +++ b/tests/control_flow_tests.rs @@ -1,4 +1,4 @@ -use xla::{ElementType, Result}; +use xla::{ArrayElement, Result}; #[test] fn while_op() -> Result<()> {