forked from NWTlter/NWT_CLM
-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
231 lines (191 loc) · 8.28 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
"""utility functions"""
"""copied from klindsay, https://github.com/klindsay28/CESM2_coup_carb_cycle_JAMES/blob/master/utils.py"""
import re
import cf_units as cf
import cftime
import numpy as np
import xarray as xr
from cartopy.util import add_cyclic_point
import matplotlib.pyplot as plt
import matplotlib.colors as colors
#from xr_ds_ex import xr_ds_ex
# generate annual means, weighted by days / month
def weighted_annual_mean(array):
mon_day = xr.DataArray(np.array([31,28,31,30,31,30,31,31,30,31,30,31]), dims=['month'])
mon_wgt = mon_day/mon_day.sum()
return (array.rolling(time=12, center=False) # rolling
.construct("month") # construct the array
.isel(time=slice(11, None, 12)) # slice so that the first element is [1..12], second is [13..24]
.dot(mon_wgt, dims=["month"]))
def change_units(ds, variable_str, variable_bounds_str, target_unit_str):
""" Applies unit conversion on an xarray DataArray """
orig_units = cf.Unit(ds[variable_str].attrs["units"])
target_units = cf.Unit(target_unit_str)
variable_in_new_units = xr.apply_ufunc(
orig_units.convert,
ds[variable_bounds_str],
target_units,
output_dtypes=[ds[variable_bounds_str].dtype],
)
return variable_in_new_units
def clean_units(units):
"""replace some troublesome unit terms with acceptable replacements"""
replacements = {'kgC':'kg', 'gC':'g', 'gC13':'g', 'gC14':'g', 'gN':'g',
'unitless':'1',
'years':'common_years', 'yr':'common_year',
'meq':'mmol', 'neq':'nmol'}
units_split = re.split('( |\(|\)|\^|\*|/|-[0-9]+|[0-9]+)', units)
units_split_repl = \
[replacements[token] if token in replacements else token for token in units_split]
return ''.join(units_split_repl)
def copy_fill_settings(da_in, da_out):
"""
propagate _FillValue and missing_value settings from da_in to da_out
return da_out
"""
if '_FillValue' in da_in.encoding:
da_out.encoding['_FillValue'] = da_in.encoding['_FillValue']
else:
da_out.encoding['_FillValue'] = None
if 'missing_value' in da_in.encoding:
da_out.attrs['missing_value'] = da_in.encoding['missing_value']
return da_out
def dim_cnt_check(ds, varname, dim_cnt):
"""confirm that varname in ds has dim_cnt dimensions"""
if len(ds[varname].dims) != dim_cnt:
msg_full = 'unexpected dim_cnt=%d, varname=%s' % (len(ds[varname].dims), varname)
raise ValueError(msg_full)
def time_set_mid(ds, time_name, deep=False):
"""
Return copy of ds with values of ds[time_name] replaced with midpoints of
ds[time_name].attrs['bounds'], if bounds attribute exists.
Except for time_name, the returned Dataset is a copy of ds2.
The copy is deep or not depending on the argument deep.
"""
ds_out = ds.copy(deep)
if "bounds" not in ds[time_name].attrs:
return ds_out
tb_name = ds[time_name].attrs["bounds"]
tb = ds[tb_name]
bounds_dim = next(dim for dim in tb.dims if dim != time_name)
# Use da = da.copy(data=...), in order to preserve attributes and encoding.
# If tb is an array of datetime objects then encode time before averaging.
# Do this because computing the mean on datetime objects with xarray fails
# if the time span is 293 or more years.
# https://github.com/klindsay28/CESM2_coup_carb_cycle_JAMES/issues/7
if tb.dtype == np.dtype("O"):
units = "days since 0001-01-01"
calendar = "noleap"
tb_vals = cftime.date2num(ds[tb_name].values, units=units, calendar=calendar)
tb_mid_decode = cftime.num2date(
tb_vals.mean(axis=1), units=units, calendar=calendar
)
ds_out[time_name] = ds[time_name].copy(data=tb_mid_decode)
else:
ds_out[time_name] = ds[time_name].copy(data=tb.mean(bounds_dim))
return ds_out
'''def time_set_mid(ds, time_name):
"""
set ds[time_name] to midpoint of ds[time_name].attrs['bounds'], if bounds attribute exists
type of ds[time_name] is not changed
ds is returned
"""
if 'bounds' not in ds[time_name].attrs:
return ds
# determine units and calendar of unencoded time values
if ds[time_name].dtype == np.dtype('O'):
units = 'days since 0000-01-01'
calendar = 'noleap'
else:
units = ds[time_name].attrs['units']
calendar = ds[time_name].attrs['calendar']
# construct unencoded midpoint values, assumes bounds dim is 2nd
tb_name = ds[time_name].attrs['bounds']
if ds[tb_name].dtype == np.dtype('O'):
tb_vals = cftime.date2num(ds[tb_name].values, units=units, calendar=calendar)
else:
tb_vals = ds[tb_name].values
tb_mid = tb_vals.mean(axis=1)
# set ds[time_name] to tb_mid
if ds[time_name].dtype == np.dtype('O'):
# WW changed for xarray 16
#ds[time_name].values = cftime.num2date(tb_mid, units=units, calendar=calendar)
ds.assign_coords({time_name: tb_mid})
else:
#ds[time_name].values = tb_mid
ds.assign_coords({time_name: tb_mid})
return ds
'''
def time_year_plus_frac(ds, time_name):
"""return time variable, as year plus fraction of year"""
# this is straightforward if time has units='days since 0000-01-01' and calendar='noleap'
# so convert specification of time to that representation
# get time values as an np.ndarray of cftime objects
if np.dtype(ds[time_name]) == np.dtype('O'):
tvals_cftime = ds[time_name].values
else:
tvals_cftime = cftime.num2date(
ds[time_name].values, ds[time_name].attrs['units'], ds[time_name].attrs['calendar'])
# convert cftime objects to representation mentioned above
tvals_days = cftime.date2num(tvals_cftime, 'days since 0000-01-01', calendar='noleap')
return tvals_days / 365.0
# add cyclic point
def cyclic_dataarray(da, coord='lon'):
""" Add a cyclic coordinate point to a DataArray along a specified
named coordinate dimension.
>>> from xray import DataArray
>>> data = DataArray([[1, 2, 3], [4, 5, 6]],
... coords={'x': [1, 2], 'y': range(3)},
... dims=['x', 'y'])
>>> cd = cyclic_dataarray(data, 'y')
>>> print cd.data
array([[1, 2, 3, 1],
[4, 5, 6, 4]])
"""
assert isinstance(da, xr.DataArray)
lon_idx = da.dims.index(coord)
cyclic_data, cyclic_coord = add_cyclic_point(da.values,
coord=da.coords[coord],
axis=lon_idx)
# Copy and add the cyclic coordinate and data
new_coords = dict(da.coords)
new_coords[coord] = cyclic_coord
new_values = cyclic_data
new_da = xr.DataArray(new_values, dims=da.dims, coords=new_coords)
# Copy the attributes for the re-constructed data and coords
for att, val in da.attrs.items():
new_da.attrs[att] = val
for c in da.coords:
for att in da.coords[c].attrs:
new_da.coords[c].attrs[att] = da.coords[c].attrs[att]
return new_da
# as above, but for a dataset
# doesn't work because dims are locked in a dataset
# Function to truncate color map #
def truncate_colormap(cmapIn='jet', minval=0.0, maxval=1.0, n=100):
cmapIn = plt.get_cmap(cmapIn)
new_cmap = colors.LinearSegmentedColormap.from_list(
'trunc({n},{a:.2f},{b:.2f})'.format(n=cmapIn.name, a=minval, b=maxval),
cmapIn(np.linspace(minval, maxval, n)))
arr = np.linspace(0, 50, 100).reshape((10, 10))
return new_cmap
'''
def cyclic_dataset(ds, coord='lon'):
assert isinstance(ds, xr.Dataset)
lon_idx = ds.dims.index(coord)
cyclic_data, cyclic_coord = add_cyclic_point(ds.values,
coord=ds.coords[coord],
axis=lon_idx)
# Copy and add the cyclic coordinate and data
new_coords = dict(ds.coords)
new_coords[coord] = cyclic_coord
new_values = cyclic_data
new_ds = xr.DataSet(new_values, dims=ds.dims, coords=new_coords)
# Copy the attributes for the re-constructed data and coords
for att, val in ds.attrs.items():
new_ds.attrs[att] = val
for c in ds.coords:
for att in ds.coords[c].attrs:
new_ds.coords[c].attrs[att] = ds.coords[c].attrs[att]
return new_ds
'''