-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathscan_hdf5.py
executable file
·175 lines (161 loc) · 6.56 KB
/
scan_hdf5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#!/usr/bin/env python3
import argparse
import itertools
import math
import sys
import time
import h5py
import numpy as np
# Inspired from https://stackoverflow.com/a/43374773
def scan_dataset(v, analyze, showattrs, sample, indent):
minv = maxv = None
print('%s - %s %s %r %r %s' % (
' ' * (indent + 1), v.name, v.dtype, v.shape,
v.chunks, v.compression))
if showattrs:
for ak in v.attrs:
print('%s :%s: %r' % (' ' * (indent + 1), ak, v.attrs[ak]))
if v.dtype.kind in {'f', 'i'} and analyze:
sumv = 0
for coor in itertools.product(*(
range(0, v.shape[idx], v.chunks[idx]) for idx in range(len(v.shape)))):
field = tuple(
slice(coor[idx], min(coor[idx] + v.chunks[idx], v.shape[idx]))
for idx in range(len(v.shape)))
part = v[field]
if minv is None:
minv = np.amin(part)
maxv = np.amax(part)
else:
minv = min(minv, np.amin(part))
maxv = max(maxv, np.amax(part))
if part.dtype == np.float16:
part = part.astype(np.float32)
sumv += part.sum()
avgv = sumv / v.size
print('%s [%g,%g] %g' % (
' ' * (indent + 1), minv, maxv, avgv))
if sample and len(v.shape) == 1:
checksize = int(math.ceil(v.shape[0] ** 0.5))
sampleset = np.unique(v[:min(v.shape[0], checksize * 2)])
if len(sampleset) < checksize:
sampleset = dict(zip(*np.unique(v, return_counts=True)))
sampleset = {k: val for val, k in sorted([
(val, k) for k, val in sampleset.items()], reverse=True)}
if len(sampleset) < max(10, checksize):
print('%s [%d kinds] %r' % (
' ' * (indent + 1), len(sampleset),
{k: sampleset[k] for k in itertools.islice(sampleset, 100)}))
return minv, maxv
def write_dataset(k, v, dest, analyze, convert, minv, maxv, indent):
lasttime = time.time()
conv = convert and (v.dtype == np.float64 or (
v.dtype == np.float32 and convert == 'float16'))
if conv:
conv = (
np.float32 if convert == 'float32' or (
minv is not None and maxv is not None and max(abs(minv), maxv) >= 65504) else
np.float16)
if conv == v.dtype:
conv = False
if conv:
destv = dest.create_dataset(
k, shape=v.shape,
dtype=conv,
chunks=True, fillvalue=0,
compression='gzip', compression_opts=9, shuffle=True)
else:
destv = dest.create_dataset(
k, shape=v.shape,
dtype=v.dtype,
chunks=True, fillvalue=v.fillvalue,
compression='gzip', compression_opts=9, shuffle=v.shuffle)
for ak in v.attrs:
destv.attrs[ak] = v.attrs[ak]
steps = len(list(itertools.product(*(
range(0, v.shape[idx], destv.chunks[idx])
for idx in range(len(v.shape))))))
skip = 0
for cidx, coor in enumerate(itertools.product(*(
range(0, v.shape[idx], destv.chunks[idx])
for idx in range(len(v.shape))))):
if time.time() - lasttime > 10:
sys.stdout.write(' %5.2f%% %r %r %r\r' % (
100.0 * cidx / steps, coor, v.shape, destv.chunks))
sys.stdout.flush()
lasttime = time.time()
field = tuple(
slice(coor[idx], min(coor[idx] + destv.chunks[idx], v.shape[idx]))
for idx in range(len(v.shape)))
part = v[field]
if conv:
if not part.any():
skip += 1
continue
part = part.astype(conv)
destv[field] = part
print('%s > %s %s %r %r %s%s' % (
' ' * (indent + 1), destv.name, destv.dtype, destv.shape,
destv.chunks, destv.compression,
' %d' % skip if skip else ''))
def scan_node(src, dest=None, analyze=False, showattrs=False, convert=None, exclude=None, sample=False, indent=0): # noqa
if exclude and src.name in exclude:
return
print('%s%s' % (' ' * indent, src.name))
for ak in src.attrs:
if showattrs:
print('%s:%s: %r' % (' ' * indent, ak, src.attrs[ak]))
if dest:
dest.attrs[ak] = src.attrs[ak]
for k, v in src.items():
if exclude and v.name in exclude:
continue
if isinstance(v, h5py.Dataset):
minv, maxv = scan_dataset(v, analyze, showattrs, sample, indent)
if dest:
write_dataset(k, v, dest, analyze, convert, minv, maxv, indent)
elif isinstance(v, h5py.Group):
destv = None
if dest:
destv = dest.create_group(k)
scan_node(v, destv, analyze, showattrs, convert, exclude, sample, indent=indent + 1)
def scan_hdf5(path, analyze=False, showattrs=False, outpath=None, convert=None,
exclude=None, sample=False):
if convert:
analyze = True
with h5py.File(path, 'r') as fptr:
fptr2 = None
if outpath:
fptr2 = h5py.File(outpath, 'w')
scan_node(fptr, fptr2, analyze, showattrs, convert, exclude, sample)
def command():
parser = argparse.ArgumentParser(
description='Scan an hdf5 file and report on its groups, datasets, '
'and attributes. Optionally report mininum, maximum, and average '
'values for datasets with integer or float datatypes. Optionally '
'rewrite the file with lower precision float datasets.')
parser.add_argument(
'source', type=str, help='Source file to read and analyze.')
parser.add_argument(
'--analyze', '-s', action='store_true',
help='Analyze the min/max/average of datasets.')
parser.add_argument(
'--sample', action='store_true',
help='Show a sample of 1-d data sets if they have fewer unique values '
'than the square root of their size.')
parser.add_argument(
'--attrs', '-k', action='store_true',
help='Show attributes on groups and datasets.')
parser.add_argument(
'--dest', help='Write a new output file')
parser.add_argument(
'--convert', choices=('float16', 'float32'),
help='Reduce the precision of the output file.')
parser.add_argument(
'--exclude', action='append',
help='Exclude a dataset or group from the output file.')
opts = parser.parse_args()
scan_hdf5(opts.source, opts.analyze, opts.attrs, opts.dest, opts.convert,
opts.exclude, opts.sample)
if __name__ == '__main__':
command()