Skip to content

Commit 6e1bd3e

Browse files
committed
2021 Version
1 parent 71351e7 commit 6e1bd3e

25 files changed

+2110
-0
lines changed

.gitignore

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
pip-wheel-metadata/
24+
share/python-wheels/
25+
*.egg-info/
26+
.installed.cfg
27+
*.egg
28+
MANIFEST
29+
30+
# PyInstaller
31+
# Usually these files are written by a python script from a template
32+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
33+
*.manifest
34+
*.spec
35+
36+
# Installer logs
37+
pip-log.txt
38+
pip-delete-this-directory.txt
39+
40+
# Unit test / coverage reports
41+
htmlcov/
42+
.tox/
43+
.nox/
44+
.coverage
45+
.coverage.*
46+
.cache
47+
nosetests.xml
48+
coverage.xml
49+
*.cover
50+
*.py,cover
51+
.hypothesis/
52+
.pytest_cache/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
target/
76+
77+
# Jupyter Notebook
78+
.ipynb_checkpoints
79+
80+
# IPython
81+
profile_default/
82+
ipython_config.py
83+
84+
# pyenv
85+
.python-version
86+
87+
# pipenv
88+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
90+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
91+
# install all needed dependencies.
92+
#Pipfile.lock
93+
94+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
95+
__pypackages__/
96+
97+
# Celery stuff
98+
celerybeat-schedule
99+
celerybeat.pid
100+
101+
# SageMath parsed files
102+
*.sage.py
103+
104+
# Environments
105+
.env
106+
.venv
107+
env/
108+
venv/
109+
ENV/
110+
env.bak/
111+
venv.bak/
112+
113+
# Spyder project settings
114+
.spyderproject
115+
.spyproject
116+
117+
# Rope project settings
118+
.ropeproject
119+
120+
# mkdocs documentation
121+
/site
122+
123+
# mypy
124+
.mypy_cache/
125+
.dmypy.json
126+
dmypy.json
127+
128+
# Pyre type checker
129+
.pyre/
130+
*.\#*

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# MiniTorch Module 0
2+
3+
<img src="https://minitorch.github.io/_images/match.png" width="100px">
4+
5+
* Docs: https://minitorch.github.io/
6+
7+
* Overview: https://minitorch.github.io/module0.html

minitorch/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .module import * # noqa: F401,F403
2+
from .testing import * # noqa: F401,F403
3+
from .datasets import * # noqa: F401,F403

minitorch/datasets.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from dataclasses import dataclass
2+
import random
3+
4+
5+
def make_pts(N):
6+
X = []
7+
for i in range(N):
8+
x_1 = random.random()
9+
x_2 = random.random()
10+
X.append((x_1, x_2))
11+
return X
12+
13+
14+
@dataclass
15+
class Graph:
16+
N: int
17+
X: list
18+
y: list
19+
20+
21+
def simple(N):
22+
X = make_pts(N)
23+
y = []
24+
for x_1, x_2 in X:
25+
y1 = 1 if x_1 < 0.5 else 0
26+
y.append(y1)
27+
return Graph(N, X, y)
28+
29+
30+
def diag(N):
31+
X = make_pts(N)
32+
y = []
33+
for x_1, x_2 in X:
34+
y1 = 1 if x_1 + x_2 < 0.5 else 0
35+
y.append(y1)
36+
return Graph(N, X, y)
37+
38+
39+
def split(N):
40+
X = make_pts(N)
41+
y = []
42+
for x_1, x_2 in X:
43+
y1 = 1 if x_1 < 0.2 or x_1 > 0.8 else 0
44+
y.append(y1)
45+
return Graph(N, X, y)
46+
47+
48+
def xor(N):
49+
X = make_pts(N)
50+
y = []
51+
for x_1, x_2 in X:
52+
y1 = 1 if ((x_1 < 0.5 and x_2 > 0.5) or (x_1 > 0.5 and x_2 < 0.5)) else 0
53+
y.append(y1)
54+
return Graph(N, X, y)
55+
56+
57+
datasets = {"Simple": simple, "Diag": diag, "Split": split, "Xor": xor}

minitorch/module.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
class Module:
2+
"""
3+
Modules form a tree that store parameters and other
4+
submodules. They make up the basis of neural network stacks.
5+
6+
Attributes:
7+
_modules (dict of name x :class:`Module`): Storage of the child modules
8+
_parameters (dict of name x :class:`Parameter`): Storage of the module's parameters
9+
training (bool): Whether the module is in training mode or evaluation mode
10+
11+
"""
12+
13+
def __init__(self):
14+
self._modules = {}
15+
self._parameters = {}
16+
self.training = True
17+
18+
def modules(self):
19+
"Return the direct child modules of this module."
20+
return self.__dict__["_modules"].values()
21+
22+
def train(self):
23+
"Set the mode of this module and all descendent modules to `train`."
24+
# TODO: Implement for Task 0.4.
25+
raise NotImplementedError('Need to implement for Task 0.4')
26+
27+
def eval(self):
28+
"Set the mode of this module and all descendent modules to `eval`."
29+
# TODO: Implement for Task 0.4.
30+
raise NotImplementedError('Need to implement for Task 0.4')
31+
32+
def named_parameters(self):
33+
"""
34+
Collect all the parameters of this module and its descendents.
35+
36+
37+
Returns:
38+
list of pairs: Contains the name and :class:`Parameter` of each ancestor parameter.
39+
"""
40+
# TODO: Implement for Task 0.4.
41+
raise NotImplementedError('Need to implement for Task 0.4')
42+
43+
def parameters(self):
44+
"Enumerate over all the parameters of this module and its descendents."
45+
# TODO: Implement for Task 0.4.
46+
raise NotImplementedError('Need to implement for Task 0.4')
47+
48+
def add_parameter(self, k, v):
49+
"""
50+
Manually add a parameter. Useful helper for scalar parameters.
51+
52+
Args:
53+
k (str): Local name of the parameter.
54+
v (value): Value for the parameter.
55+
56+
Returns:
57+
Parameter: Newly created parameter.
58+
"""
59+
val = Parameter(v, k)
60+
self.__dict__["_parameters"][k] = val
61+
return val
62+
63+
def __setattr__(self, key, val):
64+
if isinstance(val, Parameter):
65+
self.__dict__["_parameters"][key] = val
66+
elif isinstance(val, Module):
67+
self.__dict__["_modules"][key] = val
68+
else:
69+
super().__setattr__(key, val)
70+
71+
def __getattr__(self, key):
72+
if key in self.__dict__["_parameters"]:
73+
return self.__dict__["_parameters"][key]
74+
75+
if key in self.__dict__["_modules"]:
76+
return self.__dict__["_modules"][key]
77+
78+
def __call__(self, *args, **kwargs):
79+
return self.forward(*args, **kwargs)
80+
81+
def forward(self):
82+
assert False, "Not Implemented"
83+
84+
def __repr__(self):
85+
def _addindent(s_, numSpaces):
86+
s = s_.split("\n")
87+
if len(s) == 1:
88+
return s_
89+
first = s.pop(0)
90+
s = [(numSpaces * " ") + line for line in s]
91+
s = "\n".join(s)
92+
s = first + "\n" + s
93+
return s
94+
95+
child_lines = []
96+
97+
for key, module in self._modules.items():
98+
mod_str = repr(module)
99+
mod_str = _addindent(mod_str, 2)
100+
child_lines.append("(" + key + "): " + mod_str)
101+
lines = child_lines
102+
103+
main_str = self.__class__.__name__ + "("
104+
if lines:
105+
# simple one-liner info, which most builtin Modules will use
106+
main_str += "\n " + "\n ".join(lines) + "\n"
107+
108+
main_str += ")"
109+
return main_str
110+
111+
112+
class Parameter:
113+
"""
114+
A Parameter is a special container stored in a :class:`Module`.
115+
116+
It is designed to hold a :class:`Variable`, but we allow it to hold
117+
any value for testing.
118+
"""
119+
120+
def __init__(self, x=None, name=None):
121+
self.value = x
122+
self.name = name
123+
if hasattr(x, "requires_grad_"):
124+
self.value.requires_grad_(True)
125+
if self.name:
126+
self.value.name = self.name
127+
128+
def update(self, x):
129+
"Update the parameter value."
130+
self.value = x
131+
if hasattr(x, "requires_grad_"):
132+
self.value.requires_grad_(True)
133+
if self.name:
134+
self.value.name = self.name
135+
136+
def __repr__(self):
137+
return repr(self.value)
138+
139+
def __str__(self):
140+
return str(self.value)

0 commit comments

Comments
 (0)