77__all__ = ["ParityTensor" ]
88
99import dataclasses
10+ import functools
1011import typing
1112import torch
1213
@@ -20,11 +21,29 @@ class ParityTensor:
2021
2122 edges : tuple [tuple [int , int ], ...]
2223 tensor : torch .Tensor
24+ mask : torch .Tensor | None = None
2325
2426 def __post_init__ (self ) -> None :
2527 assert len (self .edges ) == self .tensor .dim (), f"Edges length ({ len (self .edges )} ) must match tensor dimensions ({ self .tensor .dim ()} )."
2628 for dim , (even , odd ) in zip (self .tensor .shape , self .edges ):
2729 assert even >= 0 and odd >= 0 and dim == even + odd , f"Dimension { dim } must equal sum of even ({ even } ) and odd ({ odd } ) parts, and both must be non-negative."
30+ if self .mask is None :
31+ self .mask = self ._tensor_mask ()
32+
33+ @classmethod
34+ def _unqueeze (cls , tensor : torch .Tensor , index : int , dim : int ) -> torch .Tensor :
35+ return tensor .view ([- 1 if i == index else 1 for i in range (dim )])
36+
37+ @classmethod
38+ def _edge_mask (cls , even : int , odd : int ) -> torch .Tensor :
39+ return torch .cat ([torch .zeros (even , dtype = torch .bool ), torch .ones (odd , dtype = torch .bool )])
40+
41+ def _tensor_mask (self ) -> torch .Tensor :
42+ return functools .reduce (
43+ torch .logical_xor ,
44+ (self ._unqueeze (self ._edge_mask (even , odd ), index , self .tensor .dim ()) for index , (even , odd ) in enumerate (self .edges )),
45+ torch .ones_like (self .tensor , dtype = torch .bool ),
46+ )
2847
2948 def _validate_edge_compatibility (self , other : ParityTensor ) -> None :
3049 """
0 commit comments