31
31
32
32
class Executor :
33
33
def __init__ (self , executable : runtime .Executable ) -> None :
34
+ runtime .GlobalDebug .flag = True
35
+ debug_types = ["allocator" , "runtime" ]
36
+ runtime .GlobalDebug .set_types (debug_types )
34
37
self .runtime_client = MLIRRuntimeClient ()
35
38
session_options = runtime .RuntimeSessionOptions (num_devices = 1 , device_id = 0 )
36
39
self .session = runtime .RuntimeSession (session_options , executable )
37
40
self .device = self .runtime_client .get_devices ()[0 ] # Assume a single device is available.
38
41
self .signature = executable .get_signature ("main" )
39
42
self .stream = default_stream ()
40
- self .num_input_args = self .signature .get_num_input_args ()
41
- self .num_output_args = self .signature .get_num_output_args ()
42
- self .output_args = [
43
- self .signature .get_arg (index + self .num_input_args ) for index in range (self .num_output_args )
44
- ]
45
- self .output_memrefs = [runtime .MemRefType (out ) for out in self .output_args ]
46
-
47
- def _create_shape_memref (self , shape ):
48
- shape = make_tuple (shape )
49
- if len (shape ) == 0 :
50
- # create an empty memref
51
- return self .runtime_client .create_memref (
52
- shape = (0 ,), dtype = runtime .runtime .ScalarTypeCode .i64 , stream = self .stream ._active_cuda_stream
53
- )
54
- return self .runtime_client .create_memref (
55
- convert_list_to_array (shape , datatype .int64 ),
56
- shape = (len (shape ),),
57
- dtype = runtime .ScalarTypeCode .i64 ,
58
- stream = self .stream ._active_cuda_stream ,
59
- )
60
-
61
- def _get_outputs_shape (self ):
62
- outputs_shape = []
63
- all_outputs_known = True
64
- for memref in self .output_memrefs :
65
- outputs_shape .append (memref .shape )
66
- all_outputs_known &= all (dim >= 0 for dim in memref .shape )
67
- return outputs_shape , all_outputs_known
68
-
69
- def _get_inputs_runtime_shape (self , inputs ):
70
- inputs_shape = []
71
- for input in inputs :
72
- inputs_shape .append (input .trace_tensor .producer .data .shape )
73
- return inputs_shape
74
-
75
- def _execute_shape_inference (self , inputs_shape , outputs_shape ):
76
- inputs_shape_memref = [self ._create_shape_memref (inp_shape ) for inp_shape in inputs_shape ]
77
- outputs_shape_memref = [self ._create_shape_memref (out_shape ) for out_shape in outputs_shape ]
78
- self .session .execute_function (
79
- name = self .signature .get_shape_func_name (), in_args = inputs_shape_memref , out_args = outputs_shape_memref
80
- )
81
-
82
- outputs_runtime_shape = [memoryview (s ).tolist () for s in outputs_shape_memref ]
83
- return outputs_runtime_shape
84
-
85
- def _get_output_tensor_info (self , outputs_runtime_shape , output_devices ):
86
- outputs_tensor_info = []
87
- for index in range (self .num_output_args ):
88
- memref = self .output_memrefs [index ]
89
- dtype = convert_runtime_dtype_to_tripy_dtype (memref .dtype )
90
-
91
- output_device = output_devices [index ]
92
- if not output_device :
93
- output_device = device (("gpu" if memref .address_space == runtime .PointerType .device else "cpu" , 0 ))
94
-
95
- runtime_shape = [rs if dim < 0 else dim for dim , rs in zip (memref .shape , outputs_runtime_shape [index ])]
96
- outputs_tensor_info .append (
97
- TensorInfo (
98
- len (runtime_shape ),
99
- tuple (runtime_shape ),
100
- dtype ,
101
- output_device ,
102
- )
103
- )
104
- return outputs_tensor_info
105
-
106
- def get_output_tensor_runtime_info (self , inputs , output_devices = List [device ]):
107
- outputs_shape , all_outputs_known = self ._get_outputs_shape ()
108
- if not all_outputs_known :
109
- inputs_shape = self ._get_inputs_runtime_shape (inputs )
110
- outputs_shape = self ._execute_shape_inference (inputs_shape , outputs_shape )
111
- output_tensor_info = self ._get_output_tensor_info (outputs_shape , output_devices )
112
- return output_tensor_info
113
43
114
44
def execute (self , output_devices : List [device ], inputs : List ["Tensor" ] = []) -> List [runtime .MemRefValue ]:
115
45
in_args = []
@@ -129,49 +59,10 @@ def execute(self, output_devices: List[device], inputs: List["Tensor"] = []) ->
129
59
)
130
60
in_args .append (memref )
131
61
132
- # HACK (#155): Remove `get_devices` once executable output tensor location matches Trace IR.
133
- out_tensor_info = self .get_output_tensor_runtime_info (inputs , output_devices )
134
-
135
- # Allocate output memory and store buffer pointers.
136
- outputs = [
137
- create_empty_memref (
138
- shape = info .shape ,
139
- dtype = info .dtype ,
140
- device = info .device ,
141
- stream = self .stream ._active_cuda_stream ,
142
- use_cache = False ,
143
- )
144
- for info in out_tensor_info
145
- ]
146
-
147
- out_args = []
148
- for out in outputs :
149
- memref = out
150
- # HACK (#155): MLIR-TensorRT requires inputs to be on device.
151
- # Remove explicit copy to device once #155 is addressed.
152
- if memref .address_space != runtime .PointerType .device :
153
- memref = self .runtime_client .copy_to_device (
154
- host_memref = memref ,
155
- device = self .runtime_client .get_devices ()[0 ],
156
- stream = self .stream ._active_cuda_stream ,
157
- )
158
- if not memref :
159
- raise_error ("Could not allocate output memref" , details = memref .error_details )
160
- out_args .append (memref )
161
-
162
62
# Execute and populate device pointers.
163
- self .session .execute_function (
164
- "main" , in_args = in_args , out_args = out_args , stream = self .stream ._active_cuda_stream
63
+ outputs = self .session .execute_function (
64
+ "main" , in_args = in_args , stream = self .stream ._active_cuda_stream , client = self . runtime_client
165
65
)
166
66
167
- # For outputs that were on the host, do the copy back
168
- # TODO(#155): MLIR-TensorRT should allow output tensor placements on host.
169
- for idx , out_info in enumerate (out_tensor_info ):
170
- if out_info .device .kind != "gpu" :
171
- self .runtime_client .copy_to_host (
172
- device_memref = out_args [idx ],
173
- existing_host_memref = outputs [idx ],
174
- stream = self .stream ._active_cuda_stream ,
175
- )
176
-
67
+ # For now return results on GPU.
177
68
return outputs
0 commit comments