@@ -64,16 +64,45 @@ def mask(self) -> torch.Tensor:
6464 self ._mask = self ._tensor_mask ()
6565 return self ._mask
6666
67- def to (self , device : torch .device ) -> ParityTensor :
67+ def to (self , whatever : torch . device | torch . dtype | str | None = None , * , device : torch .device | None = None , dtype : torch . dtype | None = None ) -> ParityTensor :
6868 """
69- Copy the tensor to a specified device.
69+ Copy the tensor to a specified device or copy it to a specified data type .
7070 """
71- return dataclasses .replace (
72- self ,
73- _tensor = self ._tensor .to (device ),
74- _parity = tuple (p .to (device ) for p in self ._parity ) if self ._parity is not None else None ,
75- _mask = self ._mask .to (device ) if self ._mask is not None else None ,
76- )
71+ if whatever is None :
72+ pass
73+ elif isinstance (whatever , torch .device ):
74+ assert device is None , "Duplicate device specification."
75+ device = whatever
76+ elif isinstance (whatever , torch .dtype ):
77+ assert dtype is None , "Duplicate dtype specification."
78+ dtype = whatever
79+ elif isinstance (whatever , str ):
80+ assert device is None , "Duplicate device specification."
81+ device = torch .device (whatever )
82+ else :
83+ raise TypeError (f"Unsupported type for 'to': { type (whatever )} . Expected torch.device, torch.dtype, or str." )
84+ match (device , dtype ):
85+ case (None , None ):
86+ return self
87+ case (None , _):
88+ return dataclasses .replace (
89+ self ,
90+ _tensor = self ._tensor .to (dtype = dtype ),
91+ )
92+ case (_, None ):
93+ return dataclasses .replace (
94+ self ,
95+ _tensor = self ._tensor .to (device = device ),
96+ _parity = tuple (p .to (device ) for p in self ._parity ) if self ._parity is not None else None ,
97+ _mask = self ._mask .to (device ) if self ._mask is not None else None ,
98+ )
99+ case _:
100+ return dataclasses .replace (
101+ self ,
102+ _tensor = self ._tensor .to (device = device , dtype = dtype ),
103+ _parity = tuple (p .to (device = device ) for p in self ._parity ) if self ._parity is not None else None ,
104+ _mask = self ._mask .to (device = device ) if self ._mask is not None else None ,
105+ )
77106
78107 def update_mask (self ) -> ParityTensor :
79108 """
0 commit comments