Skip to content

Commit

Permalink
Merge branch 'dev' into nlocal
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanrui-Wang authored Aug 23, 2023
2 parents 4172965 + b39ac31 commit e15d756
Show file tree
Hide file tree
Showing 49 changed files with 7,856 additions and 235 deletions.
8 changes: 8 additions & 0 deletions examples/grover/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Grover's Search Algorithm

Grover's Search Algorithm [1] is an algorithm which can speed up an unstructured search problem quadratically using amplitude amplification. A detailed walkthrough can be found [here](https://quantum-computing.ibm.com/composer/docs/iqx/guide/grovers-algorithm). The file `grover_example_sudoku.py` provides an example of how to use the algorithm to solve a sudoku puzzle of size 2x2.


## References

1. Grover, Lov K.. “A fast quantum mechanical algorithm for database search.” Symposium on the Theory of Computing (1996).
93 changes: 93 additions & 0 deletions examples/grover/grover_example_sudoku.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
This example is based on Qiskt's Texbook: https://learn.qiskit.org/course/ch-algorithms/grovers-algorithm#sudoku
We will now tackle a 2x2 binary sudoku problem using Grover's algorithm, where we don't necessarily possess
prior knowledge of the solution. The problem adheres to two simple rules:
1. No column can have the same value repeated.
2. No row can have the same value repeated.
We will use the following 4 variables:
---------
| a | b |
---------
| c | d |
---------
Please keep in mind that while utilizing Grover's algorithm to solve this particular problem may not be practical
(as the solution can likely be determined mentally), the intention of this example is to showcase the process of
transforming classical decision problems into oracles suitable for Grover's algorithm.
We need to check for four conditions:
1. a != b
2. c != d
3. a != c
4. b != d
"""

import torchquantum as tq
from torchquantum.algorithms import Grover


# To simplify the process, we can compile this set of comparisons into a list of clauses for convenience.
clauses = [ [0, 1], [0, 2], [1, 3], [2, 3] ]

# This circuit checks if input0 is equal to input1 and stores the output in output.
# The output of each comparison is stored in a new bit.
def XOR(input0, input1, output):
op1 = {'name': 'cnot', 'wires': [input0, output]}
op2 = {'name': 'cnot', 'wires': [input1, output]}
return [op1, op2]

# To verify each clause, we repeat the above circuit for every pairing in the `clauses`.
ops = []
clause_qubits = [4, 5, 6, 7]
for i, clause in enumerate(clauses):
ops += XOR(clause[0], clause[1], clause_qubits[i])

# To determine if the assignments of a, b, c, d are a solution to the sudoku, we examine the final state
# of the `clause_qubits`. Only when all of these qubits are 1, it indicates that the clauses are satisfied.
# To achieve this, we incorporate a multi-controlled Toffoli gate in our checking circuit. This gate
# ensures that a single output bit will be set to 1 if and only if all the clauses are satisfied,
# allowing us to easily determine if our assignment is a solution.
ops += [{'name': 'multicnot', 'n_wires': 5, 'wires': [4,5,6,7,8]}]

# In order to transform our checking circuit into a Grover oracle, it is crucial to ensure that the `clause_qubits`
# are always returned to their initial state after the computation. This guarantees that `clause_qubits` are all
# set to 0 once our circuit has finished running. To achieve this, we include a step called "uncomputation"
# where we repeat the segment of the circuit that computes the clauses. This uncomputation step ensures the
# desired state restoration, enabling us to effectively use the circuit as a Grover oracle.
for qubit, clause in enumerate(clauses):
ops += XOR(clause[0], clause[1], qubit + 4)

# Full Algorithm
# We can combine all the components we have discussed so far

qmodule = tq.QuantumModule.from_op_history(ops)
iterations = 2
qdev = tq.QuantumDevice(n_wires=9, device="cpu")

# Initialize output qubit (last qubit) in state |->
qdev.x(wires=8)
qdev.h(wires=8)

# Perform Grover's Search
grover = Grover(qmodule, iterations, 4)
result = grover.execute(qdev)
bitstring = result.bitstring[0]

# Extract the top two most likely solutions
res = {k: v for k, v in sorted(bitstring.items(), key=lambda item: item[1], reverse=True)}

# Print the top two most likely solutions
top_keys = list(res.keys())[:2]
print("Top two most likely solutions:")
for key in top_keys:
print("Solution: ", key)
print("a = ", key[0])
print("b = ", key[1])
print("c = ", key[2])
print("d = ", key[3])
print("")
1 change: 1 addition & 0 deletions examples/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ TorchQuantum Examples
param_shift_onchip_training/param_shift_onchip_training.ipynb
quantum_kernel_method/quantum_kernel_method.ipynb
quanvolution/quanvolution.ipynb
superdense_coding/superdense_coding_torchquantum.ipynb
13 changes: 13 additions & 0 deletions examples/superdense_coding/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Superdense Coding

Superdense coding is a quantum communication protocol that allows the transmission of two classical bits of information using only one qubit`[1]`. It takes advantage of quantum entanglement and the ability to manipulate qubits in superposition. Charles H. Bennett and Stephen Wiesner proposed this technique in 1970(though it was not published until 1992`[2]`) and it was experimentally realised in 1996 by Klaus Mattle, Harald Weinfurter, Paul G. Kwiat, and Anton Zeilinger utilising entangled photon pairs.

## Author

[Soham Bopardikar](https://github.com/bopardikarsoham)

## References

[1] Bennett, C.H., Brassard, G., Crépeau, C., Jozsa, R., Peres, A. and Wootters, W.K., 1993. Teleporting an unknown quantum state via dual classical and Einstein-Podolsky-Rosen channels. Physical review letters, 70(13), p.1895.

[2] Bennett, C.H. and Wiesner, S.J., 1992. Communication via one-and two-particle operators on Einstein-Podolsky-Rosen states. Physical review letters, 69(20), p.2881.
538 changes: 538 additions & 0 deletions examples/superdense_coding/superdense_coding_torchquantum.ipynb

Large diffs are not rendered by default.

49 changes: 49 additions & 0 deletions examples/superdense_coding/superdense_coding_torchquantum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torchquantum as tq
import torchquantum.functional as tqf
import argparse
import tqdm
import time

# Preparing the entangled state/ 2 qubit bell pair.
def bell_pair():
qdev = tq.QuantumDevice(n_wires=2, bsz=1, device="cpu")
qdev.h(wires=0)
qdev.cnot(wires=[0, 1])
return qdev

# Encoding the message
def encode_message(qdev, qubit, msg):
if len(msg) != 2 or not set(msg).issubset({"0","1"}):
raise ValueError(f"message '{msg}' is invalid")
if msg[1] == "1":
qdev.x(wires=qubit)
if msg[0] == "1":
qdev.z(wires=qubit)
return qdev

# Decoding the message
def decode_message(qdev):
qdev.cx(wires=[0, 1])
qdev.h(wires=0)
return qdev

# Putting all these functions together
def main():
# Creating the entangled pair between Alice and Bob
qdev = bell_pair()
# Encoding the message at Alice's end
message = '10'
qdev = encode_message(qdev, 1, message)
# Decoding the original message at Bob's end
qdev = decode_message(qdev)
# Finally, Bob measures his qubits to read Alice's message
print(tq.measure(qdev, n_shots=1024))

if __name__ == "__main__":
main()
17 changes: 17 additions & 0 deletions test/operator/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from tqdm import tqdm

import qiskit.circuit.library.standard_gates as qiskit_gate
import qiskit.circuit.library as qiskit_library


RND_TIMES = 100

Expand Down Expand Up @@ -73,6 +75,21 @@
# {'qiskit': qiskit_gate.?, 'tq': tq.CU2},
{"qiskit": qiskit_gate.CU3Gate, "tq": tq.CU3},
{"qiskit": qiskit_gate.ECRGate, "tq": tq.ECR},
{"qiskit": qiskit_library.QFT, "tq": tq.QFT},
{"qiskit": qiskit_gate.SdgGate, "tq": tq.SDG},
{"qiskit": qiskit_gate.TDgGate, "tq": tq.TDG},
{"qiskit": qiskit_gate.SXdgGate, "tq": tq.SXDG},
{"qiskit": qiskit_gate.CHGate, "tq": tq.CH},
{"qiskit": qiskit_gate.CCZGate, "tq": tq.CCZ},
{"qiskit": qiskit_gate.iSwapGate, "tq": tq.ISWAP},
{"qiskit": qiskit_gate.CSGate, "tq": tq.CS},
{"qiskit": qiskit_gate.CSdgGate, "tq": tq.CSDG},
{"qiskit": qiskit_gate.CSXGate, "tq": tq.CSX},
{"qiskit": qiskit_gate.DCXGate, "tq": tq.DCX},
{'qiskit': qiskit_gate.XXMinusYYGate, 'tq': tq.XXMINYY},
{'qiskit': qiskit_gate.XXPlusYYGate, 'tq': tq.XXPLUSYY},
{"qiskit": qiskit_gate.C3XGate, "tq": tq.C3X},
{"qiskit": qiskit_gate.RGate, "tq": tq.R},
]

import os
Expand Down
3 changes: 2 additions & 1 deletion torchquantum/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
"""

from .vqe import *
from .hamiltonian import *
from .hamiltonian import *
from .qft import *
116 changes: 116 additions & 0 deletions torchquantum/algorithm/grover.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import torchquantum as tq

__all__ = ["Grover"]

class GroverResult(object):
"""Result class for Grover algorithm"""
def __init__(self) -> None:
self.iterations: int

class Grover(object):
"""Grover's search algorithm based the paper "A fast quantum mechanical algorithm for database search" by Lov K. Grover
https://arxiv.org/abs/quant-ph/9605043
"""

def __init__(self, oracle: tq.module.QuantumModule, iterations: int, n_wires:int) -> None:
"""
Args:
oracle (tq.module.QuantumModule): The oracle is a quantum module that adds a negative phase to the
solution states.
iterations (int): The number of iterations to run the algorithm for.
n_wires (int): The number of qubits used in the quantum circuit.
"""
super().__init__()
self._oracle = oracle
self._iterations = iterations
self._n_wires = n_wires

def initial_state_prep(self):
"""
Prepares the initial state of a quantum circuit by applying a Hadamard gate to each qubit.
Returns:
a `QuantumModule` object that represents the initial state preparation circuit.
"""
ops = []
for i in range(self._n_wires):
ops.append({'name': 'hadamard', 'wires': i})
return tq.QuantumModule.from_op_history(ops)

def diffusion_operator(self):
"""
Prepares the diffusion operator for the grover's circuit.
Returns:
a quantum module that represents the diffusion operator for a quantum circuit.
"""
ops = []
hadamards = [{'name': 'hadamard', 'wires': i} for i in range(self._n_wires)]
flips = [{'name': 'x', 'wires': i} for i in range(self._n_wires)]

ops += hadamards
ops += flips

if self._n_wires == 1:
ops += [{'name': 'z', 'wires': 0}]
else:
ops += [{'name': 'hadamard', 'wires': self._n_wires - 1}]
ops += [{'name': 'multicnot', 'n_wires': self._n_wires, 'wires': range(self._n_wires)}]
ops += [{'name': 'hadamard', 'wires': self._n_wires - 1}]

ops += flips
ops += hadamards

return tq.QuantumModule.from_op_history(ops)


def construct_grover_circuit(self, qdev: tq.QuantumDevice):
"""
Constructs a Grover's algorithm circuit with an initial state preparation, oracle,
and diffusion operator, and iterates through them a specified number of times.
Args:
qdev (tq.QuantumDevice): tq.QuantumDevice is an object representing a quantum device or
simulator on which quantum circuits can be executed.
Returns:
the modified quantum device `qdev` after applying the Grover's algorithm circuit with the
specified number of iterations.
"""

self.initial_state_prep()(qdev)
for _ in range(self._iterations):
self._oracle(qdev)
self.diffusion_operator()(qdev)

return qdev

def execute(self, qdev: tq.QuantumDevice, n_shots: int =1024):
"""
Executes a Grover search algorithm on a given quantum device and returns the result.
Args:
qdev (tq.QuantumDevice): tq.QuantumDevice is an object representing a quantum device or
simulator on which quantum circuits can be executed.
n_shots (int): The number of times the circuit is run to obtain measurement statistics.
Defaults to 1024
Returns:
an instance of the `GroverResult` class, which contains information about the results of
running the Grover search algorithm on a quantum device. The `GroverResult` object includes the
number of iterations performed, the measured bitstring, the top measurement (i.e. the most
frequently measured bitstring), and the maximum probability of measuring the top measurement.
"""

qdev = self.construct_grover_circuit(qdev)
bitstring = tq.measure(qdev, n_shots=n_shots)
top_measurement, max_probability = max(bitstring[0].items(), key=lambda x: x[1])
max_probability /= n_shots

result = GroverResult()
result.iterations = self._iterations
result.bitstring = bitstring
result.top_measurement = top_measurement
result.max_probability = max_probability

return result
35 changes: 35 additions & 0 deletions torchquantum/algorithm/qft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torchquantum as tq
from typing import Iterable

__all__ = ["QFT"]


class QFT(object):
def __init__(
self, n_wires: int = None, wires: Iterable = None, do_swaps=True
) -> None:
"""Init function for QFT class
Args:
n_wires (int): Number of wires for the QFT as an integer
wires (Iterable): Wires to perform the QFT as an Iterable
add_swaps (bool): Whether or not to add the final swaps in a boolean format
inverse (bool): Whether to create an inverse QFT layer in a boolean format
"""
super().__init__()

self.n_wires = n_wires
self.wires = wires
self.do_swaps = do_swaps

def construct_qft_circuit(self) -> tq.QuantumModule:
"""Construct the QFT circuit."""
return tq.layer.QFTLayer(
n_wires=self.n_wires, wires=self.wires, do_swaps=self.do_swaps
)

def construct_inverse_qft_circuit(self) -> tq.QuantumModule:
"""Construct the inverse of a QFT circuit."""
return tq.layer.QFTLayer(
n_wires=self.n_wires, wires=self.wires, do_swaps=self.do_swaps, inverse=True
)
Loading

0 comments on commit e15d756

Please sign in to comment.