From 00ac0db9d401c64fce2614d8ec67d1399bb29967 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 7 Dec 2024 14:01:51 +0800 Subject: [PATCH] np tensors have the memory from numpy in compile3 [pr] (#8098) --- examples/openpilot/compile3.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/openpilot/compile3.py b/examples/openpilot/compile3.py index 12f1b51470dfd..cc338bfbd51bf 100644 --- a/examples/openpilot/compile3.py +++ b/examples/openpilot/compile3.py @@ -60,15 +60,20 @@ def compile(): def test(test_val=None): with open(OUTPUT, "rb") as f: run = pickle.load(f) + + # same randomness as above Tensor.manual_seed(100) new_inputs = {nm:Tensor.randn(*st.shape, dtype=dtype).mul(8).realize() for nm, (st, _, dtype, _) in sorted(zip(run.captured.expected_names, run.captured.expected_st_vars_dtype_device))} new_inputs_numpy = {k:v.numpy() for k,v in new_inputs.items()} + + # create fake "from_blob" tensors for the inputs, and wrapped NPY tensors for the numpy inputs (these have the same underlying memory) + inputs = {**{k:v for k,v in new_inputs.items() if 'img' in k}, + **{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}} + + # run 20 times for _ in range(20): st = time.perf_counter() - # Need to cast non-image inputs from numpy, this is only realistic way to run it - inputs = {**{k:v for k,v in new_inputs.items() if 'img' in k}, - **{k:Tensor(v, device="NPY").realize() for k,v in new_inputs_numpy.items() if 'img' not in k}} out = run(**inputs) mt = time.perf_counter() val = out.numpy() @@ -78,6 +83,12 @@ def test(test_val=None): if test_val is not None: np.testing.assert_equal(test_val, val) print("**** test done ****") + # test that changing the numpy changes the model outputs + for v in new_inputs_numpy.values(): v *= 2 + out = run(**inputs) + changed_val = out.numpy() + np.testing.assert_raises(AssertionError, np.testing.assert_array_equal, val, changed_val) + if __name__ == "__main__": test_val = compile() if not getenv("RUN") else None test(test_val)