Skip to content

Commit 3e97afa

Browse files
committed
Add support to more feasiable function to.
1 parent f6fb29e commit 3e97afa

File tree

1 file changed

+42
-8
lines changed

1 file changed

+42
-8
lines changed

parity_tensor/parity_tensor.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,50 @@ def mask(self) -> torch.Tensor:
6464
self._mask = self._tensor_mask()
6565
return self._mask
6666

67-
def to(self, device: torch.device) -> ParityTensor:
67+
def to(self, whatever: torch.device | torch.dtype | str, device: torch.device | None = None, dtype: torch.dtype | None = None) -> ParityTensor:
6868
"""
69-
Copy the tensor to a specified device.
69+
Copy the tensor to a specified device or copy it to a specified data type.
7070
"""
71-
return dataclasses.replace(
72-
self,
73-
_tensor=self._tensor.to(device),
74-
_parity=tuple(p.to(device) for p in self._parity) if self._parity is not None else None,
75-
_mask=self._mask.to(device) if self._mask is not None else None,
76-
)
71+
if isinstance(whatever, torch.device):
72+
assert device is None, "Duplicate device specification."
73+
device = whatever
74+
elif isinstance(whatever, torch.dtype):
75+
assert dtype is None, "Duplicate dtype specification."
76+
dtype = whatever
77+
elif isinstance(whatever, str):
78+
if dtype is None:
79+
dtype = torch.dtype(whatever) # type: ignore[call-arg]
80+
elif device is None:
81+
device = torch.device(whatever)
82+
else:
83+
try:
84+
dtype = torch.dtype(whatever) # type: ignore[call-arg]
85+
except TypeError:
86+
device = torch.device(whatever)
87+
else:
88+
raise TypeError(f"Unsupported type for 'to': {type(whatever)}. Expected torch.device, torch.dtype, or str.")
89+
match (device, dtype):
90+
case (None, None):
91+
return self
92+
case (None, _):
93+
return dataclasses.replace(
94+
self,
95+
_tensor=self._tensor.to(dtype=dtype),
96+
)
97+
case (_, None):
98+
return dataclasses.replace(
99+
self,
100+
_tensor=self._tensor.to(device=device),
101+
_parity=tuple(p.to(device) for p in self._parity) if self._parity is not None else None,
102+
_mask=self._mask.to(device) if self._mask is not None else None,
103+
)
104+
case _:
105+
return dataclasses.replace(
106+
self,
107+
_tensor=self._tensor.to(device=device, dtype=dtype),
108+
_parity=tuple(p.to(device=device, dtype=dtype) for p in self._parity) if self._parity is not None else None,
109+
_mask=self._mask.to(device=device, dtype=dtype) if self._mask is not None else None,
110+
)
77111

78112
def update_mask(self) -> ParityTensor:
79113
"""

0 commit comments

Comments
 (0)