-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added doc-comments to tensor.rs #12
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
//! Provides the core functionality for manipulating tensors. | ||
use crate::{ | ||
AsBuffer, Buffer, Field, FromOp, IntoOp, Noxpr, NoxprScalarExt, Op, Param, Scalar, Vector, | ||
}; | ||
|
@@ -10,6 +11,8 @@ use std::{ | |
}; | ||
use xla::{ArrayElement, ElementType, NativeType}; | ||
|
||
/// Represents a tensor with a specific type `T`, dimensionality `D`, and parameterization `P`. | ||
/// `P` dictates the underlying representation and operations available on the tensor. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "operations available on the tensor" is not quite accurate. Would just drop this sentence, and change the previous sentence to include the "underlying representation" part. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
#[repr(transparent)] | ||
pub struct Tensor<T, D: TensorDim, P: Param = Op> { | ||
pub(crate) inner: P::Inner, | ||
|
@@ -27,6 +30,8 @@ where | |
} | ||
} | ||
|
||
/// Trait for items that can be represented as tensors. | ||
/// Specifies the type of the item, its tensor representation, and dimensionality. | ||
pub trait TensorItem { | ||
type Item: FromOp; | ||
type Tensor<D> | ||
|
@@ -62,8 +67,10 @@ impl<T, D: TensorDim> FromOp for Tensor<T, D> { | |
} | ||
} | ||
|
||
/// Trait for collapsing a tensor into a simpler form, typically by reducing its dimensionality. | ||
pub trait Collapse { | ||
type Out; | ||
/// Collapses the tensor into a simpler form. | ||
fn collapse(self) -> Self::Out; | ||
} | ||
|
||
|
@@ -139,8 +146,11 @@ impl<T, D: TensorDim> IntoOp for Tensor<T, D, Op> { | |
} | ||
} | ||
|
||
/// Represents a dimensionality of a tensor. This trait is a marker for types that can specify tensor dimensions. | ||
pub trait TensorDim {} | ||
/// Represents non-scalar dimensions, i.e., dimensions other than `()`. | ||
pub trait NonScalarDim {} | ||
/// Represents constant dimensions, specified at compile-time. | ||
pub trait ConstDim {} | ||
|
||
pub type ScalarDim = (); | ||
|
@@ -328,7 +338,9 @@ impl<T, D: TensorDim + XlaDim> FixedSliceExt<T, D> for Tensor<T, D, Op> { | |
} | ||
} | ||
|
||
/// Extension trait for tensors supporting fixed-size slicing operations. | ||
pub trait FixedSliceExt<T, D: TensorDim> { | ||
/// Returns a tensor slice with dimensions specified by `ND`, starting at the given `offsets`. | ||
fn fixed_slice<ND: TensorDim + XlaDim>(&self, offsets: &[usize]) -> Tensor<T, ND, Op>; | ||
} | ||
|
||
|
@@ -389,6 +401,8 @@ impl<T, D: TensorDim> AsBuffer for Tensor<T, D, Buffer> { | |
} | ||
} | ||
|
||
/// Trait for mapping dimensions in tensor operations. | ||
/// Allows for transforming and replacing dimensions in tensor types. | ||
pub trait MapDim<D> { | ||
type Item: TensorDim; | ||
type MappedDim: TensorDim; | ||
|
@@ -409,10 +423,12 @@ impl<D: TensorDim> MapDim<D> for Mapped { | |
const MAPPED_DIM: usize = 0; | ||
} | ||
|
||
/// Trait for default dimension mapping in tensor operations. | ||
pub trait DefaultMap | ||
where | ||
Self: Sized, | ||
{ | ||
/// The default dimension mapping for the implementing type. | ||
type DefaultMapDim: MapDim<Self>; | ||
} | ||
|
||
|
@@ -495,6 +511,7 @@ impl<const N: usize> DefaultMap for Const<N> { | |
type DefaultMapDim = Mapped; | ||
} | ||
|
||
/// Trait representing the concatenation of dimensions. | ||
pub trait DimConcat<A, B> { | ||
type Output; | ||
} | ||
|
@@ -511,6 +528,7 @@ impl<A: NonScalarDim + NonTupleDim, B: NonScalarDim + NonTupleDim> DimConcat<A, | |
type Output = (A, B); | ||
} | ||
|
||
/// Represents types that are not tuples in dimension concatenation contexts. | ||
pub trait NonTupleDim {} | ||
|
||
impl NonTupleDim for ScalarDim {} | ||
|
@@ -594,6 +612,7 @@ impl<T: TensorItem, D: TensorDim + DefaultMap> Tensor<T, D, crate::Op> { | |
} | ||
} | ||
|
||
/// Trait for broadcasting dimensions in tensor operations, used to unify dimensions for element-wise operations. | ||
pub trait BroadcastDim<D1, D2> { | ||
type Output: TensorDim; | ||
} | ||
|
@@ -628,6 +647,7 @@ impl<D: TensorDim + NotConst1> BroadcastDim<ScalarDim, D> for ShapeConstraint { | |
type Output = D; | ||
} | ||
|
||
/// Marker trait for types not equivalent to `Const<1>`, used in broadcasting logic. | ||
pub trait NotConst1 {} | ||
|
||
seq_macro::seq!(N in 2..99 { | ||
|
@@ -640,9 +660,11 @@ impl<T: TensorItem, D: TensorDim> Tensor<T, D> { | |
} | ||
} | ||
|
||
/// Trait for indexing into tensors, allowing for the extraction of sub-tensors or elements based on indices. | ||
pub trait TensorIndex<T, D: TensorDim> { | ||
type Output; | ||
|
||
/// Performs the indexing operation on a tensor, returning the result. | ||
fn index(self, tensor: Tensor<T, D>) -> Self::Output; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"underlying representation
P
" is a bit more accurate.