-
Notifications
You must be signed in to change notification settings - Fork 0
/
useful_functions.py
237 lines (184 loc) · 7.48 KB
/
useful_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""
Quick and useful functions for data science: pre-processing, plotting, etc.
author @xaviernogueira
"""
import pandas as pd
import logging
import matplotlib.pyplot as plt
import os
def init_logger(filename, log_name=None):
"""Initializes logger w/ same name as python file or a specified name if log_name is given a valid path (.log)"""
if log_name is not None and log_name[-4:] == '.log':
if os.path.exists(os.path.dirname(log_name)):
name = log_name
else:
return print('ERROR: Logger cannot be initiated @ %s' % log_name)
else:
name = os.path.basename(filename).replace('.py', '.log')
logging.basicConfig(filename=name, filemode='w', level=logging.INFO)
stderr_logger = logging.StreamHandler()
stderr_logger.setFormatter(logging.Formatter(logging.BASIC_FORMAT))
logging.getLogger().addHandler(stderr_logger)
return
def to_df(data):
"""
Converts .csv to data frames, and marks conversion with a boolean
:param data: a pandas data frame or csv
:return: list len(2) w/ pandas data frame [0], and True if input was a .csv file, False otherwise [1]
"""
c = False
if isinstance(data, pd.DataFrame):
return data, c
elif isinstance(data, str):
if data[-3:] == 'csv':
df = pd.read_csv(data).copy()
c = True
return df, c
else:
return print('ERROR: Input must be csv or pandas data frame')
def drop_redundant_cols(data, thresh=1):
""" Drops columns w/ less than N unique values. If input is a csv, a csv is saved as the output.
:param data: a pandas data frame or csv
:param thresh:: set the amount of unique values where <= the column is dropped (default is 1)
:return same format as input. i.e., pandas data frame or csv.
"""
df, c = to_df(data)
drop_list = []
for col in list(df.columns):
if len(df[str(col)].unique()) <= thresh:
drop_list.append(col)
if len(drop_list) > 0:
df.drop(drop_list, axis=1, inplace=True)
if c:
out = df.to_csv(data)
else:
out = df
return out
def spaces_format(data):
"""
This function turns spaces in column headers as well as the data fields and replaces them w/ underscores
:param data: a pandas data frame or a csv
:return: same format as the input
"""
df, c = to_df(data)
for col in list(df.columns):
if df[col].dtypes == object:
df[col].replace(' ', '_', regex=True, inplace=True)
if ' ' in str(col):
new = str(col).replace(' ', '_')
df.rename(str(col), new)
if c:
out = df.to_csv(data)
else:
out = df
return out
def cartography(ax, projection):
""""Add geographic features (may not work unless on cartopy 0.20.0)"""
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.io import shapereader
ax.add_feature(cfeature.LAND)
ax.add_feature(cfeature.OCEAN)
ax.add_feature(cfeature.LAKES)
ax.add_feature(cfeature.COASTLINE)
ax.add_feature(cfeature.BORDERS)
ax.add_feature(cfeature.STATES)
ax.gridlines(projection, draw_labels=True, alpha=0, linestyle='--')
def make_test_csv(csv, rows=500):
"""
Takes a csv and randomly samples N number of rows to make a ML test csv (faster computation)
:param csv: a csv
:param rows: number of rows for test csv (int, default is 500)
:return: new test csv
"""
in_df = pd.read_csv(csv)
shuffled = in_df.sample(frac=1).reset_index()
if isinstance(rows, int):
out_df = shuffled.sample(n=rows)
else:
return print('ERROR: Rows parameter must be an integer')
out_dir = os.path.dirname(csv)
out_csv = out_dir + '\\%s' % os.path.basename(csv).replace('.csv', '_test_%s_rows.csv' % rows)
out_df.to_csv(out_csv)
return out_csv
def bbox_poly(bbox, region, out_folder):
import geopandas as gpd
from shapely.geometry import Polygon
# define output location
if not os.path.exists(out_folder):
os.makedirs(out_folder)
out_shp = out_folder + '\\%s_bbox.shp' % region
# get bounding box coordinates and format
long0, long1, lat0, lat1 = bbox
logging.info('Prediction extent coordinates: %s' % bbox)
poly = Polygon([[long0, lat0],
[long1, lat0],
[long1, lat1],
[long0, lat1]])
# save as a shapefile and return it's path
gpd.GeoDataFrame(pd.DataFrame(['p1'], columns=['geom']),
crs={'init': 'epsg:4326'},
geometry=[poly]).to_file(out_shp)
return out_shp
def plot_month_rasters(cropped_raster_dict, month_index, out_folder):
"""
This function stacks rasters for each variable for a specific month. The stacked rasters are plotted.
:param raster_dict: a dictionary containing variable names (str) as keys and AOI cropped raster .tif files as items
:param month_index: index (int) for the month (i.e., January = 1)
:param out_folder: a folder to store the plotted rasters
:return: a raster item for a given month containing each variable as a band
"""
from pyspatialml import Raster
import rasterio
# set up static lists
cmaps = ['Purples', 'Greens', 'Reds', 'YlGnBu', 'RdPu']
months = ['January', 'Febuary', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October',
'November', 'December']
# set up list for plotting
stack_list = []
names_list = []
# set up iteration variables
month = months[month_index - 1]
cmap_i = 0
# for each raster, check if it has 12 bands or 1, and if it has 12 add the correct month layer to a stack
for i, name in enumerate(list(cropped_raster_dict.keys())):
full_raster = Raster(cropped_raster_dict[name])
names_list.append(name)
bands = full_raster.count
logging.info('Raster %s has %s bands' % (name, bands))
if bands > 1 and bands == 12:
raster = full_raster.iloc[month_index - 1]
elif bands == 1:
raster = full_raster.iloc[0]
else:
return logging.error('ERROR: Input raster %s is not single band and also does not contain 12 month bands' % name)
# set up colormap iterator and assign each layer a different colormap for plotting
cmap_i += 1
if cmap_i >= len(cmaps):
cmap_i = 0
raster.cmap = cmaps[cmap_i]
stack_list.append(raster)
# convert the stacked raster layers into a raster, and plot each layer
stack = Raster(stack_list)
stack.plot(
title_fontsize=10,
label_fontsize=8,
legend_fontsize=6,
names=names_list,
fig_kwds={"figsize": (8, 4)},
subplots_kwds={"wspace": 0.3}
)
plt.title(month)
plt.show()
# save figure
fig_name = out_folder + '\\%s_input_rasters.png' % month
plt.savefig(fig_name, dpi=300, bbox_inches='tight')
logging.info('Done. Input raster plot for month %s saved @ %s' % (month, fig_name))
return stack, month
def main(csv, rows=500):
make_test_csv(csv, rows)
# ########### DEFINE INPUTS #############
CSV_DIR = r'C:\Users\xrnogueira\Documents\Data\NO2_stations'
main_csv = CSV_DIR + '\\master_no2_daily.csv'
if __name__ == '__main__':
main(main_csv)