@@ -168,6 +168,16 @@ pub fn normalize_cfg<H: HugrMut>(
168168 _ => unreachable ! ( ) , // Checked at entry to normalize_cfg
169169 }
170170 }
171+ let ancestor_block = |h : & H , mut n : H :: Node | {
172+ while let Some ( p) = h. get_parent ( n) {
173+ if p == cfg_node {
174+ return Some ( n) ;
175+ }
176+ n = p;
177+ }
178+ None
179+ } ;
180+
171181 // Further normalizations with effects outside the CFG
172182 let [ entry, exit] = h. children ( cfg_node) . take ( 2 ) . collect_array ( ) . unwrap ( ) ;
173183 let entry_blk = h. get_optype ( entry) . as_dataflow_block ( ) . unwrap ( ) ;
@@ -209,49 +219,60 @@ pub fn normalize_cfg<H: HugrMut>(
209219 }
210220 // 1b. Move entry block outside/before the CFG into a DFG; its successor becomes the entry block.
211221 let new_cfg_inputs = entry_blk. successor_input ( 0 ) . unwrap ( ) ;
212- let dfg = h. add_node_with_parent (
213- cfg_parent,
214- DFG {
215- signature : Signature :: new ( entry_blk. inputs . clone ( ) , new_cfg_inputs. clone ( ) ) ,
216- } ,
217- ) ;
218- let [ _, entry_output] = h. get_io ( entry) . unwrap ( ) ;
219- while let Some ( n) = h. first_child ( entry) {
220- h. set_parent ( n, dfg) ;
221- }
222- h. move_before_sibling ( succ, entry) ;
223- h. remove_node ( entry) ;
222+ // Look for nonlocal `Dom` edges from the entry block. (Ignore `Ext` edges.)
223+ let dests = h. children ( entry) . flat_map ( |n| h. output_neighbours ( n) ) ;
224+ let has_dom_outs = dests. dedup ( ) . any ( |succ| {
225+ ancestor_block ( h, succ) . expect ( "Dom edges within entry, Ext within CFG" ) != entry
226+ } ) ;
227+ if !has_dom_outs {
228+ // Move entry block contents into DFG.
229+ let dfg = h. add_node_with_parent (
230+ cfg_parent,
231+ DFG {
232+ signature : Signature :: new ( entry_blk. inputs . clone ( ) , new_cfg_inputs. clone ( ) ) ,
233+ } ,
234+ ) ;
235+ let [ _, entry_output] = h. get_io ( entry) . unwrap ( ) ;
236+ while let Some ( n) = h. first_child ( entry) {
237+ h. set_parent ( n, dfg) ;
238+ }
239+ h. move_before_sibling ( succ, entry) ;
240+ h. remove_node ( entry) ;
224241
225- unpack_before_output ( h, entry_output, new_cfg_inputs. clone ( ) ) ;
242+ unpack_before_output ( h, entry_output, new_cfg_inputs. clone ( ) ) ;
226243
227- // Inputs to CFG go directly to DFG
228- for inp in h. node_inputs ( cfg_node) . collect :: < Vec < _ > > ( ) {
229- for src in h. linked_outputs ( cfg_node, inp) . collect :: < Vec < _ > > ( ) {
230- h. connect ( src. 0 , src. 1 , dfg, inp. index ( ) ) ;
244+ // Inputs to CFG go directly to DFG
245+ for inp in h. node_inputs ( cfg_node) . collect :: < Vec < _ > > ( ) {
246+ for src in h. linked_outputs ( cfg_node, inp) . collect :: < Vec < _ > > ( ) {
247+ h. connect ( src. 0 , src. 1 , dfg, inp. index ( ) ) ;
248+ }
249+ h. disconnect ( cfg_node, inp) ;
231250 }
232- h. disconnect ( cfg_node, inp) ;
233- }
234251
235- // Update input ports
236- let cfg_ty = cfg_ty_mut ( h, cfg_node) ;
237- let inputs_to_add = new_cfg_inputs. len ( ) as isize - cfg_ty. signature . input . len ( ) as isize ;
238- cfg_ty. signature . input = new_cfg_inputs;
239- h. add_ports ( cfg_node, Direction :: Incoming , inputs_to_add) ;
252+ // Update input ports
253+ let cfg_ty = cfg_ty_mut ( h, cfg_node) ;
254+ let inputs_to_add =
255+ new_cfg_inputs. len ( ) as isize - cfg_ty. signature . input . len ( ) as isize ;
256+ cfg_ty. signature . input = new_cfg_inputs;
257+ h. add_ports ( cfg_node, Direction :: Incoming , inputs_to_add) ;
240258
241- // Wire outputs of DFG directly to CFG
242- for src in h. node_outputs ( dfg) . collect :: < Vec < _ > > ( ) {
243- h. connect ( dfg, src, cfg_node, src. index ( ) ) ;
259+ // Wire outputs of DFG directly to CFG
260+ for src in h. node_outputs ( dfg) . collect :: < Vec < _ > > ( ) {
261+ h. connect ( dfg, src, cfg_node, src. index ( ) ) ;
262+ }
263+ entry_dfg = Some ( dfg) ;
244264 }
245- entry_dfg = Some ( dfg) ;
246265 }
247266 // 2. If the exit node has a single predecessor and that predecessor has no other successors...
248267 let mut exit_dfg = None ;
249- if let Some ( pred) = h
250- . input_neighbours ( exit)
251- . exactly_one ( )
252- . ok ( )
253- . filter ( |pred| h. output_neighbours ( * pred) . count ( ) == 1 )
254- {
268+ if let Some ( pred) = h. input_neighbours ( exit) . exactly_one ( ) . ok ( ) . filter ( |pred| {
269+ // Allow only if there are no `Dom` edges into `pred`. (Ignore `Ext` edges.)
270+ let src_nodes = h. children ( * pred) . flat_map ( |ch| h. input_neighbours ( ch) ) ;
271+ h. output_neighbours ( * pred) . count ( ) == 1
272+ && src_nodes. dedup ( ) . all ( |src| {
273+ ancestor_block ( h, src) . is_none_or ( |src| src == * pred) // Nones are `Ext` edges.
274+ } )
275+ } ) {
255276 // Code in that predecessor can be moved outside (into a new DFG after the CFG),
256277 // and the predecessor deleted
257278 let [ _, output] = h. get_io ( pred) . unwrap ( ) ;
@@ -866,17 +887,19 @@ mod test {
866887 Ok ( ( ) )
867888 }
868889
869- #[ test ]
870- fn nested_cfgs_pass ( ) {
890+ #[ rstest ]
891+ fn nested_cfgs_pass ( # [ values ( true , false ) ] nonlocal : bool ) {
871892 // --> Entry --> Loop --> Tail --> EXIT
872893 // | / \
873894 // (E->X) \<-/
874895 let e = extension ( ) ;
875896 let tst_op = e. instantiate_extension_op ( "Test" , [ ] ) . unwrap ( ) ;
876- let qqu = vec ! [ qb_t( ) , qb_t( ) , usize_t( ) ] ;
897+ let qqu = TypeRow :: from ( vec ! [ qb_t( ) , qb_t( ) , usize_t( ) ] ) ;
877898 let qq = TypeRow :: from ( vec ! [ qb_t( ) ; 2 ] ) ;
878899 let mut outer = CFGBuilder :: new ( inout_sig ( qqu. clone ( ) , vec ! [ usize_t( ) , qb_t( ) ] ) ) . unwrap ( ) ;
879- let mut entry = outer. entry_builder ( vec ! [ qq. clone( ) ] , type_row ! [ ] ) . unwrap ( ) ;
900+ let mut entry = outer
901+ . entry_builder ( vec ! [ qq. clone( ) ] , usize_t ( ) . into ( ) )
902+ . unwrap ( ) ;
880903 let [ q1, q2, u] = entry. input_wires_arr ( ) ;
881904 let ( inner, inner_pred) = {
882905 let mut inner = entry
@@ -897,38 +920,41 @@ mod test {
897920 . add_dataflow_op ( Tag :: new ( 0 , vec ! [ qq. clone( ) ] ) , [ q1, q2] )
898921 . unwrap ( )
899922 . outputs_arr ( ) ;
900- let entry = entry. finish_with_outputs ( entry_pred, [ ] ) . unwrap ( ) ;
923+ let entry = entry. finish_with_outputs ( entry_pred, [ u ] ) . unwrap ( ) ;
901924
902925 let loop_b = {
926+ let qu = [ qb_t ( ) , usize_t ( ) ] ;
903927 let mut loop_b = outer
904- . block_builder ( qq . clone ( ) , [ qb_t ( ) . into ( ) , usize_t ( ) . into ( ) ] , qb_t ( ) . into ( ) )
928+ . block_builder ( qqu , qu . clone ( ) . map ( TypeRow :: from ) , Vec :: from ( qu ) . into ( ) )
905929 . unwrap ( ) ;
906- let [ q1, q2] = loop_b. input_wires_arr ( ) ;
930+ let [ q1, q2, u_local ] = loop_b. input_wires_arr ( ) ;
907931 // u here is `dom` edge from entry block
908932 let [ pred] = loop_b
909- . add_dataflow_op ( tst_op, [ q1, u ] )
933+ . add_dataflow_op ( tst_op, [ q1, if nonlocal { u } else { u_local } ] )
910934 . unwrap ( )
911935 . outputs_arr ( ) ;
912- loop_b. finish_with_outputs ( pred, [ q2] ) . unwrap ( )
936+ loop_b. finish_with_outputs ( pred, [ q2, u_local ] ) . unwrap ( )
913937 } ;
914938 outer. branch ( & entry, 0 , & loop_b) . unwrap ( ) ;
915939 outer. branch ( & loop_b, 0 , & loop_b) . unwrap ( ) ;
916940
917941 let ( tail_b, tail_pred) = {
918942 let uq = TypeRow :: from ( vec ! [ usize_t( ) , qb_t( ) ] ) ;
943+ let uqu = vec ! [ usize_t( ) , qb_t( ) , usize_t( ) ] . into ( ) ;
919944 let mut tail_b = outer
920- . block_builder ( uq . clone ( ) , vec ! [ uq. clone( ) ] , type_row ! [ ] )
945+ . block_builder ( uqu , vec ! [ uq. clone( ) ] , type_row ! [ ] )
921946 . unwrap ( ) ;
922- let [ u, q] = tail_b. input_wires_arr ( ) ;
947+ let [ u, q, _ ] = tail_b. input_wires_arr ( ) ;
923948 let [ br] = tail_b
924- . add_dataflow_op ( Tag :: new ( 0 , vec ! [ uq. clone ( ) ] ) , [ u, q] )
949+ . add_dataflow_op ( Tag :: new ( 0 , vec ! [ uq] ) , [ u, q] )
925950 . unwrap ( )
926951 . outputs_arr ( ) ;
927952 ( tail_b. finish_with_outputs ( br, [ ] ) . unwrap ( ) , br. node ( ) )
928953 } ;
929954 outer. branch ( & loop_b, 1 , & tail_b) . unwrap ( ) ;
930955 outer. branch ( & tail_b, 0 , & outer. exit_block ( ) ) . unwrap ( ) ;
931956 let mut h = outer. finish_hugr ( ) . unwrap ( ) ;
957+ // Sanity checks:
932958 assert_eq ! (
933959 h. get_parent( h. get_parent( inner_pred) . unwrap( ) ) ,
934960 Some ( inner. node( ) )
@@ -943,36 +969,59 @@ mod test {
943969 Some ( NormalizeCFGResult :: CFGToDFG )
944970 ) ;
945971 let Some ( NormalizeCFGResult :: CFGPreserved {
946- entry_dfg : Some ( entry_dfg ) ,
972+ entry_dfg,
947973 exit_dfg : Some ( tail_dfg) ,
948974 num_merged : 0 ,
949975 } ) = res. remove ( & h. entrypoint ( ) )
950976 else {
951977 panic ! ( "Unexpected result" )
952978 } ;
979+
953980 assert ! ( res. is_empty( ) ) ;
954- // Now contains only one CFG with one BB (self-loop)
981+
955982 assert_eq ! (
956983 h. nodes( )
957984 . filter( |n| h. get_optype( * n) . is_cfg( ) )
958- . exactly_one( )
959- . ok( ) ,
960- Some ( h. entrypoint( ) )
985+ . collect_vec( ) ,
986+ vec![ h. entrypoint( ) ]
961987 ) ;
962- let [ entry, exit] = h. children ( h. entrypoint ( ) ) . collect_array ( ) . unwrap ( ) ;
963- assert_eq ! ( h. output_neighbours( entry) . collect_vec( ) , [ entry, exit] ) ;
988+ let [ loop_, exit] = if nonlocal {
989+ let [ entry, exit, loop_] = h. children ( h. entrypoint ( ) ) . collect_array ( ) . unwrap ( ) ;
990+ assert_eq ! ( h. get_parent( entry_pred. node( ) ) , Some ( entry) ) ;
991+ [ loop_, exit]
992+ } else {
993+ h. children ( h. entrypoint ( ) ) . collect_array ( ) . unwrap ( )
994+ } ;
995+
996+ assert_eq ! ( h. output_neighbours( loop_) . collect_vec( ) , [ loop_, exit] ) ;
997+
964998 // Inner CFG is now a DFG (and still sibling of entry_pred)...
965999 assert_eq ! ( h. get_parent( inner_pred) , Some ( inner. node( ) ) ) ;
9661000 assert_eq ! ( h. get_optype( inner. node( ) ) . tag( ) , OpTag :: Dfg ) ;
9671001 assert_eq ! ( h. get_parent( inner. node( ) ) , h. get_parent( entry_pred. node( ) ) ) ;
9681002 // Predicates lifted appropriately...
969- for ( n, parent) in [ ( entry_pred. node ( ) , entry_dfg) , ( tail_pred. node ( ) , tail_dfg) ] {
970- assert_eq ! ( h. get_parent( n) , Some ( parent) ) ;
971- assert_eq ! ( h. get_optype( parent) . tag( ) , OpTag :: Dfg ) ;
972- assert_eq ! ( h. get_parent( parent) , h. get_parent( h. entrypoint( ) ) ) ;
973- }
1003+ let func = h. get_parent ( h. entrypoint ( ) ) . unwrap ( ) ;
1004+
1005+ assert_eq ! ( h. get_parent( tail_pred. node( ) ) , Some ( tail_dfg) ) ;
1006+ assert_eq ! ( h. get_optype( tail_dfg) . tag( ) , OpTag :: Dfg ) ;
1007+ assert_eq ! ( h. get_parent( tail_dfg) , Some ( func) ) ;
1008+ let lifted_preds = if nonlocal {
1009+ assert ! ( entry_dfg. is_none( ) ) ;
1010+ // entry_pred not lifted, still connected to output
1011+ let [ output] = h
1012+ . output_neighbours ( entry_pred. node ( ) )
1013+ . collect_array ( )
1014+ . unwrap ( ) ;
1015+ assert_eq ! ( h. get_optype( output) . tag( ) , OpTag :: Output ) ;
1016+ vec ! [ inner_pred. node( ) , tail_pred. node( ) ]
1017+ } else {
1018+ assert_eq ! ( h. get_parent( entry_dfg. unwrap( ) ) , Some ( func) ) ;
1019+ assert_eq ! ( h. get_parent( entry_pred. node( ) ) , entry_dfg) ;
1020+ vec ! [ inner_pred. node( ) , entry_pred. node( ) , tail_pred. node( ) ]
1021+ } ;
1022+
9741023 // ...and followed by UnpackTuple's
975- for n in [ inner_pred , entry_pred . node ( ) , tail_pred . node ( ) ] {
1024+ for n in lifted_preds {
9761025 let [ unpack] = h. output_neighbours ( n) . collect_array ( ) . unwrap ( ) ;
9771026 assert ! (
9781027 h. get_optype( unpack)
0 commit comments