Skip to content

Commit

Permalink
Fixed some issues related to tensor device
Browse files Browse the repository at this point in the history
  • Loading branch information
jrzaurin committed Mar 8, 2022
1 parent ef181de commit ac81b25
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 17 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
</p>

[![PyPI version](https://badge.fury.io/py/pytorch-widedeep.svg)](https://pypi.org/project/pytorch-widedeep/)
[![Python 3.7 3.8 3.9 3.10](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9%20%7C%203.10-blue.svg)](https://pypi.org/project/pytorch-widedeep/)
[![Python 3.7 3.8 3.9](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://pypi.org/project/pytorch-widedeep/)
[![Build Status](https://github.com/jrzaurin/pytorch-widedeep/actions/workflows/build.yml/badge.svg)](https://github.com/jrzaurin/pytorch-widedeep/actions)
[![Documentation Status](https://readthedocs.org/projects/pytorch-widedeep/badge/?version=latest)](https://pytorch-widedeep.readthedocs.io/en/latest/?badge=latest)
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)
Expand Down
2 changes: 1 addition & 1 deletion pypi_README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[![PyPI version](https://badge.fury.io/py/pytorch-widedeep.svg)](https://pypi.org/project/pytorch-widedeep/)
[![Python 3.7 3.8 3.9 3.10](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9%20%7C%203.10-blue.svg)](https://pypi.org/project/pytorch-widedeep/)
[![Python 3.7 3.8 3.9](https://img.shields.io/badge/python-3.7%20%7C%203.8%20%7C%203.9-blue.svg)](https://pypi.org/project/pytorch-widedeep/)
[![Build Status](https://github.com/jrzaurin/pytorch-widedeep/actions/workflows/build.yml/badge.svg)](https://github.com/jrzaurin/pytorch-widedeep/actions)
[![Documentation Status](https://readthedocs.org/projects/pytorch-widedeep/badge/?version=latest)](https://pytorch-widedeep.readthedocs.io/en/latest/?badge=latest)
[![codecov](https://codecov.io/gh/jrzaurin/pytorch-widedeep/branch/master/graph/badge.svg)](https://codecov.io/gh/jrzaurin/pytorch-widedeep)
Expand Down
33 changes: 21 additions & 12 deletions pytorch_widedeep/models/fds_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,27 +114,36 @@ def update_last_epoch_stats(self, epoch):
self.running_mean_last_epoch = self.running_mean
self.running_var_last_epoch = self.running_var

smoothed_mean_last_epoch_inp = F.pad(
self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0),
pad=(self.half_ks, self.half_ks),
mode="reflect",
)
smoothed_mean_last_epoch_weight = self.kernel_window.view(1, 1, -1).to(
smoothed_mean_last_epoch_inp.device
)
self.smoothed_mean_last_epoch = (
F.conv1d(
input=F.pad(
self.running_mean_last_epoch.unsqueeze(1).permute(2, 1, 0),
pad=(self.half_ks, self.half_ks),
mode="reflect",
),
weight=self.kernel_window.view(1, 1, -1),
input=smoothed_mean_last_epoch_inp,
weight=smoothed_mean_last_epoch_weight,
padding=0,
)
.permute(2, 1, 0)
.squeeze(1)
)

smoothed_var_last_epoch_inp = F.pad(
self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0),
pad=(self.half_ks, self.half_ks),
mode="reflect",
)
smoothed_var_last_epoch_weight = self.kernel_window.view(1, 1, -1).to(
smoothed_var_last_epoch_inp.device
)
self.smoothed_var_last_epoch = (
F.conv1d(
input=F.pad(
self.running_var_last_epoch.unsqueeze(1).permute(2, 1, 0),
pad=(self.half_ks, self.half_ks),
mode="reflect",
),
weight=self.kernel_window.view(1, 1, -1),
input=smoothed_var_last_epoch_inp,
weight=smoothed_var_last_epoch_weight,
padding=0,
)
.permute(2, 1, 0)
Expand Down
11 changes: 8 additions & 3 deletions pytorch_widedeep/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,7 +988,11 @@ def _train_step(
lds_weightt: Tensor,
):

lds_weight = None if torch.all(lds_weightt == 0) else lds_weightt.view(-1, 1)
lds_weight = (
None
if torch.all(lds_weightt == 0)
else lds_weightt.view(-1, 1).to(self.device)
)
if (
self.with_lds
and lds_weight is not None
Expand Down Expand Up @@ -1079,7 +1083,6 @@ def _fds_step(
self,
data: Dict[str, Tensor],
target: Tensor,
lds_weight: Union[None, Tensor],
epoch: int,
) -> Tuple[Tensor, Tensor]:
self.model.train()
Expand All @@ -1096,7 +1099,9 @@ def _update_fds_stats(self, train_loader: DataLoader, epoch: int):
for idx, (data, targett, lds_weight) in zip(t, train_loader):
t.set_description("FDS update")
deeptab_features, deeptab_preds = self._fds_step(
data, targett, lds_weight, epoch
data,
targett,
epoch,
)
features_l.append(deeptab_features)
y_pred_l.append(deeptab_preds)
Expand Down
1 change: 1 addition & 0 deletions pytorch_widedeep/utils/deeptabular_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def find_bin(
)
indices = np.where(indices != len(bin_edges), indices, indices - 2)
elif type(bin_edges) == Tensor and type(values) == Tensor:
bin_edges = bin_edges.to(values.device)
indices = torch.searchsorted(bin_edges, values, right=False)
indices = torch.where(
(indices == 0) | (indices == len(bin_edges)), indices, indices - 1
Expand Down

0 comments on commit ac81b25

Please sign in to comment.