Skip to content

Commit

Permalink
Add RIR type check pass (#1397)
Browse files Browse the repository at this point in the history
This change introduces a type check pass for RIR that verifies and
panics on any mismatches found in the types. It adds the check to the
default passes and the SSA transform tests.
  • Loading branch information
swernli authored Apr 18, 2024
1 parent edc5941 commit 4db51ce
Show file tree
Hide file tree
Showing 4 changed files with 400 additions and 13 deletions.
5 changes: 5 additions & 0 deletions compiler/qsc_rir/src/passes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod reindex_qubits;
mod remap_block_ids;
mod ssa_check;
mod ssa_transform;
mod type_check;
mod unreachable_code_check;

use build_dominator_graph::build_dominator_graph;
Expand All @@ -15,6 +16,7 @@ use reindex_qubits::reindex_qubits;
use remap_block_ids::remap_block_ids;
use ssa_check::check_ssa_form;
use ssa_transform::transform_to_ssa;
pub use type_check::check_types;
pub use unreachable_code_check::check_unreachable_code;

use crate::{rir::Program, utils::build_predecessors_map};
Expand All @@ -27,11 +29,14 @@ use crate::{rir::Program, utils::build_predecessors_map};
/// - Checking that the program is in SSA form
pub fn check_and_transform(program: &mut Program) {
check_unreachable_code(program);
check_types(program);
remap_block_ids(program);
let preds = build_predecessors_map(program);
transform_to_ssa(program, &preds);
let doms = build_dominator_graph(program, &preds);
check_ssa_form(program, &preds, &doms);
check_unreachable_code(program);
check_types(program);
}

/// Run the RIR passes that are necessary for targets with no mid-program measurement.
Expand Down
22 changes: 9 additions & 13 deletions compiler/qsc_rir/src/passes/ssa_transform/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,25 @@
use expect_test::expect;

use crate::{
builder::{bell_program, new_program},
passes::{build_dominator_graph, check_ssa_form, check_unreachable_code, remap_block_ids},
builder::{bell_program, new_program, teleport_program},
passes::check_and_transform,
rir::{
Block, BlockId, Callable, CallableId, CallableType, Instruction, Operand, Program, Ty,
Variable, VariableId,
},
utils::build_predecessors_map,
};

use super::transform_to_ssa;

fn transform_program(program: &mut Program) {
check_unreachable_code(program);
remap_block_ids(program);
let preds = build_predecessors_map(program);
transform_to_ssa(program, &preds);
let doms = build_dominator_graph(program, &preds);
check_ssa_form(program, &preds, &doms);
// When this configuration is replaced by target capabilities, set them to "all" here.
program.config.defer_measurements = false;
program.config.remap_qubits_on_reuse = false;
check_and_transform(program);
}

#[test]
fn ssa_transform_leaves_program_without_store_instruction_unchanged() {
let mut program = bell_program();
program.config.defer_measurements = false;
program.config.remap_qubits_on_reuse = false;
let program_string_orignal = program.to_string();

transform_program(&mut program);
Expand All @@ -38,7 +34,7 @@ fn ssa_transform_leaves_program_without_store_instruction_unchanged() {

#[test]
fn ssa_transform_leaves_branching_program_without_store_instruction_unchanged() {
let mut program = bell_program();
let mut program = teleport_program();
let program_string_orignal = program.to_string();

transform_program(&mut program);
Expand Down
71 changes: 71 additions & 0 deletions compiler/qsc_rir/src/passes/type_check.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

use crate::rir::{Callable, Instruction, Operand, Program, Ty, Variable};

#[cfg(test)]
mod tests;

pub fn check_types(program: &Program) {
for (_, block) in program.blocks.iter() {
for instr in &block.0 {
check_instr_types(program, instr);
}
}
}

fn check_instr_types(program: &Program, instr: &Instruction) {
match instr {
Instruction::Call(id, args, var) => check_call_types(program.get_callable(*id), args, *var),

Instruction::Branch(var, _, _) => assert_eq!(var.ty, Ty::Boolean),

Instruction::Add(opr1, opr2, var)
| Instruction::Sub(opr1, opr2, var)
| Instruction::Mul(opr1, opr2, var)
| Instruction::Sdiv(opr1, opr2, var)
| Instruction::Srem(opr1, opr2, var)
| Instruction::Shl(opr1, opr2, var)
| Instruction::Ashr(opr1, opr2, var)
| Instruction::LogicalAnd(opr1, opr2, var)
| Instruction::LogicalOr(opr1, opr2, var)
| Instruction::BitwiseAnd(opr1, opr2, var)
| Instruction::BitwiseOr(opr1, opr2, var)
| Instruction::BitwiseXor(opr1, opr2, var)
| Instruction::Icmp(_, opr1, opr2, var) => {
assert_eq!(opr1.get_type(), opr2.get_type());
assert_eq!(opr1.get_type(), var.ty);
}

Instruction::Store(opr, var)
| Instruction::LogicalNot(opr, var)
| Instruction::BitwiseNot(opr, var) => {
assert_eq!(opr.get_type(), var.ty);
}

Instruction::Phi(args, var) => {
for (opr, _) in args {
assert_eq!(opr.get_type(), var.ty);
}
}

Instruction::Jump(_) | Instruction::Return => {}
}
}

fn check_call_types(callable: &Callable, args: &[Operand], var: Option<Variable>) {
assert_eq!(
callable.input_type.len(),
args.len(),
"incorrect number of arguments"
);
for (arg, ty) in args.iter().zip(callable.input_type.iter()) {
assert_eq!(arg.get_type(), *ty);
}

match (var, callable.output_type) {
(Some(var), Some(ty)) => assert_eq!(ty, var.ty),
(None, None) => {}
_ => panic!("expected return type to be present in both the instruction and the callable"),
}
}
Loading

0 comments on commit 4db51ce

Please sign in to comment.