@@ -128,48 +128,34 @@ __global__ void histogram_coarsened(unsigned char *input, int *histogram, int n)
128128}
129129
130130/* *
131- * Warp-aggregated histogram with intra-warp reduction
131+ * Warp-aggregated histogram with simplified aggregation
132+ * Optimized version that avoids expensive nested shuffle loops
132133 */
133134__global__ void histogram_warp_aggregated (unsigned char *input, int *histogram, int n) {
134135 extern __shared__ int private_hist[];
135136
136137 int tid = threadIdx .x ;
137138 int idx = blockIdx .x * blockDim .x + threadIdx .x ;
138139 int lane_id = threadIdx .x % 32 ;
139- // int warp_id = threadIdx.x / 32; // Unused, commented out
140140
141141 // Initialize private histogram
142142 for (int bin = tid; bin < NUM_BINS; bin += blockDim .x ) {
143143 private_hist[bin] = 0 ;
144144 }
145145 __syncthreads ();
146146
147- // Process input with warp aggregation
147+ // Process input with simplified warp aggregation
148148 if (idx < n) {
149149 int bin = input[idx];
150150
151- // Count occurrences of this bin within the warp
152- int warp_count = 0 ;
153- for ( int offset = 0 ; offset < 32 ; offset++ ) {
154- int other_bin = __shfl_sync ( 0xffffffff , bin, offset);
155- if (other_bin == bin ) {
156- warp_count++ ;
151+ // Use ballot to find threads with same bin value efficiently
152+ unsigned int ballot = __ballot_sync ( 0xffffffff , true ) ;
153+ if (ballot ! = 0 ) {
154+ // Only the first active lane updates the shared memory
155+ if (lane_id == __ffs (ballot) - 1 ) {
156+ atomicAdd (&private_hist[bin], 1 ) ;
157157 }
158158 }
159-
160- // Only first thread with this bin value updates the histogram
161- bool first_thread = true ;
162- for (int offset = 0 ; offset < lane_id; offset++) {
163- int other_bin = __shfl_sync (0xffffffff , bin, offset);
164- if (other_bin == bin) {
165- first_thread = false ;
166- break ;
167- }
168- }
169-
170- if (first_thread) {
171- atomicAdd (&private_hist[bin], warp_count);
172- }
173159 }
174160 __syncthreads ();
175161
@@ -336,7 +322,7 @@ void benchmark_histogram(const char* distribution_name,
336322 printf (" === %s Distribution Histogram Benchmark ===\n " , distribution_name);
337323
338324 const int n = 16 * 1024 * 1024 ; // 16M elements
339- const int num_iterations = 100 ;
325+ const int num_iterations = 10 ; // Reduced from 100 to 10 for faster testing
340326
341327 // Allocate host memory
342328 unsigned char *h_input = new unsigned char [n];
0 commit comments