From 1da9edf59e22e527e83ddbea4b1b62b2295d37be Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Jul 2025 16:12:57 +0100 Subject: [PATCH 01/15] Add replace_parametrized_op_with taking better callback --- hugr-passes/src/replace_types.rs | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 6acd91c13f..321872836e 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -204,7 +204,8 @@ pub struct ReplaceTypes { param_types: HashMap Option>>, linearize: DelegatingLinearizer, op_map: HashMap, - param_ops: HashMap Option>>, + param_ops: + HashMap Option>>, consts: HashMap< CustomType, Arc Result>, @@ -352,10 +353,27 @@ impl ReplaceTypes { /// fit the bounds of the original op). /// /// If the Callback returns None, the new typeargs will be applied to the original op. + #[deprecated(note = "use replace_parametrized_op_with")] // When removed, consider renaming back over this. pub fn replace_parametrized_op( &mut self, src: &OpDef, dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, + ) { + self.param_ops + .insert(src.into(), Arc::new(move |args, _| dest_fn(args))); + } + + /// Configures this instance to change occurrences of a parametrized op `src` + /// via a callback that builds the replacement type given the [`TypeArg`]s + /// (and a handle to the [ReplaceTypes], e.g. allowing access to [Self::linearizer]). + /// Note that the `TypeArgs` will already have been updated (e.g. they may not + /// fit the bounds of the original op). + /// + /// If the Callback returns None, the new typeargs will be applied to the original op. + pub fn replace_parametrized_op_with( + &mut self, + src: &OpDef, + dest_fn: impl Fn(&[TypeArg], &ReplaceTypes) -> Option + 'static, ) { self.param_ops.insert(src.into(), Arc::new(dest_fn)); } @@ -461,7 +479,7 @@ impl ReplaceTypes { if let Some(replacement) = self .param_ops .get(&def.as_ref().into()) - .and_then(|rep_fn| rep_fn(&args)) + .and_then(|rep_fn| rep_fn(&args, self)) { replacement .replace(hugr, n) From d2668b9e6c0d30f49fdd879d72b4900901dc7ee8 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Jul 2025 16:19:24 +0100 Subject: [PATCH 02/15] Fix deprecations in tests --- hugr-passes/src/linearize_array.rs | 20 ++++++++++---------- hugr-passes/src/replace_types.rs | 10 +++++----- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/hugr-passes/src/linearize_array.rs b/hugr-passes/src/linearize_array.rs index 64daf21f4d..a22137a36e 100644 --- a/hugr-passes/src/linearize_array.rs +++ b/hugr-passes/src/linearize_array.rs @@ -53,9 +53,9 @@ impl Default for LinearizeArrayPass { Ok(Some(ArrayValue::new(ty, contents).into())) }); for op_def in ArrayOpDef::iter() { - pass.replace_parametrized_op( + pass.replace_parametrized_op_with( value_array::EXTENSION.get_op(&op_def.opdef_id()).unwrap(), - move |args| { + move |args, _| { // `get` is only allowed for copyable elements. Assuming the Hugr was // valid when we started, the only way for the element to become linear // is if it used to contain nested `value_array`s. In that case, we @@ -76,38 +76,38 @@ impl Default for LinearizeArrayPass { }, ); } - pass.replace_parametrized_op( + pass.replace_parametrized_op_with( value_array::EXTENSION.get_op(&ARRAY_REPEAT_OP_ID).unwrap(), - |args| { + |args, _| { Some(NodeTemplate::SingleOp( ArrayRepeatDef::new().instantiate(args).unwrap().into(), )) }, ); - pass.replace_parametrized_op( + pass.replace_parametrized_op_with( value_array::EXTENSION.get_op(&ARRAY_SCAN_OP_ID).unwrap(), - |args| { + |args, _| { Some(NodeTemplate::SingleOp( ArrayScanDef::new().instantiate(args).unwrap().into(), )) }, ); - pass.replace_parametrized_op( + pass.replace_parametrized_op_with( value_array::EXTENSION .get_op(&VArrayFromArrayDef::new().opdef_id()) .unwrap(), - |args| { + |args, _| { let array_ty = array_type_parametric(args[0].clone(), args[1].clone()).unwrap(); Some(NodeTemplate::SingleOp( Noop::new(array_ty).to_extension_op().unwrap().into(), )) }, ); - pass.replace_parametrized_op( + pass.replace_parametrized_op_with( value_array::EXTENSION .get_op(&VArrayToArrayDef::new().opdef_id()) .unwrap(), - |args| { + |args, _| { let array_ty = array_type_parametric(args[0].clone(), args[1].clone()).unwrap(); Some(NodeTemplate::SingleOp( Noop::new(array_ty).to_extension_op().unwrap().into(), diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 321872836e..e030e1b1b1 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -748,7 +748,7 @@ mod test { .into(), ), ); - lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), |type_args| { + lw.replace_parametrized_op_with(ext.get_op(READ).unwrap().as_ref(), |type_args, _| { Some(NodeTemplate::CompoundOp(Box::new( lowered_read(just_elem_type(type_args).clone(), DFGBuilder::new) .finish_hugr() @@ -1028,9 +1028,9 @@ mod test { option_contents(just_elem_type(args)).map(list_type) }); // and read> to get - the latter has the expected option return type - lowerer.replace_parametrized_op( + lowerer.replace_parametrized_op_with( e.get_op(READ).unwrap().as_ref(), - Box::new(|args: &[TypeArg]| { + |args: &[TypeArg], _| { option_contents(just_elem_type(args)).map(|elem| { NodeTemplate::SingleOp( ListOp::get @@ -1040,7 +1040,7 @@ mod test { .into(), ) }) - }), + }, ); assert!(lowerer.run(&mut h).unwrap()); // list -> read -> usz just becomes list -> read -> qb @@ -1134,7 +1134,7 @@ mod test { .inserted_entrypoint; let mut lw = lowerer(&e); - lw.replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args| { + lw.replace_parametrized_op_with(e.get_op(READ).unwrap().as_ref(), move |args, _| { Some(NodeTemplate::Call(read_func, args.to_owned())) }); lw.run(&mut h).unwrap(); From dc4fdb2393b58893d8562ce6dfa5b250fd23958b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Jul 2025 17:27:29 +0100 Subject: [PATCH 03/15] closure returns Result --- hugr-passes/src/linearize_array.rs | 20 ++++++++++---------- hugr-passes/src/replace_types.rs | 25 +++++++++++++++---------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/hugr-passes/src/linearize_array.rs b/hugr-passes/src/linearize_array.rs index a22137a36e..c42c563d39 100644 --- a/hugr-passes/src/linearize_array.rs +++ b/hugr-passes/src/linearize_array.rs @@ -70,26 +70,26 @@ impl Default for LinearizeArrayPass { "Cannot linearise arrays in this Hugr: \ Contains a `get` operation on nested value arrays" ); - Some(NodeTemplate::SingleOp( + Ok(Some(NodeTemplate::SingleOp( op_def.instantiate(args).unwrap().into(), - )) + ))) }, ); } pass.replace_parametrized_op_with( value_array::EXTENSION.get_op(&ARRAY_REPEAT_OP_ID).unwrap(), |args, _| { - Some(NodeTemplate::SingleOp( + Ok(Some(NodeTemplate::SingleOp( ArrayRepeatDef::new().instantiate(args).unwrap().into(), - )) + ))) }, ); pass.replace_parametrized_op_with( value_array::EXTENSION.get_op(&ARRAY_SCAN_OP_ID).unwrap(), |args, _| { - Some(NodeTemplate::SingleOp( + Ok(Some(NodeTemplate::SingleOp( ArrayScanDef::new().instantiate(args).unwrap().into(), - )) + ))) }, ); pass.replace_parametrized_op_with( @@ -98,9 +98,9 @@ impl Default for LinearizeArrayPass { .unwrap(), |args, _| { let array_ty = array_type_parametric(args[0].clone(), args[1].clone()).unwrap(); - Some(NodeTemplate::SingleOp( + Ok(Some(NodeTemplate::SingleOp( Noop::new(array_ty).to_extension_op().unwrap().into(), - )) + ))) }, ); pass.replace_parametrized_op_with( @@ -109,9 +109,9 @@ impl Default for LinearizeArrayPass { .unwrap(), |args, _| { let array_ty = array_type_parametric(args[0].clone(), args[1].clone()).unwrap(); - Some(NodeTemplate::SingleOp( + Ok(Some(NodeTemplate::SingleOp( Noop::new(array_ty).to_extension_op().unwrap().into(), - )) + ))) }, ); pass.linearizer() diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index e030e1b1b1..74852b1103 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -204,8 +204,10 @@ pub struct ReplaceTypes { param_types: HashMap Option>>, linearize: DelegatingLinearizer, op_map: HashMap, - param_ops: - HashMap Option>>, + param_ops: HashMap< + ParametricOp, + Arc Result, ReplaceTypesError>>, + >, consts: HashMap< CustomType, Arc Result>, @@ -360,7 +362,7 @@ impl ReplaceTypes { dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, ) { self.param_ops - .insert(src.into(), Arc::new(move |args, _| dest_fn(args))); + .insert(src.into(), Arc::new(move |args, _| Ok(dest_fn(args)))); } /// Configures this instance to change occurrences of a parametrized op `src` @@ -373,7 +375,8 @@ impl ReplaceTypes { pub fn replace_parametrized_op_with( &mut self, src: &OpDef, - dest_fn: impl Fn(&[TypeArg], &ReplaceTypes) -> Option + 'static, + dest_fn: impl Fn(&[TypeArg], &ReplaceTypes) -> Result, ReplaceTypesError> + + 'static, ) { self.param_ops.insert(src.into(), Arc::new(dest_fn)); } @@ -479,7 +482,9 @@ impl ReplaceTypes { if let Some(replacement) = self .param_ops .get(&def.as_ref().into()) - .and_then(|rep_fn| rep_fn(&args, self)) + .map(|rep_fn| rep_fn(&args, self)) + .transpose()? + .flatten() { replacement .replace(hugr, n) @@ -749,11 +754,11 @@ mod test { ), ); lw.replace_parametrized_op_with(ext.get_op(READ).unwrap().as_ref(), |type_args, _| { - Some(NodeTemplate::CompoundOp(Box::new( + Ok(Some(NodeTemplate::CompoundOp(Box::new( lowered_read(just_elem_type(type_args).clone(), DFGBuilder::new) .finish_hugr() .unwrap(), - ))) + )))) }); lw } @@ -1031,7 +1036,7 @@ mod test { lowerer.replace_parametrized_op_with( e.get_op(READ).unwrap().as_ref(), |args: &[TypeArg], _| { - option_contents(just_elem_type(args)).map(|elem| { + Ok(option_contents(just_elem_type(args)).map(|elem| { NodeTemplate::SingleOp( ListOp::get .with_type(elem) @@ -1039,7 +1044,7 @@ mod test { .unwrap() .into(), ) - }) + })) }, ); assert!(lowerer.run(&mut h).unwrap()); @@ -1135,7 +1140,7 @@ mod test { let mut lw = lowerer(&e); lw.replace_parametrized_op_with(e.get_op(READ).unwrap().as_ref(), move |args, _| { - Some(NodeTemplate::Call(read_func, args.to_owned())) + Ok(Some(NodeTemplate::Call(read_func, args.to_owned()))) }); lw.run(&mut h).unwrap(); From 24d3b3b28113b88b7eb1a5669f666d9a1eb77f13 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Jul 2025 17:32:18 +0100 Subject: [PATCH 04/15] Deprecate+rename linearizer->linearizer_mut --- hugr-passes/src/linearize_array.rs | 4 ++-- hugr-passes/src/replace_types.rs | 8 +++++++- hugr-passes/src/replace_types/linearize.rs | 14 +++++++------- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/hugr-passes/src/linearize_array.rs b/hugr-passes/src/linearize_array.rs index c42c563d39..bd4296d454 100644 --- a/hugr-passes/src/linearize_array.rs +++ b/hugr-passes/src/linearize_array.rs @@ -114,7 +114,7 @@ impl Default for LinearizeArrayPass { ))) }, ); - pass.linearizer() + pass.linearizer_mut() .register_callback(array_type_def(), copy_discard_array); Self(pass) } @@ -139,7 +139,7 @@ impl LinearizeArrayPass { /// Allows to configure how to clone and discard arrays that are nested /// inside opaque extension values. pub fn linearizer(&mut self) -> &mut DelegatingLinearizer { - self.0.linearizer() + self.0.linearizer_mut() } } diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 74852b1103..fdd19410c7 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -326,6 +326,12 @@ impl ReplaceTypes { self.param_types.insert(src.into(), Arc::new(dest_fn)); } + /// Deprecated + Renamed to [Self::linearizer_mut] + #[deprecated(note = "Use linearizer_mut")] + pub fn linearizer(&mut self) -> &mut DelegatingLinearizer { + self.linearizer_mut() + } + /// Allows to configure how to deal with types/wires that were [Copyable] /// but have become linear as a result of type-changing. Specifically, /// the [Linearizer] is used whenever lowering produces an outport which both @@ -335,7 +341,7 @@ impl ReplaceTypes { /// /// [Copyable]: hugr_core::types::TypeBound::Copyable /// [`array`]: hugr_core::std_extensions::collections::array::array_type - pub fn linearizer(&mut self) -> &mut DelegatingLinearizer { + pub fn linearizer_mut(&mut self) -> &mut DelegatingLinearizer { &mut self.linearize } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 5cc8a64b66..a75ed49594 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -455,7 +455,7 @@ mod test { let usize_custom_t = usize_t().as_extension().unwrap().clone(); lowerer.replace_type(usize_custom_t, Type::new_extension(lin_custom_t.clone())); lowerer - .linearizer() + .linearizer_mut() .register_simple( lin_custom_t, NodeTemplate::SingleOp(copy_op.into()), @@ -577,7 +577,7 @@ mod test { let opdef = e.get_op("copy").unwrap(); let opdef2 = opdef.clone(); lowerer - .linearizer() + .linearizer_mut() .register_callback(lin_t_def, move |args, num_outs, _| { assert!(args.is_empty()); Ok(NodeTemplate::SingleOp( @@ -639,7 +639,7 @@ mod test { let mut replacer = ReplaceTypes::default(); replacer.replace_type(usize_t().as_extension().unwrap().clone(), lin_t.clone()); - let bad_copy = replacer.linearizer().register_simple( + let bad_copy = replacer.linearizer_mut().register_simple( lin_ct.clone(), NodeTemplate::SingleOp(copy3.clone()), NodeTemplate::SingleOp(discard.clone().into()), @@ -654,7 +654,7 @@ mod test { }) ); - let bad_discard = replacer.linearizer().register_simple( + let bad_discard = replacer.linearizer_mut().register_simple( lin_ct.clone(), NodeTemplate::SingleOp(copy2.into()), NodeTemplate::SingleOp(copy3.clone()), @@ -671,7 +671,7 @@ mod test { // Try parametrized instead, but this version always returns 3 outports replacer - .linearizer() + .linearizer_mut() .register_callback(ext.get_type(LIN_T).unwrap(), move |_args, _, _| { Ok(NodeTemplate::SingleOp(copy3.clone())) }); @@ -699,7 +699,7 @@ mod test { let (e, mut lowerer) = ext_lowerer(); lowerer - .linearizer() + .linearizer_mut() .register_callback(value_array_type_def(), linearize_value_array); let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); let opt_lin_ty = Type::from(option_type(lin_t.clone())); @@ -817,7 +817,7 @@ mod test { let mut lower_discard_to_call = ReplaceTypes::default(); lower_discard_to_call - .linearizer() + .linearizer_mut() .register_simple( lin_ct.clone(), NodeTemplate::Call(backup.entrypoint(), vec![]), // Arbitrary, unused From 8eca11ccbc52d0b09242c9ab75507fe45c229671 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Jul 2025 17:33:43 +0100 Subject: [PATCH 05/15] ...and add get_linearizer --- hugr-passes/src/replace_types.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index fdd19410c7..ff1b46dbad 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -327,11 +327,16 @@ impl ReplaceTypes { } /// Deprecated + Renamed to [Self::linearizer_mut] - #[deprecated(note = "Use linearizer_mut")] + #[deprecated(note = "Use linearizer_mut")] // When removed, rename get_linearizer to linearizer pub fn linearizer(&mut self) -> &mut DelegatingLinearizer { self.linearizer_mut() } + /// Allows access to the [Linearizer], e.g. for building [NodeTemplate::CompoundOp]s + pub fn get_linearizer(&self) -> &DelegatingLinearizer { + &self.linearize + } + /// Allows to configure how to deal with types/wires that were [Copyable] /// but have become linear as a result of type-changing. Specifically, /// the [Linearizer] is used whenever lowering produces an outport which both From d3fb9e2fb15457275d1c82e6c165f181d03c0d50 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Jul 2025 17:38:33 +0100 Subject: [PATCH 06/15] driveby: shorten lifetimes --- hugr-passes/src/replace_types/linearize.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index a75ed49594..bd0587350c 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -52,8 +52,6 @@ pub trait Linearizer { src: Wire, targets: &[(Node, IncomingPort)], ) -> Result<(), LinearizeError> { - let sig = hugr.signature(src.node()).unwrap(); - let typ = sig.port_type(src.source()).unwrap(); let (tgt_node, tgt_inport) = if targets.len() == 1 { *targets.first().unwrap() } else { @@ -74,7 +72,8 @@ pub trait Linearizer { tgt_parent, }); } - let typ = typ.clone(); // Stop borrowing hugr in order to add_hugr to it + let sig = hugr.signature(src.node()).unwrap(); + let typ = sig.port_type(src.source()).unwrap().clone(); let copy_discard_op = self .copy_discard_op(&typ, targets.len())? .add_hugr(hugr, src_parent) From 4686448db173cddd799d41dca8cc71d0b1116265 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Jul 2025 17:40:37 +0100 Subject: [PATCH 07/15] driveby: note about LocalizeEdges --- hugr-passes/src/replace_types/linearize.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index bd0587350c..09e6f9eabb 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -55,7 +55,7 @@ pub trait Linearizer { let (tgt_node, tgt_inport) = if targets.len() == 1 { *targets.first().unwrap() } else { - // Fail fast if the edges are nonlocal. (TODO transform to local edges!) + // Fail fast if the edges are nonlocal. let src_parent = hugr .get_parent(src.node()) .expect("Root node cannot have out edges"); @@ -147,7 +147,8 @@ pub enum LinearizeError { sig: Option>, }, #[error( - "Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent})" + "Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent}). + Try using LocalizeEdges pass first." )] NoLinearNonLocalEdges { src: Node, From 617e03cfd2ce3f0d08ddfa9af29cc992052677b4 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Jul 2025 18:31:50 +0100 Subject: [PATCH 08/15] test --- hugr-passes/src/replace_types/linearize.rs | 59 ++++++++++++++++++++-- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 09e6f9eabb..99ced3e52f 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -371,7 +371,7 @@ mod test { inout_sig, }; - use hugr_core::extension::prelude::{option_type, usize_t}; + use hugr_core::extension::prelude::{option_type, qb_t, usize_t}; use hugr_core::extension::simple_op::MakeExtensionOp; use hugr_core::extension::{ CustomSignatureFunc, OpDef, SignatureError, SignatureFunc, TypeDefBound, Version, @@ -385,14 +385,14 @@ mod test { }; use hugr_core::types::type_param::TypeParam; use hugr_core::types::{ - FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeEnum, TypeRow, + FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeEnum, TypeRow, }; use hugr_core::{Extension, Hugr, HugrView, Node, hugr::IdentList, type_row}; use itertools::Itertools; use rstest::rstest; use crate::replace_types::handlers::linearize_value_array; - use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; + use crate::replace_types::{LinearizeError, Linearizer, NodeTemplate, ReplaceTypesError}; use crate::{ComposablePass, ReplaceTypes}; const LIN_T: &str = "Lin"; @@ -855,4 +855,57 @@ mod test { panic!("Expected error"); } } + + #[test] + fn use_in_op_callback() { + let (e, mut lowerer) = ext_lowerer(); + let drop_ext = Extension::new_arc( + IdentList::new_unchecked("DropExt"), + Version::new(0, 0, 0), + |e, w| { + e.add_op( + "drop".into(), + String::new(), + PolyFuncTypeRV::new( + [TypeBound::Linear.into()], // It won't *lower* for any type tho! + Signature::new(Type::new_var_use(0, TypeBound::Linear), vec![]), + ), + w, + ) + .unwrap(); + }, + ); + let drop_op = drop_ext.get_op("drop").unwrap(); + lowerer.replace_parametrized_op_with(drop_op, |args, rt| { + let [TypeArg::Runtime(ty)] = args else { + panic!("Expected just one type") + }; + Ok(rt.get_linearizer().copy_discard_op(ty, 0).map(Some)?) + }); + + let build_hugr = |ty: Type| { + let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); + let [inp] = dfb.input_wires_arr(); + let drop_op = drop_ext + .instantiate_extension_op("drop", [ty.into()]) + .unwrap(); + dfb.add_dataflow_op(drop_op, [inp]).unwrap(); + dfb.finish_hugr().unwrap() + }; + // We can drop a tuple of 2* lin_t + let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); + let mut h = build_hugr(Type::new_tuple(vec![lin_t; 2])); + lowerer.run(&mut h).unwrap(); + h.validate().unwrap(); + let mut exts = h.nodes().filter_map(|n| h.get_optype(n).as_extension_op()); + assert_eq!(exts.clone().count(), 2); + assert!(exts.all(|eo| eo.qualified_id() == "TestExt.discard")); + + // We cannot drop a qubit + let mut h = build_hugr(qb_t()); + assert_eq!( + lowerer.run(&mut h).unwrap_err(), + ReplaceTypesError::LinearizeError(LinearizeError::NeedCopyDiscard(Box::new(qb_t()))) + ); + } } From 6e0d68a251957b6511671f119730fb838578c9b6 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Jul 2025 17:41:52 +0100 Subject: [PATCH 09/15] Instead, add ReplaceTypes::process_replacements: bool; refactor +change_subtree --- hugr-passes/src/linearize_array.rs | 44 ++++----- hugr-passes/src/replace_types.rs | 110 ++++++++++----------- hugr-passes/src/replace_types/linearize.rs | 29 +++--- 3 files changed, 89 insertions(+), 94 deletions(-) diff --git a/hugr-passes/src/linearize_array.rs b/hugr-passes/src/linearize_array.rs index bd4296d454..64daf21f4d 100644 --- a/hugr-passes/src/linearize_array.rs +++ b/hugr-passes/src/linearize_array.rs @@ -53,9 +53,9 @@ impl Default for LinearizeArrayPass { Ok(Some(ArrayValue::new(ty, contents).into())) }); for op_def in ArrayOpDef::iter() { - pass.replace_parametrized_op_with( + pass.replace_parametrized_op( value_array::EXTENSION.get_op(&op_def.opdef_id()).unwrap(), - move |args, _| { + move |args| { // `get` is only allowed for copyable elements. Assuming the Hugr was // valid when we started, the only way for the element to become linear // is if it used to contain nested `value_array`s. In that case, we @@ -70,51 +70,51 @@ impl Default for LinearizeArrayPass { "Cannot linearise arrays in this Hugr: \ Contains a `get` operation on nested value arrays" ); - Ok(Some(NodeTemplate::SingleOp( + Some(NodeTemplate::SingleOp( op_def.instantiate(args).unwrap().into(), - ))) + )) }, ); } - pass.replace_parametrized_op_with( + pass.replace_parametrized_op( value_array::EXTENSION.get_op(&ARRAY_REPEAT_OP_ID).unwrap(), - |args, _| { - Ok(Some(NodeTemplate::SingleOp( + |args| { + Some(NodeTemplate::SingleOp( ArrayRepeatDef::new().instantiate(args).unwrap().into(), - ))) + )) }, ); - pass.replace_parametrized_op_with( + pass.replace_parametrized_op( value_array::EXTENSION.get_op(&ARRAY_SCAN_OP_ID).unwrap(), - |args, _| { - Ok(Some(NodeTemplate::SingleOp( + |args| { + Some(NodeTemplate::SingleOp( ArrayScanDef::new().instantiate(args).unwrap().into(), - ))) + )) }, ); - pass.replace_parametrized_op_with( + pass.replace_parametrized_op( value_array::EXTENSION .get_op(&VArrayFromArrayDef::new().opdef_id()) .unwrap(), - |args, _| { + |args| { let array_ty = array_type_parametric(args[0].clone(), args[1].clone()).unwrap(); - Ok(Some(NodeTemplate::SingleOp( + Some(NodeTemplate::SingleOp( Noop::new(array_ty).to_extension_op().unwrap().into(), - ))) + )) }, ); - pass.replace_parametrized_op_with( + pass.replace_parametrized_op( value_array::EXTENSION .get_op(&VArrayToArrayDef::new().opdef_id()) .unwrap(), - |args, _| { + |args| { let array_ty = array_type_parametric(args[0].clone(), args[1].clone()).unwrap(); - Ok(Some(NodeTemplate::SingleOp( + Some(NodeTemplate::SingleOp( Noop::new(array_ty).to_extension_op().unwrap().into(), - ))) + )) }, ); - pass.linearizer_mut() + pass.linearizer() .register_callback(array_type_def(), copy_discard_array); Self(pass) } @@ -139,7 +139,7 @@ impl LinearizeArrayPass { /// Allows to configure how to clone and discard arrays that are nested /// inside opaque extension values. pub fn linearizer(&mut self) -> &mut DelegatingLinearizer { - self.0.linearizer_mut() + self.0.linearizer() } } diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index ff1b46dbad..20d5ad0bc1 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -204,10 +204,7 @@ pub struct ReplaceTypes { param_types: HashMap Option>>, linearize: DelegatingLinearizer, op_map: HashMap, - param_ops: HashMap< - ParametricOp, - Arc Result, ReplaceTypesError>>, - >, + param_ops: HashMap Option>>, consts: HashMap< CustomType, Arc Result>, @@ -216,6 +213,7 @@ pub struct ReplaceTypes { ParametricType, Arc Result, ReplaceTypesError>>, >, + process_replacements: bool, } impl Default for ReplaceTypes { @@ -277,6 +275,7 @@ impl ReplaceTypes { param_ops: Default::default(), consts: Default::default(), param_consts: Default::default(), + process_replacements: false, } } @@ -326,17 +325,6 @@ impl ReplaceTypes { self.param_types.insert(src.into(), Arc::new(dest_fn)); } - /// Deprecated + Renamed to [Self::linearizer_mut] - #[deprecated(note = "Use linearizer_mut")] // When removed, rename get_linearizer to linearizer - pub fn linearizer(&mut self) -> &mut DelegatingLinearizer { - self.linearizer_mut() - } - - /// Allows access to the [Linearizer], e.g. for building [NodeTemplate::CompoundOp]s - pub fn get_linearizer(&self) -> &DelegatingLinearizer { - &self.linearize - } - /// Allows to configure how to deal with types/wires that were [Copyable] /// but have become linear as a result of type-changing. Specifically, /// the [Linearizer] is used whenever lowering produces an outport which both @@ -346,7 +334,7 @@ impl ReplaceTypes { /// /// [Copyable]: hugr_core::types::TypeBound::Copyable /// [`array`]: hugr_core::std_extensions::collections::array::array_type - pub fn linearizer_mut(&mut self) -> &mut DelegatingLinearizer { + pub fn linearizer(&mut self) -> &mut DelegatingLinearizer { &mut self.linearize } @@ -366,28 +354,10 @@ impl ReplaceTypes { /// fit the bounds of the original op). /// /// If the Callback returns None, the new typeargs will be applied to the original op. - #[deprecated(note = "use replace_parametrized_op_with")] // When removed, consider renaming back over this. pub fn replace_parametrized_op( &mut self, src: &OpDef, dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, - ) { - self.param_ops - .insert(src.into(), Arc::new(move |args, _| Ok(dest_fn(args)))); - } - - /// Configures this instance to change occurrences of a parametrized op `src` - /// via a callback that builds the replacement type given the [`TypeArg`]s - /// (and a handle to the [ReplaceTypes], e.g. allowing access to [Self::linearizer]). - /// Note that the `TypeArgs` will already have been updated (e.g. they may not - /// fit the bounds of the original op). - /// - /// If the Callback returns None, the new typeargs will be applied to the original op. - pub fn replace_parametrized_op_with( - &mut self, - src: &OpDef, - dest_fn: impl Fn(&[TypeArg], &ReplaceTypes) -> Result, ReplaceTypesError> - + 'static, ) { self.param_ops.insert(src.into(), Arc::new(dest_fn)); } @@ -419,6 +389,14 @@ impl ReplaceTypes { self.param_consts.insert(src_ty.into(), Arc::new(const_fn)); } + /// Configures this instance to (recursively) process the RHS of any replacement + /// ops or subtrees - registered with [Self::replace_op] or from callbacks + /// registered with [Self::replace_parametrized_op]. + /// The default is `false`, i.e. do not recurse on such. + pub fn process_replacements(&mut self, r: bool) { + self.process_replacements = r; + } + fn change_node( &self, hugr: &mut impl HugrMut, @@ -480,11 +458,13 @@ impl ReplaceTypes { OpType::Const(Const { value, .. }) => self.change_value(value), OpType::ExtensionOp(ext_op) => Ok( - // Copy/discard insertion done by caller if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) { replacement .replace(hugr, n) .map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?; + if self.process_replacements { + self.change_subtree(hugr, n, true)?; + } true } else { let def = ext_op.def_arc(); @@ -493,13 +473,14 @@ impl ReplaceTypes { if let Some(replacement) = self .param_ops .get(&def.as_ref().into()) - .map(|rep_fn| rep_fn(&args, self)) - .transpose()? - .flatten() + .and_then(|rep_fn| rep_fn(&args)) { replacement .replace(hugr, n) .map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?; + if self.process_replacements { + self.change_subtree(hugr, n, true)?; + } true } else { if ch { @@ -552,21 +533,21 @@ impl ReplaceTypes { Value::Function { hugr } => self.run(&mut **hugr), } } -} - -impl> ComposablePass for ReplaceTypes { - type Error = ReplaceTypesError; - type Result = bool; - fn run(&self, hugr: &mut H) -> Result { + fn change_subtree>( + &self, + hugr: &mut H, + root: H::Node, + linearize_if_no_change: bool, + ) -> Result { let mut changed = false; - for n in hugr.entry_descendants().collect::>() { + for n in hugr.descendants(root).collect::>() { changed |= self.change_node(hugr, n)?; - let new_dfsig = hugr.get_optype(n).dataflow_signature(); - if let Some(new_sig) = new_dfsig - .filter(|_| changed && n != hugr.entrypoint()) - .map(Cow::into_owned) - { + if n == root || !(changed | linearize_if_no_change) { + continue; + } + if let Some(new_sig) = hugr.get_optype(n).dataflow_signature() { + let new_sig = new_sig.into_owned(); for outp in new_sig.output_ports() { if !new_sig.out_port_type(outp).unwrap().copyable() { let targets = hugr.linked_inputs(n, outp).collect::>(); @@ -583,6 +564,15 @@ impl> ComposablePass for ReplaceTypes { } } +impl> ComposablePass for ReplaceTypes { + type Error = ReplaceTypesError; + type Result = bool; + + fn run(&self, hugr: &mut H) -> Result { + self.change_subtree(hugr, hugr.entrypoint(), false) + } +} + pub mod handlers; #[derive(Clone, Hash, PartialEq, Eq)] @@ -764,12 +754,12 @@ mod test { .into(), ), ); - lw.replace_parametrized_op_with(ext.get_op(READ).unwrap().as_ref(), |type_args, _| { - Ok(Some(NodeTemplate::CompoundOp(Box::new( + lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), |type_args| { + Some(NodeTemplate::CompoundOp(Box::new( lowered_read(just_elem_type(type_args).clone(), DFGBuilder::new) .finish_hugr() .unwrap(), - )))) + ))) }); lw } @@ -1044,10 +1034,10 @@ mod test { option_contents(just_elem_type(args)).map(list_type) }); // and read> to get - the latter has the expected option return type - lowerer.replace_parametrized_op_with( + lowerer.replace_parametrized_op( e.get_op(READ).unwrap().as_ref(), - |args: &[TypeArg], _| { - Ok(option_contents(just_elem_type(args)).map(|elem| { + Box::new(|args: &[TypeArg]| { + option_contents(just_elem_type(args)).map(|elem| { NodeTemplate::SingleOp( ListOp::get .with_type(elem) @@ -1055,8 +1045,8 @@ mod test { .unwrap() .into(), ) - })) - }, + }) + }), ); assert!(lowerer.run(&mut h).unwrap()); // list -> read -> usz just becomes list -> read -> qb @@ -1150,8 +1140,8 @@ mod test { .inserted_entrypoint; let mut lw = lowerer(&e); - lw.replace_parametrized_op_with(e.get_op(READ).unwrap().as_ref(), move |args, _| { - Ok(Some(NodeTemplate::Call(read_func, args.to_owned()))) + lw.replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args| { + Some(NodeTemplate::Call(read_func, args.to_owned())) }); lw.run(&mut h).unwrap(); diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 99ced3e52f..40f0da0f90 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -367,8 +367,8 @@ mod test { use std::sync::Arc; use hugr_core::builder::{ - BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, - inout_sig, + BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + HugrBuilder, inout_sig, }; use hugr_core::extension::prelude::{option_type, qb_t, usize_t}; @@ -392,7 +392,7 @@ mod test { use rstest::rstest; use crate::replace_types::handlers::linearize_value_array; - use crate::replace_types::{LinearizeError, Linearizer, NodeTemplate, ReplaceTypesError}; + use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; use crate::{ComposablePass, ReplaceTypes}; const LIN_T: &str = "Lin"; @@ -455,7 +455,7 @@ mod test { let usize_custom_t = usize_t().as_extension().unwrap().clone(); lowerer.replace_type(usize_custom_t, Type::new_extension(lin_custom_t.clone())); lowerer - .linearizer_mut() + .linearizer() .register_simple( lin_custom_t, NodeTemplate::SingleOp(copy_op.into()), @@ -577,7 +577,7 @@ mod test { let opdef = e.get_op("copy").unwrap(); let opdef2 = opdef.clone(); lowerer - .linearizer_mut() + .linearizer() .register_callback(lin_t_def, move |args, num_outs, _| { assert!(args.is_empty()); Ok(NodeTemplate::SingleOp( @@ -639,7 +639,7 @@ mod test { let mut replacer = ReplaceTypes::default(); replacer.replace_type(usize_t().as_extension().unwrap().clone(), lin_t.clone()); - let bad_copy = replacer.linearizer_mut().register_simple( + let bad_copy = replacer.linearizer().register_simple( lin_ct.clone(), NodeTemplate::SingleOp(copy3.clone()), NodeTemplate::SingleOp(discard.clone().into()), @@ -654,7 +654,7 @@ mod test { }) ); - let bad_discard = replacer.linearizer_mut().register_simple( + let bad_discard = replacer.linearizer().register_simple( lin_ct.clone(), NodeTemplate::SingleOp(copy2.into()), NodeTemplate::SingleOp(copy3.clone()), @@ -671,7 +671,7 @@ mod test { // Try parametrized instead, but this version always returns 3 outports replacer - .linearizer_mut() + .linearizer() .register_callback(ext.get_type(LIN_T).unwrap(), move |_args, _, _| { Ok(NodeTemplate::SingleOp(copy3.clone())) }); @@ -699,7 +699,7 @@ mod test { let (e, mut lowerer) = ext_lowerer(); lowerer - .linearizer_mut() + .linearizer() .register_callback(value_array_type_def(), linearize_value_array); let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap()); let opt_lin_ty = Type::from(option_type(lin_t.clone())); @@ -817,7 +817,7 @@ mod test { let mut lower_discard_to_call = ReplaceTypes::default(); lower_discard_to_call - .linearizer_mut() + .linearizer() .register_simple( lin_ct.clone(), NodeTemplate::Call(backup.entrypoint(), vec![]), // Arbitrary, unused @@ -876,12 +876,17 @@ mod test { }, ); let drop_op = drop_ext.get_op("drop").unwrap(); - lowerer.replace_parametrized_op_with(drop_op, |args, rt| { + lowerer.replace_parametrized_op(drop_op, |args| { let [TypeArg::Runtime(ty)] = args else { panic!("Expected just one type") }; - Ok(rt.get_linearizer().copy_discard_op(ty, 0).map(Some)?) + // The Hugr here is invalid, so we have to pull it out manually + let mut h = Hugr::new(); + let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); + std::mem::swap(&mut h, dfb.hugr_mut()); + Some(NodeTemplate::CompoundOp(Box::new(h))) }); + lowerer.process_replacements(true); let build_hugr = |ty: Type| { let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); From 5076ecfeb7aa94e4bd948135f093fd74c5d6e578 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Jul 2025 17:55:11 +0100 Subject: [PATCH 10/15] Common up replace + checks+change_subtrees...driveby remove a clone --- hugr-passes/src/replace_types.rs | 51 +++++++++++++++----------------- 1 file changed, 24 insertions(+), 27 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 20d5ad0bc1..2494b47535 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -107,9 +107,9 @@ impl NodeTemplate { } } - fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { + fn replace(self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { assert_eq!(hugr.children(n).count(), 0); - let new_optype = match self.clone() { + let new_optype = match self { NodeTemplate::SingleOp(op_type) => op_type, NodeTemplate::CompoundOp(new_h) => { let new_entrypoint = hugr.insert_hugr(n, *new_h).inserted_entrypoint; @@ -457,8 +457,25 @@ impl ReplaceTypes { | rest.transform(self)?), OpType::Const(Const { value, .. }) => self.change_value(value), - OpType::ExtensionOp(ext_op) => Ok( - if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) { + OpType::ExtensionOp(ext_op) => Ok({ + let def = ext_op.def_arc(); + let mut changed = false; + let replacement = match self.op_map.get(&OpHashWrapper::from(&*ext_op)) { + r @ Some(_) => r.cloned(), + None => { + let mut args = ext_op.args().to_vec(); + changed = args.transform(self)?; + let r2 = self + .param_ops + .get(&def.as_ref().into()) + .and_then(|rep_fn| rep_fn(&args)); + if r2.is_none() && changed { + *ext_op = ExtensionOp::new(def.clone(), args)?; + } + r2 + } + }; + if let Some(replacement) = replacement { replacement .replace(hugr, n) .map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?; @@ -467,29 +484,9 @@ impl ReplaceTypes { } true } else { - let def = ext_op.def_arc(); - let mut args = ext_op.args().to_vec(); - let ch = args.transform(self)?; - if let Some(replacement) = self - .param_ops - .get(&def.as_ref().into()) - .and_then(|rep_fn| rep_fn(&args)) - { - replacement - .replace(hugr, n) - .map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?; - if self.process_replacements { - self.change_subtree(hugr, n, true)?; - } - true - } else { - if ch { - *ext_op = ExtensionOp::new(def.clone(), args)?; - } - ch - } - }, - ), + changed + } + }), OpType::OpaqueOp(_) => panic!("OpaqueOp should not be in a Hugr"), From ea3e9f4e0359708ab61cc5f73a9c78c63ba959bb Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Thu, 10 Jul 2025 21:36:26 +0100 Subject: [PATCH 11/15] remove process_replacements, add replace_parametrized_op_recursive --- hugr-passes/src/replace_types.rs | 36 ++++++++++++---------- hugr-passes/src/replace_types/linearize.rs | 3 +- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 2494b47535..f115268389 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -204,7 +204,7 @@ pub struct ReplaceTypes { param_types: HashMap Option>>, linearize: DelegatingLinearizer, op_map: HashMap, - param_ops: HashMap Option>>, + param_ops: HashMap Option>, bool)>, consts: HashMap< CustomType, Arc Result>, @@ -213,7 +213,6 @@ pub struct ReplaceTypes { ParametricType, Arc Result, ReplaceTypesError>>, >, - process_replacements: bool, } impl Default for ReplaceTypes { @@ -275,7 +274,6 @@ impl ReplaceTypes { param_ops: Default::default(), consts: Default::default(), param_consts: Default::default(), - process_replacements: false, } } @@ -354,12 +352,26 @@ impl ReplaceTypes { /// fit the bounds of the original op). /// /// If the Callback returns None, the new typeargs will be applied to the original op. + /// + /// See also [Self::replace_parametrized_op_recursive] pub fn replace_parametrized_op( &mut self, src: &OpDef, dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, ) { - self.param_ops.insert(src.into(), Arc::new(dest_fn)); + self.param_ops + .insert(src.into(), (Arc::new(dest_fn), false)); + } + + /// Like [Self::replace_parametrized_op] but the contents of any [NodeTemplate] + /// returned by the callback will be transformed (recursively) by the same + /// ReplaceTypes instance after insertion into the target Hugr. + pub fn replace_parametrized_op_recursive( + &mut self, + src: &OpDef, + dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, + ) { + self.param_ops.insert(src.into(), (Arc::new(dest_fn), true)); } /// Configures this instance to change [Const]s of type `src_ty`, using @@ -389,14 +401,6 @@ impl ReplaceTypes { self.param_consts.insert(src_ty.into(), Arc::new(const_fn)); } - /// Configures this instance to (recursively) process the RHS of any replacement - /// ops or subtrees - registered with [Self::replace_op] or from callbacks - /// registered with [Self::replace_parametrized_op]. - /// The default is `false`, i.e. do not recurse on such. - pub fn process_replacements(&mut self, r: bool) { - self.process_replacements = r; - } - fn change_node( &self, hugr: &mut impl HugrMut, @@ -461,25 +465,25 @@ impl ReplaceTypes { let def = ext_op.def_arc(); let mut changed = false; let replacement = match self.op_map.get(&OpHashWrapper::from(&*ext_op)) { - r @ Some(_) => r.cloned(), + Some(r) => Some((r.clone(), false)), None => { let mut args = ext_op.args().to_vec(); changed = args.transform(self)?; let r2 = self .param_ops .get(&def.as_ref().into()) - .and_then(|rep_fn| rep_fn(&args)); + .and_then(|(rep_fn, rec)| rep_fn(&args).map(|nt| (nt, *rec))); if r2.is_none() && changed { *ext_op = ExtensionOp::new(def.clone(), args)?; } r2 } }; - if let Some(replacement) = replacement { + if let Some((replacement, process_recursive)) = replacement { replacement .replace(hugr, n) .map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?; - if self.process_replacements { + if process_recursive { self.change_subtree(hugr, n, true)?; } true diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 40f0da0f90..03f93f233f 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -876,7 +876,7 @@ mod test { }, ); let drop_op = drop_ext.get_op("drop").unwrap(); - lowerer.replace_parametrized_op(drop_op, |args| { + lowerer.replace_parametrized_op_recursive(drop_op, |args| { let [TypeArg::Runtime(ty)] = args else { panic!("Expected just one type") }; @@ -886,7 +886,6 @@ mod test { std::mem::swap(&mut h, dfb.hugr_mut()); Some(NodeTemplate::CompoundOp(Box::new(h))) }); - lowerer.process_replacements(true); let build_hugr = |ty: Type| { let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); From af53713d1191283ade4152a40167ec440a020568 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 11 Jul 2025 14:47:59 +0100 Subject: [PATCH 12/15] ReplacementOptions, replace_(parametrized_)op_with, break out linearize_outputs --- hugr-passes/src/replace_types.rs | 134 +++++++++++++++------ hugr-passes/src/replace_types/linearize.rs | 27 +++-- 2 files changed, 115 insertions(+), 46 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index f115268389..097540c9f6 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -171,6 +171,37 @@ fn call>( Ok(Call::try_new(func_sig, type_args)?) } +/// Options for how the replacement for an op is processed. May be specified by +/// [ReplaceTypes::replace_op_with] and [ReplaceTypes::replace_parametrized_op_with]. +/// Otherwise (the default), replacements are inserted as is (without further processing). +// TODO would be good to migrate to default being process_recursive: true +#[derive(Clone, Default, PartialEq, Eq)] // More derives might inhibit future extension +pub struct ReplacementOptions { + process_recursive: bool, + linearize: bool, +} + +impl ReplacementOptions { + /// Specifies that the replacement should be processed by the same [ReplaceTypes]. + /// This increases compositionality (in that replacements for different ops do not + /// need to account for each other), but would lead to an infinite loop if e.g. + /// changing an op for a DFG containing an instance of the same op. + pub fn with_recursive_replacement(mut self, rec: bool) -> Self { + self.process_recursive = rec; + self + } + + /// Specifies that the replacement should be linearized. + /// If [Self::with_recursive_replacement] has been set, this applies linearization + /// even to ops (within the original replacement) that are not altered by the + /// recursive processing. Otherwise, can be used to apply linearization without + /// changing any other ops. + pub fn with_linearization(mut self, lin: bool) -> Self { + self.linearize = lin; + self + } +} + /// A configuration of what types, ops, and constants should be replaced with what. /// May be applied to a Hugr via [`Self::run`]. /// @@ -203,8 +234,14 @@ pub struct ReplaceTypes { type_map: HashMap, param_types: HashMap Option>>, linearize: DelegatingLinearizer, - op_map: HashMap, - param_ops: HashMap Option>, bool)>, + op_map: HashMap, + param_ops: HashMap< + ParametricOp, + ( + Arc Option>, + ReplacementOptions, + ), + >, consts: HashMap< CustomType, Arc Result>, @@ -337,41 +374,54 @@ impl ReplaceTypes { } /// Configures this instance to change occurrences of `src` to `dest`. + /// Equivalent to [Self::replace_op_with] with default [ReplacementOptions]. + pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) { + self.replace_op_with(src, dest, ReplacementOptions::default()) + } + + /// Configures this instance to change occurrences of `src` to `dest`. + /// /// Note that if `src` is an instance of a *parametrized* [`OpDef`], this takes /// precedence over [`Self::replace_parametrized_op`] where the `src`s overlap. Thus, /// this should only be used on already-*[monomorphize](super::monomorphize())d* /// Hugrs, as substitution (parametric polymorphism) happening later will not respect /// this replacement. - pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) { - self.op_map.insert(OpHashWrapper::from(src), dest); + pub fn replace_op_with( + &mut self, + src: &ExtensionOp, + dest: NodeTemplate, + opts: ReplacementOptions, + ) { + self.op_map.insert(OpHashWrapper::from(src), (dest, opts)); } /// Configures this instance to change occurrences of a parametrized op `src` /// via a callback that builds the replacement type given the [`TypeArg`]s. - /// Note that the `TypeArgs` will already have been updated (e.g. they may not - /// fit the bounds of the original op). - /// - /// If the Callback returns None, the new typeargs will be applied to the original op. - /// - /// See also [Self::replace_parametrized_op_recursive] + /// Equivalent to [Self::replace_parametrized_op_with] with default [ReplacementOptions]. pub fn replace_parametrized_op( &mut self, src: &OpDef, dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, ) { - self.param_ops - .insert(src.into(), (Arc::new(dest_fn), false)); + self.param_ops.insert( + src.into(), + (Arc::new(dest_fn), ReplacementOptions::default()), + ); } - /// Like [Self::replace_parametrized_op] but the contents of any [NodeTemplate] - /// returned by the callback will be transformed (recursively) by the same - /// ReplaceTypes instance after insertion into the target Hugr. - pub fn replace_parametrized_op_recursive( + /// Configures this instance to change occurrences of a parametrized op `src` + /// via a callback that builds the replacement type given the [`TypeArg`]s. + /// Note that the `TypeArgs` will already have been updated (e.g. they may not + /// fit the bounds of the original op). + /// + /// If the Callback returns None, the new typeargs will be applied to the original op. + pub fn replace_parametrized_op_with( &mut self, src: &OpDef, dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, + opts: ReplacementOptions, ) { - self.param_ops.insert(src.into(), (Arc::new(dest_fn), true)); + self.param_ops.insert(src.into(), (Arc::new(dest_fn), opts)); } /// Configures this instance to change [Const]s of type `src_ty`, using @@ -465,26 +515,32 @@ impl ReplaceTypes { let def = ext_op.def_arc(); let mut changed = false; let replacement = match self.op_map.get(&OpHashWrapper::from(&*ext_op)) { - Some(r) => Some((r.clone(), false)), + r @ Some(_) => r.cloned(), None => { let mut args = ext_op.args().to_vec(); changed = args.transform(self)?; let r2 = self .param_ops .get(&def.as_ref().into()) - .and_then(|(rep_fn, rec)| rep_fn(&args).map(|nt| (nt, *rec))); + .and_then(|(rep_fn, opts)| rep_fn(&args).map(|nt| (nt, opts.clone()))); if r2.is_none() && changed { *ext_op = ExtensionOp::new(def.clone(), args)?; } r2 } }; - if let Some((replacement, process_recursive)) = replacement { + if let Some((replacement, opts)) = replacement { replacement .replace(hugr, n) .map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?; - if process_recursive { - self.change_subtree(hugr, n, true)?; + if opts.process_recursive { + self.change_subtree(hugr, n, opts.linearize)?; + } else if opts.linearize { + for d in hugr.descendants(n).collect::>() { + if d != n { + self.linearize_outputs(hugr, d)?; + } + } } true } else { @@ -544,24 +600,32 @@ impl ReplaceTypes { let mut changed = false; for n in hugr.descendants(root).collect::>() { changed |= self.change_node(hugr, n)?; - if n == root || !(changed | linearize_if_no_change) { - continue; + if n != root && (changed || linearize_if_no_change) { + self.linearize_outputs(hugr, n)?; } - if let Some(new_sig) = hugr.get_optype(n).dataflow_signature() { - let new_sig = new_sig.into_owned(); - for outp in new_sig.output_ports() { - if !new_sig.out_port_type(outp).unwrap().copyable() { - let targets = hugr.linked_inputs(n, outp).collect::>(); - if targets.len() != 1 { - hugr.disconnect(n, outp); - let src = Wire::new(n, outp); - self.linearize.insert_copy_discard(hugr, src, &targets)?; - } + } + Ok(changed) + } + + fn linearize_outputs>( + &self, + hugr: &mut H, + n: H::Node, + ) -> Result<(), LinearizeError> { + if let Some(new_sig) = hugr.get_optype(n).dataflow_signature() { + let new_sig = new_sig.into_owned(); + for outp in new_sig.output_ports() { + if !new_sig.out_port_type(outp).unwrap().copyable() { + let targets = hugr.linked_inputs(n, outp).collect::>(); + if targets.len() != 1 { + hugr.disconnect(n, outp); + let src = Wire::new(n, outp); + self.linearize.insert_copy_discard(hugr, src, &targets)?; } } } } - Ok(changed) + Ok(()) } } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 03f93f233f..0029683f47 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -392,7 +392,9 @@ mod test { use rstest::rstest; use crate::replace_types::handlers::linearize_value_array; - use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; + use crate::replace_types::{ + LinearizeError, NodeTemplate, ReplaceTypesError, ReplacementOptions, + }; use crate::{ComposablePass, ReplaceTypes}; const LIN_T: &str = "Lin"; @@ -876,16 +878,19 @@ mod test { }, ); let drop_op = drop_ext.get_op("drop").unwrap(); - lowerer.replace_parametrized_op_recursive(drop_op, |args| { - let [TypeArg::Runtime(ty)] = args else { - panic!("Expected just one type") - }; - // The Hugr here is invalid, so we have to pull it out manually - let mut h = Hugr::new(); - let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); - std::mem::swap(&mut h, dfb.hugr_mut()); - Some(NodeTemplate::CompoundOp(Box::new(h))) - }); + lowerer.replace_parametrized_op_with( + drop_op, + |args| { + let [TypeArg::Runtime(ty)] = args else { + panic!("Expected just one type") + }; + // The Hugr here is invalid, so we have to pull it out manually + let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); + let h = std::mem::take(dfb.hugr_mut()); + Some(NodeTemplate::CompoundOp(Box::new(h))) + }, + ReplacementOptions::default().with_linearization(true), + ); let build_hugr = |ty: Type| { let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); From af79d814f466bf4a64c3cde7fb926f6ef3a84611 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Jul 2025 13:23:48 +0100 Subject: [PATCH 13/15] Review comment --- hugr-passes/src/replace_types.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 097540c9f6..8eadfaadd5 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -403,10 +403,7 @@ impl ReplaceTypes { src: &OpDef, dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, ) { - self.param_ops.insert( - src.into(), - (Arc::new(dest_fn), ReplacementOptions::default()), - ); + self.replace_parametrized_op_with(src, dest_fn, ReplacementOptions::default()) } /// Configures this instance to change occurrences of a parametrized op `src` From 5f53767fde420434487a73071ed3e97a30a1a74c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 15 Jul 2025 15:06:11 +0100 Subject: [PATCH 14/15] doc note --- hugr-passes/src/replace_types.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 8eadfaadd5..2f09476f39 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -185,7 +185,10 @@ impl ReplacementOptions { /// Specifies that the replacement should be processed by the same [ReplaceTypes]. /// This increases compositionality (in that replacements for different ops do not /// need to account for each other), but would lead to an infinite loop if e.g. - /// changing an op for a DFG containing an instance of the same op. + /// changing an op for a DFG containing an instance of the same op. Also, note + /// that if the recursive processing changes the signature of the replacement, + /// this may break surrounding wires (e.g. from [Input] or to [Output] nodes) + /// because types are not subject to recursive replacement. pub fn with_recursive_replacement(mut self, rec: bool) -> Self { self.process_recursive = rec; self From f93fb614f540733e2d8a7625155f787aac789a9c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 16 Jul 2025 12:42:56 +0100 Subject: [PATCH 15/15] remove replace_recursive, undo change_subtree --- hugr-passes/src/replace_types.rs | 50 +++++++------------------------- 1 file changed, 11 insertions(+), 39 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 2f09476f39..0b5cca8f6a 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -174,31 +174,14 @@ fn call>( /// Options for how the replacement for an op is processed. May be specified by /// [ReplaceTypes::replace_op_with] and [ReplaceTypes::replace_parametrized_op_with]. /// Otherwise (the default), replacements are inserted as is (without further processing). -// TODO would be good to migrate to default being process_recursive: true #[derive(Clone, Default, PartialEq, Eq)] // More derives might inhibit future extension pub struct ReplacementOptions { - process_recursive: bool, linearize: bool, } impl ReplacementOptions { - /// Specifies that the replacement should be processed by the same [ReplaceTypes]. - /// This increases compositionality (in that replacements for different ops do not - /// need to account for each other), but would lead to an infinite loop if e.g. - /// changing an op for a DFG containing an instance of the same op. Also, note - /// that if the recursive processing changes the signature of the replacement, - /// this may break surrounding wires (e.g. from [Input] or to [Output] nodes) - /// because types are not subject to recursive replacement. - pub fn with_recursive_replacement(mut self, rec: bool) -> Self { - self.process_recursive = rec; - self - } - - /// Specifies that the replacement should be linearized. - /// If [Self::with_recursive_replacement] has been set, this applies linearization - /// even to ops (within the original replacement) that are not altered by the - /// recursive processing. Otherwise, can be used to apply linearization without - /// changing any other ops. + /// Specifies that all operations within the replacement should have their + /// output ports linearized. pub fn with_linearization(mut self, lin: bool) -> Self { self.linearize = lin; self @@ -533,9 +516,7 @@ impl ReplaceTypes { replacement .replace(hugr, n) .map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?; - if opts.process_recursive { - self.change_subtree(hugr, n, opts.linearize)?; - } else if opts.linearize { + if opts.linearize { for d in hugr.descendants(n).collect::>() { if d != n { self.linearize_outputs(hugr, d)?; @@ -591,22 +572,6 @@ impl ReplaceTypes { } } - fn change_subtree>( - &self, - hugr: &mut H, - root: H::Node, - linearize_if_no_change: bool, - ) -> Result { - let mut changed = false; - for n in hugr.descendants(root).collect::>() { - changed |= self.change_node(hugr, n)?; - if n != root && (changed || linearize_if_no_change) { - self.linearize_outputs(hugr, n)?; - } - } - Ok(changed) - } - fn linearize_outputs>( &self, hugr: &mut H, @@ -634,7 +599,14 @@ impl> ComposablePass for ReplaceTypes { type Result = bool; fn run(&self, hugr: &mut H) -> Result { - self.change_subtree(hugr, hugr.entrypoint(), false) + let mut changed = false; + for n in hugr.entry_descendants().collect::>() { + changed |= self.change_node(hugr, n)?; + if n != hugr.entrypoint() && changed { + self.linearize_outputs(hugr, n)?; + } + } + Ok(changed) } }