Skip to content

Commit 24c1a0f

Browse files
committed
Implement Send + Sync for CudaStream
1 parent a6b4de8 commit 24c1a0f

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

src/driver/safe/core.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,9 @@ pub struct CudaStream {
484484
device: Arc<CudaDevice>,
485485
}
486486

487+
unsafe impl Send for CudaStream {}
488+
unsafe impl Sync for CudaStream {}
489+
487490
impl CudaDevice {
488491
/// Allocates a new stream that can execute kernels concurrently to the default stream.
489492
///
@@ -906,6 +909,8 @@ impl<R: RangeBounds<usize>> RangeHelper for R {
906909

907910
#[cfg(test)]
908911
mod tests {
912+
use std::thread;
913+
909914
use super::*;
910915

911916
#[test]
@@ -930,4 +935,17 @@ mod tests {
930935
assert!(unsafe { slice.transmute_mut::<f32>(25) }.is_some());
931936
assert!(unsafe { slice.transmute_mut::<f32>(26) }.is_none());
932937
}
938+
939+
#[test]
940+
fn test_send_dev() {
941+
let dev = CudaDevice::new(0).unwrap();
942+
thread::spawn(|| dev);
943+
}
944+
945+
#[test]
946+
fn test_send_stream() {
947+
let dev = CudaDevice::new(0).unwrap();
948+
let stream = dev.fork_default_stream().unwrap();
949+
thread::spawn(|| stream);
950+
}
933951
}

0 commit comments

Comments
 (0)