Skip to content

Commit

Permalink
Improve aggregate sum (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
psvri authored Aug 28, 2024
1 parent 4e3042e commit 4460c54
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 28 deletions.
15 changes: 10 additions & 5 deletions crates/arithmetic/compute_shaders/f32/aggregate.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,21 @@ fn sum(
@builtin(workgroup_id) wg_id: vec3<u32>
) {
if global_id.x >= arrayLength(&input_data) {
return;
shared_data[local_id.x] = 0.0;
} else {
shared_data[local_id.x] = input_data[global_id.x];
}

shared_data[local_id.x] = input_data[global_id.x];

workgroupBarrier();

for (var s = 1u; s < wg_size; s *= 2u) {
if (local_id.x % (2u*s) == 0) && (global_id.x + s < arrayLength(&input_data)) {
shared_data[local_id.x] += shared_data[local_id.x + s];

var index = 2 * s * local_id.x;

if (index < wg_size && (index + s) < wg_size) {
shared_data[index] += shared_data[index + s];
}

workgroupBarrier();
}

Expand Down
15 changes: 10 additions & 5 deletions crates/arithmetic/compute_shaders/i32/aggregate.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,21 @@ fn sum(
@builtin(workgroup_id) wg_id: vec3<u32>
) {
if global_id.x >= arrayLength(&input_data) {
return;
shared_data[local_id.x] = 0;
} else {
shared_data[local_id.x] = input_data[global_id.x];
}

shared_data[local_id.x] = input_data[global_id.x];

workgroupBarrier();

for (var s = 1u; s < wg_size; s *= 2u) {
if (local_id.x % (2u*s) == 0) && (global_id.x + s < arrayLength(&input_data)) {
shared_data[local_id.x] += shared_data[local_id.x + s];

var index = 2 * s * local_id.x;

if (index < wg_size && (index + s) < wg_size) {
shared_data[index] += shared_data[index + s];
}

workgroupBarrier();
}

Expand Down
15 changes: 10 additions & 5 deletions crates/arithmetic/compute_shaders/u32/aggregate.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,21 @@ fn sum(
@builtin(workgroup_id) wg_id: vec3<u32>
) {
if global_id.x >= arrayLength(&input_data) {
return;
shared_data[local_id.x] = 0u;
} else {
shared_data[local_id.x] = input_data[global_id.x];
}

shared_data[local_id.x] = input_data[global_id.x];

workgroupBarrier();

for (var s = 1u; s < wg_size; s *= 2u) {
if (local_id.x % (2u*s) == 0) && (global_id.x + s < arrayLength(&input_data)) {
shared_data[local_id.x] += shared_data[local_id.x + s];

var index = 2 * s * local_id.x;

if (index < wg_size && (index + s) < wg_size) {
shared_data[index] += shared_data[index + s];
}

workgroupBarrier();
}

Expand Down
34 changes: 21 additions & 13 deletions crates/benchmarks/benches/compare_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use arrow::compute::kernels::aggregate::sum;
use arrow_gpu::kernels::broadcast::Broadcast;
use arrow_gpu::kernels::Sum;
use arrow_gpu::{array::*, gpu_utils::*};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput};

fn bench_cpu_u32_add(data: &mut UInt32Array) -> u32 {
sum(data).unwrap()
Expand All @@ -16,20 +16,28 @@ fn bench_gpu_u32_add(data: &mut UInt32ArrayGPU) -> UInt32ArrayGPU {
}

pub fn criterion_benchmark(c: &mut Criterion) {
let device = GpuDevice::new();
let device = Arc::new(GpuDevice::new());

let size = 4 * 1024 * 1024;
let size = 1 * 1024 * 1024;
let base_value = 2;
let input = vec![base_value; size];

let mut gpu_data = UInt32ArrayGPU::broadcast(base_value, size, Arc::new(device));
let mut cpu_data = UInt32Array::from(input);
c.bench_function("sum gpu u32", |b| {
b.iter(|| bench_gpu_u32_add(black_box(&mut gpu_data)))
});
c.bench_function("sum cpu u32", |b| {
b.iter(|| bench_cpu_u32_add(black_box(&mut cpu_data)))
});

let mut group = c.benchmark_group("u32_array_sum");

for i in [1, 10] {
let input_size = size * i;

let mut gpu_data = UInt32ArrayGPU::broadcast(base_value, input_size, device.clone());
let mut cpu_data = UInt32Array::from(vec![base_value; input_size]);

group.throughput(Throughput::Bytes(input_size as u64));

group.bench_function(format!("sum gpu u32 {} million", i), |b| {
b.iter(|| bench_gpu_u32_add(black_box(&mut gpu_data)))
});
group.bench_function(format!("sum cpu u32 {} million", i), |b| {
b.iter(|| bench_cpu_u32_add(black_box(&mut cpu_data)))
});
}
}

criterion_group!(benches, criterion_benchmark);
Expand Down

0 comments on commit 4460c54

Please sign in to comment.