@@ -30,17 +30,12 @@ InferenceSession *session = nullptr;
3030
3131int using_io = 0 ;
3232
33- pixel_importer_cpu *importer_cpu = nullptr ;
34- pixel_exporter_cpu *exporter_cpu = nullptr ;
35-
3633pixel_importer_gpu<float > *importer_gpu = nullptr ;
3734pixel_exporter_gpu<float > *exporter_gpu = nullptr ;
3835
3936pixel_importer_gpu<half> *importer_gpu_fp16 = nullptr ;
4037pixel_exporter_gpu<half> *exporter_gpu_fp16 = nullptr ;
4138
42- int32_t h_scale, w_scale;
43-
4439#if defined(__GNUC__)
4540extern " C" __attribute__((weak)) int32_t getInferLibVersion () noexcept {
4641 return NV_TENSORRT_VERSION;
@@ -56,11 +51,11 @@ static Logger gLogger;
5651
5752ABSL_FLAG (bool , fp16, false , " use FP16 processing, allow FP16 in engine" );
5853ABSL_FLAG (bool , int8, false , " allow INT8 in engine" );
54+ ABSL_FLAG (bool , strongly_typed, true , " enable strongly typed network definition" );
5955ABSL_FLAG (bool , force_precision, false , " Force precision config in model" );
6056ABSL_FLAG (bool , external, false , " use external algorithms from cuDNN and cuBLAS" );
6157ABSL_FLAG (bool , low_mem, false , " tweak configs to reduce memory consumption" );
6258ABSL_FLAG (int32_t , aux_stream, -1 , " Auxiliary streams to use" );
63- ABSL_FLAG (std::string, reformatter, " auto" , " reformatter used to import and export pixels: cpu, gpu, auto" );
6459
6560ABSL_FLAG (uint32_t , tile_width, 512 , " tile width" );
6661ABSL_FLAG (uint32_t , tile_height, 512 , " tile height" );
@@ -149,6 +144,7 @@ void setup_session(bool handle_alpha) {
149144 int (max_height)},
150145 1 ,
151146 absl::GetFlag (FLAGS_aux_stream),
147+ absl::GetFlag (FLAGS_strongly_typed),
152148 absl::GetFlag (FLAGS_fp16),
153149 absl::GetFlag (FLAGS_int8),
154150 absl::GetFlag (FLAGS_force_precision),
@@ -183,47 +179,23 @@ void setup_session(bool handle_alpha) {
183179 if (!err.empty ()) {
184180 LOG (QFATAL) << " Failed allocate memory for context: " << err;
185181 }
186- std::tie (h_scale, w_scale) = session->detect_scale ();
187- if (h_scale == -1 || w_scale == -1 ) {
188- LOG (QFATAL) << " Bad model, can't detect scale ratio." ;
189- }
190-
191- if (h_scale != w_scale) {
192- LOG (QFATAL) << " different width and height scale ratio unimplemented." ;
193- }
194182
195183 // ------------------------------
196184 // Import & Export
197- auto max_size = size_t (max_width) * max_height;
185+ auto max_size = size_t (max_width) * max_height * (handle_alpha ? 4 : 3 );
186+ auto max_size_out = max_size * session->scale_w * session->scale_h ;
198187
199- if (absl::GetFlag (FLAGS_reformatter) == " auto" ) {
200- absl::SetFlag (&FLAGS_reformatter, absl::GetFlag (FLAGS_fp16) ? " gpu" : " cpu" );
201- }
202- if (absl::GetFlag (FLAGS_fp16) && absl::GetFlag (FLAGS_reformatter) == " cpu" ) {
203- LOG (QFATAL) << " CPU reformatter can not handle FP16." ;
204- }
205-
206- if (absl::GetFlag (FLAGS_reformatter) == " cpu" ) {
207- importer_cpu = new pixel_importer_cpu (max_size, handle_alpha);
208- exporter_cpu = new pixel_exporter_cpu (h_scale * w_scale * max_size, handle_alpha);
209- using_io = 0 ;
210- }
211- else if (absl::GetFlag (FLAGS_reformatter) == " gpu" ) {
212- if (absl::GetFlag (FLAGS_fp16)) {
213- importer_gpu_fp16 = new pixel_importer_gpu<half>(max_size, handle_alpha);
214- exporter_gpu_fp16 =
215- new pixel_exporter_gpu<half>(h_scale * w_scale * max_size, handle_alpha);
216- using_io = 2 ;
217- }
218- else {
219- importer_gpu = new pixel_importer_gpu<float >(max_size, handle_alpha);
220- exporter_gpu =
221- new pixel_exporter_gpu<float >(h_scale * w_scale * max_size, handle_alpha);
222- using_io = 1 ;
223- }
188+ if (absl::GetFlag (FLAGS_fp16)) {
189+ importer_gpu_fp16 = new pixel_importer_gpu<half>(max_size, 1 );
190+ exporter_gpu_fp16 =
191+ new pixel_exporter_gpu<half>(max_size_out, 1 );
192+ using_io = 1 ;
224193 }
225194 else {
226- LOG (QFATAL) << " Unknown reformatter." ;
195+ importer_gpu = new pixel_importer_gpu<float >(max_size, 1 );
196+ exporter_gpu =
197+ new pixel_exporter_gpu<float >(max_size_out, 1 );
198+ using_io = 0 ;
227199 }
228200}
229201
0 commit comments