Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feedback #1

Open
wants to merge 7 commits into
base: feedback
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[![Review Assignment Due Date](https://classroom.github.com/assets/deadline-readme-button-22041afd0340ce965d47ae6ef1cefeee28c7c493a6346c4f15d667ab976d596c.svg)](https://classroom.github.com/a/YFgwt0yY)
# MiniTorch Module 2

<img src="https://minitorch.github.io/minitorch.svg" width="50%">
Expand All @@ -15,4 +16,32 @@ python sync_previous_module.py previous-module-dir current-module-dir

The files that will be synced are:

minitorch/operators.py minitorch/module.py minitorch/autodiff.py minitorch/scalar.py minitorch/scalar_functions.py minitorch/module.py project/run_manual.py project/run_scalar.py project/datasets.py
minitorch/operators.py minitorch/module.py minitorch/autodiff.py minitorch/scalar.py minitorch/scalar_functions.py minitorch/module.py project/run_manual.py project/run_scalar.py project/datasets.py

## Simple Dataset
<img src="images/Simple/SimpleDataSet_Model.png">
<img src="images/Simple/SimpleDataSet_Hyper.png">
<img src="images/Simple/SimpleDataSet_LossGraph.png">
<img src="images/Simple/SimpleDataSet_LossTable.png">
<img src="images/Simple/SimpleDataSet.png">

## Diag Dataset
<img src="images/Diag/DiagDataSet_Model.png">
<img src="images/Diag/DiagDataSet_Hyper.png">
<img src="images/Diag/DiagDataSet_LossGraph.png">
<img src="images/Diag/DiagDataSet_LossTable.png">
<img src="images/Diag/DiagDataSet.png">

## Split Dataset
<img src="images/Split/SplitDataSet_Model.png">
<img src="images/Split/SplitDataSet_Hyper.png">
<img src="images/Split/SplitDataSet_LossGraph.png">
<img src="images/Split/SplitDataSet_LossTable.png">
<img src="images/Split/SplitDataSet.png">

## XOR Dataset
<img src="images/XOR/XORDataSet_Model.png">
<img src="images/XOR/XORDataSet_Hyper.png">
<img src="images/XOR/XORDataSet_LossGraph.png">
<img src="images/XOR/XORDataSet_LossTable.png">
<img src="images/XOR/XORDataSet.png">
Binary file added images/Diag/DiagDataSet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Diag/DiagDataSet_Hyper.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Diag/DiagDataSet_LossGraph.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Diag/DiagDataSet_LossTable.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Diag/DiagDataSet_Model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Simple/SimpleDataSet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Simple/SimpleDataSet_Hyper.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Simple/SimpleDataSet_LossGraph.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Simple/SimpleDataSet_LossTable.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Simple/SimpleDataSet_Model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Split/SplitDataSet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Split/SplitDataSet_Hyper.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Split/SplitDataSet_LossGraph.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Split/SplitDataSet_LossTable.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/Split/SplitDataSet_Model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/XOR/XORDataSet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/XOR/XORDataSet_Hyper.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/XOR/XORDataSet_LossGraph.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/XOR/XORDataSet_LossTable.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/XOR/XORDataSet_Model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion minitorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .testing import MathTest, MathTestVariable # type: ignore # noqa: F401,F403
from .testing import MathTest, MathTestVariable # type: ignore # noqa: F401,F403, D104
from .tensor_data import * # noqa: F401,F403
from .tensor import * # noqa: F401,F403
from .tensor_ops import * # noqa: F401,F403
Expand Down
95 changes: 83 additions & 12 deletions minitorch/autodiff.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Iterable, List, Tuple, Protocol
from typing import Any, Iterable, Tuple, Protocol


# ## Task 1.1
Expand All @@ -25,26 +25,46 @@ def central_difference(f: Any, *vals: Any, arg: int = 0, epsilon: float = 1e-6)
An approximation of $f'_i(x_0, \ldots, x_{n-1})$

"""
raise NotImplementedError("Need to include this file from past assignment.")
# TODO: Implement for Task 1.1.

up_perturbed = vals[:arg] + (vals[arg] + epsilon,) + vals[arg + 1 :]
down_perturbed = vals[:arg] + (vals[arg] - epsilon,) + vals[arg + 1 :]

slope = (f(*up_perturbed) - f(*down_perturbed)) / (2 * epsilon)

return slope


variable_count = 1


class Variable(Protocol):
def accumulate_derivative(self, x: Any) -> None: ...
def accumulate_derivative(self, x: Any) -> None:
"""Accumulates the derivative of the output with respect to this variable."""

...

@property
def unique_id(self) -> int: ...
def unique_id(self) -> int:
"""Returns the unique identifier of this variable."""
...

def is_leaf(self) -> bool: ...
def is_leaf(self) -> bool:
"""Returns True if this variable is a leaf node."""
...

def is_constant(self) -> bool: ...
def is_constant(self) -> bool:
"""Returns True if this variable is a constant."""
...

@property
def parents(self) -> Iterable["Variable"]: ...
def parents(self) -> Iterable["Variable"]:
"""Returns the parents of this variable."""
...

def chain_rule(self, d_output: Any) -> Iterable[Tuple[Variable, Any]]: ...
def chain_rule(self, d_output: Any) -> Iterable[Tuple[Variable, Any]]:
"""Computes gradients of inputs using the chain rule."""
...


def topological_sort(variable: Variable) -> Iterable[Variable]:
Expand All @@ -59,7 +79,30 @@ def topological_sort(variable: Variable) -> Iterable[Variable]:
Non-constant Variables in topological order starting from the right.

"""
raise NotImplementedError("Need to include this file from past assignment.")
# create a set to store visited variables
visited = set()

# create a list to store sorted variables
sorted_vars = []

def visit(node: Variable) -> None:
"""Visits the variable and its parents recursively."""
if node.unique_id in visited or node.is_constant():
return
if not node.is_leaf():
# visit all the parents of the variable
for parent in node.parents:
if not parent.is_constant():
visit(parent)
# mark the variable as visited
visited.add(node.unique_id)

# once all the parents have been visited, add the variable to the sorted list
sorted_vars.insert(0, node)

visit(variable)

return sorted_vars


def backpropagate(variable: Variable, deriv: Any) -> None:
Expand All @@ -69,12 +112,39 @@ def backpropagate(variable: Variable, deriv: Any) -> None:
Args:
----
variable: The right-most variable
deriv : Its derivative that we want to propagate backward to the leaves.
deriv: Its derivative that we want to propagate backward to the leaves.

No return. Should write to its results to the derivative values of each leaf through `accumulate_derivative`.
Returns:
-------
No return. Should write to its results to the derivative values of each leaf through `accumulate_derivative`.

"""
raise NotImplementedError("Need to include this file from past assignment.")
# call topological sort
sorted_vars = topological_sort(variable)

# create dictionary to store variables and their derivatives
derivatives = {}
derivatives[variable.unique_id] = deriv

# iterate through the sorted variables in backward order
for var in sorted_vars:
# if the variable is a leaf node, accumulate the derivative
if var.is_leaf():
# accumulate the derivative
var.accumulate_derivative(derivatives[var.unique_id])

# if the variable is not a leaf node
else:
# call .chain_rule on the last function in the history of the variable
grads = var.chain_rule(derivatives[var.unique_id])

# loop through all the Scalars+derivatives provided by the chain rule
for parent, derivative in grads:
if parent.is_constant():
continue
# accumulate derivatives for the Scalar in the dictionary
derivatives.setdefault(parent.unique_id, 0)
derivatives[parent.unique_id] += derivative


@dataclass
Expand All @@ -92,4 +162,5 @@ def save_for_backward(self, *values: Any) -> None:

@property
def saved_tensors(self) -> Tuple[Any, ...]:
"""Returns the saved values."""
return self.saved_values
80 changes: 80 additions & 0 deletions minitorch/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,17 @@


def make_pts(N: int) -> List[Tuple[float, float]]:
"""Generate a list of N random 2D points.

Args:
----
N: The number of points to generate.

Returns:
-------
A list of N points, each represented as a tuple of two floats.

"""
X = []
for i in range(N):
x_1 = random.random()
Expand All @@ -21,6 +32,17 @@ class Graph:


def simple(N: int) -> Graph:
"""Generate a simple dataset where the label is 1 if the first coordinate is less than 0.5 and 0 otherwise.

Args:
----
N: The number of points to generate.

Returns:
-------
A Graph object containing the generated data.

"""
X = make_pts(N)
y = []
for x_1, x_2 in X:
Expand All @@ -30,6 +52,17 @@ def simple(N: int) -> Graph:


def diag(N: int) -> Graph:
"""Generate a dataset where the label is 1 if the sum of the coordinates is less than 0.5 and 0 otherwise.

Args:
----
N: The number of points to generate.

Returns:
-------
A Graph object containing the generated data.

"""
X = make_pts(N)
y = []
for x_1, x_2 in X:
Expand All @@ -39,6 +72,17 @@ def diag(N: int) -> Graph:


def split(N: int) -> Graph:
"""Generate a dataset where the label is 1 if the first coordinate is less than 0.2 or greater than 0.8 and 0 otherwise.

Args:
----
N: The number of points to generate.

Returns:
-------
A Graph object containing the generated data.

"""
X = make_pts(N)
y = []
for x_1, x_2 in X:
Expand All @@ -48,6 +92,17 @@ def split(N: int) -> Graph:


def xor(N: int) -> Graph:
"""Generate a dataset where the label is 1 if the first coordinate is less than 0.5 and the second coordinate is greater than 0.5 or vice versa and 0 otherwise.

Args:
----
N: The number of points to generate.

Returns:
-------
A Graph object containing the generated data.

"""
X = make_pts(N)
y = []
for x_1, x_2 in X:
Expand All @@ -57,6 +112,17 @@ def xor(N: int) -> Graph:


def circle(N: int) -> Graph:
"""Generate a dataset where the label is 1 if the point's distance from the origin (after shifting by 0.5) is greater than sqrt(0.1), and 0 otherwise.

Args:
----
N: The number of points to generate.

Returns:
-------
A Graph object containing the generated data.

"""
X = make_pts(N)
y = []
for x_1, x_2 in X:
Expand All @@ -67,10 +133,24 @@ def circle(N: int) -> Graph:


def spiral(N: int) -> Graph:
"""Generate a spiral dataset with N points where the first half of the points are in class 0 and the second half are in class 1.

Args:
----
N: The number of points to generate.

Returns:
-------
A Graph object containing the generated data.

"""

def x(t: float) -> float:
"""Calculate the x-coordinate of a point on a spiral."""
return t * math.cos(t) / 20.0

def y(t: float) -> float:
"""Calculate the y-coordinate of a point on a spiral."""
return t * math.sin(t) / 20.0

X = [
Expand Down
47 changes: 40 additions & 7 deletions minitorch/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,29 @@ def modules(self) -> Sequence[Module]:
return list(m.values())

def train(self) -> None:
"""Set the mode of this module and all descendent modules to `train`."""
raise NotImplementedError("Need to include this file from past assignment.")
"""Set the `training` flag of this and descendent to true.

Returns
-------
None.

"""
self.training = True
for module in self.modules():
module.train()

def eval(self) -> None:
"""Set the mode of this module and all descendent modules to `eval`."""
raise NotImplementedError("Need to include this file from past assignment.")
"""Set the `training` flag of this and descendent to false.

Returns
-------
None.

"""
self.training = False

for module in self.modules():
module.training = False

def named_parameters(self) -> Sequence[Tuple[str, Parameter]]:
"""Collect all the parameters of this module and its descendents.
Expand All @@ -45,11 +62,26 @@ def named_parameters(self) -> Sequence[Tuple[str, Parameter]]:
The name and `Parameter` of each ancestor parameter.

"""
raise NotImplementedError("Need to include this file from past assignment.")
parameters = {}

for key, val in self._parameters.items():
parameters[key] = val

for name, module in self._modules.items():
for key, val in module.named_parameters():
parameters[f"{name}.{key}"] = val

return list(parameters.items())

def parameters(self) -> Sequence[Parameter]:
"""Enumerate over all the parameters of this module and its descendents."""
raise NotImplementedError("Need to include this file from past assignment.")
"""Enumerate over all the parameters of this module and its descendents.

Returns
-------
The `Parameter` of this module and its descendents.

"""
return [val for _, val in self.named_parameters()]

def add_parameter(self, k: str, v: Any) -> Parameter:
"""Manually add a parameter. Useful helper for scalar parameters.
Expand Down Expand Up @@ -85,6 +117,7 @@ def __getattr__(self, key: str) -> Any:
return None

def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""Turns the instance into a callable object."""
return self.forward(*args, **kwargs)

def __repr__(self) -> str:
Expand Down
Loading
Loading