Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Transition to PyTorch's complex API #120

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions qucumber/nn_states/density_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def pi_grad(self, v, vp, phase=False, expand=False):

if phase:
temp = (v.unsqueeze(1) - vp.unsqueeze(0)) if expand else (v - vp)
sig = cplx.scalar_mult(sig, cplx.I)
sig = 1j * sig

ab_grad_real = torch.zeros_like(self.rbm_ph.aux_bias).expand(
*batch_sizes, -1
Expand All @@ -215,7 +215,7 @@ def pi_grad(self, v, vp, phase=False, expand=False):
ab_grad_real = cplx.real(sig)
ab_grad_imag = cplx.imag(sig)

U_grad = 0.5 * torch.einsum("c...j,...k->c...jk", sig, temp)
U_grad = 0.5 * torch.einsum("...j,...k->...jk", sig, temp)
U_grad_real = cplx.real(U_grad)
U_grad_imag = cplx.imag(U_grad)

Expand Down
1 change: 0 additions & 1 deletion qucumber/nn_states/positive_wavefunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def psi(self, v):
each visible state
:rtype: torch.Tensor
"""
# vector/tensor of shape (2, len(v))
return cplx.make_complex(self.amplitude(v))

def gradient(self, v, *args, **kwargs):
Expand Down
85 changes: 40 additions & 45 deletions qucumber/utils/cplx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np


I = torch.Tensor([0, 1]) # noqa: E741
I = 1j * torch.ones(1) # noqa: E741


def make_complex(x, y=None):
Expand All @@ -36,11 +36,13 @@ def make_complex(x, y=None):
:rtype: torch.Tensor
"""
if isinstance(x, np.ndarray):
return make_complex(torch.tensor(x.real), torch.tensor(x.imag)).contiguous()
x = torch.tensor(x)

if y is None:
y = torch.zeros_like(x)
return torch.cat((x.unsqueeze(0), y.unsqueeze(0)), dim=0)
elif isinstance(y, np.ndarray):
y = torch.tensor(y)
return x + 1j * y


def numpy(x):
Expand All @@ -53,7 +55,7 @@ def numpy(x):
:returns: A complex numpy array containing the data from `x`.
:rtype: numpy.ndarray
"""
return real(x).detach().cpu().numpy() + 1j * imag(x).detach().cpu().numpy()
return x.detach().cpu().numpy()


def real(x):
Expand All @@ -65,7 +67,7 @@ def real(x):
:returns: The real part of `x`; will have one less dimension than `x`.
:rtype: torch.Tensor
"""
return x[0, ...]
return x.real


def imag(x):
Expand All @@ -77,7 +79,7 @@ def imag(x):
:returns: The imaginary part of `x`; will have one less dimension than `x`.
:rtype: torch.Tensor
"""
return x[1, ...]
return x.imag


def scalar_mult(x, y, out=None):
Expand All @@ -94,16 +96,16 @@ def scalar_mult(x, y, out=None):
:rtype: torch.Tensor
"""
y = y.to(x)
if out is None:
out = torch.zeros(2, *((real(x) * real(y)).shape)).to(x)
else:
if out is x or out is y:
raise RuntimeError("Can't overwrite an argument!")
# if out is None:
# out = torch.zeros(2, *((real(x) * real(y)).shape)).to(x)
# else:
if out is x or out is y:
raise RuntimeError("Can't overwrite an argument!")

torch.mul(real(x), real(y), out=real(out)).sub_(torch.mul(imag(x), imag(y)))
torch.mul(real(x), imag(y), out=imag(out)).add_(torch.mul(imag(x), real(y)))
# torch.mul(real(x), real(y), out=real(out)).sub_(torch.mul(imag(x), imag(y)))
# torch.mul(real(x), imag(y), out=imag(out)).add_(torch.mul(imag(x), real(y)))

return out
return torch.mul(x, y, out=out)


def matmul(x, y):
Expand All @@ -121,10 +123,11 @@ def matmul(x, y):
:rtype: torch.Tensor
"""
y = y.to(x)
re = torch.matmul(real(x), real(y)).sub_(torch.matmul(imag(x), imag(y)))
im = torch.matmul(real(x), imag(y)).add_(torch.matmul(imag(x), real(y)))
# re = torch.matmul(real(x), real(y)).sub_(torch.matmul(imag(x), imag(y)))
# im = torch.matmul(real(x), imag(y)).add_(torch.matmul(imag(x), real(y)))

return make_complex(re, im)
# return make_complex(re, im)
return torch.matmul(x, y)


def inner_prod(x, y):
Expand All @@ -144,16 +147,10 @@ def inner_prod(x, y):
"""
y = y.to(x)

if x.dim() == 2 and y.dim() == 2:
return make_complex(
torch.dot(real(x), real(y)) + torch.dot(imag(x), imag(y)),
torch.dot(real(x), imag(y)) - torch.dot(imag(x), real(y)),
)
elif x.dim() == 1 and y.dim() == 1:
return make_complex(
(real(x) * real(y)) + (imag(x) * imag(y)),
(real(x) * imag(y)) - (imag(x) * real(y)),
)
if x.dim() == 1 and y.dim() == 1:
return torch.dot(torch.conj(x), y)
elif x.dim() == 0 and y.dim() == 0:
return torch.conj(x) * y
else:
raise ValueError("Unsupported input shapes!")

Expand All @@ -174,14 +171,14 @@ def outer_prod(x, y):
:math:`\\vert x \\rangle\\langle y\\vert`.
:rtype: torch.Tensor
"""
if x.dim() != 2 or y.dim() != 2:
if x.dim() != 1 or y.dim() != 1:
raise ValueError("An input is not of the right dimension.")

z = torch.zeros(2, x.size()[1], y.size()[1], dtype=x.dtype, device=x.device)
z[0] = torch.ger(real(x), real(y)) - torch.ger(imag(x), -imag(y))
z[1] = torch.ger(real(x), -imag(y)) + torch.ger(imag(x), real(y))
# z = torch.zeros(2, x.size()[1], y.size()[1], dtype=x.dtype, device=x.device)
# z[0] = torch.ger(real(x), real(y)) - torch.ger(imag(x), -imag(y))
# z[1] = torch.ger(real(x), -imag(y)) + torch.ger(imag(x), real(y))

return z
return torch.einsum("i,j->ij", x, torch.conj(y))


def einsum(equation, a, b, real_part=True, imag_part=True):
Expand Down Expand Up @@ -236,12 +233,10 @@ def conjugate(x):
:returns: The conjugate of x.
:rtype: torch.Tensor
"""
if x.dim() < 3:
if x.dim() < 2:
return conj(x)
else:
return make_complex(
torch.transpose(real(x), 0, 1), -torch.transpose(imag(x), 0, 1)
)
return torch.transpose(torch.conj(x), 0, 1)


def conj(x):
Expand All @@ -253,7 +248,7 @@ def conj(x):
:returns: The complex conjugate of x.
:rtype: torch.Tensor
"""
return make_complex(real(x), -imag(x))
return torch.conj(x)


def elementwise_mult(x, y):
Expand Down Expand Up @@ -308,11 +303,11 @@ def kronecker_prod(x, y):
:returns: The Kronecker product of x and y, :math:`x \\otimes y`.
:rtype: torch.Tensor
"""
if not (x.dim() == y.dim() == 3):
if not (x.dim() == y.dim() == 2):
raise ValueError("Inputs must be complex matrices!")

return einsum("ab,cd->acbd", x, y).reshape(
2, x.shape[1] * y.shape[1], x.shape[2] * y.shape[2]
x.shape[0] * y.shape[0], x.shape[1] * y.shape[1]
)


Expand All @@ -326,12 +321,12 @@ def sigmoid(x, y):
:returns: The complex sigmoid of :math:`x + iy`
:rtype: torch.Tensor
"""
z = (x.cpu().numpy()) + 1j * (y.cpu().numpy())

out = np.exp(z) / (1 + np.exp(z))
out = torch.tensor([np.real(out), np.imag(out)]).to(x)
z = make_complex(x, y)
# exp = torch.exp(z)
# out = np.exp(z) / (1 + np.exp(z))
# out = torch.tensor([np.real(out), np.imag(out)]).to(x)

return out
return 1.0 / (1.0 + torch.exp(-z))


def scalar_divide(x, y):
Expand All @@ -347,7 +342,7 @@ def scalar_divide(x, y):
:returns: x / y
:rtype: torch.Tensor
"""
return scalar_mult(x, inverse(y))
return x / y


def inverse(z):
Expand Down
7 changes: 4 additions & 3 deletions qucumber/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ def load_data(tr_samples_path, tr_psi_path=None, tr_bases_path=None, bases_path=

if tr_psi_path is not None:
target_psi_data = np.loadtxt(tr_psi_path, dtype="float32")
target_psi = torch.zeros(2, len(target_psi_data), dtype=torch.double)
target_psi[0] = torch.tensor(target_psi_data[:, 0], dtype=torch.double)
target_psi[1] = torch.tensor(target_psi_data[:, 1], dtype=torch.double)
target_psi = cplx.make_complex(
torch.tensor(target_psi_data[:, 0], dtype=torch.double),
torch.tensor(target_psi_data[:, 1], dtype=torch.double),
)
data.append(target_psi)

if tr_bases_path is not None:
Expand Down
58 changes: 36 additions & 22 deletions qucumber/utils/training_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@

from qucumber.nn_states import WaveFunctionBase
from qucumber.utils import cplx, deprecated_kwarg
from qucumber.utils.unitaries import rotate_psi, rotate_psi_inner_prod, rotate_rho_probs
from qucumber.utils.unitaries import (
rotate_psi,
rotate_rho,
rotate_psi_inner_prod,
rotate_rho_probs,
)


@deprecated_kwarg(target_psi="target", target_rho="target")
Expand Down Expand Up @@ -52,13 +57,13 @@ def fidelity(nn_state, target, space=None, **kwargs):
target = target.to(nn_state.device)

if isinstance(nn_state, WaveFunctionBase):
assert target.dim() == 2, "target must be a complex vector!"
assert target.dim() == 1, "target must be a complex vector!"

psi = nn_state.psi(space) / Z.sqrt()
F = cplx.inner_prod(target, psi)
return cplx.absolute_value(F).pow_(2).item()
else:
assert target.dim() == 3, "target must be a complex matrix!"
assert target.dim() == 2, "target must be a complex matrix!"

rho = nn_state.rho(space, space) / Z
rho_rbm_ = cplx.numpy(rho)
Expand Down Expand Up @@ -106,24 +111,33 @@ def NLL(nn_state, samples, space=None, sample_bases=None, **kwargs):
unique_bases, indices = np.unique(sample_bases, axis=0, return_inverse=True)
indices = torch.Tensor(indices).to(samples)

for i in range(unique_bases.shape[0]):
basis = unique_bases[i, :]
rot_sites = np.where(basis != "Z")[0]
if isinstance(nn_state, WaveFunctionBase):
for i in range(unique_bases.shape[0]):
basis = unique_bases[i, :]
rot_sites = np.where(basis != "Z")[0]

if rot_sites.size != 0:
if isinstance(nn_state, WaveFunctionBase):
if rot_sites.size != 0:
Upsi = rotate_psi_inner_prod(
nn_state, basis, samples[indices == i, :]
)
nn_probs = (cplx.absolute_value(Upsi) ** 2) / Z
else:
nn_probs = nn_state.probability(samples[indices == i, :], Z)

NLL_ -= torch.sum(probs_to_logits(nn_probs))
else:
for i in range(unique_bases.shape[0]):
basis = unique_bases[i, :]
rot_sites = np.where(basis != "Z")[0]

if rot_sites.size != 0:
nn_probs = (
rotate_rho_probs(nn_state, basis, samples[indices == i, :]) / Z
)
else:
nn_probs = nn_state.probability(samples[indices == i, :], Z)
else:
nn_probs = nn_state.probability(samples[indices == i, :], Z)

NLL_ -= torch.sum(probs_to_logits(nn_probs))
NLL_ -= torch.sum(probs_to_logits(nn_probs))

return NLL_ / float(len(samples))

Expand Down Expand Up @@ -168,9 +182,9 @@ def KL(nn_state, target, space=None, bases=None, **kwargs):
if bases is None:
bases = list(target.keys())
else:
assert set(bases) == set(
assert set(bases) <= set(
target.keys()
), "Given bases must match the keys of the target_psi dictionary."
), "Given bases must be a subset of the keys of the target_psi dictionary."
else:
target = target.to(nn_state.device)

Expand All @@ -179,16 +193,15 @@ def KL(nn_state, target, space=None, bases=None, **kwargs):
if bases is None:
target_probs = cplx.absolute_value(target) ** 2
nn_probs = nn_state.probability(space, Z)

KL += _single_basis_KL(target_probs, nn_probs)
return _single_basis_KL(target_probs, nn_probs)

elif isinstance(nn_state, WaveFunctionBase):
for basis in bases:
if isinstance(target, dict):
target_psi_r = target[basis]
assert target_psi_r.dim() == 2, "target must be a complex vector!"
assert target_psi_r.dim() == 1, "target must be a complex vector!"
else:
assert target.dim() == 2, "target must be a complex vector!"
assert target.dim() == 1, "target must be a complex vector!"
target_psi_r = rotate_psi(nn_state, basis, space, psi=target)

psi_r = rotate_psi(nn_state, basis, space)
Expand All @@ -202,13 +215,14 @@ def KL(nn_state, target, space=None, bases=None, **kwargs):
for basis in bases:
if isinstance(target, dict):
target_rho_r = target[basis]
assert target_rho_r.dim() == 3, "target must be a complex matrix!"
target_probs_r = torch.diagonal(cplx.real(target_rho_r))
assert target_rho_r.dim() == 2, "target must be a complex matrix!"

else:
assert target.dim() == 3, "target must be a complex matrix!"
target_probs_r = rotate_rho_probs(nn_state, basis, space, rho=target)
assert target.dim() == 2, "target must be a complex matrix!"
target_rho_r = rotate_rho(nn_state, basis, space, rho=target)

rho_r = rotate_rho_probs(nn_state, basis, space)
target_probs_r = torch.diagonal(cplx.real(target_rho_r))
rho_r = torch.diagonal(cplx.real(rotate_rho(nn_state, basis, space)))
nn_probs_r = rho_r / Z

KL += _single_basis_KL(target_probs_r, nn_probs_r)
Expand Down
Loading