Skip to content

Commit da41a99

Browse files
committed
Merge branch 'dev/add-generate-tensor-mask'
2 parents f44736d + d70991c commit da41a99

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

parity_tensor/parity_tensor.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
__all__ = ["ParityTensor"]
88

99
import dataclasses
10+
import functools
1011
import typing
1112
import torch
1213

@@ -20,11 +21,29 @@ class ParityTensor:
2021

2122
edges: tuple[tuple[int, int], ...]
2223
tensor: torch.Tensor
24+
mask: torch.Tensor | None = None
2325

2426
def __post_init__(self) -> None:
2527
assert len(self.edges) == self.tensor.dim(), f"Edges length ({len(self.edges)}) must match tensor dimensions ({self.tensor.dim()})."
2628
for dim, (even, odd) in zip(self.tensor.shape, self.edges):
2729
assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative."
30+
if self.mask is None:
31+
self.mask = self._tensor_mask()
32+
33+
@classmethod
34+
def _unqueeze(cls, tensor: torch.Tensor, index: int, dim: int) -> torch.Tensor:
35+
return tensor.view([-1 if i == index else 1 for i in range(dim)])
36+
37+
@classmethod
38+
def _edge_mask(cls, even: int, odd: int) -> torch.Tensor:
39+
return torch.cat([torch.zeros(even, dtype=torch.bool), torch.ones(odd, dtype=torch.bool)])
40+
41+
def _tensor_mask(self) -> torch.Tensor:
42+
return functools.reduce(
43+
torch.logical_xor,
44+
(self._unqueeze(self._edge_mask(even, odd), index, self.tensor.dim()) for index, (even, odd) in enumerate(self.edges)),
45+
torch.ones_like(self.tensor, dtype=torch.bool),
46+
)
2847

2948
def _validate_edge_compatibility(self, other: ParityTensor) -> None:
3049
"""

0 commit comments

Comments
 (0)