Skip to content

Commit fb4141e

Browse files
committed
Add support to more feasible function to.
1 parent f6fb29e commit fb4141e

File tree

1 file changed

+44
-8
lines changed

1 file changed

+44
-8
lines changed

parity_tensor/parity_tensor.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,52 @@ def mask(self) -> torch.Tensor:
6464
self._mask = self._tensor_mask()
6565
return self._mask
6666

67-
def to(self, device: torch.device) -> ParityTensor:
67+
def to(self, whatever: torch.device | torch.dtype | str | None = None, device: torch.device | None = None, dtype: torch.dtype | None = None) -> ParityTensor:
6868
"""
69-
Copy the tensor to a specified device.
69+
Copy the tensor to a specified device or copy it to a specified data type.
7070
"""
71-
return dataclasses.replace(
72-
self,
73-
_tensor=self._tensor.to(device),
74-
_parity=tuple(p.to(device) for p in self._parity) if self._parity is not None else None,
75-
_mask=self._mask.to(device) if self._mask is not None else None,
76-
)
71+
if whatever is None:
72+
pass
73+
elif isinstance(whatever, torch.device):
74+
assert device is None, "Duplicate device specification."
75+
device = whatever
76+
elif isinstance(whatever, torch.dtype):
77+
assert dtype is None, "Duplicate dtype specification."
78+
dtype = whatever
79+
elif isinstance(whatever, str):
80+
if dtype is None:
81+
dtype = torch.dtype(whatever) # type: ignore[call-arg]
82+
elif device is None:
83+
device = torch.device(whatever)
84+
else:
85+
try:
86+
dtype = torch.dtype(whatever) # type: ignore[call-arg]
87+
except TypeError:
88+
device = torch.device(whatever)
89+
else:
90+
raise TypeError(f"Unsupported type for 'to': {type(whatever)}. Expected torch.device, torch.dtype, or str.")
91+
match (device, dtype):
92+
case (None, None):
93+
return self
94+
case (None, _):
95+
return dataclasses.replace(
96+
self,
97+
_tensor=self._tensor.to(dtype=dtype),
98+
)
99+
case (_, None):
100+
return dataclasses.replace(
101+
self,
102+
_tensor=self._tensor.to(device=device),
103+
_parity=tuple(p.to(device) for p in self._parity) if self._parity is not None else None,
104+
_mask=self._mask.to(device) if self._mask is not None else None,
105+
)
106+
case _:
107+
return dataclasses.replace(
108+
self,
109+
_tensor=self._tensor.to(device=device, dtype=dtype),
110+
_parity=tuple(p.to(device=device) for p in self._parity) if self._parity is not None else None,
111+
_mask=self._mask.to(device=device) if self._mask is not None else None,
112+
)
77113

78114
def update_mask(self) -> ParityTensor:
79115
"""

0 commit comments

Comments
 (0)