From a6566234326db262c76c20723fdf11598a995943 Mon Sep 17 00:00:00 2001 From: Ejaaz Merali Date: Mon, 7 Sep 2020 22:16:16 -0400 Subject: [PATCH 1/5] Begin transitioning to PyTorch's complex API --- qucumber/nn_states/density_matrix.py | 4 +- qucumber/nn_states/positive_wavefunction.py | 1 - qucumber/utils/cplx.py | 86 +++++----- qucumber/utils/training_statistics.py | 16 +- qucumber/utils/unitaries.py | 54 +++---- tests/grads_utils.py | 13 +- tests/test_cplx.py | 171 +++++++++----------- tests/test_grads.py | 7 + tests/test_models_misc.py | 13 +- tox.ini | 10 +- 10 files changed, 183 insertions(+), 192 deletions(-) diff --git a/qucumber/nn_states/density_matrix.py b/qucumber/nn_states/density_matrix.py index f70c15fd..2c675e3a 100644 --- a/qucumber/nn_states/density_matrix.py +++ b/qucumber/nn_states/density_matrix.py @@ -203,7 +203,7 @@ def pi_grad(self, v, vp, phase=False, expand=False): if phase: temp = (v.unsqueeze(1) - vp.unsqueeze(0)) if expand else (v - vp) - sig = cplx.scalar_mult(sig, cplx.I) + sig = 1j * sig ab_grad_real = torch.zeros_like(self.rbm_ph.aux_bias).expand( *batch_sizes, -1 @@ -215,7 +215,7 @@ def pi_grad(self, v, vp, phase=False, expand=False): ab_grad_real = cplx.real(sig) ab_grad_imag = cplx.imag(sig) - U_grad = 0.5 * torch.einsum("c...j,...k->c...jk", sig, temp) + U_grad = 0.5 * torch.einsum("...j,...k->...jk", sig, temp) U_grad_real = cplx.real(U_grad) U_grad_imag = cplx.imag(U_grad) diff --git a/qucumber/nn_states/positive_wavefunction.py b/qucumber/nn_states/positive_wavefunction.py index 114ee705..c86843fd 100755 --- a/qucumber/nn_states/positive_wavefunction.py +++ b/qucumber/nn_states/positive_wavefunction.py @@ -120,7 +120,6 @@ def psi(self, v): each visible state :rtype: torch.Tensor """ - # vector/tensor of shape (2, len(v)) return cplx.make_complex(self.amplitude(v)) def gradient(self, v, *args, **kwargs): diff --git a/qucumber/utils/cplx.py b/qucumber/utils/cplx.py index 363dafc3..d27ba851 100755 --- a/qucumber/utils/cplx.py +++ b/qucumber/utils/cplx.py @@ -17,7 +17,7 @@ import numpy as np -I = torch.Tensor([0, 1]) # noqa: E741 +I = 1j * torch.ones(1) # noqa: E741 def make_complex(x, y=None): @@ -36,11 +36,14 @@ def make_complex(x, y=None): :rtype: torch.Tensor """ if isinstance(x, np.ndarray): - return make_complex(torch.tensor(x.real), torch.tensor(x.imag)).contiguous() + return torch.tensor(x) + # return make_complex(torch.tensor(x.real), torch.tensor(x.imag)).contiguous() if y is None: y = torch.zeros_like(x) - return torch.cat((x.unsqueeze(0), y.unsqueeze(0)), dim=0) + else: + y = torch.tensor(y) + return torch.tensor(x) + 1j * y def numpy(x): @@ -53,7 +56,7 @@ def numpy(x): :returns: A complex numpy array containing the data from `x`. :rtype: numpy.ndarray """ - return real(x).detach().cpu().numpy() + 1j * imag(x).detach().cpu().numpy() + return x.detach().cpu().numpy() def real(x): @@ -65,7 +68,7 @@ def real(x): :returns: The real part of `x`; will have one less dimension than `x`. :rtype: torch.Tensor """ - return x[0, ...] + return x.real def imag(x): @@ -77,7 +80,7 @@ def imag(x): :returns: The imaginary part of `x`; will have one less dimension than `x`. :rtype: torch.Tensor """ - return x[1, ...] + return x.imag def scalar_mult(x, y, out=None): @@ -94,16 +97,16 @@ def scalar_mult(x, y, out=None): :rtype: torch.Tensor """ y = y.to(x) - if out is None: - out = torch.zeros(2, *((real(x) * real(y)).shape)).to(x) - else: - if out is x or out is y: - raise RuntimeError("Can't overwrite an argument!") + # if out is None: + # out = torch.zeros(2, *((real(x) * real(y)).shape)).to(x) + # else: + if out is x or out is y: + raise RuntimeError("Can't overwrite an argument!") - torch.mul(real(x), real(y), out=real(out)).sub_(torch.mul(imag(x), imag(y))) - torch.mul(real(x), imag(y), out=imag(out)).add_(torch.mul(imag(x), real(y))) + # torch.mul(real(x), real(y), out=real(out)).sub_(torch.mul(imag(x), imag(y))) + # torch.mul(real(x), imag(y), out=imag(out)).add_(torch.mul(imag(x), real(y))) - return out + return torch.mul(x, y, out=out) def matmul(x, y): @@ -121,10 +124,11 @@ def matmul(x, y): :rtype: torch.Tensor """ y = y.to(x) - re = torch.matmul(real(x), real(y)).sub_(torch.matmul(imag(x), imag(y))) - im = torch.matmul(real(x), imag(y)).add_(torch.matmul(imag(x), real(y))) + # re = torch.matmul(real(x), real(y)).sub_(torch.matmul(imag(x), imag(y))) + # im = torch.matmul(real(x), imag(y)).add_(torch.matmul(imag(x), real(y))) - return make_complex(re, im) + # return make_complex(re, im) + return torch.matmul(x, y) def inner_prod(x, y): @@ -144,16 +148,10 @@ def inner_prod(x, y): """ y = y.to(x) - if x.dim() == 2 and y.dim() == 2: - return make_complex( - torch.dot(real(x), real(y)) + torch.dot(imag(x), imag(y)), - torch.dot(real(x), imag(y)) - torch.dot(imag(x), real(y)), - ) - elif x.dim() == 1 and y.dim() == 1: - return make_complex( - (real(x) * real(y)) + (imag(x) * imag(y)), - (real(x) * imag(y)) - (imag(x) * real(y)), - ) + if x.dim() == 1 and y.dim() == 1: + return torch.dot(torch.conj(x), y) + elif x.dim() == 0 and y.dim() == 0: + return torch.conj(x) * y else: raise ValueError("Unsupported input shapes!") @@ -174,14 +172,14 @@ def outer_prod(x, y): :math:`\\vert x \\rangle\\langle y\\vert`. :rtype: torch.Tensor """ - if x.dim() != 2 or y.dim() != 2: + if x.dim() != 1 or y.dim() != 1: raise ValueError("An input is not of the right dimension.") - z = torch.zeros(2, x.size()[1], y.size()[1], dtype=x.dtype, device=x.device) - z[0] = torch.ger(real(x), real(y)) - torch.ger(imag(x), -imag(y)) - z[1] = torch.ger(real(x), -imag(y)) + torch.ger(imag(x), real(y)) + # z = torch.zeros(2, x.size()[1], y.size()[1], dtype=x.dtype, device=x.device) + # z[0] = torch.ger(real(x), real(y)) - torch.ger(imag(x), -imag(y)) + # z[1] = torch.ger(real(x), -imag(y)) + torch.ger(imag(x), real(y)) - return z + return torch.einsum("i,j->ij", x, torch.conj(y)) def einsum(equation, a, b, real_part=True, imag_part=True): @@ -236,12 +234,10 @@ def conjugate(x): :returns: The conjugate of x. :rtype: torch.Tensor """ - if x.dim() < 3: + if x.dim() < 2: return conj(x) else: - return make_complex( - torch.transpose(real(x), 0, 1), -torch.transpose(imag(x), 0, 1) - ) + return torch.transpose(torch.conj(x), 0, 1) def conj(x): @@ -253,7 +249,7 @@ def conj(x): :returns: The complex conjugate of x. :rtype: torch.Tensor """ - return make_complex(real(x), -imag(x)) + return torch.conj(x) def elementwise_mult(x, y): @@ -308,11 +304,11 @@ def kronecker_prod(x, y): :returns: The Kronecker product of x and y, :math:`x \\otimes y`. :rtype: torch.Tensor """ - if not (x.dim() == y.dim() == 3): + if not (x.dim() == y.dim() == 2): raise ValueError("Inputs must be complex matrices!") return einsum("ab,cd->acbd", x, y).reshape( - 2, x.shape[1] * y.shape[1], x.shape[2] * y.shape[2] + x.shape[0] * y.shape[0], x.shape[1] * y.shape[1] ) @@ -326,12 +322,12 @@ def sigmoid(x, y): :returns: The complex sigmoid of :math:`x + iy` :rtype: torch.Tensor """ - z = (x.cpu().numpy()) + 1j * (y.cpu().numpy()) - - out = np.exp(z) / (1 + np.exp(z)) - out = torch.tensor([np.real(out), np.imag(out)]).to(x) + z = make_complex(x, y) + # exp = torch.exp(z) + # out = np.exp(z) / (1 + np.exp(z)) + # out = torch.tensor([np.real(out), np.imag(out)]).to(x) - return out + return 1.0 / (1.0 + torch.exp(-z)) def scalar_divide(x, y): @@ -347,7 +343,7 @@ def scalar_divide(x, y): :returns: x / y :rtype: torch.Tensor """ - return scalar_mult(x, inverse(y)) + return x / y def inverse(z): diff --git a/qucumber/utils/training_statistics.py b/qucumber/utils/training_statistics.py index 9f722d51..ac1fb870 100644 --- a/qucumber/utils/training_statistics.py +++ b/qucumber/utils/training_statistics.py @@ -52,13 +52,13 @@ def fidelity(nn_state, target, space=None, **kwargs): target = target.to(nn_state.device) if isinstance(nn_state, WaveFunctionBase): - assert target.dim() == 2, "target must be a complex vector!" + assert target.dim() == 1, "target must be a complex vector!" psi = nn_state.psi(space) / Z.sqrt() F = cplx.inner_prod(target, psi) return cplx.absolute_value(F).pow_(2).item() else: - assert target.dim() == 3, "target must be a complex matrix!" + assert target.dim() == 2, "target must be a complex matrix!" rho = nn_state.rho(space, space) / Z rho_rbm_ = cplx.numpy(rho) @@ -168,9 +168,9 @@ def KL(nn_state, target, space=None, bases=None, **kwargs): if bases is None: bases = list(target.keys()) else: - assert set(bases) == set( + assert set(bases) <= set( target.keys() - ), "Given bases must match the keys of the target_psi dictionary." + ), "Given bases must be a subset of the keys of the target_psi dictionary." else: target = target.to(nn_state.device) @@ -186,9 +186,9 @@ def KL(nn_state, target, space=None, bases=None, **kwargs): for basis in bases: if isinstance(target, dict): target_psi_r = target[basis] - assert target_psi_r.dim() == 2, "target must be a complex vector!" + assert target_psi_r.dim() == 1, "target must be a complex vector!" else: - assert target.dim() == 2, "target must be a complex vector!" + assert target.dim() == 1, "target must be a complex vector!" target_psi_r = rotate_psi(nn_state, basis, space, psi=target) psi_r = rotate_psi(nn_state, basis, space) @@ -202,10 +202,10 @@ def KL(nn_state, target, space=None, bases=None, **kwargs): for basis in bases: if isinstance(target, dict): target_rho_r = target[basis] - assert target_rho_r.dim() == 3, "target must be a complex matrix!" + assert target_rho_r.dim() == 2, "target must be a complex matrix!" target_probs_r = torch.diagonal(cplx.real(target_rho_r)) else: - assert target.dim() == 3, "target must be a complex matrix!" + assert target.dim() == 2, "target must be a complex matrix!" target_probs_r = rotate_rho_probs(nn_state, basis, space, rho=target) rho_r = rotate_rho_probs(nn_state, basis, space) diff --git a/qucumber/utils/unitaries.py b/qucumber/utils/unitaries.py index 2848da20..4bbbe43f 100644 --- a/qucumber/utils/unitaries.py +++ b/qucumber/utils/unitaries.py @@ -35,15 +35,15 @@ def create_dict(**kwargs): """ dictionary = { "X": torch.tensor( - [[[1.0, 1.0], [1.0, -1.0]], [[0.0, 0.0], [0.0, 0.0]]], dtype=torch.double + [[1.0 + 0.0j, 1.0 + 0.0j], [1.0 + 0.0j, -1.0 + 0.0j]], dtype=torch.cdouble ) / np.sqrt(2), "Y": torch.tensor( - [[[1.0, 0.0], [1.0, 0.0]], [[0.0, -1.0], [0.0, 1.0]]], dtype=torch.double + [[1.0 + 0.0j, 0.0 - 1.0j], [1.0 + 0.0j, 0.0 - 1.0j]], dtype=torch.cdouble ) / np.sqrt(2), "Z": torch.tensor( - [[[1.0, 0.0], [0.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]], dtype=torch.double + [[1.0 + 0.0j, 0.0 + 0.0j], [0.0 + 0.0j, 1.0 + 0.0j]], dtype=torch.cdouble ), } @@ -53,7 +53,7 @@ def create_dict(**kwargs): matrix.clone().detach() if isinstance(matrix, torch.Tensor) else torch.tensor(matrix) - ).to(dtype=torch.double) + ).to(dtype=torch.cdouble) for name, matrix in kwargs.items() } ) @@ -65,7 +65,7 @@ def _kron_mult(matrices, x): n = [m.size()[0] for m in matrices] l, r = np.prod(n), 1 # noqa: E741 - if l != x.shape[1]: # noqa: E741 + if l != x.shape[0]: # noqa: E741 raise ValueError("Incompatible sizes!") y = x.clone() @@ -76,8 +76,8 @@ def _kron_mult(matrices, x): for k in range(l): for i in range(r): slc = slice(k * n[s] * r + i, (k + 1) * n[s] * r + i, r) - temp = y[:, slc, ...] - y[:, slc, ...] = cplx.matmul(m, temp) + temp = y[slc, ...] + y[slc, ...] = cplx.matmul(m, temp) r *= n[s] return y @@ -106,7 +106,7 @@ def rotate_psi(nn_state, basis, space, unitaries=None, psi=None): psi = ( nn_state.psi(space) if psi is None - else psi.to(dtype=torch.double, device=nn_state.device) + else psi.to(dtype=torch.cdouble, device=nn_state.device) ) unitaries = unitaries if unitaries else nn_state.unitary_dict @@ -137,7 +137,7 @@ def rotate_rho(nn_state, basis, space, unitaries=None, rho=None): rho = ( nn_state.rho(space, space) if rho is None - else rho.to(dtype=torch.double, device=nn_state.device) + else rho.to(dtype=torch.cdouble, device=nn_state.device) ) unitaries = unitaries if unitaries else nn_state.unitary_dict @@ -153,30 +153,32 @@ def rotate_rho(nn_state, basis, space, unitaries=None, rho=None): # TODO: make this a generator function def _rotate_basis_state(nn_state, basis, states, unitaries=None): unitaries = unitaries if unitaries else nn_state.unitary_dict - unitaries = {k: v.to(device="cpu") for k, v in unitaries.items()} + unitaries = {k: v.to(device=nn_state.device) for k, v in unitaries.items()} basis = np.array(list(basis)) sites = np.where(basis != "Z")[0] if sites.size != 0: - Us = torch.stack([unitaries[b] for b in basis[sites]]).cpu().numpy() + Us = torch.stack([unitaries[b] for b in basis[sites]]) reps = [1 for _ in states.shape] v = states.unsqueeze(0).repeat(2 ** sites.size, *reps) v[..., sites] = nn_state.generate_hilbert_space(size=sites.size).unsqueeze(1) v = v.contiguous() - int_sample = states[..., sites].round().int().cpu().numpy() - ints_size = np.arange(sites.size) + int_sample = states[..., sites].round().long() + ints_size = torch.arange(sites.size, dtype=torch.long, device=nn_state.device) # overwrite rotated elements - int_vp = v[..., sites].long().cpu().numpy() - all_Us = Us[ints_size, :, int_sample, int_vp] + int_vp = v[..., sites].long() + all_Us = Us[ints_size, int_sample, int_vp] - Ut = np.prod(all_Us[..., 0] + (1j * all_Us[..., 1]), axis=-1) + Ut = torch.prod(all_Us.cpu(), dim=-1).to( + all_Us + ) # FIXME: prod is currently unsupported on GPU else: v = states.unsqueeze(0) - Ut = np.ones(v.shape[:-1], dtype=complex) + Ut = torch.ones(v.shape[:-1], dtype=torch.cdouble, device=nn_state.device) return Ut, v @@ -219,13 +221,12 @@ def rotate_psi_inner_prod( else: # pick out the entries of psi that we actually need idx = _convert_basis_element_to_index(v).long() - psi = psi[:, idx] + psi = psi[idx] - psi = cplx.numpy(psi.cpu()) - Ut *= psi + Upsi_v = Ut * psi - Upsi_v = cplx.make_complex(Ut).to(dtype=torch.double, device=nn_state.device) - Upsi = torch.sum(Upsi_v, dim=1) + # Upsi_v = cplx.make_complex(Ut).to(dtype=torch.cdouble, device=nn_state.device) + Upsi = torch.sum(Upsi_v, dim=0) if include_extras: return Upsi, Upsi_v, v @@ -260,19 +261,18 @@ def rotate_rho_probs( :rtype: torch.Tensor or tuple(torch.Tensor) """ Ut, v = _rotate_basis_state(nn_state, basis, states, unitaries=unitaries) - Ut = np.einsum("ib,jb->ijb", Ut, np.conj(Ut)) + Ut = torch.einsum("ib,jb->ijb", Ut, torch.conj(Ut)) if rho is None: rho = nn_state.rho(v).detach() else: # pick out the entries of rho that we actually need idx = _convert_basis_element_to_index(v).long() - rho = rho[:, idx.unsqueeze(0), idx.unsqueeze(1)] + rho = rho[idx.unsqueeze(0), idx.unsqueeze(1)] - rho = cplx.numpy(rho.cpu()) - Ut *= rho + UrhoU_v = Ut * rho - UrhoU_v = cplx.make_complex(Ut).to(dtype=torch.double, device=nn_state.device) + # UrhoU_v = cplx.make_complex(Ut).to(dtype=torch.cdouble, device=nn_state.device) UrhoU = torch.sum( cplx.real(UrhoU_v), dim=(0, 1) ) # imaginary parts will cancel out anyway diff --git a/tests/grads_utils.py b/tests/grads_utils.py index 3756685a..eace3ae8 100644 --- a/tests/grads_utils.py +++ b/tests/grads_utils.py @@ -39,7 +39,7 @@ def algorithmic_gradKL(self, target, space, **kwargs): ) for i in range(len(space)): sample_grad = self.nn_state.gradient(space[i])[0] - grad_KL += ((target[0, i]) ** 2) * sample_grad + grad_KL += (cplx.absolute_value(target[i]) ** 2) * sample_grad grad_KL -= self.nn_state.probability(space[i], Z) * sample_grad return [grad_KL] @@ -86,17 +86,16 @@ def numeric_gradNLL( class ComplexGradsUtils(PosGradsUtils): def load_target_psi(self, bases, psi_data): if isinstance(psi_data, torch.Tensor): - psi_data = psi_data.clone().detach().to(dtype=torch.double) + psi_data = psi_data.clone().detach().to(dtype=torch.cdouble) else: - psi_data = torch.tensor(psi_data, dtype=torch.double) + psi_data = torch.tensor(psi_data, dtype=torch.cdouble) psi_dict = {} - D = int(psi_data.shape[1] / float(len(bases))) + D = int(psi_data.shape[0] / float(len(bases))) for b in range(len(bases)): - psi = torch.zeros(2, D, dtype=torch.double) - psi[0, ...] = psi_data[0, b * D : (b + 1) * D] - psi[1, ...] = psi_data[1, b * D : (b + 1) * D] + psi = torch.zeros(D, dtype=torch.cdouble) + psi = psi_data[b * D : (b + 1) * D] psi_dict[bases[b]] = psi return psi_dict diff --git a/tests/test_cplx.py b/tests/test_cplx.py index 20348d1d..a86879b4 100644 --- a/tests/test_cplx.py +++ b/tests/test_cplx.py @@ -33,7 +33,7 @@ def test_make_complex_vector(self): y = torch.tensor([5, 6, 7, 8]) z = cplx.make_complex(x, y) - expect = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) + expect = torch.tensor([1 + 5j, 2 + 6j, 3 + 7j, 4 + 8j]) self.assertTensorsEqual(expect, z, msg="Make Complex Vector failed!") @@ -41,7 +41,7 @@ def test_make_complex_vector_with_zero_imaginary_part(self): x = torch.tensor([1, 2, 3, 4]) z = cplx.make_complex(x) - expect = torch.tensor([[1, 2, 3, 4], [0, 0, 0, 0]]) + expect = torch.tensor([1 + 0j, 2 + 0j, 3 + 0j, 4 + 0j]) self.assertTensorsEqual( expect, z, msg="Making a complex vector with zero imaginary part failed!" @@ -52,7 +52,7 @@ def test_make_complex_matrix(self): y = torch.tensor([[5, 6], [7, 8]]) z = cplx.make_complex(x, y) - expect = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + expect = torch.tensor([[1 + 5j, 2 + 6j], [3 + 7j, 4 + 8j]]) self.assertTensorsEqual(expect, z, msg="Make Complex Matrix failed!") @@ -61,28 +61,36 @@ def test_real_part_of_vector(self): y = torch.tensor([5, 6]) z = cplx.make_complex(x, y) - self.assertTensorsEqual(x, cplx.real(z), msg="Real part of vector failed!") + self.assertTensorsEqual( + x, cplx.real(z).to(x), msg="Real part of vector failed!" + ) def test_imag_part_of_vector(self): x = torch.tensor([1, 2]) y = torch.tensor([5, 6]) z = cplx.make_complex(x, y) - self.assertTensorsEqual(y, cplx.imag(z), msg="Imaginary part of vector failed!") + self.assertTensorsEqual( + y, cplx.imag(z).to(y), msg="Imaginary part of vector failed!" + ) def test_real_part_of_matrix(self): x = torch.tensor([[1, 2], [3, 4]]) y = torch.tensor([[5, 6], [7, 8]]) z = cplx.make_complex(x, y) - self.assertTensorsEqual(x, cplx.real(z), msg="Real part of matrix failed!") + self.assertTensorsEqual( + x, cplx.real(z).to(x), msg="Real part of matrix failed!" + ) def test_imag_part_of_matrix(self): x = torch.tensor([[1, 2], [3, 4]]) y = torch.tensor([[5, 6], [7, 8]]) z = cplx.make_complex(x, y) - self.assertTensorsEqual(y, cplx.imag(z), msg="Imaginary part of matrix failed!") + self.assertTensorsEqual( + y, cplx.imag(z).to(y), msg="Imaginary part of matrix failed!" + ) def test_real_part_of_tensor(self): x = torch.randn(3, 3, 3) @@ -109,10 +117,10 @@ def test_bad_complex_matrix(self): return cplx.make_complex(x, y) def test_elementwise_mult(self): - z1 = torch.tensor([[2, 3, 5], [6, 7, 2]], dtype=torch.double) - z2 = torch.tensor([[1, 2, 2], [3, 4, 8]], dtype=torch.double) + z1 = torch.tensor([2 + 6j, 3 + 7j, 5 + 2j]) + z2 = torch.tensor([1 + 3j, 2 + 4j, 2 + 8j]) - expect = torch.tensor([[-16, -22, -6], [12, 26, 44]], dtype=torch.double) + expect = torch.tensor([-16 + 12j, -22 + 26j, -6 + 44j]) self.assertTensorsEqual( cplx.elementwise_mult(z1, z2), @@ -121,12 +129,10 @@ def test_elementwise_mult(self): ) def test_elementwise_div(self): - z1 = torch.tensor([[2, 3, 5], [6, 7, 2]], dtype=torch.double) - z2 = torch.tensor([[1, 2, 2], [3, 4, 8]], dtype=torch.double) + z1 = torch.tensor([2 + 6j, 3 + 7j, 5 + 2j]) + z2 = torch.tensor([1 + 3j, 2 + 4j, 2 + 8j]) - expect = torch.tensor( - [[2, (17 / 10), (13 / 34)], [0, (1 / 10), (-9 / 17)]], dtype=torch.double - ) + expect = torch.tensor([2 + 0j, (17 / 10) + (1j / 10), (13 / 34) + (-9j / 17)],) self.assertTensorsAlmostEqual( cplx.elementwise_division(z1, z2), @@ -141,10 +147,10 @@ def test_elementwise_div_fail(self): return cplx.elementwise_division(z1, z2) def test_scalar_vector_mult(self): - scalar = torch.tensor([2, 3], dtype=torch.double) - vector = torch.tensor([[1, 2], [3, 4]], dtype=torch.double) + scalar = torch.tensor([2 + 3j]) + vector = torch.tensor([1 + 3j, 2 + 4j]) - expect = torch.tensor([[-7, -8], [9, 14]], dtype=torch.double) + expect = torch.tensor([-7 + 9j, -8 + 14j]) self.assertTensorsEqual( cplx.scalar_mult(scalar, vector), @@ -153,12 +159,10 @@ def test_scalar_vector_mult(self): ) def test_scalar_matrix_mult(self): - scalar = torch.tensor([2, 3], dtype=torch.double) - matrix = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.double) + scalar = torch.tensor([2 + 3j]) + matrix = torch.tensor([[1 + 5j, 2 + 6j], [3 + 7j, 4 + 8j]]) - expect = torch.tensor( - [[[-13, -14], [-15, -16]], [[13, 18], [23, 28]]], dtype=torch.double - ) + expect = torch.tensor([[-13 + 13j, -14 + 18j], [-15 + 23j, -16 + 28j]]) self.assertTensorsEqual( cplx.scalar_mult(scalar, matrix), @@ -167,11 +171,12 @@ def test_scalar_matrix_mult(self): ) def test_scalar_mult_overwrite(self): - scalar = torch.tensor([2, 3], dtype=torch.double) - vector = torch.tensor([[1, 2], [3, 4]], dtype=torch.double) - out = torch.zeros_like(vector) - expect = torch.tensor([[-7, -8], [9, 14]], dtype=torch.double) + scalar = torch.tensor([2 + 3j]) + vector = torch.tensor([1 + 3j, 2 + 4j]) + + expect = torch.tensor([-7 + 9j, -8 + 14j]) + out = torch.zeros_like(vector) cplx.scalar_mult(scalar, vector, out=out) self.assertTensorsEqual( @@ -181,8 +186,8 @@ def test_scalar_mult_overwrite(self): ) def test_scalar_mult_overwrite_fail(self): - scalar = torch.tensor([2, 3], dtype=torch.double) - vector = torch.tensor([[1, 2], [3, 4]], dtype=torch.double) + scalar = torch.tensor([2 + 3j]) + vector = torch.tensor([1 + 3j, 2 + 4j]) with self.assertRaises(RuntimeError): cplx.scalar_mult(scalar, vector, out=vector) @@ -191,10 +196,10 @@ def test_scalar_mult_overwrite_fail(self): cplx.scalar_mult(scalar, vector, out=scalar) def test_matrix_vector_matmul(self): - matrix = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.double) - vector = torch.tensor([[1, 2], [3, 4]], dtype=torch.double) + matrix = torch.tensor([[1 + 5j, 2 + 6j], [3 + 7j, 4 + 8j]]) + vector = torch.tensor([1 + 3j, 2 + 4j]) - expect = torch.tensor([[-34, -42], [28, 48]], dtype=torch.double) + expect = torch.tensor([-34 + 28j, -42 + 48j]) self.assertTensorsEqual( cplx.matmul(matrix, vector), @@ -203,11 +208,9 @@ def test_matrix_vector_matmul(self): ) def test_matrix_matrix_matmul(self): - matrix1 = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.double) - matrix2 = torch.tensor([[[1, 0], [3, 0]], [[0, 6], [0, 8]]], dtype=torch.double) - expect = torch.tensor( - [[[7, -78], [15, -106]], [[23, 22], [31, 50]]], dtype=torch.double - ) + matrix1 = torch.tensor([[1 + 5j, 2 + 6j], [3 + 7j, 4 + 8j]]) + matrix2 = torch.tensor([[1 + 0j, 0 + 6j], [3 + 0j, 0 + 8j]]) + expect = torch.tensor([[7 + 23j, -78 + 22j], [15 + 31j, -106 + 50j]]) self.assertTensorsEqual( cplx.matmul(matrix1, matrix2), expect, @@ -215,77 +218,64 @@ def test_matrix_matrix_matmul(self): ) def test_scalar_inner_prod(self): - scalar = torch.tensor([1, 2], dtype=torch.double) - expect = torch.tensor([5, 0], dtype=torch.double) + scalar = torch.tensor(1 + 2j) + expect = torch.tensor(5 + 0j) self.assertTensorsEqual( cplx.inner_prod(scalar, scalar), expect, msg="Scalar inner product failed!" ) def test_vector_inner_prod(self): - vector = torch.tensor([[1, 2], [3, 4]], dtype=torch.double) - expect = torch.tensor([30, 0], dtype=torch.double) + vector = torch.tensor([1 + 3j, 2 + 4j]) + expect = torch.tensor(30 + 0j) self.assertTensorsEqual( cplx.inner_prod(vector, vector), expect, msg="Vector inner product failed!" ) def test_outer_prod(self): - vector = torch.tensor([[1, 2], [3, 4]], dtype=torch.double) - expect = torch.tensor( - [[[10, 14], [14, 20]], [[0, 2], [-2, 0]]], dtype=torch.double - ) + vector = torch.tensor([1 + 3j, 2 + 4j]) + expect = torch.tensor([[10 + 0j, 14 + 2j], [14 - 2j, 20 + 0j]]) self.assertTensorsEqual( cplx.outer_prod(vector, vector), expect, msg="Outer product failed!" ) def test_outer_prod_error_small(self): - # take outer prod of 2 rank 1 tensors, instead of rank 2 - tensor = torch.tensor([1, 2], dtype=torch.double) + # take outer prod of 2 scalars, instead of vectors + scalar = torch.tensor(1 + 2j) with self.assertRaises(ValueError): - cplx.outer_prod(tensor, tensor) + cplx.outer_prod(scalar, scalar) def test_outer_prod_error_large(self): - # take outer prod of 2 rank 3 tensors, instead of rank 2 - tensor = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.double) + # take outer prod of 2 matrices, instead of vectors + matrix = torch.tensor([[1 + 5j, 2 + 6j], [3 + 7j, 4 + 8j]]) with self.assertRaises(ValueError): - cplx.outer_prod(tensor, tensor) + cplx.outer_prod(matrix, matrix) def test_conjugate(self): - vector = torch.tensor([[1, 2], [3, 4]], dtype=torch.double) + vector = torch.tensor([1 + 3j, 2 + 4j]) - expect = torch.tensor([[1, 2], [-3, -4]], dtype=torch.double) + expect = torch.tensor([1 - 3j, 2 - 4j]) self.assertTensorsEqual( cplx.conjugate(vector), expect, msg="Vector conjugate failed!" ) def test_matrix_conjugate(self): - matrix = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.double) - expect = torch.tensor( - [[[1, 3], [2, 4]], [[-5, -7], [-6, -8]]], dtype=torch.double - ) + matrix = torch.tensor([[1 + 5j, 2 + 6j], [3 + 7j, 4 + 8j]]) + expect = torch.tensor([[1 - 5j, 3 - 7j], [2 - 6j, 4 - 8j]]) self.assertTensorsEqual( cplx.conjugate(matrix), expect, msg="Matrix conjugate failed!" ) def test_kronecker_prod(self): - matrix = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.double) + matrix = torch.tensor([[1 + 5j, 2 + 6j], [3 + 7j, 4 + 8j]]) expect = torch.tensor( [ - [ - [-24, -28, -28, -32], - [-32, -36, -36, -40], - [-32, -36, -36, -40], - [-40, -44, -44, -48], - ], - [ - [10, 16, 16, 24], - [22, 28, 32, 40], - [22, 32, 28, 40], - [42, 52, 52, 64], - ], + [-24 + 10j, -28 + 16j, -28 + 16j, -32 + 24j], + [-32 + 22j, -36 + 28j, -36 + 32j, -40 + 40j], + [-32 + 22j, -36 + 32j, -36 + 28j, -40 + 40j], + [-40 + 42j, -44 + 52j, -44 + 52j, -48 + 64j], ], - dtype=torch.double, ) self.assertTensorsEqual( @@ -294,7 +284,7 @@ def test_kronecker_prod(self): def test_kronecker_prod_error_small(self): # take KronProd of 2 rank 2 tensors, instead of rank 3 - tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.double) + tensor = torch.tensor([1 + 3j, 2 + 4j]) with self.assertRaises(ValueError): cplx.kronecker_prod(tensor, tensor) @@ -305,9 +295,9 @@ def test_kronecker_prod_error_large(self): cplx.kronecker_prod(tensor, tensor) def test_vector_scalar_divide(self): - scalar = torch.tensor([1, 2], dtype=torch.double) - vector = torch.tensor([[1, 2], [3, 4]], dtype=torch.double) - expect = torch.tensor([[1.4, 2.0], [0.2, 0.0]], dtype=torch.double) + scalar = torch.tensor(1 + 2j) + vector = torch.tensor([1 + 3j, 2 + 4j]) + expect = torch.tensor([1.4 + 0.2j, 2.0 + 0.0j]) self.assertTensorsAlmostEqual( cplx.scalar_divide(vector, scalar), @@ -316,12 +306,10 @@ def test_vector_scalar_divide(self): ) def test_matrix_scalar_divide(self): - scalar = torch.tensor([1, 2], dtype=torch.double) - matrix = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.double) + scalar = torch.tensor(1 + 2j) + matrix = torch.tensor([[1 + 5j, 2 + 6j], [3 + 7j, 4 + 8j]]) - expect = torch.tensor( - [[[2.2, 2.8], [3.4, 4.0]], [[0.6, 0.4], [0.2, 0.0]]], dtype=torch.double - ) + expect = torch.tensor([[2.2 + 0.6j, 2.8 + 0.4j], [3.4 + 0.2j, 4.0 + 0.0j]]) self.assertTensorsAlmostEqual( cplx.scalar_divide(matrix, scalar), @@ -330,26 +318,27 @@ def test_matrix_scalar_divide(self): ) def test_norm_sqr(self): - scalar = torch.tensor([3, 4], dtype=torch.double) - expect = torch.tensor(25, dtype=torch.double) + scalar = torch.tensor(3 + 4j) + expect = torch.tensor(25) - self.assertTensorsEqual(cplx.norm_sqr(scalar), expect, msg="Norm failed!") + self.assertTensorsEqual( + cplx.norm_sqr(scalar).to(expect), expect, msg="Norm failed!" + ) def test_norm(self): - scalar = torch.tensor([3, 4], dtype=torch.double) - expect = torch.tensor(5, dtype=torch.double) + scalar = torch.tensor(3 + 4j) + expect = torch.tensor(5) - self.assertTensorsEqual(cplx.norm(scalar), expect, msg="Norm failed!") + self.assertTensorsEqual( + cplx.norm(scalar).to(expect), expect, msg="Norm failed!" + ) def test_absolute_value(self): tensor = torch.tensor( - [[[5, 5, -5, -5], [3, 6, -9, 1]], [[2, -2, 2, -2], [-7, 8, 0, 4]]], - dtype=torch.double, + [[5 + 2j, 5 - 2j, -5 + 2j, -5 - 2j], [3 - 7j, 6 + 8j, -9 + 0j, 1 + 4j]] ) - expect = torch.tensor( - [[[np.sqrt(29)] * 4, [np.sqrt(58), 10, 9, np.sqrt(17)]]], dtype=torch.double - ) + expect = torch.tensor([[np.sqrt(29)] * 4, [np.sqrt(58), 10, 9, np.sqrt(17)]]) self.assertTensorsAlmostEqual( cplx.absolute_value(tensor), expect, msg="Absolute Value failed!" diff --git a/tests/test_grads.py b/tests/test_grads.py index d276d172..ce5caf3b 100644 --- a/tests/test_grads.py +++ b/tests/test_grads.py @@ -22,6 +22,7 @@ import qucumber from qucumber.nn_states import PositiveWaveFunction, ComplexWaveFunction, DensityMatrix +from qucumber.utils import cplx from grads_utils import ComplexGradsUtils, PosGradsUtils, DensityGradsUtils from conftest import all_state_types, assertAlmostEqual, TOL @@ -42,6 +43,8 @@ def positive_wavefunction_data(request, gpu, num_hidden): data = torch.tensor(test_data["tfim1d"]["train_samples"], dtype=torch.double) target = torch.tensor(test_data["tfim1d"]["target_psi"], dtype=torch.double).t() + target = cplx.make_complex(target[0]) + num_visible = data.shape[-1] nn_state = PositiveWaveFunction(num_visible, num_hidden, gpu=gpu) @@ -79,6 +82,8 @@ def complex_wavefunction_data(request, gpu, num_hidden): test_data["2qubits"]["target_psi"], dtype=torch.double ).t() + target_psi_tmp = cplx.make_complex(target_psi_tmp[0], target_psi_tmp[1]) + num_visible = data_samples.shape[-1] nn_state = ComplexWaveFunction(num_visible, num_hidden, gpu=gpu) @@ -138,6 +143,8 @@ def density_matrix_data(request, gpu, num_hidden): test_data["density_matrix"]["density_matrix"], dtype=torch.double ) + target = cplx.make_complex(target[0], target[1]) + num_visible = data_samples.shape[-1] num_aux = num_visible + 1 diff --git a/tests/test_models_misc.py b/tests/test_models_misc.py index 64e00b29..948f858d 100644 --- a/tests/test_models_misc.py +++ b/tests/test_models_misc.py @@ -159,10 +159,13 @@ def test_density_matrix_tr1(): nn_state = DensityMatrix(5, gpu=False) space = nn_state.generate_hilbert_space(5) - matrix = nn_state.rho(space, space) / nn_state.normalization(space) + matrix = nn_state.rho(space, expand=True) + trace = torch.diagonal(matrix).sum() + # FIXME: should replace this^ with trace once it's implemented + Z = nn_state.normalization(space) msg = f"Trace of density matrix is not within {TOL} of 1!" - assertAlmostEqual(torch.trace(matrix[0]), torch.Tensor([1]), TOL, msg=msg) + assertAlmostEqual(trace / Z, torch.tensor(1.0).to(Z), TOL, msg=msg) def test_density_matrix_sizes(): @@ -172,7 +175,7 @@ def test_density_matrix_sizes(): rho = nn_state.rho(v, vp) - assert rho.shape == (2, v.shape[0], vp.shape[0]) + assert rho.shape == (v.shape[0], vp.shape[0]) ndo_matrices = [ @@ -219,7 +222,7 @@ def test_density_matrix_expansion(prop): msg = f"Diagonal of matrix {prop_name} is wrong!" - equation = "cii...->ci..." if is_complex else "ii->i" + equation = "ii...->i..." if is_complex else "ii->i" assertAlmostEqual(torch.einsum(equation, matrix), diag, TOL, msg=msg) @@ -231,7 +234,7 @@ def test_density_matrix_diagonal(): diag = nn_state.rho(v, expand=False) msg = "Diagonal of density matrix is wrong!" - assertAlmostEqual(torch.einsum("cii...->ci...", rho), diag, TOL, msg=msg) + assertAlmostEqual(torch.einsum("ii...->i...", rho), diag, TOL, msg=msg) @pytest.mark.parametrize("state_type", all_state_types) diff --git a/tox.ini b/tox.ini index 6bbafed9..8671aca0 100644 --- a/tox.ini +++ b/tox.ini @@ -21,15 +21,14 @@ skipsdist = true skip_missing_interpreters = true envlist = cov-init - py{36, 37}-torch{10, 11, 12, 13, 14, 15, 16, nightly} - py38-torch{14, 15, 16, nightly} + py{37}-torch{nightly} cov-report misc [travis] os = - linux: py{36, 37}-torch{10, 11, 12, 13, 14, 15, 16, nightly}, py38-torch{14, 15, 16, nightly} - osx: py{36, 37}-torch{10, 11, 12, 13, 14, 15, 16, nightly}, py38-torch{14, 15, 16, nightly} + linux: py{36,37,38}-torch{nightly} + osx: py{36,37,38}-torch{nightly} [travis:env] TORCH = @@ -73,8 +72,7 @@ commands = setenv = {[testenv:cov-init]setenv} depends = cov-init - py{36, 37}-torch{10, 11, 12, 13, 14, 15, 16, nightly} - py38-torch{14, 15, 16, nightly} + py{36,37,38}-torch{nightly} deps = {[testenv:cov-init]deps} commands = coverage combine From 495950d40f458df6a2df953571d027dd00b1b571 Mon Sep 17 00:00:00 2001 From: Ejaaz Merali Date: Mon, 7 Sep 2020 22:54:22 -0400 Subject: [PATCH 2/5] minor --- qucumber/utils/cplx.py | 7 +++---- qucumber/utils/unitaries.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/qucumber/utils/cplx.py b/qucumber/utils/cplx.py index d27ba851..cf0bef24 100755 --- a/qucumber/utils/cplx.py +++ b/qucumber/utils/cplx.py @@ -36,14 +36,13 @@ def make_complex(x, y=None): :rtype: torch.Tensor """ if isinstance(x, np.ndarray): - return torch.tensor(x) - # return make_complex(torch.tensor(x.real), torch.tensor(x.imag)).contiguous() + x = torch.tensor(x) if y is None: y = torch.zeros_like(x) - else: + elif isinstance(y, np.ndarray): y = torch.tensor(y) - return torch.tensor(x) + 1j * y + return x + 1j * y def numpy(x): diff --git a/qucumber/utils/unitaries.py b/qucumber/utils/unitaries.py index 4bbbe43f..a6f696a0 100644 --- a/qucumber/utils/unitaries.py +++ b/qucumber/utils/unitaries.py @@ -175,7 +175,7 @@ def _rotate_basis_state(nn_state, basis, states, unitaries=None): Ut = torch.prod(all_Us.cpu(), dim=-1).to( all_Us - ) # FIXME: prod is currently unsupported on GPU + ) # FIXME: prod for torch.complex is currently unsupported on GPU else: v = states.unsqueeze(0) Ut = torch.ones(v.shape[:-1], dtype=torch.cdouble, device=nn_state.device) From 48dedee67663656071ad730ad186c741267c63d5 Mon Sep 17 00:00:00 2001 From: Ejaaz Merali Date: Tue, 8 Sep 2020 13:36:32 -0400 Subject: [PATCH 3/5] minor --- qucumber/utils/data.py | 7 ++++--- qucumber/utils/training_statistics.py | 3 +-- tests/grads_utils.py | 2 +- tox.ini | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/qucumber/utils/data.py b/qucumber/utils/data.py index 3b529fe3..8dbea15c 100644 --- a/qucumber/utils/data.py +++ b/qucumber/utils/data.py @@ -42,9 +42,10 @@ def load_data(tr_samples_path, tr_psi_path=None, tr_bases_path=None, bases_path= if tr_psi_path is not None: target_psi_data = np.loadtxt(tr_psi_path, dtype="float32") - target_psi = torch.zeros(2, len(target_psi_data), dtype=torch.double) - target_psi[0] = torch.tensor(target_psi_data[:, 0], dtype=torch.double) - target_psi[1] = torch.tensor(target_psi_data[:, 1], dtype=torch.double) + target_psi = cplx.make_complex( + torch.tensor(target_psi_data[:, 0], dtype=torch.double), + torch.tensor(target_psi_data[:, 1], dtype=torch.double), + ) data.append(target_psi) if tr_bases_path is not None: diff --git a/qucumber/utils/training_statistics.py b/qucumber/utils/training_statistics.py index ac1fb870..c22211fd 100644 --- a/qucumber/utils/training_statistics.py +++ b/qucumber/utils/training_statistics.py @@ -179,8 +179,7 @@ def KL(nn_state, target, space=None, bases=None, **kwargs): if bases is None: target_probs = cplx.absolute_value(target) ** 2 nn_probs = nn_state.probability(space, Z) - - KL += _single_basis_KL(target_probs, nn_probs) + return _single_basis_KL(target_probs, nn_probs) elif isinstance(nn_state, WaveFunctionBase): for basis in bases: diff --git a/tests/grads_utils.py b/tests/grads_utils.py index eace3ae8..77aaadff 100644 --- a/tests/grads_utils.py +++ b/tests/grads_utils.py @@ -165,7 +165,7 @@ def algorithmic_gradKL(self, target, space, all_bases, **kwargs): else: target_r = rotate_rho_probs( self.nn_state, all_bases[b], space, rho=target - ) + ).to(dtype=torch.double) for i in range(len(space)): rotated_grad = self.nn_state.gradient(space[i], all_bases[b]) diff --git a/tox.ini b/tox.ini index 8671aca0..4c179123 100644 --- a/tox.ini +++ b/tox.ini @@ -56,7 +56,7 @@ deps = -e .[coverage] commands = torchnightly: python -m pip uninstall -y torch - torchnightly: python -m pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + torchnightly: python -m pip install -U --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html pytest --cov=qucumber --no-cov-on-fail {posargs} From a993d32f1283e39ee72d9b29c9c52a96641c47c5 Mon Sep 17 00:00:00 2001 From: Ejaaz Merali Date: Wed, 13 Jul 2022 13:59:38 -0400 Subject: [PATCH 4/5] Add some more tests to try to figure out whats going wrong --- qucumber/utils/training_statistics.py | 39 +++++++++++----- qucumber/utils/unitaries.py | 6 +-- setup.py | 2 +- tests/grads_utils.py | 64 ++++++++++++++++++++++++--- tests/test_grads.py | 27 +++++++---- tox.ini | 36 +++++++-------- 6 files changed, 125 insertions(+), 49 deletions(-) diff --git a/qucumber/utils/training_statistics.py b/qucumber/utils/training_statistics.py index c22211fd..07273eaa 100644 --- a/qucumber/utils/training_statistics.py +++ b/qucumber/utils/training_statistics.py @@ -20,7 +20,12 @@ from qucumber.nn_states import WaveFunctionBase from qucumber.utils import cplx, deprecated_kwarg -from qucumber.utils.unitaries import rotate_psi, rotate_psi_inner_prod, rotate_rho_probs +from qucumber.utils.unitaries import ( + rotate_psi, + rotate_rho, + rotate_psi_inner_prod, + rotate_rho_probs, +) @deprecated_kwarg(target_psi="target", target_rho="target") @@ -106,24 +111,33 @@ def NLL(nn_state, samples, space=None, sample_bases=None, **kwargs): unique_bases, indices = np.unique(sample_bases, axis=0, return_inverse=True) indices = torch.Tensor(indices).to(samples) - for i in range(unique_bases.shape[0]): - basis = unique_bases[i, :] - rot_sites = np.where(basis != "Z")[0] + if isinstance(nn_state, WaveFunctionBase): + for i in range(unique_bases.shape[0]): + basis = unique_bases[i, :] + rot_sites = np.where(basis != "Z")[0] - if rot_sites.size != 0: - if isinstance(nn_state, WaveFunctionBase): + if rot_sites.size != 0: Upsi = rotate_psi_inner_prod( nn_state, basis, samples[indices == i, :] ) nn_probs = (cplx.absolute_value(Upsi) ** 2) / Z else: + nn_probs = nn_state.probability(samples[indices == i, :], Z) + + NLL_ -= torch.sum(probs_to_logits(nn_probs)) + else: + for i in range(unique_bases.shape[0]): + basis = unique_bases[i, :] + rot_sites = np.where(basis != "Z")[0] + + if rot_sites.size != 0: nn_probs = ( rotate_rho_probs(nn_state, basis, samples[indices == i, :]) / Z ) - else: - nn_probs = nn_state.probability(samples[indices == i, :], Z) + else: + nn_probs = nn_state.probability(samples[indices == i, :], Z) - NLL_ -= torch.sum(probs_to_logits(nn_probs)) + NLL_ -= torch.sum(probs_to_logits(nn_probs)) return NLL_ / float(len(samples)) @@ -202,12 +216,13 @@ def KL(nn_state, target, space=None, bases=None, **kwargs): if isinstance(target, dict): target_rho_r = target[basis] assert target_rho_r.dim() == 2, "target must be a complex matrix!" - target_probs_r = torch.diagonal(cplx.real(target_rho_r)) + else: assert target.dim() == 2, "target must be a complex matrix!" - target_probs_r = rotate_rho_probs(nn_state, basis, space, rho=target) + target_rho_r = rotate_rho(nn_state, basis, space, rho=target) - rho_r = rotate_rho_probs(nn_state, basis, space) + target_probs_r = torch.diagonal(cplx.real(target_rho_r)) + rho_r = torch.diagonal(cplx.real(rotate_rho(nn_state, basis, space))) nn_probs_r = rho_r / Z KL += _single_basis_KL(target_probs_r, nn_probs_r) diff --git a/qucumber/utils/unitaries.py b/qucumber/utils/unitaries.py index a6f696a0..6313383f 100644 --- a/qucumber/utils/unitaries.py +++ b/qucumber/utils/unitaries.py @@ -73,6 +73,8 @@ def _kron_mult(matrices, x): l //= n[s] # noqa: E741 m = matrices[s] + # TODO: skip multiplication if m is an identity matrix + for k in range(l): for i in range(r): slc = slice(k * n[s] * r + i, (k + 1) * n[s] * r + i, r) @@ -173,9 +175,7 @@ def _rotate_basis_state(nn_state, basis, states, unitaries=None): int_vp = v[..., sites].long() all_Us = Us[ints_size, int_sample, int_vp] - Ut = torch.prod(all_Us.cpu(), dim=-1).to( - all_Us - ) # FIXME: prod for torch.complex is currently unsupported on GPU + Ut = torch.exp(torch.sum(torch.log(all_Us), dim=-1)) else: v = states.unsqueeze(0) Ut = torch.ones(v.shape[:-1], dtype=torch.cdouble, device=nn_state.device) diff --git a/setup.py b/setup.py index 0c4ea443..094d3e11 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ exec(f.read(), version_file) install_requires = [ - "torch>=1.0,<1.7; sys_platform != 'win32'", + "torch>=1.6,<1.12; sys_platform != 'win32'", "tqdm>=4.23", "numpy>=1.13", "scipy>=1.3.3", diff --git a/tests/grads_utils.py b/tests/grads_utils.py index 77aaadff..ce095d17 100644 --- a/tests/grads_utils.py +++ b/tests/grads_utils.py @@ -17,7 +17,7 @@ from qucumber.utils import cplx from qucumber.utils import training_statistics as ts -from qucumber.utils.unitaries import rotate_psi_inner_prod, rotate_rho_probs +from qucumber.utils.unitaries import rotate_psi_inner_prod, rotate_rho_probs, rotate_rho class PosGradsUtils: @@ -27,6 +27,9 @@ def __init__(self, nn_state): def compute_numerical_KL(self, target, space, all_bases=None): return ts.KL(self.nn_state, target, space, bases=all_bases) + def compute_numerical_KL_norot(self, target, space): + return ts.KL(self.nn_state, target, space) + def compute_numerical_NLL(self, data_samples, space, data_bases=None): return ts.NLL(self.nn_state, data_samples, space, sample_bases=data_bases) @@ -43,6 +46,9 @@ def algorithmic_gradKL(self, target, space, **kwargs): grad_KL -= self.nn_state.probability(space[i], Z) * sample_grad return [grad_KL] + def algorithmic_gradKL_norot(self, target, space, **kwargs): + return self.algorithmic_gradKL(target, space, **kwargs) + def algorithmic_gradNLL(self, data_samples, space, data_bases=None, **kwargs): return self.nn_state.compute_exact_gradients( data_samples, space, bases_batch=data_bases @@ -62,6 +68,20 @@ def numeric_gradKL(self, target, param, space, eps, all_bases=None, **kwargs): return torch.tensor(num_gradKL, dtype=torch.double).to(param) + def numeric_gradKL_norot(self, target, param, space, eps, **kwargs): + num_gradKL = [] + for i in range(len(param)): + param[i] += eps + KL_p = self.compute_numerical_KL_norot(target, space) + + param[i] -= 2 * eps + KL_m = self.compute_numerical_KL_norot(target, space) + + param[i] += eps + num_gradKL.append((KL_p - KL_m) / (2 * eps)) + + return torch.tensor(num_gradKL, dtype=torch.double).to(param) + def numeric_gradNLL( self, param, data_samples, space, eps, data_bases=None, **kwargs ): @@ -119,6 +139,24 @@ def algorithmic_gradKL(self, target, space, all_bases, **kwargs): ] Z = self.nn_state.normalization(space) + if all_bases is None: + if isinstance(target, dict): + unrotated_basis = [k for k in target.keys() if all(c == "Z" for c in k)] + target_r = cplx.absolute_value(target[unrotated_basis[0]]) ** 2 + else: + target_r = cplx.absolute_value(target) ** 2 + + probs = self.nn_state.probability(space, Z) + all_grads = self.nn_state.rbm_am.effective_energy_gradient( + space, reduce=False + ) + + grad_KL[0] += torch.mv( + all_grads.t(), target_r - probs + ) # average the gradients, weighted by probs + + return grad_KL + for b in range(len(all_bases)): if isinstance(target, dict): target_r = target[all_bases[b]] @@ -158,14 +196,30 @@ def algorithmic_gradKL(self, target, space, all_bases, **kwargs): ] Z = self.nn_state.normalization(space) + if all_bases is None: + if isinstance(target, dict): + unrotated_basis = [k for k in target.keys() if all(c == "Z" for c in k)] + target_r = torch.diagonal(cplx.real(target[unrotated_basis[0]])) + else: + target_r = torch.diagonal(cplx.real(target)) + + probs = self.nn_state.probability(space, Z) + all_grads = self.nn_state.rbm_am.effective_energy_gradient( + space, reduce=False + ) + + grad_KL[0] += torch.mv( + all_grads.t(), target_r - probs + ) # average the gradients, weighted by probs + + return grad_KL + for b in range(len(all_bases)): if isinstance(target, dict): target_r = target[all_bases[b]] - target_r = torch.diagonal(cplx.real(target_r)) else: - target_r = rotate_rho_probs( - self.nn_state, all_bases[b], space, rho=target - ).to(dtype=torch.double) + target_r = rotate_rho(self.nn_state, all_bases[b], space, rho=target) + target_r = torch.diagonal(cplx.real(target_r)) for i in range(len(space)): rotated_grad = self.nn_state.gradient(space[i], all_bases[b]) diff --git a/tests/test_grads.py b/tests/test_grads.py index ce5caf3b..58b027fc 100644 --- a/tests/test_grads.py +++ b/tests/test_grads.py @@ -186,7 +186,7 @@ def density_matrix_data(request, gpu, num_hidden): hidden_layer_sizes = [pytest.param(9, id="9", marks=[pytest.mark.extra]), 10] -grad_types = ["KL", "NLL"] +grad_types = ["KL", "NLL", "KL-norot", "NLL-norot"] @pytest.fixture(scope="module", params=all_state_types) @@ -214,14 +214,19 @@ def quantum_state_graddata(request, quantum_state_data): grad_type = request.param nn_state, grad_utils = quantum_state_data.nn_state, quantum_state_data.grad_utils - if grad_type == "KL": + if "KL" in grad_type: alg_grad_fn = grad_utils.algorithmic_gradKL num_grad_fn = grad_utils.numeric_gradKL else: alg_grad_fn = grad_utils.algorithmic_gradNLL num_grad_fn = grad_utils.numeric_gradNLL - alg_grads = alg_grad_fn(**quantum_state_data._asdict()) + param_dict = quantum_state_data._asdict() + if "norot" in grad_type: + param_dict["data_bases"] = None + param_dict["all_bases"] = None + + alg_grads = alg_grad_fn(**param_dict) num_grads = [None for _ in nn_state.networks] for n, net in enumerate(nn_state.networks): @@ -231,9 +236,9 @@ def quantum_state_graddata(request, quantum_state_data): num_grad = torch.cat( ( num_grad, - num_grad_fn( - param=param.view(-1), eps=EPS, **quantum_state_data._asdict() - ).to(num_grad), + num_grad_fn(param=param.view(-1), eps=EPS, **param_dict).to( + num_grad + ), ) ) num_grads[n] = num_grad @@ -251,16 +256,20 @@ def get_param_status(i, param_ranges): return p, i == rng[0] -def test_grads(quantum_state_graddata): +@pytest.mark.parametrize("network", ["rbm_am", "rbm_ph"]) +def test_grads(network, quantum_state_graddata): nn_state, alg_grads, num_grads, grad_type, test_tol = quantum_state_graddata print( - "\nTesting {} gradients for {} on {}.".format( - grad_type, nn_state.__class__.__name__, nn_state.device + "\nTesting {} gradients of {} for {} on {}.".format( + grad_type, network, nn_state.__class__.__name__, nn_state.device ) ) for n, net in enumerate(nn_state.networks): + if net != network: + continue + print("\nRBM: %s" % net) rbm = getattr(nn_state, net) diff --git a/tox.ini b/tox.ini index 4c179123..43bdc5c8 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ # and then run "tox" from this directory. [pytest] -addopts = --strict -v +addopts = --strict-markers -v markers = gpu slow @@ -21,37 +21,35 @@ skipsdist = true skip_missing_interpreters = true envlist = cov-init - py{37}-torch{nightly} + py{36,37,38}-torch{17,18,19,110,111,112,nightly} cov-report misc [travis] os = - linux: py{36,37,38}-torch{nightly} - osx: py{36,37,38}-torch{nightly} + linux: py{36,37,38}-torch{17,18,19,110,111,112,nightly} + osx: py{36,37,38}-torch{17,18,19,110,111,112,nightly} [travis:env] TORCH = - 1.0: torch10 - 1.1: torch11 - 1.2: torch12 - 1.3: torch13 - 1.4: torch14 - 1.5: torch15 - 1.6: torch16 + 1.7: torch17 + 1.8: torch18 + 1.9: torch19 + 1.10: torch110 + 1.11: torch111 + 1.12: torch112 nightly: torchnightly [testenv] setenv = COVERAGE_FILE = .coverage.{envname} deps = - torch10: torch>=1.0,<1.1 - torch11: torch>=1.1,<1.2 - torch12: torch>=1.2,<1.3 - torch13: torch>=1.3,<1.4 - torch14: torch>=1.4,<1.5 - torch15: torch>=1.5,<1.6 - torch16: torch>=1.6,<1.7 + torch17: torch>=1.7,<1.8 + torch18: torch>=1.8,<1.9 + torch19: torch>=1.9,<1.10 + torch110: torch>=1.10,<1.11 + torch111: torch>=1.11,<1.12 + torch112: torch>=1.12,<1.13 pytest-xdist -e .[coverage] commands = @@ -72,7 +70,7 @@ commands = setenv = {[testenv:cov-init]setenv} depends = cov-init - py{36,37,38}-torch{nightly} + py{36,37,38}-torch{17,18,19,110,111,112,nightly} deps = {[testenv:cov-init]deps} commands = coverage combine From bc54bb4a2576dd3ddf6f4b4f3205993dffd8b036 Mon Sep 17 00:00:00 2001 From: Ejaaz Merali Date: Thu, 13 Oct 2022 16:02:32 -0400 Subject: [PATCH 5/5] minor --- tests/grads_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/grads_utils.py b/tests/grads_utils.py index ce095d17..137b8e01 100644 --- a/tests/grads_utils.py +++ b/tests/grads_utils.py @@ -137,7 +137,7 @@ def algorithmic_gradKL(self, target, space, all_bases, **kwargs): device=self.nn_state.device, ), ] - Z = self.nn_state.normalization(space) + # Z = self.nn_state.normalization(space) if all_bases is None: if isinstance(target, dict): @@ -146,7 +146,8 @@ def algorithmic_gradKL(self, target, space, all_bases, **kwargs): else: target_r = cplx.absolute_value(target) ** 2 - probs = self.nn_state.probability(space, Z) + probs = self.nn_state.probability(space) + probs /= probs.sum() all_grads = self.nn_state.rbm_am.effective_energy_gradient( space, reduce=False )