@@ -218,15 +218,19 @@ def _get_encodings(self, dl: DataLoader) -> np.ndarray:
218
218
np.ndarray: Array containing encoded features.
219
219
"""
220
220
221
- all_encodings = []
221
+ encodings = np .empty (
222
+ [len (dl .dataset ), self .model .encoder .embedding_size ], # type: ignore[arg-type]
223
+ dtype = np .float32 ,
224
+ )
225
+
226
+ with torch .no_grad ():
222
227
223
- for x , dw , latlons , month , variable_mask in dl :
224
- x_f , dw_f , latlons_f , month_f , variable_mask_f = [
225
- t .to (device ) for t in (x , dw , latlons , month , variable_mask )
226
- ]
228
+ for i , ( x , dw , latlons , month , variable_mask ) in enumerate ( dl ) :
229
+ x_f , dw_f , latlons_f , month_f , variable_mask_f = [
230
+ t .to (device ) for t in (x , dw , latlons , month , variable_mask )
231
+ ]
227
232
228
- with torch .no_grad ():
229
- encodings = (
233
+ encodings [i * self .batch_size : i * self .batch_size + self .batch_size , :] = (
230
234
self .model .encoder (
231
235
x_f ,
232
236
dynamic_world = dw_f .long (),
@@ -238,9 +242,7 @@ def _get_encodings(self, dl: DataLoader) -> np.ndarray:
238
242
.numpy ()
239
243
)
240
244
241
- all_encodings .append (encodings )
242
-
243
- return np .concatenate (all_encodings , axis = 0 )
245
+ return encodings
244
246
245
247
def extract_presto_features (self , inarr : xr .DataArray , epsg : int = 4326 ) -> xr .DataArray :
246
248
0 commit comments