Skip to content

Commit b3062e9

Browse files
authored
refactor: NodeType constructors, adding new_auto (#635)
* Rename NodeType::open_extensions to NodeType::new_open * Rename NodeType::pure to NodeType::new_pure * Add NodeType::new_auto, which uses Pure for module-ops and Open for others * Remove special-case in infer.rs solving some module-ops to empty set * Switch builder/HugrMut methods from new_open to new_auto
1 parent 1a07cd9 commit b3062e9

File tree

14 files changed

+65
-59
lines changed

14 files changed

+65
-59
lines changed

src/builder.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ pub(crate) mod test {
146146
/// inference. Using DFGBuilder will default to a root node with an open
147147
/// extension variable
148148
pub(crate) fn closed_dfg_root_hugr(signature: FunctionType) -> Hugr {
149-
let mut hugr = Hugr::new(NodeType::pure(ops::DFG {
149+
let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG {
150150
signature: signature.clone(),
151151
}));
152152
hugr.add_op_with_parent(

src/builder/build_traits.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ pub trait Dataflow: Container {
200200
op: impl Into<OpType>,
201201
input_wires: impl IntoIterator<Item = Wire>,
202202
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
203-
self.add_dataflow_node(NodeType::open_extensions(op), input_wires)
203+
self.add_dataflow_node(NodeType::new_auto(op), input_wires)
204204
}
205205

206206
/// Add a dataflow [`NodeType`] to the sibling graph, wiring up the `input_wires` to the
@@ -628,7 +628,7 @@ fn add_op_with_wires<T: Dataflow + ?Sized>(
628628
optype: impl Into<OpType>,
629629
inputs: Vec<Wire>,
630630
) -> Result<(Node, usize), BuildError> {
631-
add_node_with_wires(data_builder, NodeType::open_extensions(optype), inputs)
631+
add_node_with_wires(data_builder, NodeType::new_auto(optype), inputs)
632632
}
633633

634634
fn add_node_with_wires<T: Dataflow + ?Sized>(

src/builder/cfg.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ impl CFGBuilder<Hugr> {
6262
signature: signature.clone(),
6363
};
6464

65-
let base = Hugr::new(NodeType::open_extensions(cfg_op));
65+
let base = Hugr::new(NodeType::new_open(cfg_op));
6666
let cfg_node = base.root();
6767
CFGBuilder::create(base, cfg_node, signature.input, signature.output)
6868
}

src/builder/conditional.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ impl ConditionalBuilder<Hugr> {
176176
extension_delta,
177177
};
178178
// TODO: Allow input extensions to be specified
179-
let base = Hugr::new(NodeType::open_extensions(op));
179+
let base = Hugr::new(NodeType::new_open(op));
180180
let conditional_node = base.root();
181181

182182
Ok(ConditionalBuilder {
@@ -194,7 +194,7 @@ impl CaseBuilder<Hugr> {
194194
let op = ops::Case {
195195
signature: signature.clone(),
196196
};
197-
let base = Hugr::new(NodeType::open_extensions(op));
197+
let base = Hugr::new(NodeType::new_open(op));
198198
let root = base.root();
199199
let dfg_builder = DFGBuilder::create_with_io(base, root, signature, None)?;
200200

src/builder/dataflow.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ impl DFGBuilder<Hugr> {
7979
let dfg_op = ops::DFG {
8080
signature: signature.clone(),
8181
};
82-
let base = Hugr::new(NodeType::open_extensions(dfg_op));
82+
let base = Hugr::new(NodeType::new_open(dfg_op));
8383
let root = base.root();
8484
DFGBuilder::create_with_io(base, root, signature, None)
8585
}

src/builder/module.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
9090
};
9191
self.hugr_mut().replace_op(
9292
f_node,
93-
NodeType::pure(ops::FuncDefn {
93+
NodeType::new_pure(ops::FuncDefn {
9494
name,
9595
signature: signature.clone(),
9696
}),

src/builder/tail_loop.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ impl TailLoopBuilder<Hugr> {
8282
rest: inputs_outputs.into(),
8383
};
8484
// TODO: Allow input extensions to be specified
85-
let base = Hugr::new(NodeType::open_extensions(tail_loop.clone()));
85+
let base = Hugr::new(NodeType::new_open(tail_loop.clone()));
8686
let root = base.root();
8787
Self::create_with_io(base, root, &tail_loop)
8888
}

src/extension/infer.rs

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -316,12 +316,6 @@ impl UnificationContext {
316316
m_output,
317317
node_type.op_signature().extension_reqs,
318318
);
319-
if matches!(
320-
node_type.tag(),
321-
OpTag::Alias | OpTag::Function | OpTag::FuncDefn
322-
) {
323-
self.add_solution(m_input, ExtensionSet::new());
324-
}
325319
}
326320
// We have a solution for everything!
327321
Some(sig) => {
@@ -723,7 +717,7 @@ mod test {
723717
signature: main_sig,
724718
};
725719

726-
let root_node = NodeType::open_extensions(op);
720+
let root_node = NodeType::new_open(op);
727721
let mut hugr = Hugr::new(root_node);
728722

729723
let input = ops::Input::new(type_row![NAT, NAT]);
@@ -833,21 +827,21 @@ mod test {
833827
// This generates a solution that causes validation to fail
834828
// because of a missing lift node
835829
fn missing_lift_node() -> Result<(), Box<dyn Error>> {
836-
let mut hugr = Hugr::new(NodeType::pure(ops::DFG {
830+
let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG {
837831
signature: FunctionType::new(type_row![NAT], type_row![NAT])
838832
.with_extension_delta(&ExtensionSet::singleton(&A)),
839833
}));
840834

841835
let input = hugr.add_node_with_parent(
842836
hugr.root(),
843-
NodeType::pure(ops::Input {
837+
NodeType::new_pure(ops::Input {
844838
types: type_row![NAT],
845839
}),
846840
)?;
847841

848842
let output = hugr.add_node_with_parent(
849843
hugr.root(),
850-
NodeType::pure(ops::Output {
844+
NodeType::new_pure(ops::Output {
851845
types: type_row![NAT],
852846
}),
853847
)?;
@@ -1049,7 +1043,7 @@ mod test {
10491043
extension_delta: rs.clone(),
10501044
};
10511045

1052-
let mut hugr = Hugr::new(NodeType::pure(op));
1046+
let mut hugr = Hugr::new(NodeType::new_pure(op));
10531047
let conditional_node = hugr.root();
10541048

10551049
let case_op = ops::Case {
@@ -1084,7 +1078,7 @@ mod test {
10841078
fn extension_adding_sequence() -> Result<(), Box<dyn Error>> {
10851079
let df_sig = FunctionType::new(type_row![NAT], type_row![NAT]);
10861080

1087-
let mut hugr = Hugr::new(NodeType::open_extensions(ops::DFG {
1081+
let mut hugr = Hugr::new(NodeType::new_open(ops::DFG {
10881082
signature: df_sig
10891083
.clone()
10901084
.with_extension_delta(&ExtensionSet::from_iter([A, B])),
@@ -1255,7 +1249,7 @@ mod test {
12551249
let b = ExtensionSet::singleton(&B);
12561250
let c = ExtensionSet::singleton(&C);
12571251

1258-
let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG {
1252+
let mut hugr = Hugr::new(NodeType::new_open(ops::CFG {
12591253
signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&abc),
12601254
}));
12611255

@@ -1353,7 +1347,7 @@ mod test {
13531347
/// +--------------------+
13541348
#[test]
13551349
fn multi_entry() -> Result<(), Box<dyn Error>> {
1356-
let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG {
1350+
let mut hugr = Hugr::new(NodeType::new_open(ops::CFG {
13571351
signature: FunctionType::new(type_row![NAT], type_row![NAT]), // maybe add extensions?
13581352
}));
13591353
let cfg = hugr.root();
@@ -1436,7 +1430,7 @@ mod test {
14361430
) -> Result<Hugr, Box<dyn Error>> {
14371431
let hugr_delta = entry_ext.clone().union(&bb1_ext).union(&bb2_ext);
14381432

1439-
let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG {
1433+
let mut hugr = Hugr::new(NodeType::new_open(ops::CFG {
14401434
signature: FunctionType::new(type_row![NAT], type_row![NAT])
14411435
.with_extension_delta(&hugr_delta),
14421436
}));

src/hugr.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ impl NodeType {
8282
}
8383

8484
/// Instantiate an OpType with no input extensions
85-
pub fn pure(op: impl Into<OpType>) -> Self {
85+
pub fn new_pure(op: impl Into<OpType>) -> Self {
8686
NodeType {
8787
op: op.into(),
8888
input_extensions: Some(ExtensionSet::new()),
@@ -91,13 +91,24 @@ impl NodeType {
9191

9292
/// Instantiate an OpType with an unknown set of input extensions
9393
/// (to be inferred later)
94-
pub fn open_extensions(op: impl Into<OpType>) -> Self {
94+
pub fn new_open(op: impl Into<OpType>) -> Self {
9595
NodeType {
9696
op: op.into(),
9797
input_extensions: None,
9898
}
9999
}
100100

101+
/// Instantiate an [OpType] with the default set of input extensions
102+
/// for that OpType.
103+
pub fn new_auto(op: impl Into<OpType>) -> Self {
104+
let op = op.into();
105+
if OpTag::ModuleOp.is_superset(op.tag()) {
106+
Self::new_pure(op)
107+
} else {
108+
Self::new_open(op)
109+
}
110+
}
111+
101112
/// Use the input extensions to calculate the concrete signature of the node
102113
pub fn signature(&self) -> Option<Signature> {
103114
self.input_extensions
@@ -119,9 +130,7 @@ impl NodeType {
119130
pub fn input_extensions(&self) -> Option<&ExtensionSet> {
120131
self.input_extensions.as_ref()
121132
}
122-
}
123133

124-
impl NodeType {
125134
/// Gets the underlying [OpType] i.e. without any [input_extensions]
126135
///
127136
/// [input_extensions]: NodeType::input_extensions
@@ -153,7 +162,7 @@ impl OpType {
153162

154163
impl Default for Hugr {
155164
fn default() -> Self {
156-
Self::new(NodeType::pure(crate::ops::Module))
165+
Self::new(NodeType::new_pure(crate::ops::Module))
157166
}
158167
}
159168

@@ -239,7 +248,7 @@ impl Hugr {
239248

240249
/// Add a node to the graph, with the default conversion from OpType to NodeType
241250
pub(crate) fn add_op(&mut self, op: impl Into<OpType>) -> Node {
242-
self.add_node(NodeType::open_extensions(op))
251+
self.add_node(NodeType::new_auto(op))
243252
}
244253

245254
/// Add a node to the graph.

src/hugr/hugrmut.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ pub trait HugrMut: HugrMutInternals {
3737
parent: Node,
3838
op: impl Into<OpType>,
3939
) -> Result<Node, HugrError> {
40-
self.add_node_with_parent(parent, NodeType::open_extensions(op))
40+
self.add_node_with_parent(parent, NodeType::new_auto(op))
4141
}
4242

4343
/// Add a node to the graph with a parent in the hierarchy.
@@ -217,7 +217,7 @@ impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
217217
}
218218

219219
fn add_op_before(&mut self, sibling: Node, op: impl Into<OpType>) -> Result<Node, HugrError> {
220-
self.add_node_before(sibling, NodeType::open_extensions(op))
220+
self.add_node_before(sibling, NodeType::new_auto(op))
221221
}
222222

223223
fn add_node_before(&mut self, sibling: Node, nodetype: NodeType) -> Result<Node, HugrError> {
@@ -620,7 +620,7 @@ mod test {
620620

621621
{
622622
let f_in = hugr
623-
.add_node_with_parent(f, NodeType::pure(ops::Input::new(type_row![NAT])))
623+
.add_node_with_parent(f, NodeType::new_pure(ops::Input::new(type_row![NAT])))
624624
.unwrap();
625625
let f_out = hugr
626626
.add_op_with_parent(f, ops::Output::new(type_row![NAT, NAT]))

0 commit comments

Comments
 (0)