diff --git a/compiler/qsc_rir/src/passes.rs b/compiler/qsc_rir/src/passes.rs index 412a772e07..2abd5b1733 100644 --- a/compiler/qsc_rir/src/passes.rs +++ b/compiler/qsc_rir/src/passes.rs @@ -26,10 +26,14 @@ use crate::{rir::Program, utils::build_predecessors_map}; /// Run the default set of RIR check and transformation passes. /// This includes: +/// - Simplifying control flow /// - Checking for unreachable code +/// - Checking types /// - Remapping block IDs /// - Transforming the program to SSA form /// - Checking that the program is in SSA form +/// - If the target has no reset capability, reindexing qubit IDs and removing resets. +/// - If the target has no mid-program measurement capability, deferring measurements to the end of the program. pub fn check_and_transform(program: &mut Program) { simplify_control_flow(program); check_unreachable_code(program); @@ -44,7 +48,8 @@ pub fn check_and_transform(program: &mut Program) { // Run the RIR passes that are necessary for targets with no mid-program measurement. // This requires that qubits are not reused after measurement or reset, so qubit ids must be reindexed. - // This also requires that the program is a single block and will panic otherwise. + // This also requires that the program has no loops and block ids form a topological ordering on a + // directed acyclic graph. if !program .config .capabilities diff --git a/compiler/qsc_rir/src/passes/reindex_qubits.rs b/compiler/qsc_rir/src/passes/reindex_qubits.rs index 0116907b86..1f44f1b1cd 100644 --- a/compiler/qsc_rir/src/passes/reindex_qubits.rs +++ b/compiler/qsc_rir/src/passes/reindex_qubits.rs @@ -4,110 +4,226 @@ #[cfg(test)] mod tests; +use std::collections::hash_map::Entry; + +use qsc_data_structures::index_map::IndexMap; use rustc_hash::FxHashMap; -use crate::rir::{ - Block, Callable, CallableId, CallableType, Instruction, Literal, Operand, Program, Ty, +use crate::{ + builder, + rir::{Block, BlockId, CallableId, CallableType, Instruction, Literal, Operand, Program, Ty}, + utils::{build_predecessors_map, get_block_successors}, }; +#[derive(Clone)] +struct BlockQubitMap { + map: FxHashMap, + next_qubit_id: u32, +} + /// Reindexes qubits after they have been measured or reset. This ensures there is no qubit reuse in /// the program. As part of the pass, reset callables are removed and mresetz calls are replaced with /// mz calls. /// Note that this pass has several assumptions: /// 1. Only one callable has a body, which is the entry point callable. -/// 2. The entry point callable has a single block. +/// 2. The entry point callable is a directed acyclic graph where block ids have topological ordering. /// 3. No dynamic qubits are used. /// The pass will panic if the input program violates any of these assumptions. pub fn reindex_qubits(program: &mut Program) { validate_assumptions(program); - let (mut used_mz, mz_id) = match find_measurement_callable(program, "__quantum__qis__mz__body") - { + let (used_mz, mz_id) = match find_callable(program, "__quantum__qis__mz__body") { Some(id) => (true, id), None => (false, add_mz(program)), }; - let mresetz_id = find_measurement_callable(program, "__quantum__qis__mresetz__body"); - - let mut qubit_map = FxHashMap::default(); - let mut next_qubit_id = program.num_qubits; - let mut highest_used_id = next_qubit_id - 1; - let mut new_block = Vec::new(); - let (block_id, block) = program - .blocks - .drain() - .next() + let (used_cx, cx_id) = match find_callable(program, "__quantum__qis__cx__body") { + Some(id) => (true, id), + None => (false, add_cx(program)), + }; + let mresetz_id = find_callable(program, "__quantum__qis__mresetz__body"); + let mut pass = ReindexQubitPass { + used_mz, + mz_id, + used_cx, + cx_id, + mresetz_id, + highest_used_id: program.num_qubits - 1, + }; + + let pred_map = build_predecessors_map(program); + + let mut block_maps = IndexMap::new(); + block_maps.insert( + BlockId(0), + BlockQubitMap { + map: FxHashMap::default(), + next_qubit_id: program.num_qubits, + }, + ); + + let mut all_blocks = program.blocks.drain().collect::>(); + let (entry_block, rest_blocks) = all_blocks + .split_first_mut() .expect("program should have at least one block"); - for instr in &block.0 { - // Assume qubits only appear in void call instructions. - match instr { - Instruction::Call(call_id, args, _) - if program.get_callable(*call_id).call_type == CallableType::Reset => - { - // Generate any new qubit ids and skip adding the instruction. - for arg in args { - if let Operand::Literal(Literal::Qubit(qubit_id)) = arg { - qubit_map.insert(*qubit_id, next_qubit_id); - next_qubit_id += 1; + + // Reindex qubits in the entry block. + pass.reindex_qubits_in_block( + program, + &mut entry_block.1, + block_maps + .get_mut(entry_block.0) + .expect("entry block with id 0 should be in block_maps"), + ); + + for (block_id, block) in rest_blocks { + // Use the predecessors to build the block's initial, inherited qubit map. + let pred_ids = pred_map + .get(*block_id) + .expect("block should have predecessors"); + + // Start from an empty map with the program's initial qubit ids. + let mut new_block_map = BlockQubitMap { + map: FxHashMap::default(), + next_qubit_id: program.num_qubits, + }; + + for pred_id in pred_ids { + let pred_qubit_map = block_maps + .get(*pred_id) + .expect("predecessor should be in block_maps"); + + // Across each predecessor, ensure that any mapped ids are same, otherwise + // panic because we can't know which id to use. + for (qubit_id, new_qubit_id) in &pred_qubit_map.map { + match new_block_map.map.entry(*qubit_id) { + Entry::Occupied(entry) if *entry.get() == *new_qubit_id => {} + Entry::Occupied(_) => { + panic!("Qubit id {qubit_id} has multiple mappings across predecessors"); } - } - } - Instruction::Call(call_id, args, None) => { - // Map the qubit args, if any, and copy over the instruction. - let new_args = args - .iter() - .map(|arg| match arg { - Operand::Literal(Literal::Qubit(qubit_id)) => { - if let Some(mapped_id) = qubit_map.get(qubit_id) { - highest_used_id = highest_used_id.max(*mapped_id); - Operand::Literal(Literal::Qubit(*mapped_id)) - } else { - *arg - } - } - _ => *arg, - }) - .collect::>(); - - // If the call was to mresetz, replace with mz. - let call_id = if Some(*call_id) == mresetz_id { - used_mz = true; - mz_id - } else { - *call_id - }; - - new_block.push(Instruction::Call(call_id, new_args, None)); - - if program.get_callable(call_id).call_type == CallableType::Measurement { - // Generate any new qubit ids after a measurement. - for arg in args { - if let Operand::Literal(Literal::Qubit(qubit_id)) = arg { - qubit_map.insert(*qubit_id, next_qubit_id); - next_qubit_id += 1; - } + Entry::Vacant(entry) => { + entry.insert(*new_qubit_id); } } } - _ => { - // Copy over the instruction. - new_block.push(instr.clone()); - } + new_block_map.next_qubit_id = new_block_map + .next_qubit_id + .max(pred_qubit_map.next_qubit_id); } + + pass.reindex_qubits_in_block(program, block, &mut new_block_map); + + block_maps.insert(*block_id, new_block_map); } - program.num_qubits = highest_used_id + 1; - program.blocks.clear(); - program.blocks.insert(block_id, Block(new_block)); + program.blocks = all_blocks.into_iter().collect(); + program.num_qubits = pass.highest_used_id + 1; // All reset function calls should be removed, so remove them from the callables. program .callables .retain(|id, callable| callable.call_type != CallableType::Reset && Some(id) != mresetz_id); - // If mz was added but not used, remove it. - if !used_mz { + // If mz or cx were added but not used, remove them. + if !pass.used_mz { program.callables.remove(mz_id); } + if !pass.used_cx { + program.callables.remove(cx_id); + } +} + +struct ReindexQubitPass { + used_mz: bool, + mz_id: CallableId, + used_cx: bool, + cx_id: CallableId, + mresetz_id: Option, + highest_used_id: u32, +} + +impl ReindexQubitPass { + fn reindex_qubits_in_block( + &mut self, + program: &Program, + block: &mut Block, + qubit_map: &mut BlockQubitMap, + ) { + let instrs = std::mem::take(&mut block.0); + for instr in instrs { + // Assume qubits only appear in void call instructions. + match instr { + Instruction::Call(call_id, args, _) + if program.get_callable(call_id).call_type == CallableType::Reset => + { + // Generate any new qubit ids and skip adding the instruction. + for arg in args { + if let Operand::Literal(Literal::Qubit(qubit_id)) = arg { + qubit_map.map.insert(qubit_id, qubit_map.next_qubit_id); + qubit_map.next_qubit_id += 1; + } + } + } + Instruction::Call(call_id, args, None) => { + // Map the qubit args, if any, and copy over the instruction. + let new_args = args + .iter() + .map(|arg| match arg { + Operand::Literal(Literal::Qubit(qubit_id)) => { + match qubit_map.map.get(qubit_id) { + Some(mapped_id) => { + // If the qubit has already been mapped, use the mapped id. + self.highest_used_id = self.highest_used_id.max(*mapped_id); + Operand::Literal(Literal::Qubit(*mapped_id)) + } + None => *arg, + } + } + _ => *arg, + }) + .collect::>(); + + if call_id == self.mz_id { + // Since the call was to mz, the new qubit replacing this one must be conditionally flipped. + // Achieve this by adding a CNOT gate before the mz call. + self.used_cx = true; + block.0.push(Instruction::Call( + self.cx_id, + vec![ + new_args[0], + Operand::Literal(Literal::Qubit(qubit_map.next_qubit_id)), + ], + None, + )); + self.highest_used_id = self.highest_used_id.max(qubit_map.next_qubit_id); + } + + // If the call was to mresetz, replace with mz. + let call_id = if Some(call_id) == self.mresetz_id { + self.used_mz = true; + self.mz_id + } else { + call_id + }; + + block.0.push(Instruction::Call(call_id, new_args, None)); + + if program.get_callable(call_id).call_type == CallableType::Measurement { + // Generate any new qubit ids after a measurement. + for arg in args { + if let Operand::Literal(Literal::Qubit(qubit_id)) = arg { + qubit_map.map.insert(qubit_id, qubit_map.next_qubit_id); + qubit_map.next_qubit_id += 1; + } + } + } + } + _ => { + // Copy over the instruction. + block.0.push(instr.clone()); + } + } + } + } } fn validate_assumptions(program: &Program) { @@ -119,28 +235,28 @@ fn validate_assumptions(program: &Program) { ); } - // Ensure entry point callable has a single block. - // Future enhancements may allow multiple blocks in the entry point callable. - assert!( - program.blocks.iter().count() == 1, - "Entry point callable must have a single block" - ); - - // Ensure that no dynamic qubits are used. - let Some((_, block)) = program.blocks.iter().next() else { - panic!("No blocks found in the program"); - }; - for instr in &block.0 { + // Ensure entry point callable blocks are a topologically ordered DAG. + // We can check this quickly by verifying that each block only has successors with higher ids. + for (block_id, block) in program.blocks.iter() { assert!( - !matches!(instr, Instruction::Store(_, var) if var.ty == Ty::Qubit), - "Dynamic qubits are not supported" + get_block_successors(block) + .iter() + .all(|&succ_id| succ_id > block_id), + "blocks must form a topologically ordered DAG" ); + // Ensure that no dynamic qubits are used. + for instr in &block.0 { + assert!( + !matches!(instr, Instruction::Store(_, var) if var.ty == Ty::Qubit), + "Dynamic qubits are not supported" + ); + } } } -fn find_measurement_callable(program: &Program, name: &str) -> Option { +fn find_callable(program: &Program, name: &str) -> Option { for (callable_id, callable) in program.callables.iter() { - if callable.call_type == CallableType::Measurement && callable.name == name { + if callable.name == name { return Some(callable_id); } } @@ -157,15 +273,20 @@ fn add_mz(program: &mut Program) -> CallableId { .expect("should be at least one callable") + 1, ); - program.callables.insert( - mz_id, - Callable { - name: "__quantum__qis__mz__body".to_string(), - input_type: vec![Ty::Qubit, Ty::Result], - output_type: None, - body: None, - call_type: CallableType::Measurement, - }, - ); + program.callables.insert(mz_id, builder::mz_decl()); mz_id } + +fn add_cx(program: &mut Program) -> CallableId { + let cx_id = CallableId( + program + .callables + .iter() + .map(|(id, _)| id.0) + .max() + .expect("should be at least one callable") + + 1, + ); + program.callables.insert(cx_id, builder::cx_decl()); + cx_id +} diff --git a/compiler/qsc_rir/src/passes/reindex_qubits/tests.rs b/compiler/qsc_rir/src/passes/reindex_qubits/tests.rs index e11d5b3ba0..5acbd6e262 100644 --- a/compiler/qsc_rir/src/passes/reindex_qubits/tests.rs +++ b/compiler/qsc_rir/src/passes/reindex_qubits/tests.rs @@ -6,8 +6,11 @@ use expect_test::expect; use crate::{ - builder::{cx_decl, h_decl, mresetz_decl, mz_decl, reset_decl, x_decl}, - rir::{Block, BlockId, CallableId, CallableType, Instruction, Literal, Operand, Program}, + builder::{cx_decl, h_decl, mresetz_decl, mz_decl, read_result_decl, reset_decl, x_decl}, + rir::{ + Block, BlockId, CallableId, CallableType, Instruction, Literal, Operand, Program, Ty, + Variable, VariableId, + }, }; use super::reindex_qubits; @@ -105,12 +108,14 @@ fn qubit_reindexed_after_mz() { expect![[r#" Block: Call id(0), args( Qubit(0), ) + Call id(2), args( Qubit(0), Qubit(1), ) Call id(1), args( Qubit(0), Result(0), ) Call id(0), args( Qubit(1), ) + Call id(2), args( Qubit(1), Qubit(2), ) Call id(1), args( Qubit(1), Result(1), ) Return"#]] .assert_eq(&program.get_block(BlockId(0)).to_string()); - assert_eq!(program.num_qubits, 2); + assert_eq!(program.num_qubits, 3); } #[test] @@ -119,8 +124,8 @@ fn qubit_reindexed_after_mresetz_and_changed_to_mz() { const MRESETZ: CallableId = CallableId(1); let mut program = Program::new(); program.num_qubits = 1; - program.callables.insert(CallableId(0), x_decl()); - program.callables.insert(CallableId(1), mresetz_decl()); + program.callables.insert(X, x_decl()); + program.callables.insert(MRESETZ, mresetz_decl()); program.blocks.insert( BlockId(0), Block(vec![ @@ -225,13 +230,13 @@ fn multiple_qubit_reindex() { } #[test] -fn qubit_reindexed_multiple_times() { +fn qubit_reindexed_multiple_times_with_mz_inserts_multiple_cx() { const X: CallableId = CallableId(0); const MZ: CallableId = CallableId(1); let mut program = Program::new(); program.num_qubits = 1; - program.callables.insert(CallableId(0), x_decl()); - program.callables.insert(CallableId(1), mz_decl()); + program.callables.insert(X, x_decl()); + program.callables.insert(MZ, mz_decl()); program.blocks.insert( BlockId(0), Block(vec![ @@ -294,14 +299,606 @@ fn qubit_reindexed_multiple_times() { expect![[r#" Block: Call id(0), args( Qubit(0), ) + Call id(2), args( Qubit(0), Qubit(1), ) Call id(1), args( Qubit(0), Result(0), ) Call id(0), args( Qubit(1), ) + Call id(2), args( Qubit(1), Qubit(2), ) Call id(1), args( Qubit(1), Result(1), ) Call id(0), args( Qubit(2), ) + Call id(2), args( Qubit(2), Qubit(3), ) Call id(1), args( Qubit(2), Result(2), ) Call id(0), args( Qubit(3), ) + Call id(2), args( Qubit(3), Qubit(4), ) Call id(1), args( Qubit(3), Result(3), ) Return"#]] .assert_eq(&program.get_block(BlockId(0)).to_string()); - assert_eq!(program.num_qubits, 4); + assert_eq!(program.num_qubits, 5); +} + +#[test] +fn qubit_reindexed_across_branches() { + const X: CallableId = CallableId(0); + const MZ: CallableId = CallableId(1); + const READ_RESULT: CallableId = CallableId(2); + let mut program = Program::new(); + program.num_qubits = 1; + program.num_results = 3; + program.callables.insert(X, x_decl()); + program.callables.insert(MZ, mz_decl()); + program.callables.insert(READ_RESULT, read_result_decl()); + + program.blocks.insert( + BlockId(0), + Block(vec![ + Instruction::Call(X, vec![Operand::Literal(Literal::Qubit(0))], None), + Instruction::Call( + MZ, + vec![ + Operand::Literal(Literal::Qubit(0)), + Operand::Literal(Literal::Result(0)), + ], + None, + ), + Instruction::Call( + READ_RESULT, + vec![Operand::Literal(Literal::Result(0))], + Some(Variable { + variable_id: VariableId(0), + ty: Ty::Boolean, + }), + ), + Instruction::Branch( + Variable { + variable_id: VariableId(0), + ty: Ty::Boolean, + }, + BlockId(1), + BlockId(2), + ), + ]), + ); + program.blocks.insert( + BlockId(1), + Block(vec![ + Instruction::Call(X, vec![Operand::Literal(Literal::Qubit(0))], None), + Instruction::Call( + MZ, + vec![ + Operand::Literal(Literal::Qubit(0)), + Operand::Literal(Literal::Result(1)), + ], + None, + ), + Instruction::Jump(BlockId(3)), + ]), + ); + program.blocks.insert( + BlockId(2), + Block(vec![ + Instruction::Call( + MZ, + vec![ + Operand::Literal(Literal::Qubit(0)), + Operand::Literal(Literal::Result(2)), + ], + None, + ), + Instruction::Jump(BlockId(3)), + ]), + ); + program.blocks.insert( + BlockId(3), + Block(vec![ + Instruction::Call(X, vec![Operand::Literal(Literal::Qubit(0))], None), + Instruction::Return, + ]), + ); + + // Before + expect![[r#" + Program: + entry: 0 + callables: + Callable 0: Callable: + name: __quantum__qis__x__body + call_type: Regular + input_type: + [0]: Qubit + output_type: + body: + Callable 1: Callable: + name: __quantum__qis__mz__body + call_type: Measurement + input_type: + [0]: Qubit + [1]: Result + output_type: + body: + Callable 2: Callable: + name: __quantum__qis__read_result__body + call_type: Readout + input_type: + [0]: Result + output_type: Boolean + body: + blocks: + Block 0: Block: + Call id(0), args( Qubit(0), ) + Call id(1), args( Qubit(0), Result(0), ) + Variable(0, Boolean) = Call id(2), args( Result(0), ) + Branch Variable(0, Boolean), 1, 2 + Block 1: Block: + Call id(0), args( Qubit(0), ) + Call id(1), args( Qubit(0), Result(1), ) + Jump(3) + Block 2: Block: + Call id(1), args( Qubit(0), Result(2), ) + Jump(3) + Block 3: Block: + Call id(0), args( Qubit(0), ) + Return + config: Config: + capabilities: Base + num_qubits: 1 + num_results: 3"#]] + .assert_eq(&program.to_string()); + + // After + reindex_qubits(&mut program); + expect![[r#" + Program: + entry: 0 + callables: + Callable 0: Callable: + name: __quantum__qis__x__body + call_type: Regular + input_type: + [0]: Qubit + output_type: + body: + Callable 1: Callable: + name: __quantum__qis__mz__body + call_type: Measurement + input_type: + [0]: Qubit + [1]: Result + output_type: + body: + Callable 2: Callable: + name: __quantum__qis__read_result__body + call_type: Readout + input_type: + [0]: Result + output_type: Boolean + body: + Callable 3: Callable: + name: __quantum__qis__cx__body + call_type: Regular + input_type: + [0]: Qubit + [1]: Qubit + output_type: + body: + blocks: + Block 0: Block: + Call id(0), args( Qubit(0), ) + Call id(3), args( Qubit(0), Qubit(1), ) + Call id(1), args( Qubit(0), Result(0), ) + Variable(0, Boolean) = Call id(2), args( Result(0), ) + Branch Variable(0, Boolean), 1, 2 + Block 1: Block: + Call id(0), args( Qubit(1), ) + Call id(3), args( Qubit(1), Qubit(2), ) + Call id(1), args( Qubit(1), Result(1), ) + Jump(3) + Block 2: Block: + Call id(3), args( Qubit(1), Qubit(2), ) + Call id(1), args( Qubit(1), Result(2), ) + Jump(3) + Block 3: Block: + Call id(0), args( Qubit(2), ) + Return + config: Config: + capabilities: Base + num_qubits: 3 + num_results: 3"#]] + .assert_eq(&program.to_string()); +} + +#[test] +fn qubit_reindexed_across_branches_with_one_branch_longer() { + const X: CallableId = CallableId(0); + const MRESETZ: CallableId = CallableId(1); + const READ_RESULT: CallableId = CallableId(2); + let mut program = Program::new(); + program.num_qubits = 1; + program.num_results = 4; + program.callables.insert(X, x_decl()); + program.callables.insert(MRESETZ, mresetz_decl()); + program.callables.insert(READ_RESULT, read_result_decl()); + + program.blocks.insert( + BlockId(0), + Block(vec![ + Instruction::Call(X, vec![Operand::Literal(Literal::Qubit(0))], None), + Instruction::Call( + MRESETZ, + vec![ + Operand::Literal(Literal::Qubit(0)), + Operand::Literal(Literal::Result(0)), + ], + None, + ), + Instruction::Call( + READ_RESULT, + vec![Operand::Literal(Literal::Result(0))], + Some(Variable { + variable_id: VariableId(0), + ty: Ty::Boolean, + }), + ), + Instruction::Branch( + Variable { + variable_id: VariableId(0), + ty: Ty::Boolean, + }, + BlockId(1), + BlockId(2), + ), + ]), + ); + program.blocks.insert( + BlockId(1), + Block(vec![ + Instruction::Call(X, vec![Operand::Literal(Literal::Qubit(0))], None), + Instruction::Call( + MRESETZ, + vec![ + Operand::Literal(Literal::Qubit(0)), + Operand::Literal(Literal::Result(1)), + ], + None, + ), + Instruction::Jump(BlockId(5)), + ]), + ); + program.blocks.insert( + BlockId(2), + Block(vec![ + Instruction::Call( + MRESETZ, + vec![ + Operand::Literal(Literal::Qubit(0)), + Operand::Literal(Literal::Result(2)), + ], + None, + ), + Instruction::Call( + READ_RESULT, + vec![Operand::Literal(Literal::Result(2))], + Some(Variable { + variable_id: VariableId(1), + ty: Ty::Boolean, + }), + ), + Instruction::Branch( + Variable { + variable_id: VariableId(1), + ty: Ty::Boolean, + }, + BlockId(3), + BlockId(4), + ), + ]), + ); + program.blocks.insert( + BlockId(3), + Block(vec![ + Instruction::Call(X, vec![Operand::Literal(Literal::Qubit(0))], None), + Instruction::Jump(BlockId(5)), + ]), + ); + program.blocks.insert( + BlockId(4), + Block(vec![ + Instruction::Call(X, vec![Operand::Literal(Literal::Qubit(0))], None), + Instruction::Jump(BlockId(5)), + ]), + ); + program.blocks.insert( + BlockId(5), + Block(vec![ + Instruction::Call(X, vec![Operand::Literal(Literal::Qubit(0))], None), + Instruction::Return, + ]), + ); + + // Before + expect![[r#" + Program: + entry: 0 + callables: + Callable 0: Callable: + name: __quantum__qis__x__body + call_type: Regular + input_type: + [0]: Qubit + output_type: + body: + Callable 1: Callable: + name: __quantum__qis__mresetz__body + call_type: Measurement + input_type: + [0]: Qubit + [1]: Result + output_type: + body: + Callable 2: Callable: + name: __quantum__qis__read_result__body + call_type: Readout + input_type: + [0]: Result + output_type: Boolean + body: + blocks: + Block 0: Block: + Call id(0), args( Qubit(0), ) + Call id(1), args( Qubit(0), Result(0), ) + Variable(0, Boolean) = Call id(2), args( Result(0), ) + Branch Variable(0, Boolean), 1, 2 + Block 1: Block: + Call id(0), args( Qubit(0), ) + Call id(1), args( Qubit(0), Result(1), ) + Jump(5) + Block 2: Block: + Call id(1), args( Qubit(0), Result(2), ) + Variable(1, Boolean) = Call id(2), args( Result(2), ) + Branch Variable(1, Boolean), 3, 4 + Block 3: Block: + Call id(0), args( Qubit(0), ) + Jump(5) + Block 4: Block: + Call id(0), args( Qubit(0), ) + Jump(5) + Block 5: Block: + Call id(0), args( Qubit(0), ) + Return + config: Config: + capabilities: Base + num_qubits: 1 + num_results: 4"#]] + .assert_eq(&program.to_string()); + + // After + reindex_qubits(&mut program); + expect![[r#" + Program: + entry: 0 + callables: + Callable 0: Callable: + name: __quantum__qis__x__body + call_type: Regular + input_type: + [0]: Qubit + output_type: + body: + Callable 2: Callable: + name: __quantum__qis__read_result__body + call_type: Readout + input_type: + [0]: Result + output_type: Boolean + body: + Callable 3: Callable: + name: __quantum__qis__mz__body + call_type: Measurement + input_type: + [0]: Qubit + [1]: Result + output_type: + body: + blocks: + Block 0: Block: + Call id(0), args( Qubit(0), ) + Call id(3), args( Qubit(0), Result(0), ) + Variable(0, Boolean) = Call id(2), args( Result(0), ) + Branch Variable(0, Boolean), 1, 2 + Block 1: Block: + Call id(0), args( Qubit(1), ) + Call id(3), args( Qubit(1), Result(1), ) + Jump(5) + Block 2: Block: + Call id(3), args( Qubit(1), Result(2), ) + Variable(1, Boolean) = Call id(2), args( Result(2), ) + Branch Variable(1, Boolean), 3, 4 + Block 3: Block: + Call id(0), args( Qubit(2), ) + Jump(5) + Block 4: Block: + Call id(0), args( Qubit(2), ) + Jump(5) + Block 5: Block: + Call id(0), args( Qubit(2), ) + Return + config: Config: + capabilities: Base + num_qubits: 3 + num_results: 4"#]] + .assert_eq(&program.to_string()); +} + +#[test] +#[should_panic(expected = "Qubit id 0 has multiple mappings across predecessors")] +fn qubit_reindex_fails_across_branches_with_one_branch_longer_different_usage_in_paths() { + const X: CallableId = CallableId(0); + const MRESETZ: CallableId = CallableId(1); + const READ_RESULT: CallableId = CallableId(2); + let mut program = Program::new(); + program.num_qubits = 1; + program.num_results = 4; + program.callables.insert(X, x_decl()); + program.callables.insert(MRESETZ, mresetz_decl()); + program.callables.insert(READ_RESULT, read_result_decl()); + + program.blocks.insert( + BlockId(0), + Block(vec![ + Instruction::Call(X, vec![Operand::Literal(Literal::Qubit(0))], None), + Instruction::Call( + MRESETZ, + vec![ + Operand::Literal(Literal::Qubit(0)), + Operand::Literal(Literal::Result(0)), + ], + None, + ), + Instruction::Call( + READ_RESULT, + vec![Operand::Literal(Literal::Result(0))], + Some(Variable { + variable_id: VariableId(0), + ty: Ty::Boolean, + }), + ), + Instruction::Branch( + Variable { + variable_id: VariableId(0), + ty: Ty::Boolean, + }, + BlockId(1), + BlockId(2), + ), + ]), + ); + program.blocks.insert( + BlockId(1), + Block(vec![ + Instruction::Call(X, vec![Operand::Literal(Literal::Qubit(0))], None), + Instruction::Call( + MRESETZ, + vec![ + Operand::Literal(Literal::Qubit(0)), + Operand::Literal(Literal::Result(1)), + ], + None, + ), + Instruction::Jump(BlockId(5)), + ]), + ); + program.blocks.insert( + BlockId(2), + Block(vec![ + Instruction::Call( + MRESETZ, + vec![ + Operand::Literal(Literal::Qubit(0)), + Operand::Literal(Literal::Result(2)), + ], + None, + ), + Instruction::Call( + READ_RESULT, + vec![Operand::Literal(Literal::Result(2))], + Some(Variable { + variable_id: VariableId(1), + ty: Ty::Boolean, + }), + ), + Instruction::Branch( + Variable { + variable_id: VariableId(1), + ty: Ty::Boolean, + }, + BlockId(3), + BlockId(4), + ), + ]), + ); + program.blocks.insert( + BlockId(3), + Block(vec![ + Instruction::Call( + MRESETZ, + vec![ + Operand::Literal(Literal::Qubit(0)), + Operand::Literal(Literal::Result(3)), + ], + None, + ), + Instruction::Jump(BlockId(5)), + ]), + ); + program.blocks.insert( + BlockId(4), + Block(vec![ + Instruction::Call(X, vec![Operand::Literal(Literal::Qubit(0))], None), + Instruction::Jump(BlockId(5)), + ]), + ); + program.blocks.insert( + BlockId(5), + Block(vec![ + Instruction::Call(X, vec![Operand::Literal(Literal::Qubit(0))], None), + Instruction::Return, + ]), + ); + + // Before + expect![[r#" + Program: + entry: 0 + callables: + Callable 0: Callable: + name: __quantum__qis__x__body + call_type: Regular + input_type: + [0]: Qubit + output_type: + body: + Callable 1: Callable: + name: __quantum__qis__mresetz__body + call_type: Measurement + input_type: + [0]: Qubit + [1]: Result + output_type: + body: + Callable 2: Callable: + name: __quantum__qis__read_result__body + call_type: Readout + input_type: + [0]: Result + output_type: Boolean + body: + blocks: + Block 0: Block: + Call id(0), args( Qubit(0), ) + Call id(1), args( Qubit(0), Result(0), ) + Variable(0, Boolean) = Call id(2), args( Result(0), ) + Branch Variable(0, Boolean), 1, 2 + Block 1: Block: + Call id(0), args( Qubit(0), ) + Call id(1), args( Qubit(0), Result(1), ) + Jump(5) + Block 2: Block: + Call id(1), args( Qubit(0), Result(2), ) + Variable(1, Boolean) = Call id(2), args( Result(2), ) + Branch Variable(1, Boolean), 3, 4 + Block 3: Block: + Call id(1), args( Qubit(0), Result(3), ) + Jump(5) + Block 4: Block: + Call id(0), args( Qubit(0), ) + Jump(5) + Block 5: Block: + Call id(0), args( Qubit(0), ) + Return + config: Config: + capabilities: Base + num_qubits: 1 + num_results: 4"#]] + .assert_eq(&program.to_string()); + + // After + reindex_qubits(&mut program); }