diff --git a/parity_tensor/parity_tensor.py b/parity_tensor/parity_tensor.py index e224234..27ff0ef 100644 --- a/parity_tensor/parity_tensor.py +++ b/parity_tensor/parity_tensor.py @@ -7,6 +7,7 @@ __all__ = ["ParityTensor"] import dataclasses +import functools import typing import torch @@ -20,11 +21,29 @@ class ParityTensor: edges: tuple[tuple[int, int], ...] tensor: torch.Tensor + mask: torch.Tensor | None = None def __post_init__(self) -> None: assert len(self.edges) == self.tensor.dim(), f"Edges length ({len(self.edges)}) must match tensor dimensions ({self.tensor.dim()})." for dim, (even, odd) in zip(self.tensor.shape, self.edges): assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative." + if self.mask is None: + self.mask = self._tensor_mask() + + @classmethod + def _unqueeze(cls, tensor: torch.Tensor, index: int, dim: int) -> torch.Tensor: + return tensor.view([-1 if i == index else 1 for i in range(dim)]) + + @classmethod + def _edge_mask(cls, even: int, odd: int) -> torch.Tensor: + return torch.cat([torch.zeros(even, dtype=torch.bool), torch.ones(odd, dtype=torch.bool)]) + + def _tensor_mask(self) -> torch.Tensor: + return functools.reduce( + torch.logical_xor, + (self._unqueeze(self._edge_mask(even, odd), index, self.tensor.dim()) for index, (even, odd) in enumerate(self.edges)), + torch.ones_like(self.tensor, dtype=torch.bool), + ) def _validate_edge_compatibility(self, other: ParityTensor) -> None: """