Skip to content

Commit

Permalink
improve copy from GPU to Rust data structure
Browse files Browse the repository at this point in the history
  • Loading branch information
dloghin committed Sep 5, 2024
1 parent 7347098 commit 73ca009
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 44 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ members = ["field", "maybe_rayon", "plonky2", "starky", "util", "gen", "u32", "e
resolver = "2"

[workspace.dependencies]
cryptography_cuda = { git = "ssh://[email protected]/okx/cryptography_cuda.git", rev = "f2ed17c3086b9ca538272974e42b47e4bf7970e2" }
cryptography_cuda = { git = "ssh://[email protected]/okx/cryptography_cuda.git", rev = "2422995ca4d1d93c798676eeb55fc75d1bc0565e" }
ahash = { version = "0.8.7", default-features = false, features = [
"compile-time-rng",
] } # NOTE: Be sure to keep this version the same as the dependency in `hashbrown`.
Expand Down
77 changes: 34 additions & 43 deletions plonky2/src/hash/merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,6 @@ fn fill_digests_buf<F: RichField, H: Hasher<F>>(
*/
}

#[cfg(feature = "cuda")]
#[repr(C)]
union U8U64 {
f1: [u8; 32],
f2: [u64; 4],
}

#[cfg(feature = "cuda")]
fn fill_digests_buf_gpu<F: RichField, H: Hasher<F>>(
digests_buf: &mut [MaybeUninit<H::Hash>],
Expand Down Expand Up @@ -382,17 +375,15 @@ fn fill_digests_buf_gpu_ptr<F: RichField, H: Hasher<F>>(
}
}
print_time(now, "fill init");

let mut host_digests: Vec<F> = vec![F::ZERO; digests_size];
let mut host_caps: Vec<F> = vec![F::ZERO; caps_size];

let stream1 = CudaStream::create().unwrap();
let stream2 = CudaStream::create().unwrap();

gpu_digests_buf
.copy_to_host_async(host_digests.as_mut_slice(), &stream1)
.copy_to_host_ptr_async(digests_buf.as_mut_ptr() as *mut core::ffi::c_void, digests_size, &stream1)
.expect("copy digests");
gpu_cap_buf
.copy_to_host_async(host_caps.as_mut_slice(), &stream2)
.copy_to_host_ptr_async(cap_buf.as_mut_ptr() as *mut core::ffi::c_void, caps_size, &stream2)
.expect("copy caps");
stream1.synchronize().expect("cuda sync");
stream2.synchronize().expect("cuda sync");
Expand All @@ -401,39 +392,38 @@ fn fill_digests_buf_gpu_ptr<F: RichField, H: Hasher<F>>(

let now = Instant::now();

if digests_buf.len() > 0 {
host_digests
.chunks_exact(4)
.zip(digests_buf)
.for_each(|(x, y)| {
unsafe {
let mut parts = U8U64 { f1: [0; 32] };
parts.f2[0] = x[0].to_canonical_u64();
parts.f2[1] = x[1].to_canonical_u64();
parts.f2[2] = x[2].to_canonical_u64();
parts.f2[3] = x[3].to_canonical_u64();
let (slice, _) = parts.f1.split_at(H::HASH_SIZE);
let h: H::Hash = H::Hash::from_bytes(slice);
y.write(h);
};
});
}
print_time(now, "copy results");
}

if cap_buf.len() > 0 {
host_caps.chunks_exact(4).zip(cap_buf).for_each(|(x, y)| {
unsafe {
let mut parts = U8U64 { f1: [0; 32] };
parts.f2[0] = x[0].to_canonical_u64();
parts.f2[1] = x[1].to_canonical_u64();
parts.f2[2] = x[2].to_canonical_u64();
parts.f2[3] = x[3].to_canonical_u64();
let (slice, _) = parts.f1.split_at(H::HASH_SIZE);
let h: H::Hash = H::Hash::from_bytes(slice);
y.write(h);
};
});
#[cfg(feature = "cuda")]
#[allow(dead_code)]
fn fill_digests_buf_cpu<F: RichField, H: Hasher<F>>(
digests_buf: &mut [MaybeUninit<H::Hash>],
cap_buf: &mut [MaybeUninit<H::Hash>],
leaves: &Vec<F>,
leaf_size: usize,
cap_height: usize,
) {
use cryptography_cuda::fill_digests_buf_linear_cpu;

let leaves_count = (leaves.len() / leaf_size) as u64;
let digests_count: u64 = digests_buf.len().try_into().unwrap();
let caps_count: u64 = cap_buf.len().try_into().unwrap();
let cap_height: u64 = cap_height.try_into().unwrap();

unsafe {
fill_digests_buf_linear_cpu(
digests_buf.as_mut_ptr() as *mut core::ffi::c_void,
cap_buf.as_mut_ptr() as *mut core::ffi::c_void,
leaves.as_ptr() as *const core::ffi::c_void,
digests_count,
caps_count,
leaves_count,
leaf_size as u64,
cap_height,
H::HASHER_TYPE as u64,
);
}
print_time(now, "copy results");
}

#[cfg(feature = "cuda")]
Expand All @@ -449,6 +439,7 @@ fn fill_digests_buf_meta<F: RichField, H: Hasher<F>>(
fill_digests_buf::<F, H>(digests_buf, cap_buf, leaves, leaf_size, cap_height);
} else {
fill_digests_buf_gpu::<F, H>(digests_buf, cap_buf, leaves, leaf_size, cap_height);
// fill_digests_buf_cpu::<F, H>(digests_buf, cap_buf, leaves, leaf_size, cap_height);
}
}

Expand Down

0 comments on commit 73ca009

Please sign in to comment.