26
26
def tensor_builder (init , dtype , namespace ):
27
27
if init is None :
28
28
out = tp .ones (dtype = namespace [dtype ], shape = (3 , 2 ))
29
- out .eval ()
30
29
return out
31
30
elif not isinstance (init , tp .Tensor ):
32
31
return init
33
32
34
33
out = init
35
34
if dtype is not None :
36
35
out = tp .cast (out , dtype = namespace [dtype ])
37
- out .eval ()
36
+ # Need to evaluate when casting because we run into MLIR-TRT bugs while deriving upper bounds.
37
+ out .eval ()
38
38
return out
39
39
40
40
@@ -47,8 +47,6 @@ def tensor_list_builder(init, dtype, namespace):
47
47
out = [tp .ones (shape = (3 , 2 ), dtype = namespace [dtype ]) for _ in range (2 )]
48
48
else :
49
49
out = [tp .cast (tens , dtype = namespace [dtype ]) for tens in init ]
50
- for t in out :
51
- t .eval ()
52
50
return out
53
51
54
52
@@ -132,7 +130,7 @@ def default_builder(init, dtype, namespace):
132
130
"pad" : {"pad" : [(0 , 1 ), (1 , 0 )]},
133
131
"permute" : {"perm" : [1 , 0 ]},
134
132
"prod" : {"dim" : 0 },
135
- "quantize" : {"scale" : tp .Tensor ([1 , 1 , 1 ]), "dim" : 0 },
133
+ "quantize" : {"input" : tp . ones (( 3 , 2 )), " scale" : tp .Tensor ([1 , 1 , 1 ]), "dim" : 0 },
136
134
"repeat" : {"repeats" : 2 , "dim" : 0 },
137
135
"reshape" : {"shape" : [6 ]},
138
136
"resize" : {
0 commit comments