diff --git a/Cargo.toml b/Cargo.toml index 6d9c89ac0e..f062218275 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = ["field", "maybe_rayon", "plonky2", "starky", "util", "gen", "u32", "e resolver = "2" [workspace.dependencies] -cryptography_cuda = { git = "ssh://git@github.com/okx/cryptography_cuda.git", rev = "f2ed17c3086b9ca538272974e42b47e4bf7970e2" } +cryptography_cuda = { git = "ssh://git@github.com/okx/cryptography_cuda.git", rev = "2422995ca4d1d93c798676eeb55fc75d1bc0565e" } ahash = { version = "0.8.7", default-features = false, features = [ "compile-time-rng", ] } # NOTE: Be sure to keep this version the same as the dependency in `hashbrown`. diff --git a/plonky2/src/hash/merkle_tree.rs b/plonky2/src/hash/merkle_tree.rs index f3b4f50488..785e25d4fa 100644 --- a/plonky2/src/hash/merkle_tree.rs +++ b/plonky2/src/hash/merkle_tree.rs @@ -257,13 +257,6 @@ fn fill_digests_buf>( */ } -#[cfg(feature = "cuda")] -#[repr(C)] -union U8U64 { - f1: [u8; 32], - f2: [u64; 4], -} - #[cfg(feature = "cuda")] fn fill_digests_buf_gpu>( digests_buf: &mut [MaybeUninit], @@ -382,17 +375,15 @@ fn fill_digests_buf_gpu_ptr>( } } print_time(now, "fill init"); - - let mut host_digests: Vec = vec![F::ZERO; digests_size]; - let mut host_caps: Vec = vec![F::ZERO; caps_size]; + let stream1 = CudaStream::create().unwrap(); let stream2 = CudaStream::create().unwrap(); gpu_digests_buf - .copy_to_host_async(host_digests.as_mut_slice(), &stream1) + .copy_to_host_ptr_async(digests_buf.as_mut_ptr() as *mut core::ffi::c_void, digests_size, &stream1) .expect("copy digests"); gpu_cap_buf - .copy_to_host_async(host_caps.as_mut_slice(), &stream2) + .copy_to_host_ptr_async(cap_buf.as_mut_ptr() as *mut core::ffi::c_void, caps_size, &stream2) .expect("copy caps"); stream1.synchronize().expect("cuda sync"); stream2.synchronize().expect("cuda sync"); @@ -401,39 +392,38 @@ fn fill_digests_buf_gpu_ptr>( let now = Instant::now(); - if digests_buf.len() > 0 { - host_digests - .chunks_exact(4) - .zip(digests_buf) - .for_each(|(x, y)| { - unsafe { - let mut parts = U8U64 { f1: [0; 32] }; - parts.f2[0] = x[0].to_canonical_u64(); - parts.f2[1] = x[1].to_canonical_u64(); - parts.f2[2] = x[2].to_canonical_u64(); - parts.f2[3] = x[3].to_canonical_u64(); - let (slice, _) = parts.f1.split_at(H::HASH_SIZE); - let h: H::Hash = H::Hash::from_bytes(slice); - y.write(h); - }; - }); - } + print_time(now, "copy results"); +} - if cap_buf.len() > 0 { - host_caps.chunks_exact(4).zip(cap_buf).for_each(|(x, y)| { - unsafe { - let mut parts = U8U64 { f1: [0; 32] }; - parts.f2[0] = x[0].to_canonical_u64(); - parts.f2[1] = x[1].to_canonical_u64(); - parts.f2[2] = x[2].to_canonical_u64(); - parts.f2[3] = x[3].to_canonical_u64(); - let (slice, _) = parts.f1.split_at(H::HASH_SIZE); - let h: H::Hash = H::Hash::from_bytes(slice); - y.write(h); - }; - }); +#[cfg(feature = "cuda")] +#[allow(dead_code)] +fn fill_digests_buf_cpu>( + digests_buf: &mut [MaybeUninit], + cap_buf: &mut [MaybeUninit], + leaves: &Vec, + leaf_size: usize, + cap_height: usize, +) { + use cryptography_cuda::fill_digests_buf_linear_cpu; + + let leaves_count = (leaves.len() / leaf_size) as u64; + let digests_count: u64 = digests_buf.len().try_into().unwrap(); + let caps_count: u64 = cap_buf.len().try_into().unwrap(); + let cap_height: u64 = cap_height.try_into().unwrap(); + + unsafe { + fill_digests_buf_linear_cpu( + digests_buf.as_mut_ptr() as *mut core::ffi::c_void, + cap_buf.as_mut_ptr() as *mut core::ffi::c_void, + leaves.as_ptr() as *const core::ffi::c_void, + digests_count, + caps_count, + leaves_count, + leaf_size as u64, + cap_height, + H::HASHER_TYPE as u64, + ); } - print_time(now, "copy results"); } #[cfg(feature = "cuda")] @@ -449,6 +439,7 @@ fn fill_digests_buf_meta>( fill_digests_buf::(digests_buf, cap_buf, leaves, leaf_size, cap_height); } else { fill_digests_buf_gpu::(digests_buf, cap_buf, leaves, leaf_size, cap_height); + // fill_digests_buf_cpu::(digests_buf, cap_buf, leaves, leaf_size, cap_height); } }