Skip to content

Commit db6994f

Browse files
committed
Add the basic parity tensor class and its arithmetic operators.
1 parent 376ad05 commit db6994f

File tree

2 files changed

+146
-1
lines changed

2 files changed

+146
-1
lines changed

parity_tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
A parity tensor package.
33
"""
44

5-
__all__ = ["__version__"]
5+
__all__ = ["__version__", "ParityTensor"]
66

77
from .version import __version__
8+
from .parity_tensor import ParityTensor

parity_tensor/parity_tensor.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""
2+
A parity tensor class.
3+
"""
4+
5+
__all__ = ["ParityTensor"]
6+
7+
import dataclasses
8+
import torch
9+
10+
11+
@dataclasses.dataclass
12+
class ParityTensor:
13+
"""
14+
A parity tensor class, which stores a tensor along with information about its edges.
15+
Each dimension of the tensor is composed of an even and an odd part, represented as a pair of integers.
16+
"""
17+
18+
edges: tuple[tuple[int, int], ...]
19+
tensor: torch.Tensor
20+
21+
def __post_init__(self):
22+
assert len(self.edges) == self.tensor.dim(), "Edges must match tensor dimensions."
23+
for dim, [even, odd] in zip(self.tensor.shape, self.edges):
24+
assert even >= 0 and odd >= 0 and dim == even + odd, "Each dimension must match the sum of even and odd parts."
25+
26+
def __add__(self, other):
27+
if isinstance(other, ParityTensor): # pylint: disable=no-else-return
28+
assert self.edges == other.edges, "Edges must match for arithmetic operations."
29+
return ParityTensor(
30+
edges=self.edges,
31+
tensor=self.tensor + other.tensor,
32+
)
33+
else:
34+
return ParityTensor(
35+
edges=self.edges,
36+
tensor=self.tensor + other,
37+
)
38+
39+
def __radd__(self, other):
40+
return ParityTensor(
41+
edges=self.edges,
42+
tensor=other + self.tensor,
43+
)
44+
45+
def __iadd__(self, other):
46+
if isinstance(other, ParityTensor):
47+
assert self.edges == other.edges, "Edges must match for arithmetic operations."
48+
self.tensor += other.tensor
49+
else:
50+
self.tensor += other
51+
return self
52+
53+
def __sub__(self, other):
54+
if isinstance(other, ParityTensor): # pylint: disable=no-else-return
55+
assert self.edges == other.edges, "Edges must match for arithmetic operations."
56+
return ParityTensor(
57+
edges=self.edges,
58+
tensor=self.tensor - other.tensor,
59+
)
60+
else:
61+
return ParityTensor(
62+
edges=self.edges,
63+
tensor=self.tensor - other,
64+
)
65+
66+
def __rsub__(self, other):
67+
return ParityTensor(
68+
edges=self.edges,
69+
tensor=other - self.tensor,
70+
)
71+
72+
def __isub__(self, other):
73+
if isinstance(other, ParityTensor):
74+
assert self.edges == other.edges, "Edges must match for arithmetic operations."
75+
self.tensor -= other.tensor
76+
else:
77+
self.tensor -= other
78+
return self
79+
80+
def __mul__(self, other):
81+
if isinstance(other, ParityTensor): # pylint: disable=no-else-return
82+
assert self.edges == other.edges, "Edges must match for arithmetic operations."
83+
return ParityTensor(
84+
edges=self.edges,
85+
tensor=self.tensor * other.tensor,
86+
)
87+
else:
88+
return ParityTensor(
89+
edges=self.edges,
90+
tensor=self.tensor * other,
91+
)
92+
93+
def __rmul__(self, other):
94+
return ParityTensor(
95+
edges=self.edges,
96+
tensor=other * self.tensor,
97+
)
98+
99+
def __imul__(self, other):
100+
if isinstance(other, ParityTensor):
101+
assert self.edges == other.edges, "Edges must match for arithmetic operations."
102+
self.tensor *= other.tensor
103+
else:
104+
self.tensor *= other
105+
return self
106+
107+
def __truediv__(self, other):
108+
if isinstance(other, ParityTensor): # pylint: disable=no-else-return
109+
assert self.edges == other.edges, "Edges must match for arithmetic operations."
110+
return ParityTensor(
111+
edges=self.edges,
112+
tensor=self.tensor / other.tensor,
113+
)
114+
else:
115+
return ParityTensor(
116+
edges=self.edges,
117+
tensor=self.tensor / other,
118+
)
119+
120+
def __rtruediv__(self, other):
121+
return ParityTensor(
122+
edges=self.edges,
123+
tensor=other / self.tensor,
124+
)
125+
126+
def __itruediv__(self, other):
127+
if isinstance(other, ParityTensor):
128+
assert self.edges == other.edges, "Edges must match for arithmetic operations."
129+
self.tensor /= other.tensor
130+
else:
131+
self.tensor /= other
132+
return self
133+
134+
def __pos__(self):
135+
return ParityTensor(
136+
edges=self.edges,
137+
tensor=+self.tensor,
138+
)
139+
140+
def __neg__(self):
141+
return ParityTensor(
142+
edges=self.edges,
143+
tensor=-self.tensor,
144+
)

0 commit comments

Comments
 (0)