Skip to content

Commit

Permalink
Merge branch 'main' into handle_device_pointer
Browse files Browse the repository at this point in the history
  • Loading branch information
sacpis authored Jul 21, 2024
2 parents 23480f6 + 79d0c06 commit 0269611
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 15 deletions.
32 changes: 27 additions & 5 deletions python/cudaq/kernel/kernel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,23 +188,40 @@ def supportCommonCast(mlirType, otherTy, arg, FromType, ToType, PyType):

def __generalCustomOperation(self, opName, *args):
"""
Utility function for adding a generic quantum operation to the MLIR representation for the PyKernel.
Utility function for adding a generic quantum operation to the MLIR
representation for the PyKernel.
A controlled version can be invoked by passing additional arguments
to the operation. For an N-qubit operation, the last N arguments are
treated as `targets` and excess arguments as `controls`.
"""

global globalRegisteredOperations
unitary = globalRegisteredOperations[opName]

numTargets = int(np.log2(np.sqrt(unitary.size)))

targets = []
qubits = []
with self.insertPoint, self.loc:
for arg in args:
if isinstance(arg, QuakeValue):
targets.append(arg.mlirValue)
qubits.append(arg.mlirValue)
else:
emitFatalError(f"invalid argument type passed to {opName}.")

assert (numTargets == len(targets))
targets = []
controls = []

if numTargets == len(qubits):
targets = qubits
elif numTargets < len(qubits):
numControls = len(qubits) - numTargets
targets = qubits[-numTargets:]
controls = qubits[:numControls]
else:
emitFatalError(
f"too few arguments passed to {opName}, expected ({numTargets})"
)

globalName = f'{nvqppPrefix}{opName}_generator_{numTargets}.rodata'
currentST = SymbolTable(self.module.operation)
Expand All @@ -216,7 +233,7 @@ def __generalCustomOperation(self, opName, *args):
quake.CustomUnitarySymbolOp([],
generator=FlatSymbolRefAttr.get(globalName),
parameters=[],
controls=[],
controls=controls,
targets=targets,
is_adj=False)
return
Expand Down Expand Up @@ -1520,6 +1537,11 @@ def getListType(eleType: type):

cudaq_runtime.pyAltLaunchKernel(self.name, self.module, *processedArgs)

def __getattr__(self, attr_name):
if hasattr(self, attr_name):
return getattr(self, attr_name)
raise AttributeError(f"'{attr_name}' is not supported on PyKernel")


setattr(PyKernel, 'h', partialmethod(__singleTargetOperation, 'h'))
setattr(PyKernel, 'x', partialmethod(__singleTargetOperation, 'x'))
Expand Down
16 changes: 7 additions & 9 deletions python/cudaq/kernel/register_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,17 @@ def kernel():
if isinstance(unitary, Callable):
raise RuntimeError("parameterized custom operations not yet supported.")

if isinstance(unitary, np.ndarray):
if (len(unitary.shape) != unitary.ndim):
raise RuntimeError(
"provide a 1D array for the matrix representation in row-major format."
)
matrix = unitary
elif isinstance(unitary, List):
if isinstance(unitary, np.matrix) or isinstance(unitary, List):
matrix = np.array(unitary)
elif isinstance(unitary, np.ndarray):
matrix = unitary
else:
raise RuntimeError("unknown type of unitary.")

# TODO: Flatten the matrix if not flattened
assert (matrix.ndim == len(matrix.shape))
matrix = matrix.flatten()
assert (
matrix.ndim == len(matrix.shape),
"provide a 1D array for the matrix representation in row-major format.")

# Size must be a power of 2
assert (matrix.size != 0)
Expand Down
15 changes: 14 additions & 1 deletion python/tests/custom/test_custom_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def kernel():


def test_builder_mode():
"""Builder-mode API """
"""Builder-mode API"""

kernel = cudaq.make_kernel()
cudaq.register_operation("custom_h",
Expand All @@ -192,6 +192,19 @@ def test_builder_mode():
check_bell(kernel)


def test_builder_mode_control():
"""Controlled operation in builder-mode"""

kernel = cudaq.make_kernel()
cudaq.register_operation("custom_x", np.array([0, 1, 1, 0]))

qubits = kernel.qalloc(2)
kernel.h(qubits[0])
kernel.custom_x(qubits[0], qubits[1])

check_bell(kernel)


def test_invalid_ctrl():
cudaq.register_operation("custom_x", np.array([0, 1, 1, 0]))

Expand Down
85 changes: 85 additions & 0 deletions python/tests/mlir/custom_op_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# ============================================================================ #
# Copyright (c) 2022 - 2024 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
# the terms of the Apache License 2.0 which accompanies this distribution. #
# ============================================================================ #

# RUN: PYTHONPATH=../../ pytest -rP %s | FileCheck %s

import numpy as np
import cudaq


def test_builder_look_up():
"""A custom operation can be looked up by its name in builder mode"""

base_name = 'foo'
op_count = 3

def register_custom_operations(matrix):
prev = np.identity(2)
for t in range(op_count):
new = prev @ matrix
cudaq.register_operation(f'{base_name}_{t}', new)
prev = new

register_custom_operations(
np.array([[1, 0], [0, np.exp(np.pi * 1j * 1 / 3)]]))

kernel = cudaq.make_kernel()

qubit = kernel.qalloc(1)
ancilla = kernel.qalloc(2)

kernel.x(qubit)
kernel.h(ancilla)

for i in range(op_count):
kernel.__getattr__(f'{base_name}_{i}')(ancilla, qubit)

print(kernel)
counts = cudaq.sample(kernel)


# CHECK-LABEL: func.func @__nvqpp__mlirgen____nvqppBuilderKernel_{{.*}}() attributes {"cudaq-entrypoint"} {
# CHECK: %[[VAL_0:.*]] = arith.constant 2 : i64
# CHECK: %[[VAL_1:.*]] = arith.constant 1 : i64
# CHECK: %[[VAL_2:.*]] = arith.constant 0 : i64
# CHECK: %[[VAL_3:.*]] = quake.alloca !quake.veq<1>
# CHECK: %[[VAL_4:.*]] = quake.alloca !quake.veq<2>
# CHECK: %[[VAL_5:.*]] = cc.loop while ((%[[VAL_6:.*]] = %[[VAL_2]]) -> (i64)) {
# CHECK: %[[VAL_7:.*]] = arith.cmpi slt, %[[VAL_6]], %[[VAL_1]] : i64
# CHECK: cc.condition %[[VAL_7]](%[[VAL_6]] : i64)
# CHECK: } do {
# CHECK: ^bb0(%[[VAL_8:.*]]: i64):
# CHECK: %[[VAL_9:.*]] = quake.extract_ref %[[VAL_3]]{{\[}}%[[VAL_8]]] : (!quake.veq<1>, i64) -> !quake.ref
# CHECK: quake.x %[[VAL_9]] : (!quake.ref) -> ()
# CHECK: cc.continue %[[VAL_8]] : i64
# CHECK: } step {
# CHECK: ^bb0(%[[VAL_10:.*]]: i64):
# CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_1]] : i64
# CHECK: cc.continue %[[VAL_11]] : i64
# CHECK: } {invariant}
# CHECK: %[[VAL_12:.*]] = cc.loop while ((%[[VAL_13:.*]] = %[[VAL_2]]) -> (i64)) {
# CHECK: %[[VAL_14:.*]] = arith.cmpi slt, %[[VAL_13]], %[[VAL_0]] : i64
# CHECK: cc.condition %[[VAL_14]](%[[VAL_13]] : i64)
# CHECK: } do {
# CHECK: ^bb0(%[[VAL_15:.*]]: i64):
# CHECK: %[[VAL_16:.*]] = quake.extract_ref %[[VAL_4]]{{\[}}%[[VAL_15]]] : (!quake.veq<2>, i64) -> !quake.ref
# CHECK: quake.h %[[VAL_16]] : (!quake.ref) -> ()
# CHECK: cc.continue %[[VAL_15]] : i64
# CHECK: } step {
# CHECK: ^bb0(%[[VAL_17:.*]]: i64):
# CHECK: %[[VAL_18:.*]] = arith.addi %[[VAL_17]], %[[VAL_1]] : i64
# CHECK: cc.continue %[[VAL_18]] : i64
# CHECK: } {invariant}
# CHECK: quake.custom_op @__nvqpp__mlirgen__foo_0_generator_1.rodata {{\[}}%[[VAL_4]]] %[[VAL_3]] : (!quake.veq<2>, !quake.veq<1>) -> ()
# CHECK: quake.custom_op @__nvqpp__mlirgen__foo_1_generator_1.rodata {{\[}}%[[VAL_4]]] %[[VAL_3]] : (!quake.veq<2>, !quake.veq<1>) -> ()
# CHECK: quake.custom_op @__nvqpp__mlirgen__foo_2_generator_1.rodata {{\[}}%[[VAL_4]]] %[[VAL_3]] : (!quake.veq<2>, !quake.veq<1>) -> ()
# CHECK: return
# CHECK: }
# CHECK-DAG: cc.global constant @__nvqpp__mlirgen__foo_0_generator_1.rodata (dense<[{{.*}}]> : tensor<4xcomplex<f64>>) : !cc.array<complex<f64> x 4>
# CHECK-DAG: cc.global constant @__nvqpp__mlirgen__foo_1_generator_1.rodata (dense<[{{.*}}]> : tensor<4xcomplex<f64>>) : !cc.array<complex<f64> x 4>
# CHECK-DAG: cc.global constant @__nvqpp__mlirgen__foo_2_generator_1.rodata (dense<[{{.*}}]> : tensor<4xcomplex<f64>>) : !cc.array<complex<f64> x 4>

0 comments on commit 0269611

Please sign in to comment.