Skip to content

Commit

Permalink
bugfix #66
Browse files Browse the repository at this point in the history
  • Loading branch information
flaport committed Dec 3, 2023
1 parent cf093ac commit 468397e
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions fdtd/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,21 +322,21 @@ def numpy(self, arr):
class TorchCudaBackend(TorchBackend):
"""Torch Cuda Backend"""

def ones(self, shape):
def ones(self, shape, **kwargs):
"""create an array filled with ones"""
return torch.ones(shape, device="cuda")
return torch.ones(shape, device="cuda", **kwargs)

def zeros(self, shape):
def zeros(self, shape, **kwargs):
"""create an array filled with zeros"""
return torch.zeros(shape, device="cuda")
return torch.zeros(shape, device="cuda", **kwargs)

def array(self, arr, dtype=None):
def array(self, arr, dtype=None, **kwargs):
"""create an array from an array-like sequence"""
if dtype is None:
dtype = torch.get_default_dtype()
if torch.is_tensor(arr):
return arr.clone().to(device="cuda", dtype=dtype)
return torch.tensor(arr, device="cuda", dtype=dtype)
return arr.clone().to(device="cuda", dtype=dtype, **kwargs)
return torch.tensor(arr, device="cuda", dtype=dtype, **kwargs)

# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# The same warning applies here.
Expand Down

0 comments on commit 468397e

Please sign in to comment.