Skip to content

Commit

Permalink
made dim arg optional in forward func for Sum/All
Browse files Browse the repository at this point in the history
  • Loading branch information
TaykhoomDalal committed Oct 11, 2024
1 parent fbf1568 commit ecae766
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
19 changes: 6 additions & 13 deletions minitorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,16 +342,13 @@ def __rmul__(self, b: TensorLike) -> Tensor:
"""Multiplication of two tensors"""
return Mul.apply(self._ensure_tensor(b), self)

def all(self, dim: Optional[TensorLike] = None) -> Tensor:
def all(self, dim: Optional[int] = None) -> Tensor:
"""All elements of a tensor"""
if dim is not None:
dim_tensor = self._ensure_tensor(dim)
return All.apply(self, dim_tensor)
else:
# all elements by passing a special value using -1
# Wrap -1 as a Tensor
all_dims_tensor = self._ensure_tensor(-1)
return All.apply(self, all_dims_tensor)
return All.apply(self)

def is_close(self, b: TensorLike) -> Tensor:
"""Check if two tensors are close"""
Expand Down Expand Up @@ -379,21 +376,17 @@ def sum(self, dim: Optional[int] = None) -> Tensor:
dim_tensor = self._ensure_tensor(dim)
return Sum.apply(self, dim_tensor)
else:
# sum over all elements by passing a special value using -1
# Wrap -1 as a Tensor
all_dims_tensor = self._ensure_tensor(-1)
return Sum.apply(self, all_dims_tensor)
# sum over all elements
return Sum.apply(self)

def mean(self, dim: Optional[int] = None) -> Tensor:
"""Mean of a tensor"""
if dim is not None:
dim_tensor = self._ensure_tensor(dim)
return Sum.apply(self, dim_tensor) / self.size
else:
# mean over all elements by passing a special value using -1
# Wrap -1 as a Tensor
all_dims_tensor = self._ensure_tensor(-1)
return Sum.apply(self, all_dims_tensor) / self.size
# mean over all elements
return Sum.apply(self) / self.size

def permute(self, *dims: int) -> Tensor:
"""Permute the dimensions of the tensor."""
Expand Down
29 changes: 18 additions & 11 deletions minitorch/tensor_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .tensor_ops import SimpleBackend, TensorBackend

if TYPE_CHECKING:
from typing import Any, List, Tuple
from typing import Any, List, Tuple, Optional, Union

from .tensor import Tensor
from .tensor_data import UserIndex, UserShape
Expand Down Expand Up @@ -103,13 +103,12 @@ def backward(ctx: Context, grad_output: Tensor) -> Tuple[Tensor, Tensor]:

class All(Function):
@staticmethod
def forward(ctx: Context, a: Tensor, dim: Tensor) -> Tensor:
def forward(ctx: Context, a: Tensor, dim: Optional[Tensor] = None) -> Tensor:
"""Return 1 if all are true"""
all_dim = int(dim.item())

if all_dim == -1:
if dim is None:
return a.f.mul_reduce(a.contiguous().view(int(operators.prod(a.shape))), 0)
else:
all_dim = int(dim.item())
return a.f.mul_reduce(a, all_dim)


Expand Down Expand Up @@ -194,21 +193,29 @@ def backward(ctx: Context, grad_output: Tensor) -> Tensor:

class Sum(Function):
@staticmethod
def forward(ctx: Context, t1: Tensor, dim: Tensor) -> Tensor:
def forward(ctx: Context, t1: Tensor, dim: Optional[Tensor] = None) -> Tensor:
"""Sum function $f(x) = sum(x)$"""
sum_dim = int(dim.item())

if sum_dim == -1:
if dim is None:
ctx.save_for_backward(dim)
return t1.f.add_reduce(
t1.contiguous().view(int(operators.prod(t1.shape))), 0
)
else:
sum_dim = int(dim.item())
ctx.save_for_backward(sum_dim)
return t1.f.add_reduce(t1, sum_dim)

@staticmethod
def backward(ctx: Context, grad_output: Tensor) -> Tuple[Tensor, float]:
def backward(
ctx: Context, grad_output: Tensor
) -> Union[Tensor, Tuple[Tensor, float]]:
"""Sum backward."""
return grad_output, 0.0
(dim,) = ctx.saved_values

if dim is None:
return grad_output
else:
return grad_output, 0.0


class LT(Function):
Expand Down

0 comments on commit ecae766

Please sign in to comment.