|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from typing import List, Optional, Sequence, Tuple, Union |
| 3 | +from typing import List, Optional, Tuple |
4 | 4 |
|
5 | 5 | import numpy as np |
6 | 6 | import tensorrt as trt |
@@ -123,88 +123,3 @@ def get_shape_with_dynamic_shape( |
123 | 123 | select_layer = ctx.net.add_select(condition_val, input_shape, scale_res) |
124 | 124 | set_layer_name(select_layer, target, f"{name}_select") |
125 | 125 | return select_layer.get_output(0) |
126 | | - |
127 | | - |
128 | | -def to_trt_shape_tensor( |
129 | | - ctx: ConversionContext, target: Target, name: str, shape_list: List[int | TRTTensor] |
130 | | -) -> TRTTensor: |
131 | | - """ |
132 | | - Convert a mixed shape list (ints + ITensors) into a single ITensor. |
133 | | -
|
134 | | - Args: |
135 | | - ctx (ConversionContext): TensorRT ConversionContext object. |
136 | | - target (Target): Target of fx node. |
137 | | - name (str): base name for layer naming. |
138 | | - shape_list (list[int | ITensor]): list containing static ints and/or ITensors. |
139 | | -
|
140 | | - Returns: |
141 | | - ITensor if shape_list contains any ITensors, else plain Python list of ints. |
142 | | - """ |
143 | | - trt_tensors = [] |
144 | | - |
145 | | - for i, s in enumerate(shape_list): |
146 | | - if isinstance(s, (int, torch.Tensor)): |
147 | | - const = ctx.net.add_constant((1,), np.array([s], dtype=np.int32)) |
148 | | - set_layer_name(const, target, f"{name}_dim{i}_const") |
149 | | - trt_tensors.append(const.get_output(0)) |
150 | | - else: |
151 | | - trt_tensors.append(s) |
152 | | - |
153 | | - if any(not isinstance(s, int) for s in shape_list): |
154 | | - # Concatenate everything into a single ITensor if there are any ITensors/Tensors |
155 | | - concat_layer = ctx.net.add_concatenation(trt_tensors) |
156 | | - concat_layer.axis = 0 |
157 | | - set_layer_name(concat_layer, target, f"{name}_shape_concat") |
158 | | - return concat_layer.get_output(0) |
159 | | - |
160 | | - # If no ITensor found, return plain list of ints |
161 | | - return shape_list |
162 | | - |
163 | | - |
164 | | -def collect_and_concat_trt_inputs( |
165 | | - ctx: ConversionContext, |
166 | | - target: Target, |
167 | | - name: str, |
168 | | - inputs: Sequence[Union[int, TRTTensor, torch.Tensor, np.ndarray]], |
169 | | - concat_axis: int = 0, |
170 | | - allow_static_return: bool = False, |
171 | | -) -> Union[TRTTensor, List[int]]: |
172 | | - """ |
173 | | - Normalize a sequence of values into TRT ITensors and concatenate them. |
174 | | - If `allow_static_return=True` and all inputs are ints, return a Python |
175 | | - list of ints instead of creating any TRT layers. |
176 | | - """ |
177 | | - trt_tensors = [] |
178 | | - has_dynamic = False |
179 | | - |
180 | | - for i, x in enumerate(inputs): |
181 | | - if isinstance(x, TRTTensor): |
182 | | - trt_tensors.append(x) |
183 | | - has_dynamic = True |
184 | | - |
185 | | - elif isinstance(x, (int, np.integer)): |
186 | | - # keep raw for now, convert only if dynamic found |
187 | | - trt_tensors.append(int(x)) |
188 | | - |
189 | | - else: |
190 | | - # torch/np tensor -> TRT tensor |
191 | | - t = get_trt_tensor(ctx, x, f"{name}_tensor_{i}") |
192 | | - trt_tensors.append(t) |
193 | | - has_dynamic = True |
194 | | - |
195 | | - # fully static shape case |
196 | | - if not has_dynamic and allow_static_return: |
197 | | - return [int(v) for v in trt_tensors] |
198 | | - |
199 | | - # promote remaining ints to TRT constants |
200 | | - for i, v in enumerate(trt_tensors): |
201 | | - if isinstance(v, int): |
202 | | - const = ctx.net.add_constant((1,), np.array([v], dtype=np.int32)) |
203 | | - set_layer_name(const, target, f"{name}_static_dim{i}_const") |
204 | | - trt_tensors[i] = const.get_output(0) |
205 | | - |
206 | | - # concatenate |
207 | | - concat = ctx.net.add_concatenation(trt_tensors) |
208 | | - concat.axis = concat_axis |
209 | | - set_layer_name(concat, target, f"{name}_concat") |
210 | | - return concat.get_output(0) |
0 commit comments