Skip to content

Commit 01e7ac0

Browse files
authored
fix: validation outside entrypoint, normalize_cfgs w/ nonlocal edges (#2633)
fixes #2598 The bug in validation was simple, `validate_subtree` was only called beneath the entrypoint. Change that... Of course there were test failures - Hugrs we thought were valid but actually weren't. * Many of these were `ReplaceTypes` tests on DFG-entrypoint Hugrs: `ReplaceTypes` acted only on/beneath the entrypoint, so (the Input/Output of) the FuncDefn containing the DFG no longer matches it. I've changed these to make the FuncDefn be entrypoint, without a DFG, so all relevant nodes are ReplaceType'd. * NormalizeCFGs can move the Entry block of a CFG into a DFG outside it. If there were nonlocal `Dom` edges from inside that entry block, this would make the Hugr invalid. (Thus, a second bug, which we had not noticed because of the validation bug.) I've fixed this by only transforming if there are no `Dom` edges from the entry block. * A similar case (when the predecessor of the Exit block can be moved outside the CFG) is now also done only if there are no *incoming* `Dom` edges.
1 parent 53f5072 commit 01e7ac0

File tree

5 files changed

+121
-76
lines changed

5 files changed

+121
-76
lines changed

hugr-core/src/hugr/validate.rs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ impl<'a, H: HugrView> ValidationContext<'a, H> {
5757
self.validate_node(node)?;
5858
}
5959

60-
// Hierarchy and children. No type variables declared outside the root.
61-
self.validate_subtree(self.hugr.entrypoint(), &[])?;
60+
// Hierarchy and children. No type variables declared by the module root.
61+
self.validate_subtree(self.hugr.module_root(), &[])?;
6262

6363
self.validate_linkage()?;
6464
// In tests we take the opportunity to verify that the hugr
@@ -600,13 +600,9 @@ impl<'a, H: HugrView> ValidationContext<'a, H> {
600600
}
601601

602602
// Check port connections.
603-
//
604-
// Root nodes are ignored, as they cannot have connected edges.
605-
if node != self.hugr.entrypoint() {
606-
for dir in Direction::BOTH {
607-
for port in self.hugr.node_ports(node, dir) {
608-
self.validate_port(node, port, op_type, var_decls)?;
609-
}
603+
for dir in Direction::BOTH {
604+
for port in self.hugr.node_ports(node, dir) {
605+
self.validate_port(node, port, op_type, var_decls)?;
610606
}
611607
}
612608

hugr-passes/src/linearize_array.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ impl LinearizeArrayPass {
141141

142142
#[cfg(test)]
143143
mod test {
144-
use hugr_core::builder::ModuleBuilder;
144+
use hugr_core::builder::{FunctionBuilder, ModuleBuilder};
145145
use hugr_core::extension::prelude::{ConstUsize, Noop};
146146
use hugr_core::ops::handle::NodeHandle;
147147
use hugr_core::ops::{Const, OpType};
@@ -287,7 +287,7 @@ mod test {
287287
),
288288
};
289289
let sig = Signature::new(src, tgt);
290-
let mut builder = DFGBuilder::new(sig).unwrap();
290+
let mut builder = FunctionBuilder::new("main", sig).unwrap();
291291
let [arr] = builder.input_wires_arr();
292292
let op: OpType = match dir {
293293
INTO => VArrayToArray::new(elem_ty.clone(), size).into(),
@@ -313,7 +313,7 @@ mod test {
313313
#[case(value_array_type(2, Type::new_tuple(vec![usize_t(), value_array_type(4, usize_t())])))]
314314
fn implicit_clone(#[case] array_ty: Type) {
315315
let sig = Signature::new(array_ty.clone(), vec![array_ty; 2]);
316-
let mut builder = DFGBuilder::new(sig).unwrap();
316+
let mut builder = FunctionBuilder::new("main", sig).unwrap();
317317
let [arr] = builder.input_wires_arr();
318318
builder.set_outputs(vec![arr, arr]).unwrap();
319319

@@ -329,7 +329,7 @@ mod test {
329329
#[case(value_array_type(2, Type::new_tuple(vec![usize_t(), value_array_type(4, usize_t())])))]
330330
fn implicit_discard(#[case] array_ty: Type) {
331331
let sig = Signature::new(array_ty, Type::EMPTY_TYPEROW);
332-
let mut builder = DFGBuilder::new(sig).unwrap();
332+
let mut builder = FunctionBuilder::new("main", sig).unwrap();
333333
builder.set_outputs(vec![]).unwrap();
334334

335335
let mut hugr = builder.finish_hugr().unwrap();

hugr-passes/src/normalize_cfgs.rs

Lines changed: 108 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,16 @@ pub fn normalize_cfg<H: HugrMut>(
168168
_ => unreachable!(), // Checked at entry to normalize_cfg
169169
}
170170
}
171+
let ancestor_block = |h: &H, mut n: H::Node| {
172+
while let Some(p) = h.get_parent(n) {
173+
if p == cfg_node {
174+
return Some(n);
175+
}
176+
n = p;
177+
}
178+
None
179+
};
180+
171181
// Further normalizations with effects outside the CFG
172182
let [entry, exit] = h.children(cfg_node).take(2).collect_array().unwrap();
173183
let entry_blk = h.get_optype(entry).as_dataflow_block().unwrap();
@@ -209,49 +219,60 @@ pub fn normalize_cfg<H: HugrMut>(
209219
}
210220
// 1b. Move entry block outside/before the CFG into a DFG; its successor becomes the entry block.
211221
let new_cfg_inputs = entry_blk.successor_input(0).unwrap();
212-
let dfg = h.add_node_with_parent(
213-
cfg_parent,
214-
DFG {
215-
signature: Signature::new(entry_blk.inputs.clone(), new_cfg_inputs.clone()),
216-
},
217-
);
218-
let [_, entry_output] = h.get_io(entry).unwrap();
219-
while let Some(n) = h.first_child(entry) {
220-
h.set_parent(n, dfg);
221-
}
222-
h.move_before_sibling(succ, entry);
223-
h.remove_node(entry);
222+
// Look for nonlocal `Dom` edges from the entry block. (Ignore `Ext` edges.)
223+
let dests = h.children(entry).flat_map(|n| h.output_neighbours(n));
224+
let has_dom_outs = dests.dedup().any(|succ| {
225+
ancestor_block(h, succ).expect("Dom edges within entry, Ext within CFG") != entry
226+
});
227+
if !has_dom_outs {
228+
// Move entry block contents into DFG.
229+
let dfg = h.add_node_with_parent(
230+
cfg_parent,
231+
DFG {
232+
signature: Signature::new(entry_blk.inputs.clone(), new_cfg_inputs.clone()),
233+
},
234+
);
235+
let [_, entry_output] = h.get_io(entry).unwrap();
236+
while let Some(n) = h.first_child(entry) {
237+
h.set_parent(n, dfg);
238+
}
239+
h.move_before_sibling(succ, entry);
240+
h.remove_node(entry);
224241

225-
unpack_before_output(h, entry_output, new_cfg_inputs.clone());
242+
unpack_before_output(h, entry_output, new_cfg_inputs.clone());
226243

227-
// Inputs to CFG go directly to DFG
228-
for inp in h.node_inputs(cfg_node).collect::<Vec<_>>() {
229-
for src in h.linked_outputs(cfg_node, inp).collect::<Vec<_>>() {
230-
h.connect(src.0, src.1, dfg, inp.index());
244+
// Inputs to CFG go directly to DFG
245+
for inp in h.node_inputs(cfg_node).collect::<Vec<_>>() {
246+
for src in h.linked_outputs(cfg_node, inp).collect::<Vec<_>>() {
247+
h.connect(src.0, src.1, dfg, inp.index());
248+
}
249+
h.disconnect(cfg_node, inp);
231250
}
232-
h.disconnect(cfg_node, inp);
233-
}
234251

235-
// Update input ports
236-
let cfg_ty = cfg_ty_mut(h, cfg_node);
237-
let inputs_to_add = new_cfg_inputs.len() as isize - cfg_ty.signature.input.len() as isize;
238-
cfg_ty.signature.input = new_cfg_inputs;
239-
h.add_ports(cfg_node, Direction::Incoming, inputs_to_add);
252+
// Update input ports
253+
let cfg_ty = cfg_ty_mut(h, cfg_node);
254+
let inputs_to_add =
255+
new_cfg_inputs.len() as isize - cfg_ty.signature.input.len() as isize;
256+
cfg_ty.signature.input = new_cfg_inputs;
257+
h.add_ports(cfg_node, Direction::Incoming, inputs_to_add);
240258

241-
// Wire outputs of DFG directly to CFG
242-
for src in h.node_outputs(dfg).collect::<Vec<_>>() {
243-
h.connect(dfg, src, cfg_node, src.index());
259+
// Wire outputs of DFG directly to CFG
260+
for src in h.node_outputs(dfg).collect::<Vec<_>>() {
261+
h.connect(dfg, src, cfg_node, src.index());
262+
}
263+
entry_dfg = Some(dfg);
244264
}
245-
entry_dfg = Some(dfg);
246265
}
247266
// 2. If the exit node has a single predecessor and that predecessor has no other successors...
248267
let mut exit_dfg = None;
249-
if let Some(pred) = h
250-
.input_neighbours(exit)
251-
.exactly_one()
252-
.ok()
253-
.filter(|pred| h.output_neighbours(*pred).count() == 1)
254-
{
268+
if let Some(pred) = h.input_neighbours(exit).exactly_one().ok().filter(|pred| {
269+
// Allow only if there are no `Dom` edges into `pred`. (Ignore `Ext` edges.)
270+
let src_nodes = h.children(*pred).flat_map(|ch| h.input_neighbours(ch));
271+
h.output_neighbours(*pred).count() == 1
272+
&& src_nodes.dedup().all(|src| {
273+
ancestor_block(h, src).is_none_or(|src| src == *pred) // Nones are `Ext` edges.
274+
})
275+
}) {
255276
// Code in that predecessor can be moved outside (into a new DFG after the CFG),
256277
// and the predecessor deleted
257278
let [_, output] = h.get_io(pred).unwrap();
@@ -866,17 +887,19 @@ mod test {
866887
Ok(())
867888
}
868889

869-
#[test]
870-
fn nested_cfgs_pass() {
890+
#[rstest]
891+
fn nested_cfgs_pass(#[values(true, false)] nonlocal: bool) {
871892
// --> Entry --> Loop --> Tail --> EXIT
872893
// | / \
873894
// (E->X) \<-/
874895
let e = extension();
875896
let tst_op = e.instantiate_extension_op("Test", []).unwrap();
876-
let qqu = vec![qb_t(), qb_t(), usize_t()];
897+
let qqu = TypeRow::from(vec![qb_t(), qb_t(), usize_t()]);
877898
let qq = TypeRow::from(vec![qb_t(); 2]);
878899
let mut outer = CFGBuilder::new(inout_sig(qqu.clone(), vec![usize_t(), qb_t()])).unwrap();
879-
let mut entry = outer.entry_builder(vec![qq.clone()], type_row![]).unwrap();
900+
let mut entry = outer
901+
.entry_builder(vec![qq.clone()], usize_t().into())
902+
.unwrap();
880903
let [q1, q2, u] = entry.input_wires_arr();
881904
let (inner, inner_pred) = {
882905
let mut inner = entry
@@ -897,38 +920,41 @@ mod test {
897920
.add_dataflow_op(Tag::new(0, vec![qq.clone()]), [q1, q2])
898921
.unwrap()
899922
.outputs_arr();
900-
let entry = entry.finish_with_outputs(entry_pred, []).unwrap();
923+
let entry = entry.finish_with_outputs(entry_pred, [u]).unwrap();
901924

902925
let loop_b = {
926+
let qu = [qb_t(), usize_t()];
903927
let mut loop_b = outer
904-
.block_builder(qq.clone(), [qb_t().into(), usize_t().into()], qb_t().into())
928+
.block_builder(qqu, qu.clone().map(TypeRow::from), Vec::from(qu).into())
905929
.unwrap();
906-
let [q1, q2] = loop_b.input_wires_arr();
930+
let [q1, q2, u_local] = loop_b.input_wires_arr();
907931
// u here is `dom` edge from entry block
908932
let [pred] = loop_b
909-
.add_dataflow_op(tst_op, [q1, u])
933+
.add_dataflow_op(tst_op, [q1, if nonlocal { u } else { u_local }])
910934
.unwrap()
911935
.outputs_arr();
912-
loop_b.finish_with_outputs(pred, [q2]).unwrap()
936+
loop_b.finish_with_outputs(pred, [q2, u_local]).unwrap()
913937
};
914938
outer.branch(&entry, 0, &loop_b).unwrap();
915939
outer.branch(&loop_b, 0, &loop_b).unwrap();
916940

917941
let (tail_b, tail_pred) = {
918942
let uq = TypeRow::from(vec![usize_t(), qb_t()]);
943+
let uqu = vec![usize_t(), qb_t(), usize_t()].into();
919944
let mut tail_b = outer
920-
.block_builder(uq.clone(), vec![uq.clone()], type_row![])
945+
.block_builder(uqu, vec![uq.clone()], type_row![])
921946
.unwrap();
922-
let [u, q] = tail_b.input_wires_arr();
947+
let [u, q, _] = tail_b.input_wires_arr();
923948
let [br] = tail_b
924-
.add_dataflow_op(Tag::new(0, vec![uq.clone()]), [u, q])
949+
.add_dataflow_op(Tag::new(0, vec![uq]), [u, q])
925950
.unwrap()
926951
.outputs_arr();
927952
(tail_b.finish_with_outputs(br, []).unwrap(), br.node())
928953
};
929954
outer.branch(&loop_b, 1, &tail_b).unwrap();
930955
outer.branch(&tail_b, 0, &outer.exit_block()).unwrap();
931956
let mut h = outer.finish_hugr().unwrap();
957+
// Sanity checks:
932958
assert_eq!(
933959
h.get_parent(h.get_parent(inner_pred).unwrap()),
934960
Some(inner.node())
@@ -943,36 +969,59 @@ mod test {
943969
Some(NormalizeCFGResult::CFGToDFG)
944970
);
945971
let Some(NormalizeCFGResult::CFGPreserved {
946-
entry_dfg: Some(entry_dfg),
972+
entry_dfg,
947973
exit_dfg: Some(tail_dfg),
948974
num_merged: 0,
949975
}) = res.remove(&h.entrypoint())
950976
else {
951977
panic!("Unexpected result")
952978
};
979+
953980
assert!(res.is_empty());
954-
// Now contains only one CFG with one BB (self-loop)
981+
955982
assert_eq!(
956983
h.nodes()
957984
.filter(|n| h.get_optype(*n).is_cfg())
958-
.exactly_one()
959-
.ok(),
960-
Some(h.entrypoint())
985+
.collect_vec(),
986+
vec![h.entrypoint()]
961987
);
962-
let [entry, exit] = h.children(h.entrypoint()).collect_array().unwrap();
963-
assert_eq!(h.output_neighbours(entry).collect_vec(), [entry, exit]);
988+
let [loop_, exit] = if nonlocal {
989+
let [entry, exit, loop_] = h.children(h.entrypoint()).collect_array().unwrap();
990+
assert_eq!(h.get_parent(entry_pred.node()), Some(entry));
991+
[loop_, exit]
992+
} else {
993+
h.children(h.entrypoint()).collect_array().unwrap()
994+
};
995+
996+
assert_eq!(h.output_neighbours(loop_).collect_vec(), [loop_, exit]);
997+
964998
// Inner CFG is now a DFG (and still sibling of entry_pred)...
965999
assert_eq!(h.get_parent(inner_pred), Some(inner.node()));
9661000
assert_eq!(h.get_optype(inner.node()).tag(), OpTag::Dfg);
9671001
assert_eq!(h.get_parent(inner.node()), h.get_parent(entry_pred.node()));
9681002
// Predicates lifted appropriately...
969-
for (n, parent) in [(entry_pred.node(), entry_dfg), (tail_pred.node(), tail_dfg)] {
970-
assert_eq!(h.get_parent(n), Some(parent));
971-
assert_eq!(h.get_optype(parent).tag(), OpTag::Dfg);
972-
assert_eq!(h.get_parent(parent), h.get_parent(h.entrypoint()));
973-
}
1003+
let func = h.get_parent(h.entrypoint()).unwrap();
1004+
1005+
assert_eq!(h.get_parent(tail_pred.node()), Some(tail_dfg));
1006+
assert_eq!(h.get_optype(tail_dfg).tag(), OpTag::Dfg);
1007+
assert_eq!(h.get_parent(tail_dfg), Some(func));
1008+
let lifted_preds = if nonlocal {
1009+
assert!(entry_dfg.is_none());
1010+
// entry_pred not lifted, still connected to output
1011+
let [output] = h
1012+
.output_neighbours(entry_pred.node())
1013+
.collect_array()
1014+
.unwrap();
1015+
assert_eq!(h.get_optype(output).tag(), OpTag::Output);
1016+
vec![inner_pred.node(), tail_pred.node()]
1017+
} else {
1018+
assert_eq!(h.get_parent(entry_dfg.unwrap()), Some(func));
1019+
assert_eq!(h.get_parent(entry_pred.node()), entry_dfg);
1020+
vec![inner_pred.node(), entry_pred.node(), tail_pred.node()]
1021+
};
1022+
9741023
// ...and followed by UnpackTuple's
975-
for n in [inner_pred, entry_pred.node(), tail_pred.node()] {
1024+
for n in lifted_preds {
9761025
let [unpack] = h.output_neighbours(n).collect_array().unwrap();
9771026
assert!(
9781027
h.get_optype(unpack)

hugr-passes/src/replace_types.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,8 +1117,8 @@ mod test {
11171117
where
11181118
GenericArrayValue<AK>: CustomConst,
11191119
{
1120-
let mut dfb =
1121-
DFGBuilder::new(inout_sig(type_row![], AK::ty(vals.len() as _, usize_t()))).unwrap();
1120+
let sig = inout_sig(type_row![], AK::ty(vals.len() as _, usize_t()));
1121+
let mut dfb = FunctionBuilder::new("main", sig).unwrap();
11221122
let c = dfb.add_load_value(GenericArrayValue::<AK>::new(
11231123
usize_t(),
11241124
vals.iter().map(|u| ConstUsize::new(*u).into()),

hugr-passes/src/replace_types/linearize.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ mod test {
375375

376376
use hugr_core::builder::{
377377
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
378-
HugrBuilder, inout_sig,
378+
FunctionBuilder, HugrBuilder, inout_sig,
379379
};
380380

381381
use hugr_core::extension::prelude::{option_type, qb_t, usize_t};
@@ -912,7 +912,7 @@ mod test {
912912
);
913913

914914
let build_hugr = |ty: Type| {
915-
let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap();
915+
let mut dfb = FunctionBuilder::new("main", Signature::new(ty.clone(), vec![])).unwrap();
916916
let [inp] = dfb.input_wires_arr();
917917
let drop_op = drop_ext
918918
.instantiate_extension_op("drop", [ty.into()])

0 commit comments

Comments
 (0)