Skip to content

Commit 6775a8e

Browse files
authored
Merge pull request #87 from WorldCereal/inference-optimization
Small improvement computing encodings
2 parents c24008c + 2b38bc9 commit 6775a8e

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

presto/inference.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -218,15 +218,19 @@ def _get_encodings(self, dl: DataLoader) -> np.ndarray:
218218
np.ndarray: Array containing encoded features.
219219
"""
220220

221-
all_encodings = []
221+
encodings = np.empty(
222+
[len(dl.dataset), self.model.encoder.embedding_size], # type: ignore[arg-type]
223+
dtype=np.float32,
224+
)
225+
226+
with torch.no_grad():
222227

223-
for x, dw, latlons, month, variable_mask in dl:
224-
x_f, dw_f, latlons_f, month_f, variable_mask_f = [
225-
t.to(device) for t in (x, dw, latlons, month, variable_mask)
226-
]
228+
for i, (x, dw, latlons, month, variable_mask) in enumerate(dl):
229+
x_f, dw_f, latlons_f, month_f, variable_mask_f = [
230+
t.to(device) for t in (x, dw, latlons, month, variable_mask)
231+
]
227232

228-
with torch.no_grad():
229-
encodings = (
233+
encodings[i * self.batch_size : i * self.batch_size + self.batch_size, :] = (
230234
self.model.encoder(
231235
x_f,
232236
dynamic_world=dw_f.long(),
@@ -238,9 +242,7 @@ def _get_encodings(self, dl: DataLoader) -> np.ndarray:
238242
.numpy()
239243
)
240244

241-
all_encodings.append(encodings)
242-
243-
return np.concatenate(all_encodings, axis=0)
245+
return encodings
244246

245247
def extract_presto_features(self, inarr: xr.DataArray, epsg: int = 4326) -> xr.DataArray:
246248

0 commit comments

Comments
 (0)