-
Notifications
You must be signed in to change notification settings - Fork 208
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'unitary-hack' into qiskit2tq-parameterexpression
- Loading branch information
Showing
76 changed files
with
4,469 additions
and
919 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,8 +9,8 @@ jobs: | |
pre-commit: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: actions/setup-python@v4 | ||
- uses: actions/checkout@v4 | ||
- uses: actions/setup-python@v5 | ||
with: | ||
python-version: ${{ env.PYTHON_VERSION }} | ||
- uses: pre-commit/[email protected] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# Quantum Circuit Born Machine | ||
(Implementation by: [Gopal Ramesh Dahale](https://github.com/Gopal-Dahale)) | ||
|
||
Quantum Circuit Born Machine (QCBM) [1] is a generative modeling algorithm which uses Born rule from quantum mechanics to sample from a quantum state $|\psi \rangle$ learned by training an ansatz $U(\theta)$ [1][2]. In this tutorial we show how `torchquantum` can be used to model a Gaussian mixture with QCBM. | ||
|
||
## Setup | ||
|
||
Below is the usage of `qcbm_gaussian_mixture.py` which can be obtained by running `python qcbm_gaussian_mixture.py -h`. | ||
|
||
``` | ||
usage: qcbm_gaussian_mixture.py [-h] [--n_wires N_WIRES] [--epochs EPOCHS] [--n_blocks N_BLOCKS] [--n_layers_per_block N_LAYERS_PER_BLOCK] [--plot] [--optimizer OPTIMIZER] [--lr LR] | ||
options: | ||
-h, --help show this help message and exit | ||
--n_wires N_WIRES Number of wires used in the circuit | ||
--epochs EPOCHS Number of training epochs | ||
--n_blocks N_BLOCKS Number of blocks in ansatz | ||
--n_layers_per_block N_LAYERS_PER_BLOCK | ||
Number of layers per block in ansatz | ||
--plot Visualize the predicted probability distribution | ||
--optimizer OPTIMIZER | ||
optimizer class from torch.optim | ||
--lr LR | ||
``` | ||
|
||
For example: | ||
|
||
``` | ||
python qcbm_gaussian_mixture.py --plot --epochs 100 --optimizer RMSprop --lr 0.01 --n_blocks 6 --n_layers_per_block 2 --n_wires 6 | ||
``` | ||
|
||
Using the command above gives an output similar to the plot below. | ||
|
||
<p align="center"> | ||
<img src ='./assets/sample_output.png' width-500 alt='sample output of QCBM'> | ||
</p> | ||
|
||
|
||
## References | ||
|
||
1. Liu, Jin-Guo, and Lei Wang. “Differentiable learning of quantum circuit born machines.” Physical Review A 98.6 (2018): 062324. | ||
2. Gili, Kaitlin, et al. "Do quantum circuit born machines generalize?." Quantum Science and Technology 8.3 (2023): 035021. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import torch | ||
from torchquantum.algorithm import QCBM, MMDLoss | ||
import torchquantum as tq | ||
import argparse | ||
import os | ||
from pprint import pprint | ||
|
||
|
||
# Reproducibility | ||
def set_seed(seed: int = 42) -> None: | ||
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed(seed) | ||
# When running on the CuDNN backend, two further options must be set | ||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
# Set a fixed value for the hash seed | ||
os.environ["PYTHONHASHSEED"] = str(seed) | ||
print(f"Random seed set as {seed}") | ||
|
||
|
||
def _setup_parser(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--n_wires", type=int, default=6, help="Number of wires used in the circuit" | ||
) | ||
parser.add_argument( | ||
"--epochs", type=int, default=10, help="Number of training epochs" | ||
) | ||
parser.add_argument( | ||
"--n_blocks", type=int, default=6, help="Number of blocks in ansatz" | ||
) | ||
parser.add_argument( | ||
"--n_layers_per_block", | ||
type=int, | ||
default=1, | ||
help="Number of layers per block in ansatz", | ||
) | ||
parser.add_argument( | ||
"--plot", | ||
action="store_true", | ||
help="Visualize the predicted probability distribution", | ||
) | ||
parser.add_argument( | ||
"--optimizer", type=str, default="Adam", help="optimizer class from torch.optim" | ||
) | ||
parser.add_argument("--lr", type=float, default=1e-2) | ||
return parser | ||
|
||
|
||
# Function to create a gaussian mixture | ||
def gaussian_mixture_pdf(x, mus, sigmas): | ||
mus, sigmas = np.array(mus), np.array(sigmas) | ||
vars = sigmas**2 | ||
values = [ | ||
(1 / np.sqrt(2 * np.pi * v)) * np.exp(-((x - m) ** 2) / (2 * v)) | ||
for m, v in zip(mus, vars) | ||
] | ||
values = np.sum([val / sum(val) for val in values], axis=0) | ||
return values / np.sum(values) | ||
|
||
|
||
def main(): | ||
set_seed() | ||
parser = _setup_parser() | ||
args = parser.parse_args() | ||
|
||
print("Configuration:") | ||
pprint(vars(args)) | ||
|
||
# Create a gaussian mixture | ||
n_wires = args.n_wires | ||
assert n_wires >= 1, "Number of wires must be at least 1" | ||
|
||
x_max = 2**n_wires | ||
x_input = np.arange(x_max) | ||
mus = [(2 / 8) * x_max, (5 / 8) * x_max] | ||
sigmas = [x_max / 10] * 2 | ||
data = gaussian_mixture_pdf(x_input, mus, sigmas) | ||
|
||
# This is the target distribution that the QCBM will learn | ||
target_probs = torch.tensor(data, dtype=torch.float32) | ||
|
||
# Ansatz | ||
layers = tq.RXYZCXLayer0( | ||
{ | ||
"n_blocks": args.n_blocks, | ||
"n_wires": n_wires, | ||
"n_layers_per_block": args.n_layers_per_block, | ||
} | ||
) | ||
|
||
qcbm = QCBM(n_wires, layers) | ||
|
||
# To train QCBMs, we use MMDLoss with radial basis function kernel. | ||
bandwidth = torch.tensor([0.25, 60]) | ||
space = torch.arange(2**n_wires) | ||
mmd = MMDLoss(bandwidth, space) | ||
|
||
# Optimization | ||
optimizer_class = getattr(torch.optim, args.optimizer) | ||
optimizer = optimizer_class(qcbm.parameters(), lr=args.lr) | ||
|
||
for i in range(args.epochs): | ||
optimizer.zero_grad(set_to_none=True) | ||
pred_probs = qcbm() | ||
loss = mmd(pred_probs, target_probs) | ||
loss.backward() | ||
optimizer.step() | ||
print(i, loss.item()) | ||
|
||
# Visualize the results | ||
if args.plot: | ||
with torch.no_grad(): | ||
pred_probs = qcbm() | ||
|
||
plt.plot(x_input, target_probs, linestyle="-.", label=r"$\pi(x)$") | ||
plt.bar(x_input, pred_probs, color="green", alpha=0.5, label="samples") | ||
plt.xlabel("Samples") | ||
plt.ylabel("Prob. Distribution") | ||
|
||
plt.legend() | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.