11//! Pass for removing dead code, i.e. that computes values that are then discarded
22
33use hugr_core:: hugr:: internal:: HugrInternals ;
4- use hugr_core:: { HugrView , hugr:: hugrmut:: HugrMut , ops:: OpType } ;
4+ use hugr_core:: { HugrView , Visibility , hugr:: hugrmut:: HugrMut , ops:: OpType } ;
55use std:: convert:: Infallible ;
66use std:: fmt:: { Debug , Formatter } ;
77use std:: {
88 collections:: { HashMap , HashSet , VecDeque } ,
99 sync:: Arc ,
1010} ;
1111
12- use crate :: ComposablePass ;
12+ use crate :: { ComposablePass , IncludeExports } ;
1313
14- /// Configuration for Dead Code Elimination pass
14+ /// Configuration for Dead Code Elimination pass, i.e. which removes nodes
15+ /// beneath the [HugrView::entrypoint] that compute only unneeded values.
1516#[ derive( Clone ) ]
1617pub struct DeadCodeElimPass < H : HugrView > {
1718 /// Nodes that are definitely needed - e.g. `FuncDefns`, but could be anything.
18- /// Hugr Root is assumed to be an entry point even if not mentioned here.
19+ /// [HugrView::entrypoint] is assumed to be needed even if not mentioned here.
1920 entry_points : Vec < H :: Node > ,
2021 /// Callback identifying nodes that must be preserved even if their
2122 /// results are not used. Defaults to [`PreserveNode::default_for`].
2223 preserve_callback : Arc < PreserveCallback < H > > ,
24+ include_exports : IncludeExports ,
2325}
2426
2527impl < H : HugrView + ' static > Default for DeadCodeElimPass < H > {
2628 fn default ( ) -> Self {
2729 Self {
2830 entry_points : Default :: default ( ) ,
2931 preserve_callback : Arc :: new ( PreserveNode :: default_for) ,
32+ include_exports : IncludeExports :: default ( ) ,
3033 }
3134 }
3235}
@@ -39,11 +42,13 @@ impl<H: HugrView> Debug for DeadCodeElimPass<H> {
3942 #[ derive( Debug ) ]
4043 struct DCEDebug < ' a , N > {
4144 entry_points : & ' a Vec < N > ,
45+ include_exports : IncludeExports ,
4246 }
4347
4448 Debug :: fmt (
4549 & DCEDebug {
4650 entry_points : & self . entry_points ,
51+ include_exports : self . include_exports ,
4752 } ,
4853 f,
4954 )
@@ -69,12 +74,12 @@ pub enum PreserveNode {
6974
7075impl PreserveNode {
7176 /// A conservative default for a given node. Just examines the node's [`OpType`]:
72- /// * Assumes all Calls must be preserved. (One could scan the called `FuncDefn`, but would
73- /// also need to check for cycles in the [ `CallGraph`](super::call_graph::CallGraph) .)
77+ /// * Assumes all Calls must be preserved. (One could scan the called `FuncDefn` for
78+ /// termination, but would also need to check for cycles in the `CallGraph`.)
7479 /// * Assumes all CFGs must be preserved. (One could, for example, allow acyclic
7580 /// CFGs to be removed.)
76- /// * Assumes all `TailLoops` must be preserved. (One could, for example, use dataflow
77- /// analysis to allow removal of `TailLoops` that never [Continue](hugr_core::ops::TailLoop::CONTINUE_TAG) .)
81+ /// * Assumes all `TailLoops` must be preserved. (One could use some analysis, e.g.
82+ /// dataflow, to allow removal of `TailLoops` with a bounded number of iterations .)
7883 pub fn default_for < H : HugrView > ( h : & H , n : H :: Node ) -> PreserveNode {
7984 match h. get_optype ( n) {
8085 OpType :: CFG ( _) | OpType :: TailLoop ( _) | OpType :: Call ( _) => PreserveNode :: MustKeep ,
@@ -91,16 +96,33 @@ impl<H: HugrView> DeadCodeElimPass<H> {
9196 self
9297 }
9398
94- /// Mark some nodes as entry points to the Hugr, i.e. so we cannot eliminate any code
95- /// used to evaluate these nodes.
96- /// [`HugrView::entrypoint`] is assumed to be an entry point;
97- /// for Module roots the client will want to mark some of the `FuncDefn` children
98- /// as entry points too.
99+ /// Mark some nodes as reachable, i.e. so we cannot eliminate any code used to
100+ /// evaluate their results. The [`HugrView::entrypoint`] is assumed to be reachable;
101+ /// if that is the [`HugrView::module_root`], then any public [FuncDefn] and
102+ /// [FuncDecl]s are also considered reachable by default,
103+ /// but this can be change by [`Self::include_module_exports`].
104+ ///
105+ /// [FuncDecl]: OpType::FuncDecl
106+ /// [FuncDefn]: OpType::FuncDefn
99107 pub fn with_entry_points ( mut self , entry_points : impl IntoIterator < Item = H :: Node > ) -> Self {
100108 self . entry_points . extend ( entry_points) ;
101109 self
102110 }
103111
112+ /// Sets whether the exported [FuncDefn](OpType::FuncDefn)s and
113+ /// [FuncDecl](OpType::FuncDecl)s are considered reachable.
114+ ///
115+ /// Note that for non-module-entry Hugrs this has no effect, since we only remove
116+ /// code beneath the entrypoint: this cannot be affected by other module children.
117+ ///
118+ /// So, for module-rooted-Hugrs: [IncludeExports::OnlyIfEntrypointIsModuleRoot] is
119+ /// equivalent to [IncludeExports::Always]; and [IncludeExports::Never] will remove
120+ /// all children, unless some are explicity added by [Self::with_entry_points].
121+ pub fn include_module_exports ( mut self , include : IncludeExports ) -> Self {
122+ self . include_exports = include;
123+ self
124+ }
125+
104126 fn find_needed_nodes ( & self , h : & H ) -> HashSet < H :: Node > {
105127 let mut must_preserve = HashMap :: new ( ) ;
106128 let mut needed = HashSet :: new ( ) ;
@@ -111,19 +133,23 @@ impl<H: HugrView> DeadCodeElimPass<H> {
111133 continue ;
112134 }
113135 for ch in h. children ( n) {
114- if self . must_preserve ( h, & mut must_preserve, ch)
115- || matches ! (
116- h. get_optype( ch) ,
136+ let must_keep = match h. get_optype ( ch) {
117137 OpType :: Case ( _) // Include all Cases in Conditionals
118138 | OpType :: DataflowBlock ( _) // and all Basic Blocks in CFGs
119139 | OpType :: ExitBlock ( _)
120140 | OpType :: AliasDecl ( _) // and all Aliases (we do not track their uses in types)
121141 | OpType :: AliasDefn ( _)
122142 | OpType :: Input ( _) // Also Dataflow input/output, these are necessary for legality
123- | OpType :: Output ( _) // Do not include FuncDecl / FuncDefn / Const unless reachable by static edges
124- // (from Call/LoadConst/LoadFunction):
125- )
126- {
143+ | OpType :: Output ( _) => true ,
144+ // FuncDefns (as children of Module) only if public and including exports
145+ // (will be included if static predecessors of Call/LoadFunction below,
146+ // regardless of Visibility or self.include_exports)
147+ OpType :: FuncDefn ( fd) => fd. visibility ( ) == & Visibility :: Public && self . include_exports . for_hugr ( h) ,
148+ OpType :: FuncDecl ( fd) => fd. visibility ( ) == & Visibility :: Public && self . include_exports . for_hugr ( h) ,
149+ // No Const, unless reached along static edges
150+ _ => false
151+ } ;
152+ if must_keep || self . must_preserve ( h, & mut must_preserve, ch) {
127153 q. push_back ( ch) ;
128154 }
129155 }
@@ -141,7 +167,6 @@ impl<H: HugrView> DeadCodeElimPass<H> {
141167 if let Some ( res) = cache. get ( & n) {
142168 return * res;
143169 }
144- #[ allow( deprecated) ]
145170 let res = match self . preserve_callback . as_ref ( ) ( h, n) {
146171 PreserveNode :: MustKeep => true ,
147172 PreserveNode :: CanRemoveIgnoringChildren => false ,
@@ -174,18 +199,57 @@ impl<H: HugrMut> ComposablePass<H> for DeadCodeElimPass<H> {
174199mod test {
175200 use std:: sync:: Arc ;
176201
177- use hugr_core:: Hugr ;
178- use hugr_core:: builder:: { CFGBuilder , Container , Dataflow , DataflowSubContainer , HugrBuilder } ;
202+ use hugr_core:: builder:: {
203+ CFGBuilder , Container , Dataflow , DataflowSubContainer , HugrBuilder , ModuleBuilder ,
204+ } ;
179205 use hugr_core:: extension:: prelude:: { ConstUsize , usize_t} ;
180- use hugr_core:: ops:: { OpTag , OpTrait , handle:: NodeHandle } ;
181- use hugr_core:: types:: Signature ;
182- use hugr_core:: { HugrView , ops:: Value , type_row} ;
206+ use hugr_core:: ops:: { OpTag , OpTrait , Value , handle:: NodeHandle } ;
207+ use hugr_core:: { Hugr , HugrView , type_row, types:: Signature } ;
183208 use itertools:: Itertools ;
209+ use rstest:: rstest;
184210
185- use crate :: ComposablePass ;
211+ use crate :: { ComposablePass , IncludeExports } ;
186212
187213 use super :: { DeadCodeElimPass , PreserveNode } ;
188214
215+ #[ rstest]
216+ #[ case( false , IncludeExports :: Never , true ) ]
217+ #[ case( false , IncludeExports :: OnlyIfEntrypointIsModuleRoot , false ) ]
218+ #[ case( false , IncludeExports :: Always , false ) ]
219+ #[ case( true , IncludeExports :: Never , true ) ]
220+ #[ case( true , IncludeExports :: OnlyIfEntrypointIsModuleRoot , false ) ]
221+ #[ case( true , IncludeExports :: Always , false ) ]
222+ fn test_module_exports (
223+ #[ case] include_dfn : bool ,
224+ #[ case] module_exports : IncludeExports ,
225+ #[ case] decl_removed : bool ,
226+ ) {
227+ let mut mb = ModuleBuilder :: new ( ) ;
228+ let dfn = mb
229+ . define_function ( "foo" , Signature :: new_endo ( usize_t ( ) ) )
230+ . unwrap ( ) ;
231+ let ins = dfn. input_wires ( ) ;
232+ let dfn = dfn. finish_with_outputs ( ins) . unwrap ( ) ;
233+ let dcl = mb
234+ . declare ( "bar" , Signature :: new_endo ( usize_t ( ) ) . into ( ) )
235+ . unwrap ( ) ;
236+ let mut h = mb. finish_hugr ( ) . unwrap ( ) ;
237+ let mut dce = DeadCodeElimPass :: < Hugr > :: default ( ) . include_module_exports ( module_exports) ;
238+ if include_dfn {
239+ dce = dce. with_entry_points ( [ dfn. node ( ) ] ) ;
240+ }
241+ dce. run ( & mut h) . unwrap ( ) ;
242+ let defn_retained = include_dfn;
243+ let decl_retained = !decl_removed;
244+ let children = h. children ( h. module_root ( ) ) . collect_vec ( ) ;
245+ assert_eq ! ( defn_retained, children. iter( ) . contains( & dfn. node( ) ) ) ;
246+ assert_eq ! ( decl_retained, children. iter( ) . contains( & dcl. node( ) ) ) ;
247+ assert_eq ! (
248+ children. len( ) ,
249+ ( defn_retained as usize ) + ( decl_retained as usize )
250+ ) ;
251+ }
252+
189253 #[ test]
190254 fn test_cfg_callback ( ) {
191255 let mut cb = CFGBuilder :: new ( Signature :: new_endo ( type_row ! [ ] ) ) . unwrap ( ) ;
0 commit comments