Skip to content

Commit 96fbf5d

Browse files
Removes dead code
1 parent 35bb22b commit 96fbf5d

File tree

10 files changed

+9
-194
lines changed

10 files changed

+9
-194
lines changed

tripy/tests/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,13 @@ detect dead code. *This **will** include false positives for our code, so be car
7474
You can run it with:
7575

7676
```bash
77-
vulture tripy tests --sort-by-size
77+
vulture . --sort-by-size
7878
```
7979

8080
To exclude false positives, use:
8181

8282
```bash
83-
vulture tripy tests --sort-by-size --min-confidence=100
83+
vulture . --sort-by-size --min-confidence=100
8484
```
8585

8686

tripy/tests/backend/api/test_stream.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
import pytest
16-
import gc, sys
17-
import tripy as tp
1815
import cupy as cp
1916

17+
import tripy as tp
18+
2019

2120
def test_default_stream_creation():
2221
default_stream1 = tp.default_stream()

tripy/tests/common/test_utils.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,11 @@
1616
#
1717

1818
import pytest
19-
import struct
20-
from collections import ChainMap
21-
from textwrap import dedent
22-
23-
import cupy as cp
24-
import torch
19+
from tests import helper
2520

2621
import tripy.common.datatype
27-
28-
from tests import helper
2922
from tripy.common.exception import TripyException
30-
from tripy.common.utils import (
31-
convert_list_to_array,
32-
get_element_type,
33-
)
23+
from tripy.common.utils import convert_list_to_array, get_element_type
3424

3525

3626
def test_get_element_type():

tripy/tests/constraints/test_dtypes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
PUBLIC_API_TENSOR_FUNCTIONS = []
3737
PUBLIC_API_TENSOR_FUNCTION_NAMES = []
3838
for api in PUBLIC_APIS:
39-
is_module = False
4039
if inspect.isfunction(api.obj):
4140
funcs = [api.obj]
4241
elif inspect.isclass(api.obj):

tripy/tripy/backend/mlir/utils.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -107,41 +107,6 @@ def list_to_dense_attr(data: List, mlir_dtype):
107107
return attrs
108108

109109

110-
def get_mlir_quant_dtype(
111-
origin_dtype: "tripy.dtype",
112-
quant_dtype: "tripy.dtype",
113-
scale: float,
114-
zero_point: int,
115-
storage_type_min: int,
116-
storage_type_max: int,
117-
):
118-
"""
119-
Converts a tripy data type to an MLIR quantized data type.
120-
121-
Args:
122-
origin_dtype: original data type to be quantized
123-
quant_dtype: target data type to quantize
124-
dtype: One of int4, int8, float8
125-
scale: scale value of quantized tensor
126-
zero_point: zero point of quantized tensor
127-
storage_type_min: min value of quantized dtype
128-
storage_type_max: max value of quantized dtype
129-
"""
130-
from mlir_tensorrt.compiler.dialects import quant
131-
132-
storage_type = get_mlir_dtype(quant_dtype)
133-
expressed_type = get_mlir_dtype(origin_dtype)
134-
return quant.UniformQuantizedType.get(
135-
quant.UniformQuantizedType.FLAG_SIGNED,
136-
storage_type,
137-
expressed_type,
138-
scale,
139-
zero_point,
140-
storage_type_min,
141-
storage_type_max,
142-
)
143-
144-
145110
def make_mlir_tensor(
146111
dtype: "tripy.common.dtype", shape: Optional[Sequence[int]] = None, rank: Optional[int] = None
147112
) -> ir.RankedTensorType:

tripy/tripy/flat_ir/flat_ir.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,6 @@ def _create_new_function(
177177
def _get_function_input_types(func: FlatIRFunction, mlir_tensor_map: Dict[str, ir.Value]) -> List[ir.Type]:
178178
"""Get the input types for a function, converting to dynamic tensors if necessary."""
179179

180-
def convert_to_dynamic_tensor(rtt: ir.RankedTensorType) -> ir.RankedTensorType:
181-
dynamic_shape = [ir.ShapedType.get_dynamic_size()] * rtt.rank
182-
return ir.RankedTensorType.get(dynamic_shape, rtt.element_type)
183-
184180
# Skip converting to dynamic tensor for Quantize/Dequantize scale operation.
185181
if "Quantize" in func.name or "Dequantize" in func.name:
186182
return [

tripy/tripy/frontend/trace/ops/convolution.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from dataclasses import dataclass
2020

2121
import tripy.frontend.trace.ops.utils as op_utils
22-
from tripy import constraints, utils
22+
from tripy import constraints
2323
from tripy.frontend.trace.ops.base import BaseTraceOp
2424

2525

@@ -31,35 +31,6 @@ class Convolution(BaseTraceOp):
3131
lhs_dilation: Sequence[int]
3232
rhs_dilation: Sequence[int]
3333

34-
def verify_spatial_rank(self, attr, rank, string):
35-
spatial_rank = rank - 2
36-
if attr and len(attr) != spatial_rank:
37-
utils.raise_error_io_info(
38-
self,
39-
f"Number of {string} values does not match number of spatial dimensions in the input.",
40-
details=[
41-
f"Got {len(attr)} {string} value pairs but the number of spatial dimensions is: {spatial_rank}.",
42-
],
43-
)
44-
45-
def validate_inputs(self, tensor_shape, kernel_shape):
46-
if len(tensor_shape) != len(kernel_shape):
47-
utils.raise_error_io_info(
48-
self,
49-
"Input tensor and kernel must have the same rank.",
50-
details=[
51-
f"Input tensor for operation: 'convolution' has shape: {tensor_shape} [rank = {len(tensor_shape)}], "
52-
f"but should have the same rank as the kernel of shape: {kernel_shape} [rank = {len(kernel_shape)}]."
53-
],
54-
)
55-
56-
rank = len(tensor_shape)
57-
58-
self.verify_spatial_rank(self.padding, rank, "padding")
59-
self.verify_spatial_rank(self.stride, rank, "stride")
60-
self.verify_spatial_rank(self.lhs_dilation, rank, "lhs_dilation")
61-
self.verify_spatial_rank(self.rhs_dilation, rank, "rhs_dilation")
62-
6334
infer_rank = op_utils.InferRankPolicies.same_as_input()
6435

6536
def infer_dtypes(self):

tripy/tripy/frontend/trace/ops/utils.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -43,21 +43,10 @@ def is_minus_one(arg):
4343

4444

4545
##
46-
## Inferring shape helpers
46+
## infer_rank helpers
4747
##
4848

4949

50-
def infer_broadcasted_shape(*input_shapes: Sequence[List[int]]):
51-
"""
52-
Given dynamic input shapes of trace tensors, infers a broadcasted shape.
53-
This does not do any error checking since that can be done more reliably
54-
later in the compiler.
55-
"""
56-
max_rank = max(len(shape) for shape in input_shapes)
57-
input_shapes = [[1] * (max_rank - len(shape)) + shape for shape in input_shapes]
58-
return [max(dim) for dim in zip(*input_shapes)]
59-
60-
6150
class InferRankPolicies:
6251
def same_as_input(idx=0):
6352
def impl(self):
@@ -177,29 +166,6 @@ def reshape_scalar_to_1d(input: "FlatIRTensor"):
177166
##
178167

179168

180-
def get_broadcast_compatible_shapes(shape1, shape2):
181-
# Make the shorter shape the same length as the longer shape by padding with ones
182-
if len(shape1) > len(shape2):
183-
shape2 = (1,) * (len(shape1) - len(shape2)) + shape2
184-
elif len(shape2) > len(shape1):
185-
shape1 = (1,) * (len(shape2) - len(shape1)) + shape1
186-
187-
return shape1, shape2
188-
189-
190-
def is_broadcast_compatible(shape1, shape2) -> Result:
191-
# Now check each dimension pair
192-
for index, (dim1, dim2) in enumerate(zip(shape1, shape2)):
193-
if dim1 != dim2 and dim1 != 1 and dim2 != 1:
194-
return Result.err(
195-
[
196-
f"for tensor shapes: {shape1} and {shape2}, dimensions on axis {index}: '{dim1}' and '{dim2}' are not broadcast compatible"
197-
],
198-
)
199-
200-
return Result.ok()
201-
202-
203169
# Given two shapes, compute the shape of the resulting broadcast. Assumes that the shapes are of equal rank
204170
def compute_shape_of_broadcast(
205171
shape1, shape2, output_rank: int, shape1_name: Optional[str] = None, shape2_name: Optional[str] = None
@@ -358,10 +324,6 @@ def is_quantized_dtype(dtype: "tripy.common.datatype.dtype") -> bool:
358324
return dtype in QUANTIZED_DTYPES
359325

360326

361-
def is_quantizable_dtype(dtype: "tripy.common.datatype.dtype") -> bool:
362-
return dtype in QUANTIZABLE_DTYPES
363-
364-
365327
def get_clamp_min_max(element_dtype, quant_dtype):
366328
QUANT_CLAMP_MIN_MAX = {
367329
tp_dtype.int8: (-128.0, 127.0),

tripy/tripy/utils/ast.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616
#
1717

1818
import ast
19-
import inspect
20-
import textwrap
21-
from typing import Callable, List, Optional, Tuple
19+
from typing import List, Optional, Tuple
2220

2321
from tripy.utils.result import Result
2422
from tripy.utils.stack_info import SourceInfo
@@ -181,25 +179,3 @@ def check_name_matches():
181179
candidate_column_offsets.append((indentation + node.col_offset, indentation + node.end_col_offset))
182180

183181
return candidate_column_offsets
184-
185-
186-
def find_node_in_method(method, node_finder: Callable) -> List[str]:
187-
"""
188-
Returns a list of source line of code where node is found.
189-
190-
Args:
191-
method: Source function where node is searched.
192-
node_finder (Callable): User function that takes (node, source) and returns a bool whether node is found in ast or not.
193-
194-
Returns:
195-
List[str]: List of source line of code
196-
"""
197-
source = textwrap.dedent(inspect.getsource(method))
198-
tree = ast.parse(source)
199-
source = source.splitlines()
200-
nodes_found = []
201-
for node in ast.walk(tree):
202-
if node_finder(node, source):
203-
nodes_found.append(source[node.lineno - 1].strip())
204-
205-
return nodes_found

tripy/tripy/utils/utils.py

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -254,21 +254,6 @@ def custom_setattr(self, name, value):
254254
##
255255

256256

257-
def find_file_in_dir(file_name: str, search_directory: str) -> List:
258-
"""
259-
Search for file_name recursively in the root_directory.
260-
261-
Args:
262-
file_name: The file name or pattern with wildcards.
263-
search_directory: The root directory from where to search for file_name.
264-
Returns:
265-
List of absolute path for matching files.
266-
"""
267-
search_pattern = os.path.join(search_directory, "**", file_name)
268-
matching_files = glob.glob(search_pattern, recursive=True)
269-
return matching_files
270-
271-
272257
def warn_if_wrong_mode(file_like: typing.IO, mode: str):
273258
def binary(mode):
274259
return "b" in mode
@@ -456,31 +441,3 @@ def merge_function_arguments(func, *args, **kwargs):
456441
all_args = get_positional_arg_names(func, *args)
457442
all_args.extend(kwargs.items())
458443
return all_args
459-
460-
461-
def get_arg_by_name(name, func, *args, **kwargs):
462-
if name in kwargs:
463-
return kwargs[name]
464-
465-
args = dict(get_positional_arg_names(func, *args))
466-
if name in args:
467-
return args[name]
468-
469-
assert False, f"No such argument: {name}"
470-
471-
472-
def modify_arg(name, modify_func, func, *args, **kwargs):
473-
"""
474-
Modifies an argument corresponding to the provided name if it is present.
475-
`modify_func` should be a function that accepts the argument and returns the modified argument.
476-
"""
477-
if name in kwargs:
478-
kwargs[name] = modify_func(kwargs[name])
479-
480-
all_args = get_positional_arg_names(func, *args)
481-
args = list(args)
482-
for index, (arg_name, _) in enumerate(all_args):
483-
if name == arg_name:
484-
args[index] = modify_func(args[index])
485-
486-
return args, kwargs

0 commit comments

Comments
 (0)