Skip to content

Commit

Permalink
Support remapping of RIR qubit ids with branching (#1404)
Browse files Browse the repository at this point in the history
This updates the RIR qubit reindexing pass to allow for running on RIR
that includes branches so long as all paths result in the same
remapping.

---------

Co-authored-by: Mine Starks <[email protected]>
  • Loading branch information
swernli and minestarks authored Apr 30, 2024
1 parent e8c8949 commit 09be343
Show file tree
Hide file tree
Showing 3 changed files with 830 additions and 107 deletions.
7 changes: 6 additions & 1 deletion compiler/qsc_rir/src/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,14 @@ use crate::{rir::Program, utils::build_predecessors_map};

/// Run the default set of RIR check and transformation passes.
/// This includes:
/// - Simplifying control flow
/// - Checking for unreachable code
/// - Checking types
/// - Remapping block IDs
/// - Transforming the program to SSA form
/// - Checking that the program is in SSA form
/// - If the target has no reset capability, reindexing qubit IDs and removing resets.
/// - If the target has no mid-program measurement capability, deferring measurements to the end of the program.
pub fn check_and_transform(program: &mut Program) {
simplify_control_flow(program);
check_unreachable_code(program);
Expand All @@ -44,7 +48,8 @@ pub fn check_and_transform(program: &mut Program) {

// Run the RIR passes that are necessary for targets with no mid-program measurement.
// This requires that qubits are not reused after measurement or reset, so qubit ids must be reindexed.
// This also requires that the program is a single block and will panic otherwise.
// This also requires that the program has no loops and block ids form a topological ordering on a
// directed acyclic graph.
if !program
.config
.capabilities
Expand Down
315 changes: 218 additions & 97 deletions compiler/qsc_rir/src/passes/reindex_qubits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,110 +4,226 @@
#[cfg(test)]
mod tests;

use std::collections::hash_map::Entry;

use qsc_data_structures::index_map::IndexMap;
use rustc_hash::FxHashMap;

use crate::rir::{
Block, Callable, CallableId, CallableType, Instruction, Literal, Operand, Program, Ty,
use crate::{
builder,
rir::{Block, BlockId, CallableId, CallableType, Instruction, Literal, Operand, Program, Ty},
utils::{build_predecessors_map, get_block_successors},
};

#[derive(Clone)]
struct BlockQubitMap {
map: FxHashMap<u32, u32>,
next_qubit_id: u32,
}

/// Reindexes qubits after they have been measured or reset. This ensures there is no qubit reuse in
/// the program. As part of the pass, reset callables are removed and mresetz calls are replaced with
/// mz calls.
/// Note that this pass has several assumptions:
/// 1. Only one callable has a body, which is the entry point callable.
/// 2. The entry point callable has a single block.
/// 2. The entry point callable is a directed acyclic graph where block ids have topological ordering.
/// 3. No dynamic qubits are used.
/// The pass will panic if the input program violates any of these assumptions.
pub fn reindex_qubits(program: &mut Program) {
validate_assumptions(program);

let (mut used_mz, mz_id) = match find_measurement_callable(program, "__quantum__qis__mz__body")
{
let (used_mz, mz_id) = match find_callable(program, "__quantum__qis__mz__body") {
Some(id) => (true, id),
None => (false, add_mz(program)),
};
let mresetz_id = find_measurement_callable(program, "__quantum__qis__mresetz__body");

let mut qubit_map = FxHashMap::default();
let mut next_qubit_id = program.num_qubits;
let mut highest_used_id = next_qubit_id - 1;
let mut new_block = Vec::new();
let (block_id, block) = program
.blocks
.drain()
.next()
let (used_cx, cx_id) = match find_callable(program, "__quantum__qis__cx__body") {
Some(id) => (true, id),
None => (false, add_cx(program)),
};
let mresetz_id = find_callable(program, "__quantum__qis__mresetz__body");
let mut pass = ReindexQubitPass {
used_mz,
mz_id,
used_cx,
cx_id,
mresetz_id,
highest_used_id: program.num_qubits - 1,
};

let pred_map = build_predecessors_map(program);

let mut block_maps = IndexMap::new();
block_maps.insert(
BlockId(0),
BlockQubitMap {
map: FxHashMap::default(),
next_qubit_id: program.num_qubits,
},
);

let mut all_blocks = program.blocks.drain().collect::<Vec<_>>();
let (entry_block, rest_blocks) = all_blocks
.split_first_mut()
.expect("program should have at least one block");
for instr in &block.0 {
// Assume qubits only appear in void call instructions.
match instr {
Instruction::Call(call_id, args, _)
if program.get_callable(*call_id).call_type == CallableType::Reset =>
{
// Generate any new qubit ids and skip adding the instruction.
for arg in args {
if let Operand::Literal(Literal::Qubit(qubit_id)) = arg {
qubit_map.insert(*qubit_id, next_qubit_id);
next_qubit_id += 1;

// Reindex qubits in the entry block.
pass.reindex_qubits_in_block(
program,
&mut entry_block.1,
block_maps
.get_mut(entry_block.0)
.expect("entry block with id 0 should be in block_maps"),
);

for (block_id, block) in rest_blocks {
// Use the predecessors to build the block's initial, inherited qubit map.
let pred_ids = pred_map
.get(*block_id)
.expect("block should have predecessors");

// Start from an empty map with the program's initial qubit ids.
let mut new_block_map = BlockQubitMap {
map: FxHashMap::default(),
next_qubit_id: program.num_qubits,
};

for pred_id in pred_ids {
let pred_qubit_map = block_maps
.get(*pred_id)
.expect("predecessor should be in block_maps");

// Across each predecessor, ensure that any mapped ids are same, otherwise
// panic because we can't know which id to use.
for (qubit_id, new_qubit_id) in &pred_qubit_map.map {
match new_block_map.map.entry(*qubit_id) {
Entry::Occupied(entry) if *entry.get() == *new_qubit_id => {}
Entry::Occupied(_) => {
panic!("Qubit id {qubit_id} has multiple mappings across predecessors");
}
}
}
Instruction::Call(call_id, args, None) => {
// Map the qubit args, if any, and copy over the instruction.
let new_args = args
.iter()
.map(|arg| match arg {
Operand::Literal(Literal::Qubit(qubit_id)) => {
if let Some(mapped_id) = qubit_map.get(qubit_id) {
highest_used_id = highest_used_id.max(*mapped_id);
Operand::Literal(Literal::Qubit(*mapped_id))
} else {
*arg
}
}
_ => *arg,
})
.collect::<Vec<_>>();

// If the call was to mresetz, replace with mz.
let call_id = if Some(*call_id) == mresetz_id {
used_mz = true;
mz_id
} else {
*call_id
};

new_block.push(Instruction::Call(call_id, new_args, None));

if program.get_callable(call_id).call_type == CallableType::Measurement {
// Generate any new qubit ids after a measurement.
for arg in args {
if let Operand::Literal(Literal::Qubit(qubit_id)) = arg {
qubit_map.insert(*qubit_id, next_qubit_id);
next_qubit_id += 1;
}
Entry::Vacant(entry) => {
entry.insert(*new_qubit_id);
}
}
}
_ => {
// Copy over the instruction.
new_block.push(instr.clone());
}
new_block_map.next_qubit_id = new_block_map
.next_qubit_id
.max(pred_qubit_map.next_qubit_id);
}

pass.reindex_qubits_in_block(program, block, &mut new_block_map);

block_maps.insert(*block_id, new_block_map);
}

program.num_qubits = highest_used_id + 1;
program.blocks.clear();
program.blocks.insert(block_id, Block(new_block));
program.blocks = all_blocks.into_iter().collect();
program.num_qubits = pass.highest_used_id + 1;

// All reset function calls should be removed, so remove them from the callables.
program
.callables
.retain(|id, callable| callable.call_type != CallableType::Reset && Some(id) != mresetz_id);

// If mz was added but not used, remove it.
if !used_mz {
// If mz or cx were added but not used, remove them.
if !pass.used_mz {
program.callables.remove(mz_id);
}
if !pass.used_cx {
program.callables.remove(cx_id);
}
}

struct ReindexQubitPass {
used_mz: bool,
mz_id: CallableId,
used_cx: bool,
cx_id: CallableId,
mresetz_id: Option<CallableId>,
highest_used_id: u32,
}

impl ReindexQubitPass {
fn reindex_qubits_in_block(
&mut self,
program: &Program,
block: &mut Block,
qubit_map: &mut BlockQubitMap,
) {
let instrs = std::mem::take(&mut block.0);
for instr in instrs {
// Assume qubits only appear in void call instructions.
match instr {
Instruction::Call(call_id, args, _)
if program.get_callable(call_id).call_type == CallableType::Reset =>
{
// Generate any new qubit ids and skip adding the instruction.
for arg in args {
if let Operand::Literal(Literal::Qubit(qubit_id)) = arg {
qubit_map.map.insert(qubit_id, qubit_map.next_qubit_id);
qubit_map.next_qubit_id += 1;
}
}
}
Instruction::Call(call_id, args, None) => {
// Map the qubit args, if any, and copy over the instruction.
let new_args = args
.iter()
.map(|arg| match arg {
Operand::Literal(Literal::Qubit(qubit_id)) => {
match qubit_map.map.get(qubit_id) {
Some(mapped_id) => {
// If the qubit has already been mapped, use the mapped id.
self.highest_used_id = self.highest_used_id.max(*mapped_id);
Operand::Literal(Literal::Qubit(*mapped_id))
}
None => *arg,
}
}
_ => *arg,
})
.collect::<Vec<_>>();

if call_id == self.mz_id {
// Since the call was to mz, the new qubit replacing this one must be conditionally flipped.
// Achieve this by adding a CNOT gate before the mz call.
self.used_cx = true;
block.0.push(Instruction::Call(
self.cx_id,
vec![
new_args[0],
Operand::Literal(Literal::Qubit(qubit_map.next_qubit_id)),
],
None,
));
self.highest_used_id = self.highest_used_id.max(qubit_map.next_qubit_id);
}

// If the call was to mresetz, replace with mz.
let call_id = if Some(call_id) == self.mresetz_id {
self.used_mz = true;
self.mz_id
} else {
call_id
};

block.0.push(Instruction::Call(call_id, new_args, None));

if program.get_callable(call_id).call_type == CallableType::Measurement {
// Generate any new qubit ids after a measurement.
for arg in args {
if let Operand::Literal(Literal::Qubit(qubit_id)) = arg {
qubit_map.map.insert(qubit_id, qubit_map.next_qubit_id);
qubit_map.next_qubit_id += 1;
}
}
}
}
_ => {
// Copy over the instruction.
block.0.push(instr.clone());
}
}
}
}
}

fn validate_assumptions(program: &Program) {
Expand All @@ -119,28 +235,28 @@ fn validate_assumptions(program: &Program) {
);
}

// Ensure entry point callable has a single block.
// Future enhancements may allow multiple blocks in the entry point callable.
assert!(
program.blocks.iter().count() == 1,
"Entry point callable must have a single block"
);

// Ensure that no dynamic qubits are used.
let Some((_, block)) = program.blocks.iter().next() else {
panic!("No blocks found in the program");
};
for instr in &block.0 {
// Ensure entry point callable blocks are a topologically ordered DAG.
// We can check this quickly by verifying that each block only has successors with higher ids.
for (block_id, block) in program.blocks.iter() {
assert!(
!matches!(instr, Instruction::Store(_, var) if var.ty == Ty::Qubit),
"Dynamic qubits are not supported"
get_block_successors(block)
.iter()
.all(|&succ_id| succ_id > block_id),
"blocks must form a topologically ordered DAG"
);
// Ensure that no dynamic qubits are used.
for instr in &block.0 {
assert!(
!matches!(instr, Instruction::Store(_, var) if var.ty == Ty::Qubit),
"Dynamic qubits are not supported"
);
}
}
}

fn find_measurement_callable(program: &Program, name: &str) -> Option<CallableId> {
fn find_callable(program: &Program, name: &str) -> Option<CallableId> {
for (callable_id, callable) in program.callables.iter() {
if callable.call_type == CallableType::Measurement && callable.name == name {
if callable.name == name {
return Some(callable_id);
}
}
Expand All @@ -157,15 +273,20 @@ fn add_mz(program: &mut Program) -> CallableId {
.expect("should be at least one callable")
+ 1,
);
program.callables.insert(
mz_id,
Callable {
name: "__quantum__qis__mz__body".to_string(),
input_type: vec![Ty::Qubit, Ty::Result],
output_type: None,
body: None,
call_type: CallableType::Measurement,
},
);
program.callables.insert(mz_id, builder::mz_decl());
mz_id
}

fn add_cx(program: &mut Program) -> CallableId {
let cx_id = CallableId(
program
.callables
.iter()
.map(|(id, _)| id.0)
.max()
.expect("should be at least one callable")
+ 1,
);
program.callables.insert(cx_id, builder::cx_decl());
cx_id
}
Loading

0 comments on commit 09be343

Please sign in to comment.