@@ -72,16 +72,15 @@ void AMOStandardTester<T>::launchKernel(dim3 gridsize, dim3 blocksize, int loop,
7272 _ret_val, _type, _shmem_context);
7373
7474 _gridSize = gridsize;
75- num_msgs = (loop + args.skip ) * gridsize.x ;
76- num_timed_msgs = loop;
75+ num_msgs = (loop + args.skip ) * gridsize.x * blocksize. x ;
76+ num_timed_msgs = loop * gridsize. x * blocksize. x ;
7777}
7878
7979template <typename T>
8080void AMOStandardTester<T>::verifyResults(size_t size) {
8181 T ret;
8282 if (args.myid == 0 ) {
8383 T expected_val = 0 ;
84-
8584 switch (_type) {
8685 case AMO_FAddTestType:
8786 expected_val = 2 * (num_msgs - 1 );
@@ -102,7 +101,7 @@ void AMOStandardTester<T>::verifyResults(size_t size) {
102101 break ;
103102 }
104103
105- int fetch_op = (_type == AMO_FAddTestType || _type == AMO_FIncTestType || _type == AMO_FCswapTestType) ? 1 : 0 ;
104+ int fetch_op = (_type == AMO_FAddTestType || _type == AMO_FIncTestType || _type == AMO_FCswapTestType) ? 1 : 0 ;
106105
107106 if (fetch_op == 1 ) {
108107 ret = *std::max_element (_ret_val, _ret_val + args.num_wgs );
@@ -117,55 +116,71 @@ void AMOStandardTester<T>::verifyResults(size_t size) {
117116 }
118117}
119118
120- #define AMO_STANDARD_DEF_GEN (T, TNAME ) \
121- template <> \
122- __global__ void AMOStandardTest<T>( \
123- int loop, int skip, long long int *start_time, \
124- long long int *end_time, char *r_buf, T *s_buf, T *ret_val, \
125- TestType type, ShmemContextType ctx_type) { \
126- __shared__ rocshmem_ctx_t ctx; \
127- int wg_id = get_flat_grid_id (); \
128- rocshmem_wg_init (); \
129- rocshmem_wg_ctx_create (ctx_type, &ctx); \
130- if (hipThreadIdx_x == 0 ) { \
131- T ret = 0 ; \
132- T cond = 0 ; \
133- for (int i = 0 ; i < loop + skip; i++) { \
134- if (i == skip) { \
135- start_time[wg_id] = wall_clock64 (); \
136- } \
137- switch (type) { \
138- case AMO_FAddTestType: \
139- ret = rocshmem_ctx_##TNAME##_atomic_fetch_add (ctx, (T *)r_buf, 2 , \
140- 1 ); \
141- break ; \
142- case AMO_FIncTestType: \
143- ret = \
144- rocshmem_ctx_##TNAME##_atomic_fetch_inc (ctx, (T *)r_buf, 1 ); \
145- break ; \
146- case AMO_FCswapTestType: \
147- ret = rocshmem_ctx_##TNAME##_atomic_compare_swap (ctx, (T *)r_buf, \
148- cond, (T)i, 1 ); \
149- cond = i; \
150- break ; \
151- case AMO_AddTestType: \
152- rocshmem_ctx_##TNAME##_atomic_add (ctx, (T *)r_buf, 2 , 1 ); \
153- break ; \
154- case AMO_IncTestType: \
155- rocshmem_ctx_##TNAME##_atomic_inc (ctx, (T *)r_buf, 1 ); \
156- break ; \
157- default : \
158- break ; \
159- } \
160- } \
161- rocshmem_ctx_quiet (ctx); \
162- end_time[wg_id] = wall_clock64 (); \
163- ret_val[wg_id] = ret; \
164- rocshmem_ctx_getmem (ctx, &s_buf[wg_id], r_buf, sizeof (T), 1 ); \
165- } \
166- rocshmem_wg_ctx_destroy (&ctx); \
167- rocshmem_wg_finalize (); \
168- } \
119+ #define AMO_STANDARD_DEF_GEN (T, TNAME ) \
120+ template <> \
121+ __global__ void AMOStandardTest<T>(int loop, int skip, long long int *start_time, long long int *end_time, char *r_buf, \
122+ T *s_buf, T *ret_val, TestType type, ShmemContextType ctx_type) { \
123+ __shared__ rocshmem_ctx_t ctx; \
124+ rocshmem_wg_init (); \
125+ rocshmem_wg_ctx_create (ctx_type, &ctx); \
126+ __shared__ long long int wf_start_time[16 ]; \
127+ __shared__ long long int wf_ret_val[16 ]; \
128+ int wg_id = get_flat_grid_id (); \
129+ int t_id = get_flat_block_id (); \
130+ int wf_size = 64 ; \
131+ int wf_id = t_id / wf_size; \
132+ wf_ret_val[wf_id] = 0 ; \
133+ T ret = 0 ; \
134+ T cond = 0 ; \
135+ for (int i = 0 ; i < loop + skip; i++) { \
136+ if (i == skip) { \
137+ wf_start_time[wf_id] = wall_clock64 (); \
138+ } \
139+ switch (type) { \
140+ case AMO_FAddTestType: \
141+ ret = rocshmem_ctx_##TNAME##_atomic_fetch_add (ctx, (T *)r_buf, 2 , 1 ); \
142+ break ; \
143+ case AMO_FCswapTestType: \
144+ ret = rocshmem_ctx_##TNAME##_atomic_compare_swap (ctx, (T *)r_buf, cond, (T)i, 1 ); \
145+ cond = i; \
146+ break ; \
147+ case AMO_FIncTestType: \
148+ ret = rocshmem_ctx_##TNAME##_atomic_fetch_inc (ctx, (T *)r_buf, 1 ); \
149+ break ; \
150+ case AMO_AddTestType: \
151+ rocshmem_ctx_##TNAME##_atomic_add (ctx, (T *)r_buf, 2 , 1 ); \
152+ break ; \
153+ case AMO_IncTestType: \
154+ rocshmem_ctx_##TNAME##_atomic_inc (ctx, (T *)r_buf, 1 ); \
155+ break ; \
156+ default : \
157+ break ; \
158+ } \
159+ } \
160+ rocshmem_ctx_quiet (ctx); \
161+ end_time[wg_id] = wall_clock64 (); \
162+ rocshmem_ctx_getmem (ctx, &s_buf[wg_id], r_buf, sizeof (T), 1 ); \
163+ __hip_atomic_fetch_max (&wf_ret_val[wf_id], ret, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_WORKGROUP); \
164+ __syncthreads (); \
165+ int num_wfs = (get_flat_block_size () - 1 ) / wf_size + 1 ; \
166+ for (int i = num_wfs / 2 ; i > 0 ; i >>= 1 ) { \
167+ if (t_id < i) { \
168+ wf_ret_val[t_id] = max (wf_ret_val[t_id], wf_ret_val[t_id + i]); \
169+ } \
170+ } \
171+ ret_val[wg_id] = wf_ret_val[0 ]; \
172+ for (int i = num_wfs / 2 ; i > 0 ; i >>= 1 ) { \
173+ if (t_id < i) { \
174+ wf_start_time[t_id] = min (wf_start_time[t_id], wf_start_time[t_id + i]); \
175+ } \
176+ } \
177+ __syncthreads (); \
178+ if (t_id == 0 ) { \
179+ start_time[wg_id] = wf_start_time[0 ]; \
180+ } \
181+ rocshmem_wg_ctx_destroy (&ctx); \
182+ rocshmem_wg_finalize (); \
183+ } \
169184 template class AMOStandardTester <T>;
170185
171186AMO_STANDARD_DEF_GEN (int , int )
0 commit comments