Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added doc-comments to tensor.rs #12

Merged
merged 3 commits into from
Mar 31, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions libs/nox/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//! Provides the core functionality for manipulating tensors.
use crate::{
AsBuffer, Buffer, Field, FromOp, IntoOp, Noxpr, NoxprScalarExt, Op, Param, Scalar, Vector,
};
Expand All @@ -10,6 +11,8 @@ use std::{
};
use xla::{ArrayElement, ElementType, NativeType};

/// Represents a tensor with a specific type `T`, dimensionality `D`, and parameterization `P`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"underlying representation P" is a bit more accurate.

/// `P` dictates the underlying representation and operations available on the tensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"operations available on the tensor" is not quite accurate. Would just drop this sentence, and change the previous sentence to include the "underlying representation" part.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

#[repr(transparent)]
pub struct Tensor<T, D: TensorDim, P: Param = Op> {
pub(crate) inner: P::Inner,
Expand All @@ -27,6 +30,8 @@ where
}
}

/// Trait for items that can be represented as tensors.
/// Specifies the type of the item, its tensor representation, and dimensionality.
pub trait TensorItem {
type Item: FromOp;
type Tensor<D>
Expand Down Expand Up @@ -62,8 +67,10 @@ impl<T, D: TensorDim> FromOp for Tensor<T, D> {
}
}

/// Trait for collapsing a tensor into a simpler form, typically by reducing its dimensionality.
pub trait Collapse {
type Out;
/// Collapses the tensor into a simpler form.
fn collapse(self) -> Self::Out;
}

Expand Down Expand Up @@ -139,8 +146,11 @@ impl<T, D: TensorDim> IntoOp for Tensor<T, D, Op> {
}
}

/// Represents a dimensionality of a tensor. This trait is a marker for types that can specify tensor dimensions.
pub trait TensorDim {}
/// Represents non-scalar dimensions, i.e., dimensions other than `()`.
pub trait NonScalarDim {}
/// Represents constant dimensions, specified at compile-time.
pub trait ConstDim {}

pub type ScalarDim = ();
Expand Down Expand Up @@ -328,7 +338,9 @@ impl<T, D: TensorDim + XlaDim> FixedSliceExt<T, D> for Tensor<T, D, Op> {
}
}

/// Extension trait for tensors supporting fixed-size slicing operations.
pub trait FixedSliceExt<T, D: TensorDim> {
/// Returns a tensor slice with dimensions specified by `ND`, starting at the given `offsets`.
fn fixed_slice<ND: TensorDim + XlaDim>(&self, offsets: &[usize]) -> Tensor<T, ND, Op>;
}

Expand Down Expand Up @@ -389,6 +401,8 @@ impl<T, D: TensorDim> AsBuffer for Tensor<T, D, Buffer> {
}
}

/// Trait for mapping dimensions in tensor operations.
/// Allows for transforming and replacing dimensions in tensor types.
pub trait MapDim<D> {
type Item: TensorDim;
type MappedDim: TensorDim;
Expand All @@ -409,10 +423,12 @@ impl<D: TensorDim> MapDim<D> for Mapped {
const MAPPED_DIM: usize = 0;
}

/// Trait for default dimension mapping in tensor operations.
pub trait DefaultMap
where
Self: Sized,
{
/// The default dimension mapping for the implementing type.
type DefaultMapDim: MapDim<Self>;
}

Expand Down Expand Up @@ -495,6 +511,7 @@ impl<const N: usize> DefaultMap for Const<N> {
type DefaultMapDim = Mapped;
}

/// Trait representing the concatenation of dimensions.
pub trait DimConcat<A, B> {
type Output;
}
Expand All @@ -511,6 +528,7 @@ impl<A: NonScalarDim + NonTupleDim, B: NonScalarDim + NonTupleDim> DimConcat<A,
type Output = (A, B);
}

/// Represents types that are not tuples in dimension concatenation contexts.
pub trait NonTupleDim {}

impl NonTupleDim for ScalarDim {}
Expand Down Expand Up @@ -594,6 +612,7 @@ impl<T: TensorItem, D: TensorDim + DefaultMap> Tensor<T, D, crate::Op> {
}
}

/// Trait for broadcasting dimensions in tensor operations, used to unify dimensions for element-wise operations.
pub trait BroadcastDim<D1, D2> {
type Output: TensorDim;
}
Expand Down Expand Up @@ -628,6 +647,7 @@ impl<D: TensorDim + NotConst1> BroadcastDim<ScalarDim, D> for ShapeConstraint {
type Output = D;
}

/// Marker trait for types not equivalent to `Const<1>`, used in broadcasting logic.
pub trait NotConst1 {}

seq_macro::seq!(N in 2..99 {
Expand All @@ -640,9 +660,11 @@ impl<T: TensorItem, D: TensorDim> Tensor<T, D> {
}
}

/// Trait for indexing into tensors, allowing for the extraction of sub-tensors or elements based on indices.
pub trait TensorIndex<T, D: TensorDim> {
type Output;

/// Performs the indexing operation on a tensor, returning the result.
fn index(self, tensor: Tensor<T, D>) -> Self::Output;
}

Expand Down