Skip to content

Commit 691cd83

Browse files
committed
Follow self.tensor.device during various operators.
1 parent 1d4f421 commit 691cd83

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

parity_tensor/parity_tensor.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def permute(self, before_by_after: tuple[int, ...]) -> ParityTensor:
8181
for j in range(self.tensor.dim())
8282
for i in range(0, j) # all 0 <= i < j < dim
8383
if before_by_after[i] > before_by_after[j]),
84-
torch.zeros([], dtype=torch.bool),
84+
torch.zeros([], dtype=torch.bool, device=self.tensor.device),
8585
)
8686
tensor = torch.where(total_parity, -tensor, +tensor)
8787

@@ -97,13 +97,11 @@ def __post_init__(self) -> None:
9797
for dim, (even, odd) in zip(self._tensor.shape, self._edges):
9898
assert even >= 0 and odd >= 0 and dim == even + odd, f"Dimension {dim} must equal sum of even ({even}) and odd ({odd}) parts, and both must be non-negative."
9999

100-
@classmethod
101-
def _unqueeze(cls, tensor: torch.Tensor, index: int, dim: int) -> torch.Tensor:
100+
def _unqueeze(self, tensor: torch.Tensor, index: int, dim: int) -> torch.Tensor:
102101
return tensor.view([-1 if i == index else 1 for i in range(dim)])
103102

104-
@classmethod
105-
def _edge_mask(cls, even: int, odd: int) -> torch.Tensor:
106-
return torch.cat([torch.zeros(even, dtype=torch.bool), torch.ones(odd, dtype=torch.bool)])
103+
def _edge_mask(self, even: int, odd: int) -> torch.Tensor:
104+
return torch.cat([torch.zeros(even, dtype=torch.bool, device=self.tensor.device), torch.ones(odd, dtype=torch.bool, device=self.tensor.device)])
107105

108106
def _tensor_mask(self) -> torch.Tensor:
109107
return functools.reduce(

0 commit comments

Comments
 (0)