Skip to content

Commit

Permalink
more explicit handling of device
Browse files Browse the repository at this point in the history
  • Loading branch information
Butsko Christina committed Nov 19, 2024
1 parent 4e6e83c commit a53b5de
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion presto/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def get_presto_features(
from_url=from_url,
strict=False,
valid_month_as_token=use_valid_date_token,
)
).to(device)

# Compile for optimized inference. Note that warmup takes some time
# so this is only recommended for larger inference jobs
Expand Down
9 changes: 4 additions & 5 deletions presto/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,12 +418,11 @@ def forward(
valid_month: Optional[torch.Tensor] = None,
eval_task: bool = True,
):
# device = x.device

if mask is None:
mask = torch.zeros_like(x, device=x.device)
mask = torch.zeros_like(x, device=device)

months = month_to_tensor(month, x.shape[0], x.shape[1], x.device)
months = month_to_tensor(month, x.shape[0], x.shape[1], device)
month_embedding = self.month_embed(months)
positional_embedding = repeat(
self.pos_embed[:, : x.shape[1], :],
Expand Down Expand Up @@ -622,7 +621,7 @@ def add_embeddings(self, x, month: Union[torch.Tensor, int]):
# channel group doesn't have timesteps
num_timesteps = int((x.shape[1] - 2) / (num_channel_groups - 1))
srtm_index = self.band_group_to_idx["SRTM"] * num_timesteps
months = month_to_tensor(month, x.shape[0], num_timesteps, x.device)
months = month_to_tensor(month, x.shape[0], num_timesteps, device)

# when we expand the encodings, each channel_group gets num_timesteps
# encodings. However, there is only one SRTM token so we remove the
Expand Down Expand Up @@ -676,7 +675,7 @@ def reconstruct_inputs(self, x) -> Tuple[torch.Tensor, torch.Tensor]:
srtm_index = self.band_group_to_idx["SRTM"] * num_timesteps
srtm_token = x[:, srtm_index : srtm_index + 1, :]

mask = torch.full((x.shape[1],), True, device=x.device)
mask = torch.full((x.shape[1],), True, device=device)
mask[torch.tensor(srtm_index)] = False
x = x[:, mask]

Expand Down

0 comments on commit a53b5de

Please sign in to comment.