@@ -43,21 +43,10 @@ def is_minus_one(arg):
43
43
44
44
45
45
##
46
- ## Inferring shape helpers
46
+ ## infer_rank helpers
47
47
##
48
48
49
49
50
- def infer_broadcasted_shape (* input_shapes : Sequence [List [int ]]):
51
- """
52
- Given dynamic input shapes of trace tensors, infers a broadcasted shape.
53
- This does not do any error checking since that can be done more reliably
54
- later in the compiler.
55
- """
56
- max_rank = max (len (shape ) for shape in input_shapes )
57
- input_shapes = [[1 ] * (max_rank - len (shape )) + shape for shape in input_shapes ]
58
- return [max (dim ) for dim in zip (* input_shapes )]
59
-
60
-
61
50
class InferRankPolicies :
62
51
def same_as_input (idx = 0 ):
63
52
def impl (self ):
@@ -177,29 +166,6 @@ def reshape_scalar_to_1d(input: "FlatIRTensor"):
177
166
##
178
167
179
168
180
- def get_broadcast_compatible_shapes (shape1 , shape2 ):
181
- # Make the shorter shape the same length as the longer shape by padding with ones
182
- if len (shape1 ) > len (shape2 ):
183
- shape2 = (1 ,) * (len (shape1 ) - len (shape2 )) + shape2
184
- elif len (shape2 ) > len (shape1 ):
185
- shape1 = (1 ,) * (len (shape2 ) - len (shape1 )) + shape1
186
-
187
- return shape1 , shape2
188
-
189
-
190
- def is_broadcast_compatible (shape1 , shape2 ) -> Result :
191
- # Now check each dimension pair
192
- for index , (dim1 , dim2 ) in enumerate (zip (shape1 , shape2 )):
193
- if dim1 != dim2 and dim1 != 1 and dim2 != 1 :
194
- return Result .err (
195
- [
196
- f"for tensor shapes: { shape1 } and { shape2 } , dimensions on axis { index } : '{ dim1 } ' and '{ dim2 } ' are not broadcast compatible"
197
- ],
198
- )
199
-
200
- return Result .ok ()
201
-
202
-
203
169
# Given two shapes, compute the shape of the resulting broadcast. Assumes that the shapes are of equal rank
204
170
def compute_shape_of_broadcast (
205
171
shape1 , shape2 , output_rank : int , shape1_name : Optional [str ] = None , shape2_name : Optional [str ] = None
@@ -358,10 +324,6 @@ def is_quantized_dtype(dtype: "tripy.common.datatype.dtype") -> bool:
358
324
return dtype in QUANTIZED_DTYPES
359
325
360
326
361
- def is_quantizable_dtype (dtype : "tripy.common.datatype.dtype" ) -> bool :
362
- return dtype in QUANTIZABLE_DTYPES
363
-
364
-
365
327
def get_clamp_min_max (element_dtype , quant_dtype ):
366
328
QUANT_CLAMP_MIN_MAX = {
367
329
tp_dtype .int8 : (- 128.0 , 127.0 ),
0 commit comments