Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

QCBM algorithm #271

Merged
merged 11 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/QCBM/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Quantum Circuit Born Machine
(Implementation by: [Gopal Ramesh Dahale](https://github.com/Gopal-Dahale))

Gopal-Dahale marked this conversation as resolved.
Show resolved Hide resolved
Quantum Circuit Born Machine (QCBM) [1] is a generative modeling algorithm which uses Born rule from quantum mechanics to sample from a quantum state $|\psi \rangle$ learned by training an ansatz $U(\theta)$ [1][2]. In this tutorial we show how `torchquantum` can be used to model a Gaussian mixture with QCBM.

## Setup

Below is the usage of `qcbm_gaussian_mixture.py` which can be obtained by running `python qcbm_gaussian_mixture.py -h`.

```
usage: qcbm_gaussian_mixture.py [-h] [--n_wires N_WIRES] [--epochs EPOCHS] [--n_blocks N_BLOCKS] [--n_layers_per_block N_LAYERS_PER_BLOCK] [--plot] [--optimizer OPTIMIZER] [--lr LR]

options:
-h, --help show this help message and exit
--n_wires N_WIRES Number of wires used in the circuit
--epochs EPOCHS Number of training epochs
--n_blocks N_BLOCKS Number of blocks in ansatz
--n_layers_per_block N_LAYERS_PER_BLOCK
Number of layers per block in ansatz
--plot Visualize the predicted probability distribution
--optimizer OPTIMIZER
optimizer class from torch.optim
--lr LR
```

For example:

```
python qcbm_gaussian_mixture.py --plot --epochs 100 --optimizer RMSprop --lr 0.01 --n_blocks 6 --n_layers_per_block 2 --n_wires 6
```

Using the command above gives an output similar to the plot below.

<p align="center">
<img src ='./assets/sample_output.png' width-500 alt='sample output of QCBM'>
</p>


## References

1. Liu, Jin-Guo, and Lei Wang. “Differentiable learning of quantum circuit born machines.” Physical Review A 98.6 (2018): 062324.
2. Gili, Kaitlin, et al. "Do quantum circuit born machines generalize?." Quantum Science and Technology 8.3 (2023): 035021.
Binary file added examples/QCBM/assets/sample_output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
255 changes: 255 additions & 0 deletions examples/QCBM/qcbm_gaussian_mixture.ipynb

Large diffs are not rendered by default.

129 changes: 129 additions & 0 deletions examples/QCBM/qcbm_gaussian_mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
from torchquantum.algorithm import QCBM, MMDLoss
import torchquantum as tq
import argparse
import os
from pprint import pprint


# Reproducibility
def set_seed(seed: int = 42) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# When running on the CuDNN backend, two further options must be set
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set a fixed value for the hash seed
os.environ["PYTHONHASHSEED"] = str(seed)
print(f"Random seed set as {seed}")


def _setup_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--n_wires", type=int, default=6, help="Number of wires used in the circuit"
)
parser.add_argument(
"--epochs", type=int, default=10, help="Number of training epochs"
)
parser.add_argument(
"--n_blocks", type=int, default=6, help="Number of blocks in ansatz"
)
parser.add_argument(
"--n_layers_per_block",
type=int,
default=1,
help="Number of layers per block in ansatz",
)
parser.add_argument(
"--plot",
action="store_true",
help="Visualize the predicted probability distribution",
)
parser.add_argument(
"--optimizer", type=str, default="Adam", help="optimizer class from torch.optim"
)
parser.add_argument("--lr", type=float, default=1e-2)
return parser


# Function to create a gaussian mixture
def gaussian_mixture_pdf(x, mus, sigmas):
mus, sigmas = np.array(mus), np.array(sigmas)
vars = sigmas**2
values = [
(1 / np.sqrt(2 * np.pi * v)) * np.exp(-((x - m) ** 2) / (2 * v))
for m, v in zip(mus, vars)
]
values = np.sum([val / sum(val) for val in values], axis=0)
return values / np.sum(values)


def main():
set_seed()
parser = _setup_parser()
args = parser.parse_args()

print("Configuration:")
pprint(vars(args))

# Create a gaussian mixture
n_wires = args.n_wires
assert n_wires >= 1, "Number of wires must be at least 1"

x_max = 2**n_wires
x_input = np.arange(x_max)
mus = [(2 / 8) * x_max, (5 / 8) * x_max]
sigmas = [x_max / 10] * 2
data = gaussian_mixture_pdf(x_input, mus, sigmas)

# This is the target distribution that the QCBM will learn
target_probs = torch.tensor(data, dtype=torch.float32)

# Ansatz
layers = tq.RXYZCXLayer0(
{
"n_blocks": args.n_blocks,
"n_wires": n_wires,
"n_layers_per_block": args.n_layers_per_block,
}
)

qcbm = QCBM(n_wires, layers)

# To train QCBMs, we use MMDLoss with radial basis function kernel.
bandwidth = torch.tensor([0.25, 60])
space = torch.arange(2**n_wires)
mmd = MMDLoss(bandwidth, space)

# Optimization
optimizer_class = getattr(torch.optim, args.optimizer)
optimizer = optimizer_class(qcbm.parameters(), lr=args.lr)

for i in range(args.epochs):
optimizer.zero_grad(set_to_none=True)
pred_probs = qcbm()
loss = mmd(pred_probs, target_probs)
loss.backward()
optimizer.step()
print(i, loss.item())

# Visualize the results
if args.plot:
with torch.no_grad():
pred_probs = qcbm()

plt.plot(x_input, target_probs, linestyle="-.", label=r"$\pi(x)$")
plt.bar(x_input, pred_probs, color="green", alpha=0.5, label="samples")
plt.xlabel("Samples")
plt.ylabel("Prob. Distribution")

plt.legend()
plt.show()


if __name__ == "__main__":
main()
31 changes: 31 additions & 0 deletions test/algorithm/test_qcbm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from torchquantum.algorithm.qcbm import QCBM, MMDLoss
import torchquantum as tq
import torch


def test_qcbm_forward():
n_wires = 3
n_layers = 3
ops = []
for l in range(n_layers):
for q in range(n_wires):
ops.append({"name": "rx", "wires": q, "params": 0.0, "trainable": True})
for q in range(n_wires - 1):
ops.append({"name": "cnot", "wires": [q, q + 1]})

data = torch.ones(2**n_wires)
qmodule = tq.QuantumModule.from_op_history(ops)
qcbm = QCBM(n_wires, qmodule)
probs = qcbm()
expected = torch.tensor([1.0, 0, 0, 0, 0, 0, 0, 0])
assert torch.allclose(probs, expected)


def test_mmd_loss():
n_wires = 2
bandwidth = torch.tensor([0.1, 1.0])
space = torch.arange(2**n_wires)

mmd = MMDLoss(bandwidth, space)
loss = mmd(torch.zeros(4), torch.zeros(4))
assert torch.isclose(loss, torch.tensor(0.0), rtol=1e-5)
9 changes: 5 additions & 4 deletions torchquantum/algorithm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
SOFTWARE.
"""

from .vqe import *
from .hamiltonian import *
from .qft import *
from .grover import *
from .vqe import VQE
from .hamiltonian import Hamiltonian
from .qft import QFT
from .grover import Grover
from .qcbm import QCBM, MMDLoss
96 changes: 96 additions & 0 deletions torchquantum/algorithm/qcbm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
import torch.nn as nn

import torchquantum as tq

__all__ = ["QCBM", "MMDLoss"]


class MMDLoss(nn.Module):
"""Squared maximum mean discrepancy with radial basis function kerne"""

def __init__(self, scales, space):
"""
Initialize MMDLoss object. Calculates and stores the kernel matrix.

Args:
scales: Bandwidth parameters.
space: Basis input space.
"""
super().__init__()

gammas = 1 / (2 * (scales**2))

# squared Euclidean distance
sq_dists = torch.abs(space[:, None] - space[None, :]) ** 2

# Kernel matrix
self.K = sum(torch.exp(-gamma * sq_dists) for gamma in gammas) / len(scales)
self.scales = scales

def k_expval(self, px, py):
"""
Kernel expectation value

Args:
px: First probability distribution
py: Second probability distribution

Returns:
Expectation value of the RBF Kernel.
"""

return px @ self.K @ py

def forward(self, px, py):
"""
Squared MMD loss.

Gopal-Dahale marked this conversation as resolved.
Show resolved Hide resolved
Args:
px: First probability distribution
py: Second probability distribution

Returns:
Squared MMD loss.
"""
pxy = px - py
return self.k_expval(pxy, pxy)


class QCBM(nn.Module):
"""
Quantum Circuit Born Machine (QCBM)

Attributes:
ansatz: An Ansatz object
n_wires: Number of wires in the ansatz used.

Methods:
__init__: Initialize the QCBM object.
forward: Returns the probability distribution (output from measurement).
"""

def __init__(self, n_wires, ansatz):
"""
Initialize QCBM object

Args:
ansatz (Ansatz): An Ansatz object
n_wires (int): Number of wires in the ansatz used.
"""
super().__init__()

self.ansatz = ansatz
self.n_wires = n_wires

def forward(self):
"""
Execute and obtain the probability distribution

Returns:
Probabilities (torch.Tensor)
"""
qdev = tq.QuantumDevice(n_wires=self.n_wires, bsz=1, device="cpu")
self.ansatz(qdev)
probs = torch.abs(qdev.states.flatten()) ** 2
return probs
Loading