Skip to content

Commit 1426b1e

Browse files
committed
Revert "Revert hugr-passes changes"
This reverts commit 3a460a4.
1 parent 3a460a4 commit 1426b1e

File tree

8 files changed

+288
-92
lines changed

8 files changed

+288
-92
lines changed

hugr-passes/src/call_graph.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pub enum CallGraphNode<N = Node> {
1818
FuncDecl(N),
1919
/// petgraph-node corresponds to a [`FuncDefn`](OpType::FuncDefn) node (specified) in the Hugr
2020
FuncDefn(N),
21-
/// petgraph-node corresponds to the root node of the hugr, that is not
21+
/// petgraph-node corresponds to the entrypoint node of the hugr, that is not
2222
/// a [`FuncDefn`](OpType::FuncDefn). Note that it will not be a [Module](OpType::Module)
2323
/// either, as such a node could not have outgoing edges, so is not represented in the petgraph.
2424
NonFuncRoot,

hugr-passes/src/const_fold.rs

Lines changed: 59 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,38 +15,44 @@ use hugr_core::{
1515
};
1616
use value_handle::ValueHandle;
1717

18-
use crate::dataflow::{
19-
ConstLoader, ConstLocation, DFContext, Machine, PartialValue, TailLoopTermination,
20-
partial_from_const,
21-
};
2218
use crate::dead_code::{DeadCodeElimPass, PreserveNode};
2319
use crate::{ComposablePass, composable::validate_if_test};
20+
use crate::{
21+
IncludeExports,
22+
dataflow::{
23+
ConstLoader, ConstLocation, DFContext, Machine, PartialValue, TailLoopTermination,
24+
partial_from_const,
25+
},
26+
};
2427

2528
#[derive(Debug, Clone, Default)]
2629
/// A configuration for the Constant Folding pass.
30+
///
31+
/// Note that by default we assume that only the entrypoint is reachable and
32+
/// only if it is not the module root; see [Self::with_inputs]. Mutation
33+
/// occurs anywhere beneath the entrypoint.
2734
pub struct ConstantFoldPass {
2835
allow_increase_termination: bool,
2936
/// Each outer key Node must be either:
30-
/// - a `FuncDefn` child of the root, if the root is a module; or
31-
/// - the root, if the root is not a Module
37+
/// - a `FuncDefn` child of the module-root
38+
/// - the entrypoint
3239
inputs: HashMap<Node, HashMap<IncomingPort, Value>>,
3340
}
3441

3542
#[derive(Clone, Debug, Error, PartialEq)]
3643
#[non_exhaustive]
3744
/// Errors produced by [`ConstantFoldPass`].
3845
pub enum ConstFoldError {
39-
/// Error raised when a Node is specified as an entry-point but
40-
/// is neither a dataflow parent, nor a [CFG](OpType::CFG), nor
41-
/// a [Conditional](OpType::Conditional).
46+
/// Error raised when inputs are provided for a Node that is neither a dataflow
47+
/// parent, nor a [CFG](OpType::CFG), nor a [Conditional](OpType::Conditional).
4248
#[error("{node} has OpType {op} which cannot be an entry-point")]
4349
InvalidEntryPoint {
4450
/// The node which was specified as an entry-point
4551
node: Node,
4652
/// The `OpType` of the node
4753
op: OpType,
4854
},
49-
/// The chosen entrypoint is not in the hugr.
55+
/// Inputs were provided for a node that is not in the hugr.
5056
#[error("Entry-point {node} is not part of the Hugr")]
5157
MissingEntryPoint {
5258
/// The missing node
@@ -67,15 +73,25 @@ impl ConstantFoldPass {
6773
}
6874

6975
/// Specifies a number of external inputs to an entry point of the Hugr.
70-
/// In normal use, for Module-rooted Hugrs, `node` is a `FuncDefn` child of the root;
71-
/// or for non-Module-rooted Hugrs, `node` is the root of the Hugr. (This is not
76+
/// In normal use, for Module-rooted Hugrs, `node` is a `FuncDefn` (child of the root);
77+
/// for non-Module-rooted Hugrs, `node` is the [HugrView::entrypoint]. (This is not
7278
/// enforced, but it must be a container and not a module itself.)
7379
///
7480
/// Multiple calls for the same entry-point combine their values, with later
7581
/// values on the same in-port replacing earlier ones.
7682
///
77-
/// Note that if `inputs` is empty, this still marks the node as an entry-point, i.e.
78-
/// we must preserve nodes required to compute its result.
83+
/// Note that providing empty `inputs` indicates that we must preserve the ability
84+
/// to compute the result of `node` for all possible inputs.
85+
/// * If the entrypoint is the module-root, this method should be called for every
86+
/// [FuncDefn] that is externally callable
87+
/// * Otherwise, i.e. if the entrypoint is not the module-root,
88+
/// * The default is to assume the entrypoint is callable with any inputs;
89+
/// * If `node` is the entrypoint, this method allows to restrict the possible inputs
90+
/// * If `node` is beneath the entrypoint, this merely degrades the analysis. (We
91+
/// will mutate only beneath the entrypoint, but using results of analysing the
92+
/// whole Hugr wrt. the specified/any inputs too).
93+
///
94+
/// [FuncDefn]: hugr_core::ops::FuncDefn
7995
pub fn with_inputs(
8096
mut self,
8197
node: Node,
@@ -97,8 +113,7 @@ impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for ConstantFoldPass {
97113
///
98114
/// # Errors
99115
///
100-
/// [`ConstFoldError::InvalidEntryPoint`] if an entry-point added by [`Self::with_inputs`]
101-
/// was of an invalid [`OpType`]
116+
/// [ConstFoldError] if inputs were provided via [`Self::with_inputs`] for an invalid node.
102117
fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> {
103118
let fresh_node = Node::from(portgraph::NodeIndex::new(
104119
hugr.nodes().max().map_or(0, |n| n.index() + 1),
@@ -184,25 +199,51 @@ impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for ConstantFoldPass {
184199
}
185200
}
186201

202+
const NO_INPUTS: [(IncomingPort, Value); 0] = [];
203+
187204
/// Exhaustively apply constant folding to a HUGR.
188205
/// If the Hugr's entrypoint is its [`Module`], assumes all [`FuncDefn`] children are reachable.
206+
/// Otherwise, assume that the [HugrView::entrypoint] is itself reachable.
189207
///
190208
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
191209
/// [`Module`]: hugr_core::ops::OpType::Module
210+
#[deprecated(note = "Use fold_constants, or manually configure ConstantFoldPass")]
192211
pub fn constant_fold_pass<H: HugrMut<Node = Node> + 'static>(mut h: impl AsMut<H>) {
193212
let h = h.as_mut();
194213
let c = ConstantFoldPass::default();
195214
let c = if h.get_optype(h.entrypoint()).is_module() {
196-
let no_inputs: [(IncomingPort, _); 0] = [];
197215
h.children(h.entrypoint())
198216
.filter(|n| h.get_optype(*n).is_func_defn())
199-
.fold(c, |c, n| c.with_inputs(n, no_inputs.iter().cloned()))
217+
.fold(c, |c, n| c.with_inputs(n, NO_INPUTS.clone()))
200218
} else {
201219
c
202220
};
203221
validate_if_test(c, h).unwrap();
204222
}
205223

224+
/// Exhaustively apply constant folding to a HUGR.
225+
/// Assumes that the Hugr's entrypoint is reachable (if it is not a [`Module`]).
226+
/// Also uses `policy` to determine which public [`FuncDefn`] children of the [`HugrView::module_root`] are reachable.
227+
///
228+
/// [`Module`]: hugr_core::ops::OpType::Module
229+
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
230+
pub fn fold_constants(h: &mut (impl HugrMut<Node = Node> + 'static), policy: IncludeExports) {
231+
let mut funcs = Vec::new();
232+
if !h.entrypoint_optype().is_module() {
233+
funcs.push(h.entrypoint());
234+
}
235+
if policy.for_hugr(&h) {
236+
funcs.extend(
237+
h.children(h.module_root())
238+
.filter(|n| h.get_optype(*n).is_func_defn()),
239+
)
240+
}
241+
let c = funcs.into_iter().fold(ConstantFoldPass::default(), |c, n| {
242+
c.with_inputs(n, NO_INPUTS.clone())
243+
});
244+
validate_if_test(c, h).unwrap();
245+
}
246+
206247
struct ConstFoldContext;
207248

208249
impl ConstLoader<ValueHandle<Node>> for ConstFoldContext {

hugr-passes/src/const_fold/test.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,14 @@ use hugr_core::std_extensions::logic::LogicOp;
2929
use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV};
3030
use hugr_core::{Hugr, HugrView, IncomingPort, Node, type_row};
3131

32-
use crate::ComposablePass as _;
3332
use crate::dataflow::{DFContext, PartialValue, partial_from_const};
33+
use crate::{ComposablePass as _, IncludeExports};
3434

35-
use super::{ConstFoldContext, ConstantFoldPass, ValueHandle, constant_fold_pass};
35+
use super::{ConstFoldContext, ConstantFoldPass, ValueHandle, fold_constants};
36+
37+
fn constant_fold_pass(h: &mut (impl HugrMut<Node = Node> + 'static)) {
38+
fold_constants(h, IncludeExports::Always);
39+
}
3640

3741
#[rstest]
3842
#[case(ConstInt::new_u(4, 2).unwrap(), true)]

hugr-passes/src/dataflow/datalog.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> {
116116
} else {
117117
let ep = self.0.entrypoint();
118118
let mut p = in_values.into_iter().peekable();
119-
// We must provide some inputs to the root so that they are Top rather than Bottom.
119+
// We must provide some inputs to the entrypoint so that they are Top rather than Bottom.
120120
// (However, this test will fail for DataflowBlock or Case roots, i.e. if no
121121
// inputs have been provided they will still see Bottom. We could store the "input"
122122
// values for even these nodes in self.1 and then convert to actual Wire values

hugr-passes/src/dead_code.rs

Lines changed: 91 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,35 @@
11
//! Pass for removing dead code, i.e. that computes values that are then discarded
22
33
use hugr_core::hugr::internal::HugrInternals;
4-
use hugr_core::{HugrView, hugr::hugrmut::HugrMut, ops::OpType};
4+
use hugr_core::{HugrView, Visibility, hugr::hugrmut::HugrMut, ops::OpType};
55
use std::convert::Infallible;
66
use std::fmt::{Debug, Formatter};
77
use std::{
88
collections::{HashMap, HashSet, VecDeque},
99
sync::Arc,
1010
};
1111

12-
use crate::ComposablePass;
12+
use crate::{ComposablePass, IncludeExports};
1313

14-
/// Configuration for Dead Code Elimination pass
14+
/// Configuration for Dead Code Elimination pass, i.e. which removes nodes
15+
/// beneath the [HugrView::entrypoint] that compute only unneeded values.
1516
#[derive(Clone)]
1617
pub struct DeadCodeElimPass<H: HugrView> {
1718
/// Nodes that are definitely needed - e.g. `FuncDefns`, but could be anything.
18-
/// Hugr Root is assumed to be an entry point even if not mentioned here.
19+
/// [HugrView::entrypoint] is assumed to be needed even if not mentioned here.
1920
entry_points: Vec<H::Node>,
2021
/// Callback identifying nodes that must be preserved even if their
2122
/// results are not used. Defaults to [`PreserveNode::default_for`].
2223
preserve_callback: Arc<PreserveCallback<H>>,
24+
include_exports: IncludeExports,
2325
}
2426

2527
impl<H: HugrView + 'static> Default for DeadCodeElimPass<H> {
2628
fn default() -> Self {
2729
Self {
2830
entry_points: Default::default(),
2931
preserve_callback: Arc::new(PreserveNode::default_for),
32+
include_exports: IncludeExports::default(),
3033
}
3134
}
3235
}
@@ -39,11 +42,13 @@ impl<H: HugrView> Debug for DeadCodeElimPass<H> {
3942
#[derive(Debug)]
4043
struct DCEDebug<'a, N> {
4144
entry_points: &'a Vec<N>,
45+
include_exports: IncludeExports,
4246
}
4347

4448
Debug::fmt(
4549
&DCEDebug {
4650
entry_points: &self.entry_points,
51+
include_exports: self.include_exports,
4752
},
4853
f,
4954
)
@@ -69,12 +74,12 @@ pub enum PreserveNode {
6974

7075
impl PreserveNode {
7176
/// A conservative default for a given node. Just examines the node's [`OpType`]:
72-
/// * Assumes all Calls must be preserved. (One could scan the called `FuncDefn`, but would
73-
/// also need to check for cycles in the [`CallGraph`](super::call_graph::CallGraph).)
77+
/// * Assumes all Calls must be preserved. (One could scan the called `FuncDefn` for
78+
/// termination, but would also need to check for cycles in the `CallGraph`.)
7479
/// * Assumes all CFGs must be preserved. (One could, for example, allow acyclic
7580
/// CFGs to be removed.)
76-
/// * Assumes all `TailLoops` must be preserved. (One could, for example, use dataflow
77-
/// analysis to allow removal of `TailLoops` that never [Continue](hugr_core::ops::TailLoop::CONTINUE_TAG).)
81+
/// * Assumes all `TailLoops` must be preserved. (One could use some analysis, e.g.
82+
/// dataflow, to allow removal of `TailLoops` with a bounded number of iterations.)
7883
pub fn default_for<H: HugrView>(h: &H, n: H::Node) -> PreserveNode {
7984
match h.get_optype(n) {
8085
OpType::CFG(_) | OpType::TailLoop(_) | OpType::Call(_) => PreserveNode::MustKeep,
@@ -91,16 +96,33 @@ impl<H: HugrView> DeadCodeElimPass<H> {
9196
self
9297
}
9398

94-
/// Mark some nodes as entry points to the Hugr, i.e. so we cannot eliminate any code
95-
/// used to evaluate these nodes.
96-
/// [`HugrView::entrypoint`] is assumed to be an entry point;
97-
/// for Module roots the client will want to mark some of the `FuncDefn` children
98-
/// as entry points too.
99+
/// Mark some nodes as reachable, i.e. so we cannot eliminate any code used to
100+
/// evaluate their results. The [`HugrView::entrypoint`] is assumed to be reachable;
101+
/// if that is the [`HugrView::module_root`], then any public [FuncDefn] and
102+
/// [FuncDecl]s are also considered reachable by default,
103+
/// but this can be change by [`Self::include_module_exports`].
104+
///
105+
/// [FuncDecl]: OpType::FuncDecl
106+
/// [FuncDefn]: OpType::FuncDefn
99107
pub fn with_entry_points(mut self, entry_points: impl IntoIterator<Item = H::Node>) -> Self {
100108
self.entry_points.extend(entry_points);
101109
self
102110
}
103111

112+
/// Sets whether the exported [FuncDefn](OpType::FuncDefn)s and
113+
/// [FuncDecl](OpType::FuncDecl)s are considered reachable.
114+
///
115+
/// Note that for non-module-entry Hugrs this has no effect, since we only remove
116+
/// code beneath the entrypoint: this cannot be affected by other module children.
117+
///
118+
/// So, for module-rooted-Hugrs: [IncludeExports::OnlyIfEntrypointIsModuleRoot] is
119+
/// equivalent to [IncludeExports::Always]; and [IncludeExports::Never] will remove
120+
/// all children, unless some are explicity added by [Self::with_entry_points].
121+
pub fn include_module_exports(mut self, include: IncludeExports) -> Self {
122+
self.include_exports = include;
123+
self
124+
}
125+
104126
fn find_needed_nodes(&self, h: &H) -> HashSet<H::Node> {
105127
let mut must_preserve = HashMap::new();
106128
let mut needed = HashSet::new();
@@ -111,19 +133,23 @@ impl<H: HugrView> DeadCodeElimPass<H> {
111133
continue;
112134
}
113135
for ch in h.children(n) {
114-
if self.must_preserve(h, &mut must_preserve, ch)
115-
|| matches!(
116-
h.get_optype(ch),
136+
let must_keep = match h.get_optype(ch) {
117137
OpType::Case(_) // Include all Cases in Conditionals
118138
| OpType::DataflowBlock(_) // and all Basic Blocks in CFGs
119139
| OpType::ExitBlock(_)
120140
| OpType::AliasDecl(_) // and all Aliases (we do not track their uses in types)
121141
| OpType::AliasDefn(_)
122142
| OpType::Input(_) // Also Dataflow input/output, these are necessary for legality
123-
| OpType::Output(_) // Do not include FuncDecl / FuncDefn / Const unless reachable by static edges
124-
// (from Call/LoadConst/LoadFunction):
125-
)
126-
{
143+
| OpType::Output(_) => true,
144+
// FuncDefns (as children of Module) only if public and including exports
145+
// (will be included if static predecessors of Call/LoadFunction below,
146+
// regardless of Visibility or self.include_exports)
147+
OpType::FuncDefn(fd) => fd.visibility() == &Visibility::Public && self.include_exports.for_hugr(h),
148+
OpType::FuncDecl(fd) => fd.visibility() == &Visibility::Public && self.include_exports.for_hugr(h),
149+
// No Const, unless reached along static edges
150+
_ => false
151+
};
152+
if must_keep || self.must_preserve(h, &mut must_preserve, ch) {
127153
q.push_back(ch);
128154
}
129155
}
@@ -141,7 +167,6 @@ impl<H: HugrView> DeadCodeElimPass<H> {
141167
if let Some(res) = cache.get(&n) {
142168
return *res;
143169
}
144-
#[allow(deprecated)]
145170
let res = match self.preserve_callback.as_ref()(h, n) {
146171
PreserveNode::MustKeep => true,
147172
PreserveNode::CanRemoveIgnoringChildren => false,
@@ -174,18 +199,57 @@ impl<H: HugrMut> ComposablePass<H> for DeadCodeElimPass<H> {
174199
mod test {
175200
use std::sync::Arc;
176201

177-
use hugr_core::Hugr;
178-
use hugr_core::builder::{CFGBuilder, Container, Dataflow, DataflowSubContainer, HugrBuilder};
202+
use hugr_core::builder::{
203+
CFGBuilder, Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder,
204+
};
179205
use hugr_core::extension::prelude::{ConstUsize, usize_t};
180-
use hugr_core::ops::{OpTag, OpTrait, handle::NodeHandle};
181-
use hugr_core::types::Signature;
182-
use hugr_core::{HugrView, ops::Value, type_row};
206+
use hugr_core::ops::{OpTag, OpTrait, Value, handle::NodeHandle};
207+
use hugr_core::{Hugr, HugrView, type_row, types::Signature};
183208
use itertools::Itertools;
209+
use rstest::rstest;
184210

185-
use crate::ComposablePass;
211+
use crate::{ComposablePass, IncludeExports};
186212

187213
use super::{DeadCodeElimPass, PreserveNode};
188214

215+
#[rstest]
216+
#[case(false, IncludeExports::Never, true)]
217+
#[case(false, IncludeExports::OnlyIfEntrypointIsModuleRoot, false)]
218+
#[case(false, IncludeExports::Always, false)]
219+
#[case(true, IncludeExports::Never, true)]
220+
#[case(true, IncludeExports::OnlyIfEntrypointIsModuleRoot, false)]
221+
#[case(true, IncludeExports::Always, false)]
222+
fn test_module_exports(
223+
#[case] include_dfn: bool,
224+
#[case] module_exports: IncludeExports,
225+
#[case] decl_removed: bool,
226+
) {
227+
let mut mb = ModuleBuilder::new();
228+
let dfn = mb
229+
.define_function("foo", Signature::new_endo(usize_t()))
230+
.unwrap();
231+
let ins = dfn.input_wires();
232+
let dfn = dfn.finish_with_outputs(ins).unwrap();
233+
let dcl = mb
234+
.declare("bar", Signature::new_endo(usize_t()).into())
235+
.unwrap();
236+
let mut h = mb.finish_hugr().unwrap();
237+
let mut dce = DeadCodeElimPass::<Hugr>::default().include_module_exports(module_exports);
238+
if include_dfn {
239+
dce = dce.with_entry_points([dfn.node()]);
240+
}
241+
dce.run(&mut h).unwrap();
242+
let defn_retained = include_dfn;
243+
let decl_retained = !decl_removed;
244+
let children = h.children(h.module_root()).collect_vec();
245+
assert_eq!(defn_retained, children.iter().contains(&dfn.node()));
246+
assert_eq!(decl_retained, children.iter().contains(&dcl.node()));
247+
assert_eq!(
248+
children.len(),
249+
(defn_retained as usize) + (decl_retained as usize)
250+
);
251+
}
252+
189253
#[test]
190254
fn test_cfg_callback() {
191255
let mut cb = CFGBuilder::new(Signature::new_endo(type_row![])).unwrap();

0 commit comments

Comments
 (0)