Skip to content

Commit

Permalink
Revert "switch beautiful_mnist to use new optimizer [pr] (tinygrad#8231
Browse files Browse the repository at this point in the history
…)" (tinygrad#8233)

This reverts commit e9ee39d.
  • Loading branch information
geohot authored Dec 14, 2024
1 parent e9ee39d commit 37fa38d
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 61 deletions.
3 changes: 1 addition & 2 deletions docs/abstractions3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,10 @@ def model(x): return x.flatten(1).dot(l1.T).relu().dot(l2.T)
from tinygrad.nn.optim import SGD
optim = SGD([l1, l2])

Tensor.training = True
X, Y = X_train[(samples:=Tensor.randint(128, high=X_train.shape[0]))], Y_train[samples]
optim.zero_grad()
model(X).sparse_categorical_crossentropy(Y).backward()
optim.schedule_step() # this will step the optimizer without running realize
optim._step() # this will step the optimizer without running realize

# *****
# 3. Create a schedule.
Expand Down
4 changes: 0 additions & 4 deletions docs/tensor/properties.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,4 @@
::: tinygrad.Tensor.shard_
::: tinygrad.Tensor.contiguous
::: tinygrad.Tensor.contiguous_backward

## Gradient

::: tinygrad.Tensor.gradient
::: tinygrad.Tensor.backward
4 changes: 3 additions & 1 deletion examples/beautiful_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
@TinyJit
@Tensor.train()
def train_step() -> Tensor:
opt.zero_grad()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
opt.step(loss:=model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]))
loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
opt.step()
return loss

@TinyJit
Expand Down
24 changes: 4 additions & 20 deletions test/unit/test_gradient.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from typing import Callable
import unittest, math
import numpy as np
import jax
import jax.numpy as jnp
from tinygrad import Tensor
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp
from tinygrad.gradient import gradient
from tinygrad.nn.optim import SGD, Adam

class TestGradient(unittest.TestCase):
def _cmp_nan_okay(self, x, y):
Expand Down Expand Up @@ -62,29 +60,15 @@ def test_big_chain(self): self._test_two_input_function(lambda x,y: (1.0/x*y)+x*

class TestTensorGradient(unittest.TestCase):
def test_example(self):
x = Tensor.eye(3)
# NOTE: this contiguous shouldn't be needed. gradient should go to base
x = Tensor.eye(3).contiguous()
y = Tensor([[2.0,0,-2.0]])
z = y.matmul(x).sum()
dx, dy = z.gradient(x, y)
print(dx.tolist())
print(dy.tolist())
self.assertListEqual(dx.tolist(), [[2.0, 2.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -2.0, -2.0]])
self.assertListEqual(dy.tolist(), [[1.0, 1.0, 1.0]])

def test_raises(self):
x = Tensor([1.0, 2.0, 3.0])
w = Tensor.randn((3,))
with self.assertRaises(RuntimeError): x.sum().gradient(w)

def test_optim(self):
with Tensor.train():
w = Tensor([1.0, 2.0, 3.0])
SGD([w], lr=0.1).step(w.sum())
np.testing.assert_almost_equal(w.tolist(), [0.9, 1.9, 2.9])

def test_optim_rng(self):
with Tensor.train():
x = Tensor([1.0, 2.0, 3.0])
w = Tensor.randn((3,))
Adam([w], lr=0.1).step((x*w).sum())

if __name__ == '__main__':
unittest.main()
7 changes: 2 additions & 5 deletions tinygrad/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@ def reduce_gradient(ctx:UOp, ret:UOp):
# TODO: this cast can be removed by putting the casts around the EXPAND
(UPat(Ops.EXPAND, name="ret"), lambda ctx, ret:
(ctx.cast(sum_acc_dtype(ctx.dtype)).r(Ops.ADD, tuple(i for i,(si,so) in enumerate(zip(ret.src[0].shape, ret.arg)) if si!=so)).cast(ctx.dtype),)),

# there's no gradient for...is this ASSIGN?
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.BUFFER_VIEW))), lambda: (None, None)),
])

# copied from tensor.py, get relevant toposort of gradients
Expand All @@ -59,8 +56,8 @@ def gradient(root:UOp, targets:list[UOp]) -> list[UOp]:
for t0 in reversed(_deepwalk(root, targets)):
if t0 not in grads: continue
lgrads: tuple[UOp|None, ...]|None = cast(tuple[UOp, ...]|None, pm_gradient.rewrite(t0, ctx=grads[t0]))
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}\n\nin {str(t0)[0:1000]}...")
assert len(lgrads) == len(t0.src), f"got {len(lgrads)} gradient, expected {len(t0.src)}"
if lgrads is None: raise RuntimeError(f"failed to compute gradient for {t0.op}")
assert len(lgrads) == len(t0.src)
for k,v in zip(t0.src, lgrads):
if v is None: continue
if k in grads: grads[k] = grads[k] + v
Expand Down
35 changes: 19 additions & 16 deletions tinygrad/nn/optim.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# sorted in order of increasing complexity
from typing import List, Optional
from tinygrad.helpers import dedup, flatten, getenv, unwrap
from typing import List
from tinygrad.helpers import dedup, flatten, getenv
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes, least_upper_dtype

Expand All @@ -27,20 +27,20 @@ def zero_grad(self):
"""
for param in self.params: param.grad = None

def step(self, loss:Optional[Tensor]=None):
def step(self):
"""
Performs a single optimization step.
"""
Tensor.realize(*self.schedule_step(loss))
def schedule_step(self, loss:Optional[Tensor]=None) -> List[Tensor]:
Tensor.realize(*self.schedule_step())
def schedule_step(self) -> List[Tensor]:
"""
Returns the tensors that need to be realized to perform a single optimization step.
"""
assert Tensor.training, (
f"""Tensor.training={Tensor.training}, Tensor.training must be enabled to use the optimizer.
- help: Consider setting Tensor.training=True before calling Optimizer.step().""")
return self._step([unwrap(t.grad) for t in self.params] if loss is None else loss.gradient(*self.params))+self.params+self.buffers
def _step(self, grads:List[Tensor]) -> List[Tensor]: raise NotImplementedError
return self._step()+self.params+self.buffers
def _step(self) -> List[Tensor]: raise NotImplementedError

class OptimizerGroup(Optimizer):
"""
Expand All @@ -51,7 +51,7 @@ def __init__(self, *optimizers: Optimizer): # pylint: disable=super-init-not-cal
self.params, self.buffers = flatten([o.params for o in self.optimizers]), flatten([o.buffers for o in self.optimizers])
def __getitem__(self, i): return self.optimizers[i]
def zero_grad(self): [o.zero_grad() for o in self.optimizers]
def schedule_step(self, loss:Optional[Tensor]=None) -> List[Tensor]: return [x for o in self.optimizers for x in o.schedule_step(loss)]
def _step(self) -> List[Tensor]: return [x for o in self.optimizers for x in o._step()]

# LARS is essentially just trust ratio to SGD so if we just set the trust coeff 0.0 its just standard SGD.
def SGD(params: List[Tensor], lr=0.001, momentum=0.0, weight_decay=0.0, nesterov=False, classic=False):
Expand All @@ -76,10 +76,12 @@ def __init__(self, params:List[Tensor], lr=0.001, momentum=0.9, weight_decay=1e-
self.momentum, self.wd, self.nesterov, self.classic, self.tcoef = momentum, weight_decay, nesterov, classic, tcoef
self.b = [Tensor.zeros(*t.shape, dtype=t.dtype, device=t.device, requires_grad=False) for t in self.params] if self.momentum else []

def _step(self, grads:List[Tensor]) -> List[Tensor]:
for i, (t, g) in enumerate(zip(self.params, grads)):
# contiguous is needed since the grads can allegedly form a "diamond". TODO: is this fixed?
g = g.contiguous()
def _step(self) -> List[Tensor]:
for i, t in enumerate(self.params):
assert t.grad is not None
# contiguous is needed since the grads can allegedly form a "diamond"
# TODO: fix this in lazy.py
g = t.grad.contiguous()
if self.tcoef != 0:
r1 = t.detach().square().sum().sqrt()
r2 = g.square().sum().sqrt()
Expand Down Expand Up @@ -128,12 +130,13 @@ def __init__(self, params: List[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, w
self.m = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]
self.v = [Tensor.zeros(*t.shape, dtype=dtypes.float32, device=t.device, requires_grad=False).contiguous() for t in self.params]

def _step(self, grads:List[Tensor]) -> List[Tensor]:
def _step(self) -> List[Tensor]:
self.b1_t *= self.b1
self.b2_t *= self.b2
for i, (t, g) in enumerate(zip(self.params, grads)):
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * g)
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (g * g))
for i, t in enumerate(self.params):
assert t.grad is not None
self.m[i].assign(self.b1 * self.m[i] + (1.0 - self.b1) * t.grad)
self.v[i].assign(self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad))
m_hat = self.m[i] / (1.0 - self.b1_t)
v_hat = self.v[i] / (1.0 - self.b2_t)
up = (m_hat / (v_hat.sqrt() + self.eps)) + self.wd * t.detach()
Expand Down
13 changes: 0 additions & 13 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,19 +866,6 @@ def multinomial(self:Tensor, num_samples:int = 1, replacement:bool = False) -> T
# ***** toposort and backward pass *****

def gradient(self, *targets:Tensor) -> list[Tensor]:
"""
Compute the gradient of the targets with respect to self.
```python exec="true" source="above" session="tensor" result="python"
x = Tensor.eye(3)
y = Tensor([[2.0,0,-2.0]])
z = y.matmul(x).sum()
dx, dy = z.gradient(x, y)
print(dx.tolist()) # dz/dx
print(dy.tolist()) # dz/dy
```
"""
assert isinstance(self.lazydata, UOp), "multi isn't supported yet"
target_uops: List[UOp] = [x.lazydata for x in targets if isinstance(x.lazydata, UOp)]
return [Tensor(y) for y in gradient(self.lazydata, target_uops)]
Expand Down

0 comments on commit 37fa38d

Please sign in to comment.