Skip to content

Commit 47efa52

Browse files
authored
Merge pull request #82 from satsin06/extract_training
Extract LiDAR data and tests
2 parents 0ffba35 + dec8772 commit 47efa52

File tree

11 files changed

+129
-4
lines changed

11 files changed

+129
-4
lines changed

.github/workflows/python-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
4949
- name: Upload coverage to Codecov
5050
uses: codecov/codecov-action@v1
51-
if: matrix.os == 'ubuntu-20.04' && matrix.python-version == '3.8' && matrix.r-version == 'release'
51+
if: matrix.os == 'ubuntu-20.04' && matrix.python-version == '3.8'
5252
env:
5353
OS: ${{ runner.os }}
5454
PYTHON: ${{ matrix.python-version }}
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Function to extract lidar data using RGB data and vst data."""
2+
import laspy
3+
import numpy as np
4+
import os
5+
import re
6+
from neonwranglerpy.lib.retrieve_aop_data import retrieve_aop_data
7+
8+
9+
def extract_lidar_data(rgb_data,
10+
vst_data,
11+
year,
12+
savepath="/content",
13+
dpID="DP1.30003.001",
14+
site="DELA"):
15+
"""
16+
Extract LiDAR data using RGB data and vst data.
17+
18+
Arguments:
19+
------------
20+
rgb_data: GeoDataFrame containing the plot data
21+
vst_data: DataFrame containing the plot data
22+
year: Year of the data
23+
savepath: Path to save the data
24+
dpID: LiDAR data product ID
25+
site: Site name
26+
"""
27+
retrieve_aop_data(vst_data, year, dpID, savepath)
28+
29+
site_level_data = vst_data[vst_data.plotID.str.contains(site)]
30+
get_tiles = (((site_level_data.easting / 1000).astype(int) * 1000).astype(str) + "_" +
31+
((site_level_data.northing / 1000).astype(int) * 1000).astype(str))
32+
33+
pattern = fr"{get_tiles.unique()}_classified_point_cloud.laz"
34+
35+
saveFile = savepath + "/" + dpID
36+
37+
filtered_data_list = []
38+
39+
for root, dirs, files in os.walk(saveFile):
40+
for file in files:
41+
if re.search(pattern, file):
42+
lidar_file = os.path.join(root, file)
43+
# directory_path = os.path.dirname(lidar_file) + '/'
44+
file_name = os.path.basename(lidar_file)
45+
print(lidar_file)
46+
47+
lidar = laspy.read(lidar_file)
48+
49+
x = lidar.x
50+
y = lidar.y
51+
z = lidar.z
52+
53+
data = np.vstack((x, y, z)).transpose()
54+
55+
lidar_dir = savepath + "/data/lidar"
56+
os.makedirs(lidar_dir, exist_ok=True)
57+
58+
for index, row in rgb_data.iterrows():
59+
geometry = row['geometry']
60+
61+
minx, miny, maxx, maxy = geometry.bounds
62+
63+
filtered_data = data[(data[:, 0] >= minx) & (data[:, 0] <= maxx) &
64+
(data[:, 1] >= miny) & (data[:, 1] <= maxy)]
65+
66+
if len(filtered_data) > 0:
67+
filtered_data_list.append(filtered_data)
68+
69+
filename = os.path.join(lidar_dir,
70+
f"lidar_{file_name}_{index}.npy")
71+
np.save(filename, filtered_data)
72+
73+
print(f"LiDAR data for index {index} saved as '{filename}'")
74+
else:
75+
print(f"No LiDAR data for index {index}")
76+
77+
filtered_data_array = np.concatenate(filtered_data_list, axis=0)
78+
79+
return filtered_data_array

neonwranglerpy/lib/predict_aop_data.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import re
66
import matplotlib.pyplot as plt
7+
import pandas as pd
78
from shapely.geometry import Point
89
import rasterio
910
from deepforest import main
@@ -42,8 +43,8 @@ def predict_aop_data(vst_data,
4243
vst_data['utmZone'].map(lambda x: (326 * 100) + int(x[:-1]))).astype(str)
4344
geo_data_frame = gpd.GeoDataFrame(vst_data, geometry=geometry, crs=epsg_codes.iloc[0])
4445
site_level_data = vst_data[vst_data.plotID.str.contains(site)]
45-
get_tiles = ((site_level_data.easting / 1000).astype(int) * 1000).astype(str) + "_"
46-
+((site_level_data.northing / 1000).astype(int) * 1000).astype(str)
46+
get_tiles = (((site_level_data.easting / 1000).astype(int) * 1000).astype(str) + "_" +
47+
((site_level_data.northing / 1000).astype(int) * 1000).astype(str))
4748
print(get_tiles.unique())
4849

4950
pattern = fr"{year}_{site}_.*_{get_tiles.unique()[0]}"
@@ -96,4 +97,6 @@ def predict_aop_data(vst_data,
9697

9798
all_predictions.append(prediction)
9899

99-
return all_predictions
100+
all_predictions_df = pd.concat(all_predictions)
101+
102+
return all_predictions_df

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ opencv-python
1414
numpy
1515
rasterio
1616
deepforest
17+
laspy[lazrs,laszip]

tests/raw_data/dataframe.cpg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
ISO-8859-1

tests/raw_data/dataframe.dbf

749 KB
Binary file not shown.

tests/raw_data/dataframe.prj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
PROJCS["WGS_1984_UTM_Zone_16N",GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137.0,298.257223563]],PRIMEM["Greenwich",0.0],UNIT["Degree",0.0174532925199433]],PROJECTION["Transverse_Mercator"],PARAMETER["False_Easting",500000.0],PARAMETER["False_Northing",0.0],PARAMETER["Central_Meridian",-87.0],PARAMETER["Scale_Factor",0.9996],PARAMETER["Latitude_Of_Origin",0.0],UNIT["Meter",1.0]]

tests/raw_data/dataframe.shp

27.9 KB
Binary file not shown.

tests/raw_data/dataframe.shx

1.73 KB
Binary file not shown.

tests/test_extract_lidar_data.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""Test extract_lidar_data.py file."""
2+
import geopandas as gpd
3+
import pandas as pd
4+
from neonwranglerpy.lib.extract_lidar_data import extract_lidar_data
5+
6+
7+
def test_extract_lidar_data():
8+
"""Test extract_lidar_data function."""
9+
savepath = 'tests/raw_data'
10+
vst_data = pd.read_csv('tests/raw_data/vst_data.csv')
11+
12+
rgb_data = gpd.read_file("tests/raw_data/dataframe.shp")
13+
14+
result = extract_lidar_data(rgb_data=rgb_data,
15+
vst_data=vst_data,
16+
year="2018",
17+
savepath=savepath,
18+
dpID="DP1.30003.001",
19+
site="DELA")
20+
21+
assert (vst_data.shape[0] > 0) & (vst_data.shape[1] > 0)
22+
assert len(result) > 0

0 commit comments

Comments
 (0)