Skip to content

Commit

Permalink
Shape to CShape conversion.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed May 9, 2023
1 parent fba8c8b commit 09351d8
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 11 deletions.
3 changes: 3 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ pub enum Error {
#[error("not an array, expected: {expected:?}, got: {got:?}")]
NotAnArray { expected: Option<usize>, got: crate::Shape },

#[error("cannot handle unsupported shapes {shape:?}")]
UnsupportedShape { shape: crate::Shape },

#[error("unexpected number of tuple elements, expected: {expected}, got: {got}")]
UnexpectedNumberOfElemsInTuple { expected: usize, got: usize },

Expand Down
34 changes: 27 additions & 7 deletions src/wrappers/shape.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{ArrayElement, ElementType, PrimitiveType};
use crate::{c_lib, Error};
use crate::{c_lib, Error, Result};

#[derive(Clone, PartialEq, Eq, Debug)]
pub struct ArrayShape {
Expand Down Expand Up @@ -96,12 +96,32 @@ impl Shape {
Self::Array { .. } | Self::Unsupported(_) => None,
}
}

#[allow(dead_code)]
pub(crate) fn c_shape(&self) -> Result<CShape> {
match self {
Self::Tuple(shapes) => {
let shapes = shapes.iter().map(|s| s.c_shape()).collect::<Result<Vec<_>>>()?;
let ptrs: Vec<_> = shapes.iter().map(|s| s.0).collect();
let c_shape = CShape(unsafe { c_lib::make_shape_tuple(ptrs.len(), ptrs.as_ptr()) });
drop(shapes);
Ok(c_shape)
}
Self::Array(a) => {
let dims = a.dims();
Ok(CShape(unsafe {
c_lib::make_shape_array(a.primitive_type() as i32, dims.len(), dims.as_ptr())
}))
}
Self::Unsupported(_) => Err(Error::UnsupportedShape { shape: self.clone() }),
}
}
}

impl TryFrom<&Shape> for ArrayShape {
type Error = Error;

fn try_from(value: &Shape) -> Result<Self, Self::Error> {
fn try_from(value: &Shape) -> Result<Self> {
match value {
Shape::Tuple(_) | Shape::Unsupported(_) => {
Err(Error::NotAnArray { expected: None, got: value.clone() })
Expand All @@ -116,7 +136,7 @@ macro_rules! extract_dims {
impl TryFrom<&ArrayShape> for $out_type {
type Error = Error;

fn try_from(value: &ArrayShape) -> Result<Self, Self::Error> {
fn try_from(value: &ArrayShape) -> Result<Self> {
if value.dims.len() != $cnt {
Err(Error::UnexpectedNumberOfDims {
expected: $cnt,
Expand All @@ -132,7 +152,7 @@ macro_rules! extract_dims {
impl TryFrom<&Shape> for $out_type {
type Error = Error;

fn try_from(value: &Shape) -> Result<Self, Self::Error> {
fn try_from(value: &Shape) -> Result<Self> {
match value {
Shape::Tuple(_) | Shape::Unsupported(_) => {
Err(Error::NotAnArray { expected: Some($cnt), got: value.clone() })
Expand All @@ -157,15 +177,15 @@ impl CShape {
Self(ptr)
}

pub(crate) fn shape(&self) -> crate::Result<Shape> {
fn from_ptr_rec(ptr: c_lib::shape) -> crate::Result<Shape> {
pub(crate) fn shape(&self) -> Result<Shape> {
fn from_ptr_rec(ptr: c_lib::shape) -> Result<Shape> {
let ty = unsafe { c_lib::shape_element_type(ptr) };
let ty = super::FromPrimitive::from_i32(ty)
.ok_or_else(|| Error::UnexpectedElementType(ty))?;
match ty {
PrimitiveType::Tuple => {
let elem_cnt = unsafe { c_lib::shape_tuple_shapes_size(ptr) };
let shapes: crate::Result<Vec<_>> = (0..elem_cnt)
let shapes: Result<Vec<_>> = (0..elem_cnt)
.map(|i| from_ptr_rec(unsafe { c_lib::shape_tuple_shapes(ptr, i as i32) }))
.collect();
Ok(Shape::Tuple(shapes?))
Expand Down
20 changes: 16 additions & 4 deletions xla_rs/xla_rs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ xla_op constant_literal(const xla_builder b, const literal l) {
FOR_EACH_NATIVE_TYPE(CONST_OP_R01)
#undef CONST_OP_R01

Shape make_shape(int pr_type, int dsize, const int64_t *ds) {
Shape make_shape_internal(int pr_type, int dsize, const int64_t *ds) {
bool has_negative_dim = false;
for (int i = 0; i < dsize; ++i) {
if (ds[i] < 0) {
Expand Down Expand Up @@ -238,25 +238,37 @@ Shape make_shape(int pr_type, int dsize, const int64_t *ds) {
return shape;
}

shape make_shape_array(int pr_type, size_t dsize, const int64_t *ds) {
return new Shape(make_shape_internal(pr_type, dsize, ds));
}

shape make_shape_tuple(size_t dsize, const shape *ds) {
std::vector<Shape> elts;
for (size_t i = 0; i < dsize; ++i) {
elts.push_back(*ds[i]);
}
return new Shape(ShapeUtil::MakeTupleShape(elts));
}

xla_op parameter(const xla_builder b, int64_t id, int pr_type, int dsize,
const int64_t *ds, const char *name) {
BEGIN_PROTECT_OP
Shape shape = make_shape(pr_type, dsize, ds);
Shape shape = make_shape_internal(pr_type, dsize, ds);
return new XlaOp(Parameter(b, id, shape, std::string(name)));
END_PROTECT_OP_B(b)
}

xla_op infeed(const xla_builder b, int pr_type, int dsize, const int64_t *ds,
const char *config) {
BEGIN_PROTECT_OP
Shape shape = make_shape(pr_type, dsize, ds);
Shape shape = make_shape_internal(pr_type, dsize, ds);
return new XlaOp(Infeed(b, shape, std::string(config)));
END_PROTECT_OP_B(b)
}

void outfeed(const xla_op op, int pr_type, int dsize, const int64_t *ds,
const char *outfeed_config) {
Shape shape = make_shape(pr_type, dsize, ds);
Shape shape = make_shape_internal(pr_type, dsize, ds);
Outfeed(*op, shape, std::string(outfeed_config));
}

Expand Down
3 changes: 3 additions & 0 deletions xla_rs/xla_rs.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tpu_client.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#pragma GCC diagnostic pop
using namespace xla;
Expand Down Expand Up @@ -188,6 +189,8 @@ shape shape_tuple_shapes(const shape, int);
int shape_element_type(const shape);
int64_t shape_dimensions(const shape, int);
void shape_free(shape);
shape make_shape_array(int, size_t, const int64_t *);
shape make_shape_tuple(size_t, const shape *);

status get_shape(const xla_builder, const xla_op, shape *);
status get_element_type(const xla_builder, const xla_op, int *);
Expand Down

0 comments on commit 09351d8

Please sign in to comment.