Skip to content

Commit dec8772

Browse files
committed
Add test for predict_aop_data
1 parent 06ed817 commit dec8772

File tree

3 files changed

+27
-5
lines changed

3 files changed

+27
-5
lines changed

neonwranglerpy/lib/extract_lidar_data.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Some text here."""
1+
"""Function to extract lidar data using RGB data and vst data."""
22
import laspy
33
import numpy as np
44
import os
@@ -13,9 +13,10 @@ def extract_lidar_data(rgb_data,
1313
dpID="DP1.30003.001",
1414
site="DELA"):
1515
"""
16-
Extract LiDAR data with geo_data_frame and image predictions.
16+
Extract LiDAR data using RGB data and vst data.
1717
1818
Arguments:
19+
------------
1920
rgb_data: GeoDataFrame containing the plot data
2021
vst_data: DataFrame containing the plot data
2122
year: Year of the data

neonwranglerpy/lib/predict_aop_data.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import re
66
import matplotlib.pyplot as plt
7+
import pandas as pd
78
from shapely.geometry import Point
89
import rasterio
910
from deepforest import main
@@ -42,8 +43,8 @@ def predict_aop_data(vst_data,
4243
vst_data['utmZone'].map(lambda x: (326 * 100) + int(x[:-1]))).astype(str)
4344
geo_data_frame = gpd.GeoDataFrame(vst_data, geometry=geometry, crs=epsg_codes.iloc[0])
4445
site_level_data = vst_data[vst_data.plotID.str.contains(site)]
45-
get_tiles = ((site_level_data.easting / 1000).astype(int) * 1000).astype(str) + "_"
46-
+((site_level_data.northing / 1000).astype(int) * 1000).astype(str)
46+
get_tiles = (((site_level_data.easting / 1000).astype(int) * 1000).astype(str) + "_" +
47+
((site_level_data.northing / 1000).astype(int) * 1000).astype(str))
4748
print(get_tiles.unique())
4849

4950
pattern = fr"{year}_{site}_.*_{get_tiles.unique()[0]}"
@@ -96,4 +97,6 @@ def predict_aop_data(vst_data,
9697

9798
all_predictions.append(prediction)
9899

99-
return all_predictions
100+
all_predictions_df = pd.concat(all_predictions)
101+
102+
return all_predictions_df

tests/test_predict_aop_data.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""Test predict_aop_data.py file."""
2+
import pandas as pd
3+
from neonwranglerpy.lib.predict_aop_data import predict_aop_data
4+
5+
6+
def test_predict_aop_data():
7+
"""Test predict_aop_data function."""
8+
savepath = 'tests/raw_data'
9+
vst_data = pd.read_csv('tests/raw_data/vst_data.csv')
10+
11+
result = predict_aop_data(vst_data=vst_data.iloc[1:10, :], year='2018',
12+
dpID='DP3.30010.001', savepath=savepath, site='DELA',
13+
plot_crop=False)
14+
15+
assert (vst_data.shape[0] > 0) & (vst_data.shape[1] > 0)
16+
assert len(result) > 0
17+
assert isinstance(result, pd.DataFrame)
18+
assert result[['xmin', 'ymin', 'xmax', 'ymax']].duplicated().any()

0 commit comments

Comments
 (0)