diff --git a/oslo/torch/nn/modules/linear.py b/oslo/torch/nn/modules/linear.py index c68e030a..d4c9ebe7 100644 --- a/oslo/torch/nn/modules/linear.py +++ b/oslo/torch/nn/modules/linear.py @@ -17,12 +17,13 @@ def __init__( bias: bool = True, dtype: Optional[torch.dtype] = None, skip_bias_add: bool = False, + device: torch.device = None, ): super(Linear, self).__init__( in_features=in_features, out_features=out_features, bias=bias, - device=torch.device(torch.cuda.current_device()), + device=device, dtype=dtype, ) self.skip_bias_add = skip_bias_add