Skip to content

Commit

Permalink
Support closures in partial evaluation (#1482)
Browse files Browse the repository at this point in the history
This adds support for closures with bound local variables in partial
evaluation so that they work for QIR codegen. This helps with generating
base profile compliant programs with the new QIR codegen.
  • Loading branch information
swernli authored May 6, 2024
1 parent df8ce8a commit cd2cabb
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 8 deletions.
2 changes: 1 addition & 1 deletion compiler/qsc_eval/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1424,7 +1424,7 @@ fn spec_from_functor_app(functor: FunctorApp) -> Spec {
}
}

fn resolve_closure(
pub fn resolve_closure(
env: &Env,
package: PackageId,
span: Span,
Expand Down
41 changes: 35 additions & 6 deletions compiler/qsc_partial_eval/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use management::{QuantumIntrinsicsChecker, ResourceManager};
use miette::Diagnostic;
use qsc_data_structures::span::Span;
use qsc_data_structures::{functors::FunctorApp, target::TargetCapabilityFlags};
use qsc_eval::resolve_closure;
use qsc_eval::{
self, exec_graph_section,
output::GenericReceiver,
Expand Down Expand Up @@ -450,8 +451,16 @@ impl<'a> PartialEvaluator<'a> {
ExprKind::Call(callee_expr_id, args_expr_id) => {
self.eval_expr_call(*callee_expr_id, *args_expr_id)
}
ExprKind::Closure(_, _) => {
panic!("instruction generation for closure expressions is unsupported")
ExprKind::Closure(args, callable) => {
let closure = resolve_closure(
&self.eval_context.get_current_scope().env,
self.get_current_package_id(),
expr.span,
args,
*callable,
)
.map_err(|e| Error::EvaluationFailed(e.to_string(), e.span().span))?;
Ok(EvalControlFlow::Continue(closure))
}
ExprKind::Fail(_) => panic!("instruction generation for fail expression is invalid"),
ExprKind::Field(_, _) => Err(Error::Unimplemented("Field Expr".to_string(), expr.span)),
Expand Down Expand Up @@ -723,7 +732,11 @@ impl<'a> PartialEvaluator<'a> {
}

// Get the callable.
let (store_item_id, functor_app) = callee_control_flow.into_value().unwrap_global();
let (store_item_id, functor_app, fixed_args) = match callee_control_flow.into_value() {
Value::Closure(inner) => (inner.id, inner.functor, Some(inner.fixed_args)),
Value::Global(id, functor) => (id, functor, None),
_ => panic!("value is not callable"),
};
let global = self
.package_store
.get_global(store_item_id)
Expand Down Expand Up @@ -751,6 +764,7 @@ impl<'a> PartialEvaluator<'a> {
spec_impl,
callable_decl.input,
args_control_flow.into_value(),
fixed_args,
)?,
};
Ok(EvalControlFlow::Continue(value))
Expand Down Expand Up @@ -811,6 +825,7 @@ impl<'a> PartialEvaluator<'a> {
(store_item_id.package, callable_decl.input).into(),
args_value,
None,
None,
);
assert!(
ctls_arg.is_none(),
Expand All @@ -834,6 +849,7 @@ impl<'a> PartialEvaluator<'a> {
spec_impl: &SpecImpl,
args_pat: PatId,
args_value: Value,
fixed_args: Option<Rc<[Value]>>,
) -> Result<Value, Error> {
let spec_decl = get_spec_decl(spec_impl, functor_app);

Expand All @@ -858,6 +874,7 @@ impl<'a> PartialEvaluator<'a> {
(global_callable_id.package, args_pat).into(),
args_value,
ctls,
fixed_args,
);
let call_scope = Scope::new(
global_callable_id.package,
Expand Down Expand Up @@ -1344,6 +1361,7 @@ impl<'a> PartialEvaluator<'a> {
store_pat_id: StorePatId,
value: Value,
ctls: Option<(StorePatId, u8)>,
fixed_args: Option<Rc<[Value]>>,
) -> (Vec<Arg>, Option<Arg>) {
let mut value = value;
let ctls_arg = if let Some((ctls_pat_id, ctls_count)) = ctls {
Expand Down Expand Up @@ -1374,6 +1392,14 @@ impl<'a> PartialEvaluator<'a> {
None
};

let value = if let Some(fixed_args) = fixed_args {
let mut fixed_args = fixed_args.to_vec();
fixed_args.push(value);
Value::Tuple(fixed_args.into())
} else {
value
};

let pat = self.package_store.get_pat(store_pat_id);
let args = match &pat.kind {
PatKind::Discard => vec![Arg::Discard(value)],
Expand All @@ -1396,9 +1422,12 @@ impl<'a> PartialEvaluator<'a> {
let pat_value_tuples = pats.iter().zip(values.to_vec());
for (pat_id, value) in pat_value_tuples {
// At this point we should no longer have control qubits so pass None.
let (mut element_args, None) =
self.resolve_args((store_pat_id.package, *pat_id).into(), value, None)
else {
let (mut element_args, None) = self.resolve_args(
(store_pat_id.package, *pat_id).into(),
value,
None,
None,
) else {
panic!("no control qubit are expected at this point");
};
args.append(&mut element_args);
Expand Down
153 changes: 152 additions & 1 deletion compiler/qsc_partial_eval/tests/calls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ fn call_to_boolean_function_using_result_literal_as_argument_yields_constant() {
}
if ResultAsBool(One) {
Op(q);
}
}
}
}
"#});
Expand Down Expand Up @@ -1049,3 +1049,154 @@ fn call_to_unitary_operation_using_multiple_controlled_functors() {
Return"#]],
);
}

#[test]
fn call_to_closue_with_no_bound_locals() {
let program = get_rir_program(indoc! {"
namespace Test {
operation Op() : (Qubit => Unit) {
X(_)
}
@EntryPoint()
operation Main() : Unit {
use q = Qubit();
(Op())(q);
}
}
"});
assert_callable(
&program,
CallableId(1),
&expect![[r#"
Callable:
name: __quantum__qis__x__body
call_type: Regular
input_type:
[0]: Qubit
output_type: <VOID>
body: <NONE>"#]],
);
assert_block_instructions(
&program,
BlockId(0),
&expect![[r#"
Block:
Call id(1), args( Qubit(0), )
Call id(2), args( Integer(0), Pointer, )
Return"#]],
);
}

#[test]
fn call_to_closue_with_one_bound_local() {
let program = get_rir_program(indoc! {"
namespace Test {
operation Op() : (Qubit => Unit) {
Rx(1.0, _)
}
@EntryPoint()
operation Main() : Unit {
use q = Qubit();
(Op())(q);
}
}
"});
assert_callable(
&program,
CallableId(1),
&expect![[r#"
Callable:
name: __quantum__qis__rx__body
call_type: Regular
input_type:
[0]: Double
[1]: Qubit
output_type: <VOID>
body: <NONE>"#]],
);
assert_block_instructions(
&program,
BlockId(0),
&expect![[r#"
Block:
Call id(1), args( Double(1), Qubit(0), )
Call id(2), args( Integer(0), Pointer, )
Return"#]],
);
}

#[test]
fn call_to_closue_with_two_bound_locals() {
let program = get_rir_program(indoc! {"
namespace Test {
operation Op() : (Qubit => Unit) {
R(PauliX, 1.0, _)
}
@EntryPoint()
operation Main() : Unit {
use q = Qubit();
(Op())(q);
}
}
"});
assert_callable(
&program,
CallableId(1),
&expect![[r#"
Callable:
name: __quantum__qis__rx__body
call_type: Regular
input_type:
[0]: Double
[1]: Qubit
output_type: <VOID>
body: <NONE>"#]],
);
assert_block_instructions(
&program,
BlockId(0),
&expect![[r#"
Block:
Call id(1), args( Double(1), Qubit(0), )
Call id(2), args( Integer(0), Pointer, )
Return"#]],
);
}

#[test]
fn call_to_closue_with_one_bound_local_two_unbound() {
let program = get_rir_program(indoc! {"
namespace Test {
operation Op() : ((Double, Qubit) => Unit) {
R(PauliX, _, _)
}
@EntryPoint()
operation Main() : Unit {
use q = Qubit();
(Op())(1.0, q);
}
}
"});
assert_callable(
&program,
CallableId(1),
&expect![[r#"
Callable:
name: __quantum__qis__rx__body
call_type: Regular
input_type:
[0]: Double
[1]: Qubit
output_type: <VOID>
body: <NONE>"#]],
);
assert_block_instructions(
&program,
BlockId(0),
&expect![[r#"
Block:
Call id(1), args( Double(1), Qubit(0), )
Call id(2), args( Integer(0), Pointer, )
Return"#]],
);
}

0 comments on commit cd2cabb

Please sign in to comment.