diff --git a/compiler/qsc_rir/src/passes.rs b/compiler/qsc_rir/src/passes.rs index 6079204eb4..d6cd307665 100644 --- a/compiler/qsc_rir/src/passes.rs +++ b/compiler/qsc_rir/src/passes.rs @@ -7,6 +7,7 @@ mod reindex_qubits; mod remap_block_ids; mod ssa_check; mod ssa_transform; +mod type_check; mod unreachable_code_check; use build_dominator_graph::build_dominator_graph; @@ -15,6 +16,7 @@ use reindex_qubits::reindex_qubits; use remap_block_ids::remap_block_ids; use ssa_check::check_ssa_form; use ssa_transform::transform_to_ssa; +pub use type_check::check_types; pub use unreachable_code_check::check_unreachable_code; use crate::{rir::Program, utils::build_predecessors_map}; @@ -27,11 +29,14 @@ use crate::{rir::Program, utils::build_predecessors_map}; /// - Checking that the program is in SSA form pub fn check_and_transform(program: &mut Program) { check_unreachable_code(program); + check_types(program); remap_block_ids(program); let preds = build_predecessors_map(program); transform_to_ssa(program, &preds); let doms = build_dominator_graph(program, &preds); check_ssa_form(program, &preds, &doms); + check_unreachable_code(program); + check_types(program); } /// Run the RIR passes that are necessary for targets with no mid-program measurement. diff --git a/compiler/qsc_rir/src/passes/ssa_transform/tests.rs b/compiler/qsc_rir/src/passes/ssa_transform/tests.rs index 41f3e1a015..da9d64064d 100644 --- a/compiler/qsc_rir/src/passes/ssa_transform/tests.rs +++ b/compiler/qsc_rir/src/passes/ssa_transform/tests.rs @@ -6,29 +6,25 @@ use expect_test::expect; use crate::{ - builder::{bell_program, new_program}, - passes::{build_dominator_graph, check_ssa_form, check_unreachable_code, remap_block_ids}, + builder::{bell_program, new_program, teleport_program}, + passes::check_and_transform, rir::{ Block, BlockId, Callable, CallableId, CallableType, Instruction, Operand, Program, Ty, Variable, VariableId, }, - utils::build_predecessors_map, }; - -use super::transform_to_ssa; - fn transform_program(program: &mut Program) { - check_unreachable_code(program); - remap_block_ids(program); - let preds = build_predecessors_map(program); - transform_to_ssa(program, &preds); - let doms = build_dominator_graph(program, &preds); - check_ssa_form(program, &preds, &doms); + // When this configuration is replaced by target capabilities, set them to "all" here. + program.config.defer_measurements = false; + program.config.remap_qubits_on_reuse = false; + check_and_transform(program); } #[test] fn ssa_transform_leaves_program_without_store_instruction_unchanged() { let mut program = bell_program(); + program.config.defer_measurements = false; + program.config.remap_qubits_on_reuse = false; let program_string_orignal = program.to_string(); transform_program(&mut program); @@ -38,7 +34,7 @@ fn ssa_transform_leaves_program_without_store_instruction_unchanged() { #[test] fn ssa_transform_leaves_branching_program_without_store_instruction_unchanged() { - let mut program = bell_program(); + let mut program = teleport_program(); let program_string_orignal = program.to_string(); transform_program(&mut program); diff --git a/compiler/qsc_rir/src/passes/type_check.rs b/compiler/qsc_rir/src/passes/type_check.rs new file mode 100644 index 0000000000..757657ff40 --- /dev/null +++ b/compiler/qsc_rir/src/passes/type_check.rs @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::rir::{Callable, Instruction, Operand, Program, Ty, Variable}; + +#[cfg(test)] +mod tests; + +pub fn check_types(program: &Program) { + for (_, block) in program.blocks.iter() { + for instr in &block.0 { + check_instr_types(program, instr); + } + } +} + +fn check_instr_types(program: &Program, instr: &Instruction) { + match instr { + Instruction::Call(id, args, var) => check_call_types(program.get_callable(*id), args, *var), + + Instruction::Branch(var, _, _) => assert_eq!(var.ty, Ty::Boolean), + + Instruction::Add(opr1, opr2, var) + | Instruction::Sub(opr1, opr2, var) + | Instruction::Mul(opr1, opr2, var) + | Instruction::Sdiv(opr1, opr2, var) + | Instruction::Srem(opr1, opr2, var) + | Instruction::Shl(opr1, opr2, var) + | Instruction::Ashr(opr1, opr2, var) + | Instruction::LogicalAnd(opr1, opr2, var) + | Instruction::LogicalOr(opr1, opr2, var) + | Instruction::BitwiseAnd(opr1, opr2, var) + | Instruction::BitwiseOr(opr1, opr2, var) + | Instruction::BitwiseXor(opr1, opr2, var) + | Instruction::Icmp(_, opr1, opr2, var) => { + assert_eq!(opr1.get_type(), opr2.get_type()); + assert_eq!(opr1.get_type(), var.ty); + } + + Instruction::Store(opr, var) + | Instruction::LogicalNot(opr, var) + | Instruction::BitwiseNot(opr, var) => { + assert_eq!(opr.get_type(), var.ty); + } + + Instruction::Phi(args, var) => { + for (opr, _) in args { + assert_eq!(opr.get_type(), var.ty); + } + } + + Instruction::Jump(_) | Instruction::Return => {} + } +} + +fn check_call_types(callable: &Callable, args: &[Operand], var: Option) { + assert_eq!( + callable.input_type.len(), + args.len(), + "incorrect number of arguments" + ); + for (arg, ty) in args.iter().zip(callable.input_type.iter()) { + assert_eq!(arg.get_type(), *ty); + } + + match (var, callable.output_type) { + (Some(var), Some(ty)) => assert_eq!(ty, var.ty), + (None, None) => {} + _ => panic!("expected return type to be present in both the instruction and the callable"), + } +} diff --git a/compiler/qsc_rir/src/passes/type_check/tests.rs b/compiler/qsc_rir/src/passes/type_check/tests.rs new file mode 100644 index 0000000000..79da706d0c --- /dev/null +++ b/compiler/qsc_rir/src/passes/type_check/tests.rs @@ -0,0 +1,315 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::rir::{ + BlockId, Callable, CallableId, CallableType, Instruction, Literal, Operand, Program, Ty, + Variable, VariableId, +}; + +use super::check_instr_types; + +#[test] +fn binop_instr_matching_types_passes_check() { + let var = Variable { + variable_id: VariableId(0), + ty: Ty::Integer, + }; + let opr1 = Operand::Variable(var); + let opr2 = Operand::Literal(Literal::Integer(0)); + + check_instr_types(&Program::new(), &Instruction::Add(opr1, opr2, var)); +} + +#[test] +#[should_panic(expected = "assertion `left == right` failed")] +fn binop_instr_mismatching_types_fails_check() { + let var = Variable { + variable_id: VariableId(0), + ty: Ty::Integer, + }; + let opr1 = Operand::Variable(var); + let opr2 = Operand::Literal(Literal::Bool(false)); + + check_instr_types(&Program::new(), &Instruction::Add(opr1, opr2, var)); +} + +#[test] +fn unop_instr_matching_types_passes_check() { + let var = Variable { + variable_id: VariableId(0), + ty: Ty::Boolean, + }; + let opr = Operand::Variable(var); + + check_instr_types(&Program::new(), &Instruction::BitwiseNot(opr, var)); +} + +#[test] +#[should_panic(expected = "assertion `left == right` failed")] +fn unop_instr_mismatching_types_fails_check() { + let var = Variable { + variable_id: VariableId(0), + ty: Ty::Integer, + }; + let opr = Operand::Variable(var); + + check_instr_types( + &Program::new(), + &Instruction::BitwiseNot( + opr, + Variable { + variable_id: VariableId(1), + ty: Ty::Boolean, + }, + ), + ); +} + +#[test] +fn phi_instr_matching_types_passes_check() { + let var = Variable { + variable_id: VariableId(0), + ty: Ty::Integer, + }; + let opr = Operand::Variable(var); + + check_instr_types( + &Program::new(), + &Instruction::Phi(vec![(opr, BlockId(0)), (opr, BlockId(1))], var), + ); +} + +#[test] +#[should_panic(expected = "assertion `left == right` failed")] +fn phi_instr_mismatching_types_fails_check() { + let var = Variable { + variable_id: VariableId(0), + ty: Ty::Integer, + }; + let opr = Operand::Variable(var); + + check_instr_types( + &Program::new(), + &Instruction::Phi( + vec![(opr, BlockId(0)), (opr, BlockId(1))], + Variable { + variable_id: VariableId(1), + ty: Ty::Boolean, + }, + ), + ); +} + +#[test] +fn call_instr_matching_types_passes_check() { + let var = Variable { + variable_id: VariableId(0), + ty: Ty::Integer, + }; + let opr = Operand::Variable(var); + + let mut program = Program::new(); + program.callables.insert( + CallableId(0), + Callable { + name: "foo".to_string(), + input_type: vec![Ty::Integer], + output_type: Some(Ty::Integer), + call_type: CallableType::Regular, + body: None, + }, + ); + + check_instr_types( + &program, + &Instruction::Call( + CallableId(0), + vec![opr], + Some(Variable { + variable_id: VariableId(1), + ty: Ty::Integer, + }), + ), + ); +} + +#[test] +#[should_panic(expected = "assertion `left == right` failed")] +fn call_instr_mismatching_output_types_fails_check() { + let var = Variable { + variable_id: VariableId(0), + ty: Ty::Integer, + }; + let opr = Operand::Variable(var); + + let mut program = Program::new(); + program.callables.insert( + CallableId(0), + Callable { + name: "foo".to_string(), + input_type: vec![Ty::Integer], + output_type: Some(Ty::Integer), + call_type: CallableType::Regular, + body: None, + }, + ); + + check_instr_types( + &program, + &Instruction::Call( + CallableId(0), + vec![opr], + Some(Variable { + variable_id: VariableId(1), + ty: Ty::Boolean, + }), + ), + ); +} + +#[test] +#[should_panic(expected = "assertion `left == right` failed")] +fn call_instr_mismatching_input_types_fails_check() { + let mut program = Program::new(); + program.callables.insert( + CallableId(0), + Callable { + name: "foo".to_string(), + input_type: vec![Ty::Integer], + output_type: Some(Ty::Integer), + call_type: CallableType::Regular, + body: None, + }, + ); + + check_instr_types( + &program, + &Instruction::Call( + CallableId(0), + vec![Operand::Literal(Literal::Bool(true))], + Some(Variable { + variable_id: VariableId(0), + ty: Ty::Integer, + }), + ), + ); +} + +#[test] +#[should_panic(expected = "assertion `left == right` failed")] +fn call_instr_too_many_args_fails_check() { + let var = Variable { + variable_id: VariableId(0), + ty: Ty::Integer, + }; + let opr = Operand::Variable(var); + + let mut program = Program::new(); + program.callables.insert( + CallableId(0), + Callable { + name: "foo".to_string(), + input_type: vec![Ty::Integer], + output_type: Some(Ty::Integer), + call_type: CallableType::Regular, + body: None, + }, + ); + + check_instr_types( + &program, + &Instruction::Call( + CallableId(0), + vec![opr, opr], + Some(Variable { + variable_id: VariableId(1), + ty: Ty::Integer, + }), + ), + ); +} + +#[test] +fn call_instr_no_return_type_no_output_var_passes_check() { + let var = Variable { + variable_id: VariableId(0), + ty: Ty::Integer, + }; + let opr = Operand::Variable(var); + + let mut program = Program::new(); + program.callables.insert( + CallableId(0), + Callable { + name: "foo".to_string(), + input_type: vec![Ty::Integer], + output_type: None, + call_type: CallableType::Regular, + body: None, + }, + ); + + check_instr_types(&program, &Instruction::Call(CallableId(0), vec![opr], None)); +} + +#[test] +#[should_panic( + expected = "expected return type to be present in both the instruction and the callable" +)] +fn call_instr_return_type_without_output_var_fails() { + let var = Variable { + variable_id: VariableId(0), + ty: Ty::Integer, + }; + let opr = Operand::Variable(var); + + let mut program = Program::new(); + program.callables.insert( + CallableId(0), + Callable { + name: "foo".to_string(), + input_type: vec![Ty::Integer], + output_type: Some(Ty::Integer), + call_type: CallableType::Regular, + body: None, + }, + ); + + check_instr_types(&program, &Instruction::Call(CallableId(0), vec![opr], None)); +} + +#[test] +#[should_panic( + expected = "expected return type to be present in both the instruction and the callable" +)] +fn call_instr_output_var_without_return_type_fails() { + let var = Variable { + variable_id: VariableId(0), + ty: Ty::Integer, + }; + let opr = Operand::Variable(var); + + let mut program = Program::new(); + program.callables.insert( + CallableId(0), + Callable { + name: "foo".to_string(), + input_type: vec![Ty::Integer], + output_type: None, + call_type: CallableType::Regular, + body: None, + }, + ); + + check_instr_types( + &program, + &Instruction::Call( + CallableId(0), + vec![opr], + Some(Variable { + variable_id: VariableId(1), + ty: Ty::Integer, + }), + ), + ); +}