@@ -81,7 +81,7 @@ def permute(self, before_by_after: tuple[int, ...]) -> ParityTensor:
8181 for j in range (self .tensor .dim ())
8282 for i in range (0 , j ) # all 0 <= i < j < dim
8383 if before_by_after [i ] > before_by_after [j ]),
84- torch .zeros ([], dtype = torch .bool ),
84+ torch .zeros ([], dtype = torch .bool , device = self . tensor . device ),
8585 )
8686 tensor = torch .where (total_parity , - tensor , + tensor )
8787
@@ -97,13 +97,11 @@ def __post_init__(self) -> None:
9797 for dim , (even , odd ) in zip (self ._tensor .shape , self ._edges ):
9898 assert even >= 0 and odd >= 0 and dim == even + odd , f"Dimension { dim } must equal sum of even ({ even } ) and odd ({ odd } ) parts, and both must be non-negative."
9999
100- @classmethod
101- def _unqueeze (cls , tensor : torch .Tensor , index : int , dim : int ) -> torch .Tensor :
100+ def _unqueeze (self , tensor : torch .Tensor , index : int , dim : int ) -> torch .Tensor :
102101 return tensor .view ([- 1 if i == index else 1 for i in range (dim )])
103102
104- @classmethod
105- def _edge_mask (cls , even : int , odd : int ) -> torch .Tensor :
106- return torch .cat ([torch .zeros (even , dtype = torch .bool ), torch .ones (odd , dtype = torch .bool )])
103+ def _edge_mask (self , even : int , odd : int ) -> torch .Tensor :
104+ return torch .cat ([torch .zeros (even , dtype = torch .bool , device = self .tensor .device ), torch .ones (odd , dtype = torch .bool , device = self .tensor .device )])
107105
108106 def _tensor_mask (self ) -> torch .Tensor :
109107 return functools .reduce (
0 commit comments