Skip to content

Commit

Permalink
Ensure device comparison always between string representations (#1225)
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasGesseyJonesPX authored Aug 20, 2024
1 parent dcfdf35 commit b3254ed
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sbi/utils/nn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def check_net_device(

if isinstance(net, nn.Identity):
return net
if str(next(net.parameters()).device) != device:
if str(next(net.parameters()).device) != str(device):
warn(
message or f"Network is not on the correct device. Moving it to {device}.",
stacklevel=2,
Expand Down
4 changes: 2 additions & 2 deletions sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def validate_theta_and_x(
assert theta.dtype == float32, "Type of parameters must be float32."
assert x.dtype == float32, "Type of simulator outputs must be float32."

if str(x.device) != data_device:
if str(x.device) != str(data_device):
warnings.warn(
f"Data x has device '{x.device}'. "
f"Moving x to the data_device '{data_device}'. "
Expand All @@ -724,7 +724,7 @@ def validate_theta_and_x(
)
x = x.to(data_device)

if str(theta.device) != data_device:
if str(theta.device) != str(data_device):
warnings.warn(
f"Parameters theta has device '{theta.device}'. "
f"Moving theta to the data_device '{data_device}'. "
Expand Down

0 comments on commit b3254ed

Please sign in to comment.