-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathdata_loader.py
More file actions
145 lines (125 loc) · 4.9 KB
/
data_loader.py
File metadata and controls
145 lines (125 loc) · 4.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import logging
import glob
import torch
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch import Tensor
import h5py
def worker_init(wrk_id):
np.random.seed(torch.utils.data.get_worker_info().seed % (2**32 - 1))
def get_data_loader(params, files_pattern, distributed, train):
dataset = ERA5Dataset(params, files_pattern, train)
if distributed:
if hasattr(params, "data_num_shards"):
# this is for model parallelism
assert hasattr(
params, "data_shard_id"
), "please set data_num_shards and data_shard_id"
sampler = DistributedSampler(
dataset,
shuffle=train,
num_replicas=params.data_num_shards,
rank=params.data_shard_id,
)
else:
sampler = DistributedSampler(dataset, shuffle=train)
else:
sampler = None
dataloader = DataLoader(
dataset,
batch_size=int(params.local_batch_size),
num_workers=params.num_data_workers,
shuffle=(sampler is None),
sampler=sampler,
worker_init_fn=worker_init,
drop_last=True,
# persistent_workers=train,
pin_memory=torch.cuda.is_available(),
)
if train:
return dataloader, dataset, sampler
else:
return dataloader, dataset
class ERA5Dataset(Dataset):
def __init__(self, params, location, train):
self.params = params
self.location = location
self.train = train
self.dt = params.dt
self.n_in_channels = params.n_in_channels
self.n_out_channels = params.n_out_channels
self.normalize = True
self.means = np.load(params.global_means_path)[0]
self.stds = np.load(params.global_stds_path)[0]
self.limit_nsamples = (
params.limit_nsamples if train else params.limit_nsamples_val
)
self._get_files_stats()
def _get_files_stats(self):
self.files_paths = glob.glob(self.location + "/*.h5")
self.files_paths.sort()
self.years = [
int(os.path.splitext(os.path.basename(x))[0][-4:]) for x in self.files_paths
]
self.n_years = len(self.files_paths)
with h5py.File(self.files_paths[0], "r") as _f:
logging.info("Getting file stats from {}".format(self.files_paths[0]))
self.n_samples_per_year = _f["fields"].shape[0]
self.img_shape_x = self.params.img_size[0]
self.img_shape_y = self.params.img_size[1]
assert (
self.img_shape_x <= _f["fields"].shape[2]
and self.img_shape_y <= _f["fields"].shape[3]
), "image shapes are greater than dataset image shapes"
self.n_samples_total = self.n_years * self.n_samples_per_year
if self.limit_nsamples is not None:
self.n_samples_total = min(self.n_samples_total, self.limit_nsamples)
logging.info(
"Overriding total number of samples to: {}".format(self.n_samples_total)
)
self.files = [None for _ in range(self.n_years)]
logging.info("Number of samples per year: {}".format(self.n_samples_per_year))
logging.info(
"Found data at path {}. Number of examples: {}. Image Shape: {} x {} x {}".format(
self.location,
self.n_samples_total,
self.img_shape_x,
self.img_shape_y,
self.n_in_channels,
)
)
def _open_file(self, year_idx):
_file = h5py.File(self.files_paths[year_idx], "r")
self.files[year_idx] = _file["fields"]
def __len__(self):
return self.n_samples_total
def _normalize(self, img):
if self.normalize:
img -= self.means
img /= self.stds
return torch.as_tensor(img)
def __getitem__(self, global_idx):
year_idx = int(global_idx / self.n_samples_per_year) # which year
local_idx = int(
global_idx % self.n_samples_per_year
) # which sample in that year
# open image file
if self.files[year_idx] is None:
self._open_file(year_idx)
step = self.dt # time step
# boundary conditions to ensure we don't pull data that is not in a specific year
local_idx = local_idx % (self.n_samples_per_year - step)
if local_idx < step:
local_idx += step
# pre-process and get the image fields
inp_field = self.files[year_idx][
local_idx, :, 0 : self.img_shape_x, 0 : self.img_shape_y
]
tar_field = self.files[year_idx][
local_idx + step, :, 0 : self.img_shape_x, 0 : self.img_shape_y
]
inp, tar = self._normalize(inp_field), self._normalize(tar_field)
return inp, tar