Skip to content

Commit

Permalink
Merge branch 'main' into handle_device_pointer
Browse files Browse the repository at this point in the history
  • Loading branch information
sacpis authored Jul 23, 2024
2 parents 85a5099 + f0df672 commit d162912
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 5 deletions.
15 changes: 10 additions & 5 deletions lib/Optimizer/Transforms/ApplyOpSpecialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,21 @@ struct ApplyVariants {
// any variants are added to this set.
bool merge(ApplyVariants that) {
bool rv = false;

auto checkAndSet = [&](bool &bit0, bool bit1) {
rv = !bit0 & bit1;
bit0 = bit0 | bit1;
};

checkAndSet(needsControlVariant, that.needsControlVariant);
checkAndSet(needsAdjointVariant, that.needsAdjointVariant);
checkAndSet(needsAdjointControlVariant, that.needsAdjointControlVariant);
// `that` has control and uses `this` which has adjoint, or `that` has
// adjoint and uses `this` which has control, so generate a `.adj.ctrl`
// variant for `this`, if not already present
checkAndSet(needsAdjointControlVariant,
that.needsAdjointControlVariant ||
(that.needsControlVariant && needsAdjointVariant) ||
(that.needsAdjointVariant && needsControlVariant));
return rv;
}
};
Expand Down Expand Up @@ -76,10 +84,7 @@ struct ApplyOpAnalysis {
variant.needsAdjointVariant = true;
else if (!apply.getControls().empty())
variant.needsControlVariant = true;
if (iter != infoMap.end())
infoMap[callee.getOperation()] = variant;
else
infoMap.insert(std::make_pair(callee.getOperation(), variant));
infoMap[callee.getOperation()] = variant;
});

// Propagate the transitive closure over the call tree.
Expand Down
76 changes: 76 additions & 0 deletions python/tests/mlir/bug_1871.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# ============================================================================ #
# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

# RUN: PYTHONPATH=../../ pytest -rP %s | FileCheck %s

import cudaq

import pytest


def test_control_on_adjoint():

@cudaq.kernel
def my_func(q: cudaq.qubit, theta: float):
ry(theta, q)
rz(theta, q)

@cudaq.kernel
def adj_func(q: cudaq.qubit, theta: float):
cudaq.adjoint(my_func, q, theta)

@cudaq.kernel
def kernel(theta: float):
ancilla = cudaq.qubit()
q = cudaq.qubit()

h(ancilla)
cudaq.control(my_func, ancilla, q, theta)
cudaq.control(adj_func, ancilla, q, theta)

print(kernel)
theta = 1.5
# also test that this compiles and runs
cudaq.sample(kernel, theta).dump()


# CHECK-LABEL: func.func @__nvqpp__mlirgen__my_func(
# CHECK-SAME: %[[VAL_0:.*]]: !quake.ref,
# CHECK-SAME: %[[VAL_1:.*]]: f64) {
# CHECK: %[[VAL_2:.*]] = cc.alloca f64
# CHECK: cc.store %[[VAL_1]], %[[VAL_2]] : !cc.ptr<f64>
# CHECK: %[[VAL_3:.*]] = cc.load %[[VAL_2]] : !cc.ptr<f64>
# CHECK: quake.ry (%[[VAL_3]]) %[[VAL_0]] : (f64, !quake.ref) -> ()
# CHECK: %[[VAL_4:.*]] = cc.load %[[VAL_2]] : !cc.ptr<f64>
# CHECK: quake.rz (%[[VAL_4]]) %[[VAL_0]] : (f64, !quake.ref) -> ()
# CHECK: return
# CHECK: }

# CHECK-LABEL: func.func @__nvqpp__mlirgen__adj_func(
# CHECK-SAME: %[[VAL_0:.*]]: !quake.ref,
# CHECK-SAME: %[[VAL_1:.*]]: f64) {
# CHECK: %[[VAL_2:.*]] = cc.alloca f64
# CHECK: cc.store %[[VAL_1]], %[[VAL_2]] : !cc.ptr<f64>
# CHECK: %[[VAL_3:.*]] = cc.load %[[VAL_2]] : !cc.ptr<f64>
# CHECK: quake.apply<adj> @__nvqpp__mlirgen__my_func %[[VAL_0]], %[[VAL_3]] : (!quake.ref, f64) -> ()
# CHECK: return
# CHECK: }

# CHECK-LABEL: func.func @__nvqpp__mlirgen__kernel(
# CHECK-SAME: %[[VAL_0:.*]]: f64) attributes {"cudaq-entrypoint"} {
# CHECK: %[[VAL_1:.*]] = cc.alloca f64
# CHECK: cc.store %[[VAL_0]], %[[VAL_1]] : !cc.ptr<f64>
# CHECK: %[[VAL_2:.*]] = quake.alloca !quake.ref
# CHECK: %[[VAL_3:.*]] = quake.alloca !quake.ref
# CHECK: quake.h %[[VAL_2]] : (!quake.ref) -> ()
# CHECK: %[[VAL_4:.*]] = cc.load %[[VAL_1]] : !cc.ptr<f64>
# CHECK: quake.apply @__nvqpp__mlirgen__my_func {{\[}}%[[VAL_2]]] %[[VAL_3]], %[[VAL_4]] : (!quake.ref, !quake.ref, f64) -> ()
# CHECK: %[[VAL_5:.*]] = cc.load %[[VAL_1]] : !cc.ptr<f64>
# CHECK: quake.apply @__nvqpp__mlirgen__adj_func {{\[}}%[[VAL_2]]] %[[VAL_3]], %[[VAL_5]] : (!quake.ref, !quake.ref, f64) -> ()
# CHECK: return
# CHECK: }
134 changes: 134 additions & 0 deletions test/Quake/control_on_adjoint.qke
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// ========================================================================== //
// Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. //
// All rights reserved. //
// //
// This source code and the accompanying materials are made available under //
// the terms of the Apache License 2.0 which accompanies this distribution. //
// ========================================================================== //

// RUN: cudaq-opt --apply-op-specialization --canonicalize %s | FileCheck %s

func.func @__nvqpp__mlirgen__my_func(%arg0: !quake.ref, %arg1: f64) {
%0 = cc.alloca f64
cc.store %arg1, %0 : !cc.ptr<f64>
%1 = cc.load %0 : !cc.ptr<f64>
quake.ry (%1) %arg0 : (f64, !quake.ref) -> ()
%2 = cc.load %0 : !cc.ptr<f64>
quake.rz (%2) %arg0 : (f64, !quake.ref) -> ()
return
}

func.func @__nvqpp__mlirgen__adj_func(%arg0: !quake.ref, %arg1: f64) {
%0 = cc.alloca f64
cc.store %arg1, %0 : !cc.ptr<f64>
%1 = cc.load %0 : !cc.ptr<f64>
quake.apply<adj> @__nvqpp__mlirgen__my_func %arg0, %1 : (!quake.ref, f64) -> ()
return
}

func.func @__nvqpp__mlirgen__kernel(%arg0: f64) attributes {"cudaq-entrypoint"} {
%0 = cc.alloca f64
cc.store %arg0, %0 : !cc.ptr<f64>
%1 = quake.alloca !quake.ref
%2 = quake.alloca !quake.ref
quake.h %1 : (!quake.ref) -> ()
%3 = cc.load %0 : !cc.ptr<f64>
quake.apply @__nvqpp__mlirgen__my_func [%1] %2, %3 : (!quake.ref, !quake.ref, f64) -> ()
%4 = cc.load %0 : !cc.ptr<f64>
quake.apply @__nvqpp__mlirgen__adj_func [%1] %2, %4 : (!quake.ref, !quake.ref, f64) -> ()
return
}

// CHECK-LABEL: func.func private @__nvqpp__mlirgen__adj_func.ctrl(
// CHECK-SAME: %[[VAL_0:.*]]: !quake.veq<?>,
// CHECK-SAME: %[[VAL_1:.*]]: !quake.ref,
// CHECK-SAME: %[[VAL_2:.*]]: f64) {
// CHECK: %[[VAL_3:.*]] = cc.alloca f64
// CHECK: cc.store %[[VAL_2]], %[[VAL_3]] : !cc.ptr<f64>
// CHECK: %[[VAL_4:.*]] = cc.load %[[VAL_3]] : !cc.ptr<f64>
// CHECK: %[[VAL_5:.*]] = quake.concat %[[VAL_0]] : (!quake.veq<?>) -> !quake.veq<?>
// CHECK: call @__nvqpp__mlirgen__my_func.adj.ctrl(%[[VAL_5]], %[[VAL_1]], %[[VAL_4]]) : (!quake.veq<?>, !quake.ref, f64) -> ()
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func private @__nvqpp__mlirgen__my_func.adj.ctrl(
// CHECK-SAME: %[[VAL_0:.*]]: !quake.veq<?>,
// CHECK-SAME: %[[VAL_1:.*]]: !quake.ref,
// CHECK-SAME: %[[VAL_2:.*]]: f64) {
// CHECK: %[[VAL_3:.*]] = cc.alloca f64
// CHECK: cc.store %[[VAL_2]], %[[VAL_3]] : !cc.ptr<f64>
// CHECK: %[[VAL_4:.*]] = cc.load %[[VAL_3]] : !cc.ptr<f64>
// CHECK: %[[VAL_5:.*]] = cc.load %[[VAL_3]] : !cc.ptr<f64>
// CHECK: %[[VAL_6:.*]] = arith.negf %[[VAL_5]] : f64
// CHECK: quake.rz (%[[VAL_6]]) {{\[}}%[[VAL_0]]] %[[VAL_1]] : (f64, !quake.veq<?>, !quake.ref) -> ()
// CHECK: %[[VAL_7:.*]] = arith.negf %[[VAL_4]] : f64
// CHECK: quake.ry (%[[VAL_7]]) {{\[}}%[[VAL_0]]] %[[VAL_1]] : (f64, !quake.veq<?>, !quake.ref) -> ()
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func private @__nvqpp__mlirgen__my_func.adj(
// CHECK-SAME: %[[VAL_0:.*]]: !quake.ref,
// CHECK-SAME: %[[VAL_1:.*]]: f64) {
// CHECK: %[[VAL_2:.*]] = cc.alloca f64
// CHECK: cc.store %[[VAL_1]], %[[VAL_2]] : !cc.ptr<f64>
// CHECK: %[[VAL_3:.*]] = cc.load %[[VAL_2]] : !cc.ptr<f64>
// CHECK: %[[VAL_4:.*]] = cc.load %[[VAL_2]] : !cc.ptr<f64>
// CHECK: %[[VAL_5:.*]] = arith.negf %[[VAL_4]] : f64
// CHECK: quake.rz (%[[VAL_5]]) %[[VAL_0]] : (f64, !quake.ref) -> ()
// CHECK: %[[VAL_6:.*]] = arith.negf %[[VAL_3]] : f64
// CHECK: quake.ry (%[[VAL_6]]) %[[VAL_0]] : (f64, !quake.ref) -> ()
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func private @__nvqpp__mlirgen__my_func.ctrl(
// CHECK-SAME: %[[VAL_0:.*]]: !quake.veq<?>,
// CHECK-SAME: %[[VAL_1:.*]]: !quake.ref,
// CHECK-SAME: %[[VAL_2:.*]]: f64) {
// CHECK: %[[VAL_3:.*]] = cc.alloca f64
// CHECK: cc.store %[[VAL_2]], %[[VAL_3]] : !cc.ptr<f64>
// CHECK: %[[VAL_4:.*]] = cc.load %[[VAL_3]] : !cc.ptr<f64>
// CHECK: quake.ry (%[[VAL_4]]) {{\[}}%[[VAL_0]]] %[[VAL_1]] : (f64, !quake.veq<?>, !quake.ref) -> ()
// CHECK: %[[VAL_5:.*]] = cc.load %[[VAL_3]] : !cc.ptr<f64>
// CHECK: quake.rz (%[[VAL_5]]) {{\[}}%[[VAL_0]]] %[[VAL_1]] : (f64, !quake.veq<?>, !quake.ref) -> ()
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func @__nvqpp__mlirgen__my_func(
// CHECK-SAME: %[[VAL_0:.*]]: !quake.ref,
// CHECK-SAME: %[[VAL_1:.*]]: f64) {
// CHECK: %[[VAL_2:.*]] = cc.alloca f64
// CHECK: cc.store %[[VAL_1]], %[[VAL_2]] : !cc.ptr<f64>
// CHECK: %[[VAL_3:.*]] = cc.load %[[VAL_2]] : !cc.ptr<f64>
// CHECK: quake.ry (%[[VAL_3]]) %[[VAL_0]] : (f64, !quake.ref) -> ()
// CHECK: %[[VAL_4:.*]] = cc.load %[[VAL_2]] : !cc.ptr<f64>
// CHECK: quake.rz (%[[VAL_4]]) %[[VAL_0]] : (f64, !quake.ref) -> ()
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func @__nvqpp__mlirgen__adj_func(
// CHECK-SAME: %[[VAL_0:.*]]: !quake.ref,
// CHECK-SAME: %[[VAL_1:.*]]: f64) {
// CHECK: %[[VAL_2:.*]] = cc.alloca f64
// CHECK: cc.store %[[VAL_1]], %[[VAL_2]] : !cc.ptr<f64>
// CHECK: %[[VAL_3:.*]] = cc.load %[[VAL_2]] : !cc.ptr<f64>
// CHECK: call @__nvqpp__mlirgen__my_func.adj(%[[VAL_0]], %[[VAL_3]]) : (!quake.ref, f64) -> ()
// CHECK: return
// CHECK: }

// CHECK-LABEL: func.func @__nvqpp__mlirgen__kernel(
// CHECK-SAME: %[[VAL_0:.*]]: f64) attributes {"cudaq-entrypoint"} {
// CHECK: %[[VAL_1:.*]] = cc.alloca f64
// CHECK: cc.store %[[VAL_0]], %[[VAL_1]] : !cc.ptr<f64>
// CHECK: %[[VAL_2:.*]] = quake.alloca !quake.ref
// CHECK: %[[VAL_3:.*]] = quake.alloca !quake.ref
// CHECK: quake.h %[[VAL_2]] : (!quake.ref) -> ()
// CHECK: %[[VAL_4:.*]] = cc.load %[[VAL_1]] : !cc.ptr<f64>
// CHECK: %[[VAL_5:.*]] = quake.concat %[[VAL_2]] : (!quake.ref) -> !quake.veq<1>
// CHECK: %[[VAL_6:.*]] = quake.relax_size %[[VAL_5]] : (!quake.veq<1>) -> !quake.veq<?>
// CHECK: call @__nvqpp__mlirgen__my_func.ctrl(%[[VAL_6]], %[[VAL_3]], %[[VAL_4]]) : (!quake.veq<?>, !quake.ref, f64) -> ()
// CHECK: %[[VAL_7:.*]] = cc.load %[[VAL_1]] : !cc.ptr<f64>
// CHECK: %[[VAL_8:.*]] = quake.concat %[[VAL_2]] : (!quake.ref) -> !quake.veq<1>
// CHECK: %[[VAL_9:.*]] = quake.relax_size %[[VAL_8]] : (!quake.veq<1>) -> !quake.veq<?>
// CHECK: call @__nvqpp__mlirgen__adj_func.ctrl(%[[VAL_9]], %[[VAL_3]], %[[VAL_7]]) : (!quake.veq<?>, !quake.ref, f64) -> ()
// CHECK: return
// CHECK: }

0 comments on commit d162912

Please sign in to comment.