16
16
#
17
17
18
18
from textwrap import indent
19
- from typing import Any , List , Optional , Sequence , Union
19
+ from typing import Any , Optional
20
+
21
+ import mlir_tensorrt .runtime .api as runtime
20
22
21
23
# Import ops to populate the registry before we define our Tensor class
22
24
import tripy .frontend .ops
27
29
from tripy .common .exception import raise_error
28
30
from tripy .frontend .ops .registry import TENSOR_METHOD_REGISTRY
29
31
from tripy .frontend .trace .ops import Storage
30
-
31
- import mlir_tensorrt .runtime .api as runtime
32
+ from tripy .utils .stack_info import StackInfo
32
33
33
34
34
35
class TensorMeta (type ):
@@ -73,19 +74,21 @@ def _get_unique_name(cls):
73
74
74
75
def __init__ (
75
76
self ,
76
- data : Union [ List , "np.ndarray" , "cp.ndarray" , "torch.Tensor" , "jnp.ndarray" ] ,
77
+ data : Any ,
77
78
dtype : Optional ["tripy.dtype" ] = None ,
78
79
device : Optional ["tripy.device" ] = None ,
79
80
name : Optional [str ] = None ,
80
- stack_info : Optional [ "StackInfo" ] = None ,
81
+ fetch_stack_info : bool = True ,
81
82
) -> None :
82
83
"""
83
84
Args:
84
85
data: The data with which to initialize the tensor.
85
86
dtype: The data type of the tensor.
86
87
device: The device on which to allocate the tensor.
87
88
name: The name of the tensor. If provided, this must be a unique string.
88
- stack_info: The stack infomation of the tensor.
89
+ fetch_stack_info: Whether to fetch stack information for the tensor.
90
+ Stack information allows Tripy to generate much higher quality error
91
+ messages at the cost of a small overhead when initializing the tensor.
89
92
90
93
.. code-block:: python
91
94
:linenos:
@@ -95,13 +98,12 @@ def __init__(
95
98
"""
96
99
from tripy .frontend .trace .tensor import TraceTensor
97
100
98
- # We include code for everything above the `BaseTraceOp.build` function, which is called at most
99
- # this many stack frames above the constructor.
100
- STACK_DEPTH_OF_BUILD = 4
101
- # not using utils.default() because it always evaluates the `default` argument.
102
- stack_info = (
103
- stack_info if stack_info is not None else utils .get_stack_info (include_code_index = STACK_DEPTH_OF_BUILD )
104
- )
101
+ stack_info = StackInfo ([])
102
+ if fetch_stack_info :
103
+ # We include code for everything above the `BaseTraceOp.build` function, which is called at most
104
+ # this many stack frames above the constructor.
105
+ STACK_DEPTH_OF_BUILD = 4
106
+ stack_info = utils .get_stack_info (include_code_index = STACK_DEPTH_OF_BUILD )
105
107
106
108
name = name if name is not None else Tensor ._get_unique_name ()
107
109
0 commit comments