Skip to content

Commit 90dfc55

Browse files
committed
Merge branch 'dev/add-support-to-dtype-to'
2 parents 2ffc31a + 9508f04 commit 90dfc55

File tree

2 files changed

+93
-8
lines changed

2 files changed

+93
-8
lines changed

parity_tensor/parity_tensor.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,45 @@ def mask(self) -> torch.Tensor:
6464
self._mask = self._tensor_mask()
6565
return self._mask
6666

67-
def to(self, device: torch.device) -> ParityTensor:
67+
def to(self, whatever: torch.device | torch.dtype | str | None = None, *, device: torch.device | None = None, dtype: torch.dtype | None = None) -> ParityTensor:
6868
"""
69-
Copy the tensor to a specified device.
69+
Copy the tensor to a specified device or copy it to a specified data type.
7070
"""
71-
return dataclasses.replace(
72-
self,
73-
_tensor=self._tensor.to(device),
74-
_parity=tuple(p.to(device) for p in self._parity) if self._parity is not None else None,
75-
_mask=self._mask.to(device) if self._mask is not None else None,
76-
)
71+
if whatever is None:
72+
pass
73+
elif isinstance(whatever, torch.device):
74+
assert device is None, "Duplicate device specification."
75+
device = whatever
76+
elif isinstance(whatever, torch.dtype):
77+
assert dtype is None, "Duplicate dtype specification."
78+
dtype = whatever
79+
elif isinstance(whatever, str):
80+
assert device is None, "Duplicate device specification."
81+
device = torch.device(whatever)
82+
else:
83+
raise TypeError(f"Unsupported type for 'to': {type(whatever)}. Expected torch.device, torch.dtype, or str.")
84+
match (device, dtype):
85+
case (None, None):
86+
return self
87+
case (None, _):
88+
return dataclasses.replace(
89+
self,
90+
_tensor=self._tensor.to(dtype=dtype),
91+
)
92+
case (_, None):
93+
return dataclasses.replace(
94+
self,
95+
_tensor=self._tensor.to(device=device),
96+
_parity=tuple(p.to(device) for p in self._parity) if self._parity is not None else None,
97+
_mask=self._mask.to(device) if self._mask is not None else None,
98+
)
99+
case _:
100+
return dataclasses.replace(
101+
self,
102+
_tensor=self._tensor.to(device=device, dtype=dtype),
103+
_parity=tuple(p.to(device=device) for p in self._parity) if self._parity is not None else None,
104+
_mask=self._mask.to(device=device) if self._mask is not None else None,
105+
)
77106

78107
def update_mask(self) -> ParityTensor:
79108
"""

tests/conversion_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import typing
2+
import pytest
3+
import torch
4+
from parity_tensor import ParityTensor
5+
6+
7+
@pytest.fixture()
8+
def x() -> ParityTensor:
9+
return ParityTensor((False, False), ((2, 2), (1, 3)), torch.randn([4, 4]))
10+
11+
12+
@pytest.mark.parametrize("dtype_arg", ["position", "keyword", "none"])
13+
@pytest.mark.parametrize("device_arg", ["position", "keyword", "none"])
14+
@pytest.mark.parametrize("device_format", ["object", "string"])
15+
def test_conversion(
16+
x: ParityTensor,
17+
dtype_arg: typing.Literal["position", "keyword", "none"],
18+
device_arg: typing.Literal["position", "keyword", "none"],
19+
device_format: typing.Literal["object", "string"],
20+
) -> None:
21+
args: list[typing.Any] = []
22+
kwargs: dict[str, typing.Any] = {}
23+
24+
device = torch.device("cpu") if device_format == "object" else "cpu"
25+
match device_arg:
26+
case "position":
27+
args.append(device)
28+
case "keyword":
29+
kwargs["device"] = device
30+
case _:
31+
pass
32+
33+
match dtype_arg:
34+
case "position":
35+
args.append(torch.complex128)
36+
case "keyword":
37+
kwargs["dtype"] = torch.complex128
38+
case _:
39+
pass
40+
41+
if len(args) <= 1:
42+
y = x.to(*args, **kwargs)
43+
44+
45+
def test_conversion_invalid_type(x: ParityTensor) -> None:
46+
with pytest.raises(TypeError):
47+
x.to(2333) # type: ignore[arg-type]
48+
49+
50+
def test_conversion_duplicated_value(x: ParityTensor) -> None:
51+
with pytest.raises(AssertionError):
52+
x.to(torch.device("cpu"), device=torch.device("cpu"))
53+
with pytest.raises(AssertionError):
54+
x.to(torch.complex128, dtype=torch.complex128)
55+
with pytest.raises(AssertionError):
56+
x.to("cpu", device=torch.device("cpu"))

0 commit comments

Comments
 (0)