Skip to content

Commit e2858fb

Browse files
authored
Merge pull request #79 from satsin06/extract_train
Add test for extract training data
2 parents 87c7577 + fbc4e00 commit e2858fb

File tree

8 files changed

+1762
-56
lines changed

8 files changed

+1762
-56
lines changed

neonwranglerpy/lib/extract_training_data.py

Lines changed: 73 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1+
"""Extract training data from NEON AOP data."""
12
import geopandas as gpd
23
import cv2
34
import os
45
import re
56
import numpy as np
67
import pandas as pd
7-
import matplotlib.pyplot as plt
8-
from shapely.geometry import Point
98
import rasterio
109
from deepforest import main
1110
from deepforest import utilities
@@ -14,11 +13,12 @@
1413

1514
def extract_training_data(vst_data,
1615
geo_data_frame,
17-
year, dpID='DP3.30010.001',
18-
savepath='/content',
19-
site='DELA'):
16+
year,
17+
dpID='DP3.30010.001',
18+
savepath='/content',
19+
site='DELA'):
2020
"""
21-
Extracting training data with geo_data_frame and image predictions.
21+
Extract training data with geo_data_frame and image predictions.
2222
2323
Parameters
2424
------------
@@ -37,12 +37,13 @@ def extract_training_data(vst_data,
3737
int(x[:-1]))).astype(str)
3838
geo_data_frame = gpd.GeoDataFrame(vst_data, geometry=geometry, crs=epsg_codes.iloc[0])
3939
40-
extract_training_data(vst_data=vst_data, geo_data_frame=geo_data_frame, year='2018', dpID='DP3.30010.001',
41-
savepath='/content', site='DELA')
40+
extract_training_data(vst_data=vst_data, geo_data_frame=geo_data_frame, year='2018',
41+
dpID='DP3.30010.001', savepath='/content', site='DELA')
4242
"""
4343
retrieve_aop_data(vst_data, year, dpID, savepath)
4444
site_level_data = vst_data[vst_data.plotID.str.contains(site)]
45-
get_tiles = ((site_level_data.easting/1000).astype(int) * 1000).astype(str) + "_" + ((site_level_data.northing/1000).astype(int) * 1000).astype(str)
45+
get_tiles = (((site_level_data.easting / 1000).astype(int) * 1000).astype(str) + "_" +
46+
((site_level_data.northing / 1000).astype(int) * 1000).astype(str))
4647
print("get_tiles")
4748
print(get_tiles.unique())
4849

@@ -85,46 +86,87 @@ def extract_training_data(vst_data,
8586
easting = row.easting
8687
northing = row.northing
8788

88-
x_min = int(affine[2] + 10/affine[0] - easting)
89-
y_min = int(affine[5] + 10/affine[0] - northing)
90-
x_max = int(affine[2] - 10/affine[0] - easting)
91-
y_max = int(affine[5] - 10/affine[0] - northing)
89+
x_min = int(affine[2] + 10 / affine[0] - easting)
90+
y_min = int(affine[5] + 10 / affine[0] - northing)
91+
x_max = int(affine[2] - 10 / affine[0] - easting)
92+
y_max = int(affine[5] - 10 / affine[0] - northing)
9293

93-
section_file = os.path.join(output_folder, f"section_{x_min}_{y_min}_{x_max}_{y_max}.tif")
94+
file_name = f"section_{x_min}_{y_min}_{x_max}_{y_max}.tif"
9495

96+
section_file = os.path.join(output_folder, file_name)
9597

9698
if section_file not in section_files:
99+
section = image[y_max:y_min, x_max:x_min, :]
100+
print("Section shape:", section.shape)
97101

98-
section = image[y_max:y_min, x_max:x_min, :]
99-
print("Section shape:", section.shape)
102+
section_meta = src.meta.copy()
103+
section_meta['width'] = (affine[2] + x_min) - (affine[2] +
104+
x_max)
105+
section_meta['height'] = (affine[5] + y_min) - (affine[5] +
106+
y_max)
107+
section_meta['transform'] = rasterio.Affine(
108+
affine[0], 0, (affine[2] - x_min), 0, affine[4],
109+
(affine[5] - y_min))
100110

101-
section_meta = src.meta.copy()
102-
section_meta['width'], section_meta['height'] = (affine[2] + x_min) - (affine[2] + x_max), (affine[5] + y_min) - (affine[5] + y_max)
103-
section_meta['transform'] = rasterio.Affine(affine[0], 0, (affine[2] - x_min), 0, affine[4], (affine[5] - y_min))
111+
section_np = np.moveaxis(section, -1, 0)
104112

113+
with rasterio.open(section_file, 'w', **section_meta) as dst:
114+
dst.write(section_np)
115+
section_affine = dst.transform
105116

106-
section_np = np.moveaxis(section, -1, 0)
117+
section_files[section_file] = section_affine
107118

108-
with rasterio.open(section_file, 'w', **section_meta) as dst:
109-
dst.write(section_np)
110-
section_affine = dst.transform
119+
print("Crop affine: ")
120+
print(section_affine)
111121

112-
section_files[section_file] = section_affine
113-
114-
print("Crop affine: ")
115-
print(section_affine)
116-
117-
print("Expected file path:", section_file)
122+
print("Expected file path:", section_file)
118123

119124
prediction = model.predict_image(path=section_file)
120125

121-
gdf = utilities.boxes_to_shapefile(prediction, root_dir=os.path.dirname(section_file), projected=True)
126+
gdf = utilities.boxes_to_shapefile(
127+
prediction,
128+
root_dir=os.path.dirname(section_file),
129+
projected=True)
122130

123131
all_predictions.append(gdf)
124132

125133
all_predictions_df = pd.concat(all_predictions)
126134

135+
all_predictions_df['temp_geo'] = all_predictions_df['geometry']
136+
127137
merged_data = gpd.sjoin(geo_data_frame, all_predictions_df, how="inner", op="within")
138+
merged_data.drop(columns=['geometry'], inplace=True)
139+
merged_data.rename(columns={'temp_geo': 'geometry'}, inplace=True)
140+
canopy_position_mapping = {
141+
np.nan: 0,
142+
'Full shade': 1,
143+
'Mostly shaded': 2,
144+
'Partially shaded': 3,
145+
'Full sun': 4,
146+
'Open grown': 5
147+
}
148+
149+
predictions = merged_data
150+
151+
predictions_copy = predictions.copy()
152+
153+
cp = 'canopyPosition'
154+
155+
predictions_copy[cp] = predictions_copy[cp].replace(canopy_position_mapping)
156+
157+
duplicate_mask = predictions_copy.duplicated(subset=['xmin', 'ymin', 'xmax', 'ymax'],
158+
keep=False)
159+
160+
duplicate_entries = predictions[duplicate_mask]
161+
162+
print(duplicate_entries)
163+
164+
predictions_sorted = predictions.sort_values(by=['height', cp, 'stemDiameter'],
165+
ascending=[False, False, False])
166+
167+
duplicates_mask = predictions_sorted.duplicated(
168+
subset=['xmin', 'ymin', 'xmax', 'ymax'], keep='first')
128169

170+
clean_predictions = predictions_sorted[~duplicates_mask]
129171

130-
return merged_data
172+
return clean_predictions

neonwranglerpy/lib/predict_aop_data.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212

1313
def predict_aop_data(vst_data,
14-
year, dpID='DP3.30010.001',
14+
year,
15+
dpID='DP3.30010.001',
1516
savepath='/content',
1617
site='DELA',
1718
plot_crop=True):
@@ -33,14 +34,16 @@ def predict_aop_data(vst_data,
3334
savepath='/content', site='DELA')
3435
"""
3536
retrieve_aop_data(vst_data, year, dpID, savepath)
36-
geometry = [Point(easting, northing) for easting, northing in
37-
zip(vst_data['easting'], vst_data['northing'])]
38-
epsg_codes = (vst_data['utmZone'].map(lambda x: (326 * 100) +
39-
int(x[:-1]))).astype(str)
37+
geometry = [
38+
Point(easting, northing)
39+
for easting, northing in zip(vst_data['easting'], vst_data['northing'])
40+
]
41+
epsg_codes = (
42+
vst_data['utmZone'].map(lambda x: (326 * 100) + int(x[:-1]))).astype(str)
4043
geo_data_frame = gpd.GeoDataFrame(vst_data, geometry=geometry, crs=epsg_codes.iloc[0])
4144
site_level_data = vst_data[vst_data.plotID.str.contains(site)]
42-
get_tiles = ((site_level_data.easting/1000).astype(int) * 1000).astype(str) + "_"
43-
+ ((site_level_data.northing/1000).astype(int) * 1000).astype(str)
45+
get_tiles = ((site_level_data.easting / 1000).astype(int) * 1000).astype(str) + "_"
46+
+((site_level_data.northing / 1000).astype(int) * 1000).astype(str)
4447
print(get_tiles.unique())
4548

4649
pattern = fr"{year}_{site}_.*_{get_tiles.unique()[0]}"
@@ -74,10 +77,10 @@ def predict_aop_data(vst_data,
7477
easting = row.easting
7578
northing = row.northing
7679

77-
x_min = int(affine[2] + 10/affine[0] - easting)
78-
y_min = int(affine[5] + 10/affine[0] - northing)
79-
x_max = int(affine[2] - 10/affine[0] - easting)
80-
y_max = int(affine[5] - 10/affine[0] - northing)
80+
x_min = int(affine[2] + 10 / affine[0] - easting)
81+
y_min = int(affine[5] + 10 / affine[0] - northing)
82+
x_max = int(affine[2] - 10 / affine[0] - easting)
83+
y_max = int(affine[5] - 10 / affine[0] - northing)
8184

8285
print(x_min, y_min, x_max, y_max)
8386
section = image[y_max:y_min, x_max:x_min, :]
@@ -94,4 +97,3 @@ def predict_aop_data(vst_data,
9497
all_predictions.append(prediction)
9598

9699
return all_predictions
97-

neonwranglerpy/lib/retrieve_aop_data.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Download AOP data around vst data for specified year, site."""
22
from neonwranglerpy.utilities.byTileAOP import by_tile_aop
33

4+
45
def retrieve_aop_data(data, year=2019, dpID=['DP3.30006.001'], savepath=""):
56
"""Download AOP data around vst data for specified year, site.
67
@@ -20,20 +21,18 @@ def retrieve_aop_data(data, year=2019, dpID=['DP3.30006.001'], savepath=""):
2021
savepath : str
2122
The full path to the folder in which the files would be placed locally.
2223
"""
23-
coords_for_tiles = data[[
24-
'plotID', 'siteID', 'utmZone', 'easting', 'northing']]
24+
coords_for_tiles = data[['plotID', 'siteID', 'utmZone', 'easting', 'northing']]
2525
# get tiles dimensions
26-
coords_for_tiles['easting'] = (coords_for_tiles[['easting']] /
26+
coords_for_tiles['easting'] = (coords_for_tiles[['easting']] /
2727
1000).astype(int) * 1000
28-
coords_for_tiles['northing'] = (coords_for_tiles[['northing']] /
28+
coords_for_tiles['northing'] = (coords_for_tiles[['northing']] /
2929
1000).astype(int) * 1000
30-
print(coords_for_tiles.easting.shape[0] )
30+
print(coords_for_tiles.easting.shape[0])
3131
# if there are more than 1 row, drop duplicates
3232
if coords_for_tiles.easting.shape[0] > 1:
33-
# drop duplicates values
33+
# drop duplicates values
3434
tiles = coords_for_tiles.drop_duplicates(
35-
subset=['siteID', 'utmZone', 'easting', 'northing']).reset_index(
36-
drop=True)
35+
subset=['siteID', 'utmZone', 'easting', 'northing']).reset_index(drop=True)
3736
tiles.dropna(axis=0, how='any', inplace=True)
3837

3938
# convert CHEQ into STEI
@@ -49,7 +48,7 @@ def retrieve_aop_data(data, year=2019, dpID=['DP3.30006.001'], savepath=""):
4948
else:
5049
tiles = coords_for_tiles
5150
tiles.dropna(axis=0, how='any', inplace=True)
52-
# convert CHEQ into STEI
51+
# convert CHEQ into STEI
5352
which_cheq = tiles['siteID'] == 'STEI'
5453
if which_cheq:
5554
which_easting = tiles['easting'] > 500000
@@ -64,9 +63,13 @@ def retrieve_aop_data(data, year=2019, dpID=['DP3.30006.001'], savepath=""):
6463
try:
6564
if coords_for_tiles.easting.shape[0] > 1:
6665
tile = tiles.iloc[i, :]
67-
siteID, tile_easting, tile_northing = tile['siteID'], tile['easting'], tile['northing']
66+
siteID = tile['siteID']
67+
tile_easting = tile['easting']
68+
tile_northing = tile['northing']
6869
else:
69-
siteID, tile_easting, tile_northing = tiles['siteID'], tiles['easting'][0], tiles['northing'][0]
70+
siteID = tiles['siteID']
71+
tile_easting = tiles['easting'][0]
72+
tile_northing = tiles['northing'][0]
7073

7174
by_tile_aop(prd,
7275
siteID,

neonwranglerpy/utilities/byTileAOP.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
"""Downloads the AOP data from NEON API."""
2-
import math
32
import os
43
import re
54
import numpy as np

neonwranglerpy/utilities/get_tile_urls.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from neonwranglerpy.utilities.tools import get_api
44
import numpy as np
55

6+
67
def get_tile_urls(
78
month_url,
89
easting,
@@ -29,7 +30,10 @@ def get_tile_urls(
2930
file_urls = [x for x in temp_ if f'_{easting}_{northing}' in x['name']]
3031
elif isinstance(easting, np.ndarray) and isinstance(northing, np.ndarray):
3132
for j in range(len(easting)):
32-
urls = [x for x in temp_ if f'_{easting.iloc[j]}_{northing.iloc[j]}' in x['name']]
33+
urls = [
34+
x for x in temp_
35+
if f'_{easting.iloc[j]}_{northing.iloc[j]}' in x['name']
36+
]
3337

3438
# df1 = df.loc[df['name'].str.contains(str(easting[j]))]
3539
# df2 = df.loc[df['name'].str.contains(str(northing[j]))]

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,7 @@ yapf
1010
sphinx-press-theme
1111
laspy
1212
lazrs
13+
opencv-python
14+
numpy
15+
rasterio
16+
deepforest

0 commit comments

Comments
 (0)