Skip to content

Commit

Permalink
[FIX] Tensor and generator devices for sampling (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
runame authored Feb 13, 2024
1 parent a46820c commit 4364ae8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
9 changes: 5 additions & 4 deletions curvlinops/fisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from torch import (
Generator,
Tensor,
as_tensor,
autograd,
multinomial,
normal,
softmax,
tensor,
zeros_like,
)
from torch.nn import CrossEntropyLoss, MSELoss, Parameter
Expand Down Expand Up @@ -187,7 +187,7 @@ def _matmat(self, M: ndarray) -> ndarray:
Returns:
Matrix-multiplication result ``mat @ M``.
"""
if self._generator is None:
if self._generator is None or self._generator.device != self._device:
self._generator = Generator(device=self._device)
self._generator.manual_seed(self._seed)

Expand Down Expand Up @@ -271,8 +271,9 @@ def sample_grad_output(self, output: Tensor) -> Tensor:
C = output.shape[1]

if isinstance(self._loss_func, MSELoss):
std = tensor(
sqrt(0.5 / C) if self._loss_func.reduction == "mean" else sqrt(0.5)
std = as_tensor(
sqrt(0.5 / C) if self._loss_func.reduction == "mean" else sqrt(0.5),
device=output.device,
)
return 2 * normal(zeros_like(output), std, generator=self._generator)

Expand Down
2 changes: 1 addition & 1 deletion curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def _compute_kfac(self):
)

# loop over data set, computing the Kronecker factors
if self._generator is None:
if self._generator is None or self._generator.device != self._device:
self._generator = Generator(device=self._device)
self._generator.manual_seed(self._seed)

Expand Down

0 comments on commit 4364ae8

Please sign in to comment.