@@ -545,7 +545,7 @@ def test_jit_script_compatible(
545545@pytest .mark .parametrize ("include_transpose" , [True , False ])
546546@pytest .mark .parametrize ("box_type" , [None , "triclinic" , "rectangular" ])
547547@pytest .mark .parametrize ("dtype" , [torch .float32 ])
548- def test_cuda_graph_compatible (
548+ def test_cuda_graph_compatible_forward (
549549 device , strategy , n_batches , cutoff , loop , include_transpose , box_type , dtype
550550):
551551 if device == "cuda" and not torch .cuda .is_available ():
@@ -563,7 +563,7 @@ def test_cuda_graph_compatible(
563563 # Ensure there is at least one pair
564564 pos [0 , :] = torch .zeros (3 )
565565 pos [1 , :] = torch .zeros (3 )
566- pos .requires_grad = True
566+ pos .requires_grad_ ( True )
567567 if box_type is None :
568568 box = None
569569 else :
@@ -618,3 +618,76 @@ def test_cuda_graph_compatible(
618618 assert np .allclose (neighbors , ref_neighbors )
619619 assert np .allclose (distances , ref_distances )
620620 assert np .allclose (distance_vecs , ref_distance_vecs )
621+
622+ @pytest .mark .parametrize ("device" , ["cuda" ])
623+ @pytest .mark .parametrize ("strategy" , ["brute" , "shared" , "cell" ])
624+ @pytest .mark .parametrize ("n_batches" , [1 , 128 ])
625+ @pytest .mark .parametrize ("cutoff" , [1.0 ])
626+ @pytest .mark .parametrize ("loop" , [True , False ])
627+ @pytest .mark .parametrize ("include_transpose" , [True , False ])
628+ @pytest .mark .parametrize ("box_type" , [None , "triclinic" , "rectangular" ])
629+ @pytest .mark .parametrize ("dtype" , [torch .float32 ])
630+ def test_cuda_graph_compatible_backward (
631+ device , strategy , n_batches , cutoff , loop , include_transpose , box_type , dtype
632+ ):
633+ if device == "cuda" and not torch .cuda .is_available ():
634+ pytest .skip ("CUDA not available" )
635+ if box_type == "triclinic" and strategy == "cell" :
636+ pytest .skip ("Triclinic only supported for brute force" )
637+ torch .manual_seed (4321 )
638+ n_atoms_per_batch = torch .randint (3 , 100 , size = (n_batches ,))
639+ batch = torch .repeat_interleave (
640+ torch .arange (n_batches , dtype = torch .int64 ), n_atoms_per_batch
641+ ).to (device )
642+ cumsum = np .cumsum (np .concatenate ([[0 ], n_atoms_per_batch ]))
643+ lbox = 10.0
644+ pos = torch .rand (cumsum [- 1 ], 3 , device = device , dtype = dtype ) * lbox
645+ # Ensure there is at least one pair
646+ pos [0 , :] = torch .zeros (3 )
647+ pos [1 , :] = torch .zeros (3 )
648+ pos .requires_grad_ (True )
649+ if box_type is None :
650+ box = None
651+ else :
652+ box = (
653+ torch .tensor ([[lbox , 0.0 , 0.0 ], [0.0 , lbox , 0.0 ], [0.0 , 0.0 , lbox ]])
654+ .to (pos .dtype )
655+ .to (device )
656+ )
657+ ref_neighbors , ref_distance_vecs , ref_distances = compute_ref_neighbors (
658+ pos , batch , loop , include_transpose , cutoff , box
659+ )
660+ max_num_pairs = ref_neighbors .shape [1 ]
661+ s = torch .cuda .Stream ()
662+ s .wait_stream (torch .cuda .current_stream ())
663+ with torch .cuda .stream (s ):
664+ nl = OptimizedDistance (
665+ cutoff_lower = 0.0 ,
666+ loop = loop ,
667+ cutoff_upper = cutoff ,
668+ max_num_pairs = max_num_pairs ,
669+ strategy = strategy ,
670+ box = box ,
671+ return_vecs = True ,
672+ include_transpose = include_transpose ,
673+ check_errors = False ,
674+ resize_to_fit = False ,
675+ )
676+ batch .to (device )
677+
678+ graph = torch .cuda .CUDAGraph ()
679+ # Warm up
680+ neighbors , distappnces , distance_vecs = nl (pos , batch )
681+ for _ in range (10 ):
682+ neighbors , distances , distance_vecs = nl (pos , batch )
683+ distances .sum ().backward ()
684+ pos .grad .data .zero_ ()
685+ torch .cuda .synchronize ()
686+
687+ # Capture
688+ with torch .cuda .graph (graph ):
689+ neighbors , distances , distance_vecs = nl (pos , batch )
690+ distances .sum ().backward ()
691+ pos .grad .data .zero_ ()
692+ graph .replay ()
693+ torch .cuda .synchronize ()
0 commit comments