Skip to content

Commit 54e1526

Browse files
internal code quality attempts 3
1 parent 91583d0 commit 54e1526

File tree

8 files changed

+77
-93
lines changed

8 files changed

+77
-93
lines changed

.pylintrc

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,24 +56,6 @@ max-module-lines=2000
5656
overgeneral-exceptions=BaseException,
5757
Exception
5858

59-
60-
[BASIC]
61-
62-
good-names=Bell_pair_block,
63-
QAOA_block,
64-
Grid2D_block,
65-
dtypestr,
66-
libjax,
67-
jnp,
68-
jsp,
69-
torchlib,
70-
N,
71-
M,
72-
U,
73-
S,
74-
V,
75-
R,
76-
Q,
77-
A,
78-
E,
79-
P
59+
# Note how codecc doesn't accept goddnames in pylintrc, and how we use pylint disable invalid name per file instead
60+
# it is not neat at all, but this is codecc's badness :(
61+
# stupid to check variable name convention when you are a scientist dealing with lots of N, Pauli or QAOA

tensorcircuit/backends/ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
customized ops for ML framework
33
"""
4+
# pylint: disable=invalid-name
45

56
from typing import Any, Tuple, Sequence
67

tensorcircuit/backends/pytorch_backend.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
from typing import Any, Callable, Optional, Sequence, Tuple, Union
77

8+
import tensornetwork
89
from tensornetwork.backends.pytorch import pytorch_backend
910

1011
try: # old version tn compatiblity
@@ -30,6 +31,20 @@
3031
# To be added once pytorch backend is ready
3132

3233

34+
def _sum_torch(
35+
self: Any,
36+
tensor: Tensor,
37+
axis: Optional[Sequence[int]] = None,
38+
keepdims: bool = False,
39+
) -> Tensor:
40+
if axis is None:
41+
axis = tuple([i for i in range(len(tensor.shape))])
42+
return torchlib.sum(tensor, dim=axis, keepdim=keepdims)
43+
44+
45+
tensornetwork.backends.pytorch.pytorch_backend.PyTorchBackend.sum = _sum_torch
46+
47+
3348
class PyTorchBackend(pytorch_backend.PyTorchBackend): # type: ignore
3449
def __init__(self) -> None:
3550
super(PyTorchBackend, self).__init__()
@@ -155,33 +170,39 @@ def cond(
155170
return false_fun()
156171

157172
def switch(self, index: Tensor, branches: Sequence[Callable[[], Tensor]]) -> Tensor:
158-
return branches[index]()
173+
return branches[index.numpy()]()
159174

160175
def grad(
161176
self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0
162177
) -> Callable[..., Any]:
163178
def wrapper(*args: Any, **kws: Any) -> Any:
164-
x = []
165-
if isinstance(argnums, int):
166-
argnumsl = [argnums]
167-
# if you also call lhs as argnums, something weird may happen
168-
# the reason is that python then take it as local vars
169-
else:
170-
argnumsl = argnums # type: ignore
171-
for i, arg in enumerate(args):
172-
if i in argnumsl:
173-
x.append(arg.requires_grad_(True))
174-
else:
175-
x.append(arg)
176-
y = f(*x, **kws)
177-
y.backward()
178-
gs = [x[i].grad for i in argnumsl]
179-
if len(gs) == 1:
180-
gs = gs[0]
181-
return gs
179+
_, gr = self.value_and_grad(f, argnums)(*args, **kws)
180+
return gr
182181

183182
return wrapper
184183

184+
# def wrapper(*args: Any, **kws: Any) -> Any:
185+
# x = []
186+
# if isinstance(argnums, int):
187+
# argnumsl = [argnums]
188+
# # if you also call lhs as argnums, something weird may happen
189+
# # the reason is that python then take it as local vars
190+
# else:
191+
# argnumsl = argnums # type: ignore
192+
# for i, arg in enumerate(args):
193+
# if i in argnumsl:
194+
# x.append(arg.requires_grad_(True))
195+
# else:
196+
# x.append(arg)
197+
# y = f(*x, **kws)
198+
# y.backward()
199+
# gs = [x[i].grad for i in argnumsl]
200+
# if len(gs) == 1:
201+
# gs = gs[0]
202+
# return gs
203+
204+
# return wrapper
205+
185206
def value_and_grad(
186207
self, f: Callable[..., Any], argnums: Union[int, Sequence[int]] = 0
187208
) -> Callable[..., Tuple[Any, Any]]:

tests/test_backends.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# pylint: disable=invalid-name
2+
13
import sys
24
import os
35
from functools import partial
@@ -54,6 +56,16 @@ def test_vmap_torch(torchb):
5456
assert r.numpy()[0, 0] == 3.0
5557

5658

59+
def test_grad_torch(torchb):
60+
a = tc.backend.ones([2], dtype="float32")
61+
62+
@tc.backend.grad
63+
def f(x):
64+
return tc.backend.sum(x)
65+
66+
np.testing.assert_allclose(f(a), np.ones([2]), atol=1e-5)
67+
68+
5769
@pytest.mark.parametrize("backend", [lf("npb"), lf("tfb"), lf("jaxb")])
5870
def test_backend_scatter(backend):
5971
assert np.allclose(

tests/test_circuit.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -573,8 +573,6 @@ def f(param, max_singular_values=None, max_truncation_err=None, fixed_choice=Non
573573
s3, g3 = f_vag(tc.backend.ones([4, n]), max_singular_values=2, fixed_choice=1)
574574

575575
np.testing.assert_allclose(s1, s3, atol=1e-5)
576-
print(g1[:, :])
577-
print(g3[:, :])
578576
# DONE(@refraction-ray): nan on jax backend?
579577
# i see, complex value SVD is not supported on jax for now :)
580578
# I shall further customize complex SVD, finally it has applications
@@ -615,13 +613,9 @@ def f(param, max_singular_values=None, max_truncation_err=None, fixed_choice=Non
615613
c.rx(i, theta=param[2 * j + 1, i])
616614
loss = c.expectation(
617615
(
618-
tc.gates.z(),
616+
tc.gates.x(),
619617
[1],
620618
),
621-
(
622-
tc.gates.z(),
623-
[2],
624-
),
625619
)
626620
return tc.backend.real(loss)
627621

tests/test_dmcircuit.py

Lines changed: 16 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -109,55 +109,25 @@ def test_inputs_and_kraus():
109109

110110

111111
def test_gate_depolarizing():
112-
c = tc.DMCircuit(2)
113-
c.H(0)
114-
c.apply_general_kraus(depolarizingchannel(0.1, 0.1, 0.1), [(1,)])
115-
np.testing.assert_allclose(
116-
c.densitymatrix(check=True),
117-
np.array(
118-
[[0.4, 0, 0.4, 0], [0, 0.1, 0, 0.1], [0.4, 0, 0.4, 0], [0, 0.1, 0, 0.1]]
119-
),
120-
atol=1e-5,
112+
ans = np.array(
113+
[[0.4, 0, 0.4, 0], [0, 0.1, 0, 0.1], [0.4, 0, 0.4, 0], [0, 0.1, 0, 0.1]]
121114
)
122115

123-
c = tc.DMCircuit(2)
124-
c.H(0)
125-
kraus = depolarizingchannel(0.1, 0.1, 0.1)
126-
# kraus = [k.tensor for k in kraus]
127-
c.depolarizing(1, px=0.1, py=0.1, pz=0.1)
128-
np.testing.assert_allclose(
129-
c.densitymatrix(check=True),
130-
np.array(
131-
[[0.4, 0, 0.4, 0], [0, 0.1, 0, 0.1], [0.4, 0, 0.4, 0], [0, 0.1, 0, 0.1]]
132-
),
133-
atol=1e-5,
134-
)
135-
136-
c = tc.DMCircuit2(2)
137-
c.H(0)
138-
kraus = depolarizingchannel(0.1, 0.1, 0.1)
139-
# kraus = [k.tensor for k in kraus]
140-
c.apply_general_kraus(kraus, 1)
141-
np.testing.assert_allclose(
142-
c.densitymatrix(check=True),
143-
np.array(
144-
[[0.4, 0, 0.4, 0], [0, 0.1, 0, 0.1], [0.4, 0, 0.4, 0], [0, 0.1, 0, 0.1]]
145-
),
146-
atol=1e-5,
147-
)
116+
def check_template(c, api="v1"):
117+
c.H(0)
118+
if api == "v1":
119+
kraus = depolarizingchannel(0.1, 0.1, 0.1)
120+
c.apply_general_kraus(kraus, [(1,)])
121+
else:
122+
c.depolarizing(1, px=0.1, py=0.1, pz=0.1)
123+
np.testing.assert_allclose(
124+
c.densitymatrix(check=True),
125+
ans,
126+
atol=1e-5,
127+
)
148128

149-
c = tc.DMCircuit2(2)
150-
c.H(0)
151-
kraus = depolarizingchannel(0.1, 0.1, 0.1)
152-
# kraus = [k.tensor for k in kraus]
153-
c.depolarizing(1, px=0.1, py=0.1, pz=0.1)
154-
np.testing.assert_allclose(
155-
c.densitymatrix(check=True),
156-
np.array(
157-
[[0.4, 0, 0.4, 0], [0, 0.1, 0, 0.1], [0.4, 0, 0.4, 0], [0, 0.1, 0, 0.1]]
158-
),
159-
atol=1e-5,
160-
)
129+
for c, v in zip([tc.DMCircuit(2), tc.DMCircuit2(2)], ["v1", "v2"]):
130+
check_template(c, v)
161131

162132

163133
@pytest.mark.parametrize("backend", [lf("jaxb"), lf("tfb")])

tests/test_mpscircuit.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# pylint: disable=unused-variable
2+
# pylint: disable=invalid-name
3+
24
import sys
35
import os
46
import numpy as np

tests/test_quantum.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# pylint: disable=invalid-name
2+
13
from functools import partial
24
import os
35
import sys

0 commit comments

Comments
 (0)