Skip to content

Commit

Permalink
move copy into the JIT for openpilot compile3 (tinygrad#7937)
Browse files Browse the repository at this point in the history
* move copy into the JIT, test fails

* ahh, prune was the issue
  • Loading branch information
geohot authored Dec 7, 2024
1 parent 0ed731b commit 22feb3a
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions examples/openpilot/compile3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
if "NOLOCALS" not in os.environ: os.environ["NOLOCALS"] = "1"
if "JIT_BATCH_SIZE" not in os.environ: os.environ["JIT_BATCH_SIZE"] = "0"

from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters
from tinygrad import fetch, Tensor, TinyJit, Context, GlobalCounters, Device
from tinygrad.helpers import DEBUG, getenv
from tinygrad.tensor import _from_np_dtype

Expand All @@ -30,18 +30,22 @@ def compile():
if getenv("FLOAT16", 0) == 0: input_types = {k:(np.float32 if v==np.float16 else v) for k,v in input_types.items()}
Tensor.manual_seed(100)
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()}
print("created tensors")

run_onnx_jit = TinyJit(lambda **kwargs: run_onnx(kwargs), prune=True)
run_onnx_jit = TinyJit(lambda **kwargs:
next(iter(run_onnx({k:v.to(Device.DEFAULT) for k,v in kwargs.items()}).values())).cast('float32'), prune=True)
for i in range(3):
GlobalCounters.reset()
print(f"run {i}")
inputs = {**{k:v.clone() for k,v in new_inputs.items() if 'img' in k},
**{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}}
with Context(DEBUG=max(DEBUG.value, 2 if i == 2 else 1)):
ret = next(iter(run_onnx_jit(**new_inputs).values())).cast('float32').numpy()
ret = run_onnx_jit(**inputs).numpy()
# copy i == 1 so use of JITBEAM is okay
if i == 1: test_val = np.copy(ret)
print(f"captured {len(run_onnx_jit.captured.jit_cache)} kernels")
np.testing.assert_equal(test_val, ret)
np.testing.assert_equal(test_val, ret, "JIT run failed")
print("jit run validated")

with open(OUTPUT, "wb") as f:
Expand All @@ -64,10 +68,10 @@ def test(test_val=None):
st = time.perf_counter()
# Need to cast non-image inputs from numpy, this is only realistic way to run it
inputs = {**{k:v for k,v in new_inputs.items() if 'img' in k},
**{k:Tensor(v) for k,v in new_inputs_numpy.items() if 'img' not in k}}
**{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}}
out = run(**inputs)
mt = time.perf_counter()
val = out['outputs'].numpy()
val = out.numpy()
et = time.perf_counter()
print(f"enqueue {(mt-st)*1e3:6.2f} ms -- total run {(et-st)*1e3:6.2f} ms")
print(out, val.shape, val.dtype)
Expand Down

0 comments on commit 22feb3a

Please sign in to comment.