@@ -64,16 +64,50 @@ def mask(self) -> torch.Tensor:
6464 self ._mask = self ._tensor_mask ()
6565 return self ._mask
6666
67- def to (self , device : torch .device ) -> ParityTensor :
67+ def to (self , whatever : torch . device | torch . dtype | str , device : torch .device | None = None , dtype : torch . dtype | None = None ) -> ParityTensor :
6868 """
69- Copy the tensor to a specified device.
69+ Copy the tensor to a specified device or copy it to a specified data type .
7070 """
71- return dataclasses .replace (
72- self ,
73- _tensor = self ._tensor .to (device ),
74- _parity = tuple (p .to (device ) for p in self ._parity ) if self ._parity is not None else None ,
75- _mask = self ._mask .to (device ) if self ._mask is not None else None ,
76- )
71+ if isinstance (whatever , torch .device ):
72+ assert device is None , "Duplicate device specification."
73+ device = whatever
74+ elif isinstance (whatever , torch .dtype ):
75+ assert dtype is None , "Duplicate dtype specification."
76+ dtype = whatever
77+ elif isinstance (whatever , str ):
78+ if dtype is None :
79+ dtype = torch .dtype (whatever ) # type: ignore[call-arg]
80+ elif device is None :
81+ device = torch .device (whatever )
82+ else :
83+ try :
84+ dtype = torch .dtype (whatever ) # type: ignore[call-arg]
85+ except TypeError :
86+ device = torch .device (whatever )
87+ else :
88+ raise TypeError (f"Unsupported type for 'to': { type (whatever )} . Expected torch.device, torch.dtype, or str." )
89+ match (device , dtype ):
90+ case (None , None ):
91+ return self
92+ case (None , _):
93+ return dataclasses .replace (
94+ self ,
95+ _tensor = self ._tensor .to (dtype = dtype ),
96+ )
97+ case (_, None ):
98+ return dataclasses .replace (
99+ self ,
100+ _tensor = self ._tensor .to (device = device ),
101+ _parity = tuple (p .to (device ) for p in self ._parity ) if self ._parity is not None else None ,
102+ _mask = self ._mask .to (device ) if self ._mask is not None else None ,
103+ )
104+ case _:
105+ return dataclasses .replace (
106+ self ,
107+ _tensor = self ._tensor .to (device = device , dtype = dtype ),
108+ _parity = tuple (p .to (device = device , dtype = dtype ) for p in self ._parity ) if self ._parity is not None else None ,
109+ _mask = self ._mask .to (device = device , dtype = dtype ) if self ._mask is not None else None ,
110+ )
77111
78112 def update_mask (self ) -> ParityTensor :
79113 """
0 commit comments