Skip to content

Commit 19f14cd

Browse files
authored
Merge pull request #1311 from yenong-amd/release/rocm-rel-5.6
Hotfix for rocblas initialization time issues
2 parents 089c30b + 4f24b81 commit 19f14cd

File tree

6 files changed

+87
-34
lines changed

6 files changed

+87
-34
lines changed

clients/benchmarks/client.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,7 @@ int run_bench_test(bool init,
941941
{
942942
if(init)
943943
{
944-
static int runOnce = (rocblas_client_initialize(), 0); // Initialize rocBLAS
944+
static int runOnce = (rocblas_parallel_initialize(1), 0); // Initialize rocBLAS
945945
}
946946

947947
rocblas_cout << std::setiosflags(std::ios::fixed)
@@ -1177,8 +1177,6 @@ void gpu_thread_init_device(int id,
11771177
{
11781178
CHECK_HIP_ERROR(hipSetDevice(id));
11791179

1180-
rocblas_client_initialize();
1181-
11821180
Arguments a(arg);
11831181
std::string name_filter = "";
11841182
a.cold_iters = 1;
@@ -1207,6 +1205,9 @@ int run_bench_gpu_test(int parallel_devices,
12071205
return 1;
12081206

12091207
// initialization
1208+
rocblas_parallel_initialize(parallel_devices);
1209+
1210+
// run cold call on each device
12101211
auto thread_init = std::make_unique<std::thread[]>(parallel_devices);
12111212

12121213
for(int id = 0; id < parallel_devices; ++id)
@@ -1215,7 +1216,7 @@ int run_bench_gpu_test(int parallel_devices,
12151216
for(int id = 0; id < parallel_devices; ++id)
12161217
thread_init[id].join();
12171218

1218-
// synchronzied launch of cold & hot calls
1219+
// synchronized launch of cold & hot calls
12191220
auto thread = std::make_unique<std::thread[]>(parallel_devices);
12201221

12211222
for(int id = 0; id < parallel_devices; ++id)

clients/common/utility.cpp

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -352,42 +352,86 @@ void rocblas_local_handle::rocblas_stream_end_capture()
352352
#endif
353353
}
354354

355+
void rocblas_parallel_initialize_thread(int id, size_t& memory_used)
356+
{
357+
size_t before_init, after_init, total_memory;
358+
CHECK_HIP_ERROR(hipSetDevice(id));
359+
CHECK_HIP_ERROR(hipMemGetInfo(&before_init, &total_memory));
360+
rocblas_initialize();
361+
CHECK_HIP_ERROR(hipMemGetInfo(&after_init, &total_memory));
362+
memory_used = before_init - after_init;
363+
}
364+
355365
/*!
356-
* Initialize rocBLAS for the current HIP device and report
357-
* the time taken to complete the initialization. This is to
358-
* avoid costly startup time at the first call on that device.
359-
* Internal use for benchmark & testing.
366+
* Initialize rocBLAS for the requested number of HIP devices
367+
* and report the time taken to complete the initialization.
368+
* This is to avoid costly startup time at the first call on
369+
* that device. Internal use for benchmark & testing.
370+
* Initializes devices indexed from 0 to parallel_devices-1.
371+
* If parallel_devices is 1, hipSetDevice should be called
372+
* before calling this function.
360373
*/
361-
void rocblas_client_initialize()
374+
void rocblas_parallel_initialize(int parallel_devices)
362375
{
363-
// when executed on a CPU under normal load( Disk I/O, memory etc.),
364-
// this routine completes execution under max limit of 12 seconds.
365-
// The minimum time it takes to complete varies based on
366-
// the architecture & build options used while building the library.
367-
// Setting a max duration of 5 seconds for rocblas library initialization to complete.
368-
constexpr static int max_duration = 5;
376+
auto thread = std::make_unique<std::thread[]>(parallel_devices);
377+
std::vector<size_t> init_memory(parallel_devices);
369378

370379
// Store the start timepoint of rocblas initialize
371380
auto start_time = std::chrono::steady_clock::now();
372381

373-
rocblas_initialize();
382+
if(parallel_devices == 1)
383+
{
384+
size_t before_init, after_init, total_memory;
385+
CHECK_HIP_ERROR(hipMemGetInfo(&before_init, &total_memory));
386+
rocblas_initialize();
387+
CHECK_HIP_ERROR(hipMemGetInfo(&after_init, &total_memory));
388+
init_memory[0] = before_init - after_init;
389+
}
390+
else
391+
{
392+
393+
for(int id = 0; id < parallel_devices; ++id)
394+
thread[id]
395+
= std::thread(rocblas_parallel_initialize_thread, id, std::ref(init_memory[id]));
396+
for(int id = 0; id < parallel_devices; ++id)
397+
thread[id].join();
398+
}
374399

375400
// Store the end timepoint of rocblas initialize
376401
auto end_time = std::chrono::steady_clock::now();
377402

378-
// Compute the time taken to load the Tensile kernels (in seconds).
379-
auto total_library_initialize_time
380-
= std::chrono::duration_cast<std::chrono::seconds>(end_time - start_time).count();
381-
382403
// Compute the time taken to load the Tensile kernels (in milliseconds).
383404
auto init_time_in_ms
384405
= std::chrono::duration_cast<std::chrono::milliseconds>(end_time - start_time).count();
385406

386407
rocblas_cout << "\nrocBLAS info: Time taken to complete rocBLAS library initialization is "
387408
<< init_time_in_ms << " milliseconds." << std::endl;
388409

389-
// If initialization time exceeds the max duration, display the following info message.
390-
if(total_library_initialize_time > max_duration)
391-
rocblas_cerr << "\nrocBLAS info: rocBLAS initialization exceeded the max duration of "
392-
<< max_duration << " seconds. Check CPU's load metrics." << std::endl;
410+
// Calculate average initialization time per GPU
411+
auto avg_init_time_in_ms = init_time_in_ms / parallel_devices;
412+
if(parallel_devices > 1)
413+
{
414+
rocblas_cout
415+
<< "\nrocBLAS info: Average time taken to complete rocBLAS library initialization "
416+
"per device is "
417+
<< avg_init_time_in_ms << " milliseconds." << std::endl;
418+
}
419+
420+
// If average initialization time exceeds the max duration, display the following info message.
421+
constexpr static int max_duration = 5000;
422+
if(avg_init_time_in_ms > max_duration)
423+
rocblas_cerr << "\nrocBLAS info: average time to initialize each device exceeded the max "
424+
"duration of "
425+
<< max_duration << " milliseconds. Check CPU's load metrics." << std::endl;
426+
427+
constexpr static float max_memory = 1.0;
428+
auto max_library_size
429+
= *std::max_element(std::begin(init_memory), std::end(init_memory)) * 1.0e-9;
430+
431+
rocblas_cout << "\nrocBLAS info: maximum library size per device is " << max_library_size
432+
<< " GB." << std::endl;
433+
if(max_library_size > max_memory)
434+
rocblas_cerr << "\nrocBLAS info: max kernel library size " << max_library_size
435+
<< " GB exceeds the max recommended memory " << max_memory
436+
<< " GB. Check library logic file sizes." << std::endl;
393437
}

clients/gtest/multiheaded_gtest.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/* ************************************************************************
2-
* Copyright (C) 2020-2022 Advanced Micro Devices, Inc. All rights reserved.
2+
* Copyright (C) 2020-2023 Advanced Micro Devices, Inc. All rights reserved.
33
*
44
* Permission is hereby granted, free of charge, to any person obtaining a copy
55
* of this software and associated documentation files (the "Software"), to deal
@@ -66,9 +66,6 @@ namespace
6666
{
6767
CHECK_HIP_ERROR(hipSetDevice(id));
6868

69-
//Initialize rocblas
70-
rocblas_client_initialize();
71-
7269
rocblas_operation transa = rocblas_operation_none, transb = rocblas_operation_transpose;
7370
float alpha = 1.1, beta = 0.9;
7471
rocblas_int m = 1023, n = 1024, k = 1025;
@@ -191,6 +188,9 @@ namespace
191188
<< std::endl;
192189
return;
193190
}
191+
192+
rocblas_parallel_initialize(count);
193+
194194
auto thread = std::make_unique<std::thread[]>(count);
195195

196196
for(int id = 0; id < count; ++id)

clients/include/utility.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,12 +96,15 @@
9696
#define NOOP (void)0
9797

9898
/*!
99-
* Initialize rocBLAS for the current HIP device and report
100-
* the time taken to complete the initialization. This is used to
101-
* avoid costly startup time at the first call on that device.
102-
* Internal use for benchmark & testing.
99+
* Initialize rocBLAS for the requested number of HIP devices
100+
* and report the time taken to complete the initialization.
101+
* This is to avoid costly startup time at the first call on
102+
* that device. Internal use for benchmark & testing.
103+
* Initializes devices indexed from 0 to parallel_devices-1.
104+
* If parallel_devices is 1, hipSetDevice should be called
105+
* before calling this function.
103106
*/
104-
void rocblas_client_initialize();
107+
void rocblas_parallel_initialize(int parallel_devices);
105108

106109
/* ============================================================================================ */
107110
/*! \brief local handle which is automatically created and destroyed */

clients/samples/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ foreach( exe ${sample_list_fortran} )
8080
endforeach( )
8181

8282
foreach( exe ${sample_list_all} )
83-
target_link_libraries( ${exe} PRIVATE roc::rocblas )
83+
target_link_libraries( ${exe} PRIVATE roc::rocblas Threads::Threads )
8484

8585
set_target_properties( ${exe} PROPERTIES
8686
CXX_STANDARD 14

library/src/tensile_host.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,9 @@ namespace
685685
if(!skip_xnack.empty()
686686
&& codeObjectFile.find(skip_xnack) != std::string::npos)
687687
continue;
688+
// Skip experimental libraries
689+
if(codeObjectFile.find("Experimental") != std::string::npos)
690+
continue;
688691
adapter.loadCodeObjectFile(codeObjectFile.c_str());
689692
} while(FindNextFileA(hfine, &finddata));
690693
}
@@ -703,6 +706,8 @@ namespace
703706
std::string cofile = glob_result.gl_pathv[i];
704707
if(!skip_xnack.empty() && cofile.find(skip_xnack) != std::string::npos)
705708
continue;
709+
if(cofile.find("Experimental") != std::string::npos)
710+
continue;
706711
adapter.loadCodeObjectFile(cofile);
707712
}
708713
}

0 commit comments

Comments
 (0)