Skip to content

Commit 3ac751b

Browse files
Speeds up datatype constraints tests
Makes datatype constraints tests significantly faster by not evaluating inputs.
1 parent 3a77a99 commit 3ac751b

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

tripy/tests/constraints/object_builders.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@
2626
def tensor_builder(init, dtype, namespace):
2727
if init is None:
2828
out = tp.ones(dtype=namespace[dtype], shape=(3, 2))
29-
out.eval()
3029
return out
3130
elif not isinstance(init, tp.Tensor):
3231
return init
3332

3433
out = init
3534
if dtype is not None:
3635
out = tp.cast(out, dtype=namespace[dtype])
37-
out.eval()
36+
# Need to evaluate when casting because we run into MLIR-TRT bugs while deriving upper bounds.
37+
out.eval()
3838
return out
3939

4040

@@ -47,8 +47,6 @@ def tensor_list_builder(init, dtype, namespace):
4747
out = [tp.ones(shape=(3, 2), dtype=namespace[dtype]) for _ in range(2)]
4848
else:
4949
out = [tp.cast(tens, dtype=namespace[dtype]) for tens in init]
50-
for t in out:
51-
t.eval()
5250
return out
5351

5452

@@ -132,7 +130,7 @@ def default_builder(init, dtype, namespace):
132130
"pad": {"pad": [(0, 1), (1, 0)]},
133131
"permute": {"perm": [1, 0]},
134132
"prod": {"dim": 0},
135-
"quantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0},
133+
"quantize": {"input": tp.ones((3, 2)), "scale": tp.Tensor([1, 1, 1]), "dim": 0},
136134
"repeat": {"repeats": 2, "dim": 0},
137135
"reshape": {"shape": [6]},
138136
"resize": {

0 commit comments

Comments
 (0)