@@ -352,42 +352,86 @@ void rocblas_local_handle::rocblas_stream_end_capture()
352
352
#endif
353
353
}
354
354
355
+ void rocblas_parallel_initialize_thread (int id, size_t & memory_used)
356
+ {
357
+ size_t before_init, after_init, total_memory;
358
+ CHECK_HIP_ERROR (hipSetDevice (id));
359
+ CHECK_HIP_ERROR (hipMemGetInfo (&before_init, &total_memory));
360
+ rocblas_initialize ();
361
+ CHECK_HIP_ERROR (hipMemGetInfo (&after_init, &total_memory));
362
+ memory_used = before_init - after_init;
363
+ }
364
+
355
365
/* !
356
- * Initialize rocBLAS for the current HIP device and report
357
- * the time taken to complete the initialization. This is to
358
- * avoid costly startup time at the first call on that device.
359
- * Internal use for benchmark & testing.
366
+ * Initialize rocBLAS for the requested number of HIP devices
367
+ * and report the time taken to complete the initialization.
368
+ * This is to avoid costly startup time at the first call on
369
+ * that device. Internal use for benchmark & testing.
370
+ * Initializes devices indexed from 0 to parallel_devices-1.
371
+ * If parallel_devices is 1, hipSetDevice should be called
372
+ * before calling this function.
360
373
*/
361
- void rocblas_client_initialize ( )
374
+ void rocblas_parallel_initialize ( int parallel_devices )
362
375
{
363
- // when executed on a CPU under normal load( Disk I/O, memory etc.),
364
- // this routine completes execution under max limit of 12 seconds.
365
- // The minimum time it takes to complete varies based on
366
- // the architecture & build options used while building the library.
367
- // Setting a max duration of 5 seconds for rocblas library initialization to complete.
368
- constexpr static int max_duration = 5 ;
376
+ auto thread = std::make_unique<std::thread[]>(parallel_devices);
377
+ std::vector<size_t > init_memory (parallel_devices);
369
378
370
379
// Store the start timepoint of rocblas initialize
371
380
auto start_time = std::chrono::steady_clock::now ();
372
381
373
- rocblas_initialize ();
382
+ if (parallel_devices == 1 )
383
+ {
384
+ size_t before_init, after_init, total_memory;
385
+ CHECK_HIP_ERROR (hipMemGetInfo (&before_init, &total_memory));
386
+ rocblas_initialize ();
387
+ CHECK_HIP_ERROR (hipMemGetInfo (&after_init, &total_memory));
388
+ init_memory[0 ] = before_init - after_init;
389
+ }
390
+ else
391
+ {
392
+
393
+ for (int id = 0 ; id < parallel_devices; ++id)
394
+ thread[id]
395
+ = std::thread (rocblas_parallel_initialize_thread, id, std::ref (init_memory[id]));
396
+ for (int id = 0 ; id < parallel_devices; ++id)
397
+ thread[id].join ();
398
+ }
374
399
375
400
// Store the end timepoint of rocblas initialize
376
401
auto end_time = std::chrono::steady_clock::now ();
377
402
378
- // Compute the time taken to load the Tensile kernels (in seconds).
379
- auto total_library_initialize_time
380
- = std::chrono::duration_cast<std::chrono::seconds>(end_time - start_time).count ();
381
-
382
403
// Compute the time taken to load the Tensile kernels (in milliseconds).
383
404
auto init_time_in_ms
384
405
= std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time).count ();
385
406
386
407
rocblas_cout << " \n rocBLAS info: Time taken to complete rocBLAS library initialization is "
387
408
<< init_time_in_ms << " milliseconds." << std::endl;
388
409
389
- // If initialization time exceeds the max duration, display the following info message.
390
- if (total_library_initialize_time > max_duration)
391
- rocblas_cerr << " \n rocBLAS info: rocBLAS initialization exceeded the max duration of "
392
- << max_duration << " seconds. Check CPU's load metrics." << std::endl;
410
+ // Calculate average initialization time per GPU
411
+ auto avg_init_time_in_ms = init_time_in_ms / parallel_devices;
412
+ if (parallel_devices > 1 )
413
+ {
414
+ rocblas_cout
415
+ << " \n rocBLAS info: Average time taken to complete rocBLAS library initialization "
416
+ " per device is "
417
+ << avg_init_time_in_ms << " milliseconds." << std::endl;
418
+ }
419
+
420
+ // If average initialization time exceeds the max duration, display the following info message.
421
+ constexpr static int max_duration = 5000 ;
422
+ if (avg_init_time_in_ms > max_duration)
423
+ rocblas_cerr << " \n rocBLAS info: average time to initialize each device exceeded the max "
424
+ " duration of "
425
+ << max_duration << " milliseconds. Check CPU's load metrics." << std::endl;
426
+
427
+ constexpr static float max_memory = 1.0 ;
428
+ auto max_library_size
429
+ = *std::max_element (std::begin (init_memory), std::end (init_memory)) * 1.0e-9 ;
430
+
431
+ rocblas_cout << " \n rocBLAS info: maximum library size per device is " << max_library_size
432
+ << " GB." << std::endl;
433
+ if (max_library_size > max_memory)
434
+ rocblas_cerr << " \n rocBLAS info: max kernel library size " << max_library_size
435
+ << " GB exceeds the max recommended memory " << max_memory
436
+ << " GB. Check library logic file sizes." << std::endl;
393
437
}
0 commit comments