@@ -352,10 +352,14 @@ void rocblas_local_handle::rocblas_stream_end_capture()
352
352
#endif
353
353
}
354
354
355
- void rocblas_parallel_initialize_thread (int id)
355
+ void rocblas_parallel_initialize_thread (int id, size_t & memory_used )
356
356
{
357
+ size_t before_init, after_init, total_memory;
357
358
CHECK_HIP_ERROR (hipSetDevice (id));
359
+ CHECK_HIP_ERROR (hipMemGetInfo (&before_init, &total_memory));
358
360
rocblas_initialize ();
361
+ CHECK_HIP_ERROR (hipMemGetInfo (&after_init, &total_memory));
362
+ memory_used = before_init - after_init;
359
363
}
360
364
361
365
/* !
@@ -369,17 +373,26 @@ void rocblas_parallel_initialize_thread(int id)
369
373
*/
370
374
void rocblas_parallel_initialize (int parallel_devices)
371
375
{
372
- auto thread = std::make_unique<std::thread[]>(parallel_devices);
376
+ auto thread = std::make_unique<std::thread[]>(parallel_devices);
377
+ std::vector<size_t > init_memory (parallel_devices);
373
378
374
379
// Store the start timepoint of rocblas initialize
375
380
auto start_time = std::chrono::steady_clock::now ();
376
381
377
382
if (parallel_devices == 1 )
383
+ {
384
+ size_t before_init, after_init, total_memory;
385
+ CHECK_HIP_ERROR (hipMemGetInfo (&before_init, &total_memory));
378
386
rocblas_initialize ();
387
+ CHECK_HIP_ERROR (hipMemGetInfo (&after_init, &total_memory));
388
+ init_memory[0 ] = before_init - after_init;
389
+ }
379
390
else
380
391
{
392
+
381
393
for (int id = 0 ; id < parallel_devices; ++id)
382
- thread[id] = std::thread (rocblas_parallel_initialize_thread, id);
394
+ thread[id]
395
+ = std::thread (rocblas_parallel_initialize_thread, id, std::ref (init_memory[id]));
383
396
for (int id = 0 ; id < parallel_devices; ++id)
384
397
thread[id].join ();
385
398
}
@@ -410,4 +423,15 @@ void rocblas_parallel_initialize(int parallel_devices)
410
423
rocblas_cerr << " \n rocBLAS info: average time to initialize each device exceeded the max "
411
424
" duration of "
412
425
<< max_duration << " milliseconds. Check CPU's load metrics." << std::endl;
426
+
427
+ constexpr static float max_memory = 1.0 ;
428
+ auto max_library_size
429
+ = *std::max_element (std::begin (init_memory), std::end (init_memory)) * 1.0e-9 ;
430
+
431
+ rocblas_cout << " \n rocBLAS info: maximum library size per device is " << max_library_size
432
+ << " GB." << std::endl;
433
+ if (max_library_size > max_memory)
434
+ rocblas_cerr << " \n rocBLAS info: max kernel library size " << max_library_size
435
+ << " GB exceeds the max recommended memory " << max_memory
436
+ << " GB. Check library logic file sizes." << std::endl;
413
437
}
0 commit comments