Skip to content

Commit 6bf1100

Browse files
committed
Add the basic parity tensor class and its arithmetic operators.
1 parent 376ad05 commit 6bf1100

File tree

2 files changed

+154
-1
lines changed

2 files changed

+154
-1
lines changed

parity_tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
A parity tensor package.
33
"""
44

5-
__all__ = ["__version__"]
5+
__all__ = ["__version__", "ParityTensor"]
66

77
from .version import __version__
8+
from .parity_tensor import ParityTensor

parity_tensor/parity_tensor.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
"""
2+
A parity tensor class.
3+
"""
4+
5+
from __future__ import annotations
6+
7+
__all__ = ["ParityTensor"]
8+
9+
import dataclasses
10+
import torch
11+
12+
13+
@dataclasses.dataclass
14+
class ParityTensor:
15+
"""
16+
A parity tensor class, which stores a tensor along with information about its edges.
17+
Each dimension of the tensor is composed of an even and an odd part, represented as a pair of integers.
18+
"""
19+
20+
edges: tuple[tuple[int, int], ...]
21+
tensor: torch.Tensor
22+
23+
def __post_init__(self):
24+
assert len(self.edges) == self.tensor.dim(), "Edges length must match tensor dimensions."
25+
for dim, (even, odd) in zip(self.tensor.shape, self.edges):
26+
assert even >= 0 and odd >= 0 and dim == even + odd, "Each dimension must match the sum of even and odd parts."
27+
28+
def _validate_edge_compatibility(self, other: ParityTensor) -> None:
29+
"""
30+
Validate that the edges of two ParityTensor instances are compatible for arithmetic operations.
31+
"""
32+
assert self.edges == other.edges, "Edges must match for arithmetic operations."
33+
34+
def __add__(self, other):
35+
if isinstance(other, ParityTensor): # pylint: disable=no-else-return
36+
self._validate_edge_compatibility(other)
37+
return ParityTensor(
38+
edges=self.edges,
39+
tensor=self.tensor + other.tensor,
40+
)
41+
else:
42+
return ParityTensor(
43+
edges=self.edges,
44+
tensor=self.tensor + other,
45+
)
46+
47+
def __radd__(self, other):
48+
return ParityTensor(
49+
edges=self.edges,
50+
tensor=other + self.tensor,
51+
)
52+
53+
def __iadd__(self, other):
54+
if isinstance(other, ParityTensor):
55+
self._validate_edge_compatibility(other)
56+
self.tensor += other.tensor
57+
else:
58+
self.tensor += other
59+
return self
60+
61+
def __sub__(self, other):
62+
if isinstance(other, ParityTensor): # pylint: disable=no-else-return
63+
self._validate_edge_compatibility(other)
64+
return ParityTensor(
65+
edges=self.edges,
66+
tensor=self.tensor - other.tensor,
67+
)
68+
else:
69+
return ParityTensor(
70+
edges=self.edges,
71+
tensor=self.tensor - other,
72+
)
73+
74+
def __rsub__(self, other):
75+
return ParityTensor(
76+
edges=self.edges,
77+
tensor=other - self.tensor,
78+
)
79+
80+
def __isub__(self, other):
81+
if isinstance(other, ParityTensor):
82+
self._validate_edge_compatibility(other)
83+
self.tensor -= other.tensor
84+
else:
85+
self.tensor -= other
86+
return self
87+
88+
def __mul__(self, other):
89+
if isinstance(other, ParityTensor): # pylint: disable=no-else-return
90+
self._validate_edge_compatibility(other)
91+
return ParityTensor(
92+
edges=self.edges,
93+
tensor=self.tensor * other.tensor,
94+
)
95+
else:
96+
return ParityTensor(
97+
edges=self.edges,
98+
tensor=self.tensor * other,
99+
)
100+
101+
def __rmul__(self, other):
102+
return ParityTensor(
103+
edges=self.edges,
104+
tensor=other * self.tensor,
105+
)
106+
107+
def __imul__(self, other):
108+
if isinstance(other, ParityTensor):
109+
self._validate_edge_compatibility(other)
110+
self.tensor *= other.tensor
111+
else:
112+
self.tensor *= other
113+
return self
114+
115+
def __truediv__(self, other):
116+
if isinstance(other, ParityTensor): # pylint: disable=no-else-return
117+
self._validate_edge_compatibility(other)
118+
return ParityTensor(
119+
edges=self.edges,
120+
tensor=self.tensor / other.tensor,
121+
)
122+
else:
123+
return ParityTensor(
124+
edges=self.edges,
125+
tensor=self.tensor / other,
126+
)
127+
128+
def __rtruediv__(self, other):
129+
return ParityTensor(
130+
edges=self.edges,
131+
tensor=other / self.tensor,
132+
)
133+
134+
def __itruediv__(self, other):
135+
if isinstance(other, ParityTensor):
136+
self._validate_edge_compatibility(other)
137+
self.tensor /= other.tensor
138+
else:
139+
self.tensor /= other
140+
return self
141+
142+
def __pos__(self):
143+
return ParityTensor(
144+
edges=self.edges,
145+
tensor=+self.tensor,
146+
)
147+
148+
def __neg__(self):
149+
return ParityTensor(
150+
edges=self.edges,
151+
tensor=-self.tensor,
152+
)

0 commit comments

Comments
 (0)