Skip to content

Commit

Permalink
Merge pull request #81 from IGNF/fix-inference-when-num_nodes-is-one
Browse files Browse the repository at this point in the history
Dev: allow inference for smallest clouds possible + addresses many smaller issues.
  • Loading branch information
CharlesGaydon committed Aug 9, 2023
2 parents bbbf4c8 + b98f120 commit b4d2ef8
Show file tree
Hide file tree
Showing 14 changed files with 142 additions and 99 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/cicd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ jobs:
- name: Example inference run via Docker with default config and checkpoint
run: >
docker run
-v /var/data/cicd/CICD_github_assets/myria3d_V3.3.0/inputs/:/inputs/
-v /var/data/cicd/CICD_github_assets/myria3d_V3.3.0/outputs/:/outputs/
-v /var/data/cicd/CICD_github_assets/myria3d_V3.4.0/inputs/:/inputs/
-v /var/data/cicd/CICD_github_assets/myria3d_V3.4.0/outputs/:/outputs/
--ipc=host
--shm-size=2gb
myria3d
Expand All @@ -54,14 +54,14 @@ jobs:
- name: Example inference run via Docker with inference-time subtiles overlap to smooth-out results.
run: >
docker run
-v /var/data/cicd/CICD_github_assets/myria3d_V3.3.0/inputs/:/inputs/
-v /var/data/cicd/CICD_github_assets/myria3d_V3.3.0/outputs/:/outputs/
-v /var/data/cicd/CICD_github_assets/myria3d_V3.4.0/inputs/:/inputs/
-v /var/data/cicd/CICD_github_assets/myria3d_V3.4.0/outputs/:/outputs/
--ipc=host
--shm-size=2gb
myria3d
python run.py
--config-path /inputs/
--config-name proto151_V2.0_epoch_100_Myria3DV3.1.0_predict_config_V3.3.0
--config-name proto151_V2.0_epoch_100_Myria3DV3.1.0_predict_config_V3.4.0
predict.ckpt_path=/inputs/proto151_V2.0_epoch_100_Myria3DV3.1.0.ckpt
predict.src_las=/inputs/792000_6272000_subset_buildings.las
predict.output_dir=/outputs/
Expand Down
28 changes: 28 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# main

# 3.4.8
- Raise an informative error in case of unexpected task_name.

# 3.4.7
- Remove tqdm when splitting a lidar tile to avoid cluttered logs during data preparation.

# 3.4.6
- Document the possible use of ign-pdal-tools for colorization.

# 3.4.5
- Set a default task_name (fit) to avoid common error at lauch time.

# 3.4.4
- Remove duplicated experiment configuration.

# 3.4.3
- Remove outdated and incorrect hydra parameter in config.yaml.

# 3.4.2
- Reconstruct absolute path of input LAS files explicitely, removing a costly glob operation.

# 3.4.1
- Fix dataset description for pacasam: there was an unwanted int-to-int mapping in classification_dict.

# 3.4.0
- Allow inference for the smallest possible patches (num_nodes=1) to have consistent inference behavior
4 changes: 0 additions & 4 deletions configs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,3 @@ defaults:
# enable color logging
- override hydra/hydra_logging: colorlog
- override hydra/job_logging: colorlog

hydra:
searchpath:
- pkg://default_files_for_predict
2 changes: 1 addition & 1 deletion configs/datamodule/hdf5_datamodule.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pre_filter:
_target_: functools.partial
_args_:
- "${get_method:myria3d.pctl.dataset.utils.pre_filter_below_n_points}"
min_num_nodes: 50
min_num_nodes: 1

tile_width: 1000
subtile_width: 50
Expand Down
16 changes: 16 additions & 0 deletions configs/dataset_description/20230601_lidarhd_pacasam_dataset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
_convert_: all # For omegaconf struct to be converted to python dictionnaries
# classification_preprocessing_dict = {source_class_code_int: target_class_code_int},
# 3: medium vegetation -> vegetation
# 4: high vegetation -> vegetation
# 0: no processing --> unclassified
# 66: synthetic points --> noise (synthetic points are useful for specific modelling task on already classified data).
# We set them to noise so that they are ignored during training.
# Codes that should not have been in the data: 100, 101.
classification_preprocessing_dict: {3: 5, 4: 5, 0: 1, 66: 65, 100: 1, 101: 1}

# classification_dict = {code_int: name_str, ...} and MUST be sorted (increasing order).
classification_dict: {1: "unclassified", 2: "ground", 5: vegetation, 6: "building", 9: water, 17: bridge, 64: lasting_above}

# Input and output dims of neural net are dataset dependant:
d_in: 9
num_classes: 7
18 changes: 0 additions & 18 deletions configs/experiment/RandLaNet_base_run_FR_pyg_randla_net.yaml

This file was deleted.

22 changes: 12 additions & 10 deletions docs/source/tutorials/prepare_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

## Peprocessing functions

The loading function is dataset dependant, and is `lidar_hd_pre_transform` by default. The function takes points loaded from a LAS file via pdal as input, and returns a `pytorch_geometric.Data` object following the standard naming convention of `pytorch_geometric`, plus a list of features names for later use in transforms.
The loading function is dataset dependant, and is `lidar_hd_pre_transform` by default. The function takes points loaded from a LAS file via pdal as input, and returns a `pytorch_geometric.Data` object following the standard naming convention of `pytorch_geometric`, plus a list of features names for later use in transforms. In the loading function, the return number and color information (RGBI) are scaled to 0-1 interval, a NDVI and an average color ((R+G+B)/3) dimension are created, and points that may be occluded (as indicated by higher return number) have their color set to 0.

It is adapted to the French Lidar HD data provided by IGN (see [the official page](https://geoservices.ign.fr/lidarhd) - link in French). Return number and color information (RGBI) are scaled to 0-1 interval, a NDVI and an average color ((R+G+B)/3) dimension are created, and points that may be occluded (as indicated by higher return number) have their color set to 0.
Customization: You may want to implement your own logic (e.g. with custom, additional features) in directory `points_pre_transform`. It then needs to be referenced similarly to `lidar_hd_pre_transform`.

You may want to implement your own logic (e.g. with custom, additional features) in directory `points_pre_transform`. It then needs to be referenced similarly to `lidar_hd_pre_transform`.
The loading function is designed for the French Lidar HD data provided by IGN (see [the official page](https://geoservices.ign.fr/lidarhd) - link in French). Note that the clouds are shared without color information, and should be colorized (RGB+Infrared) to use myria3d. The [open-source ign-pdal-tools library](https://pypi.org/project/ign-pdal-tools/) is a convenient toolkit that can be used to colorize the raw clouds with IGN aerial imagery (see function 'pdaltools.color.color(...)').

If you use your own classification convention , you will need to create a `dataset_description` configuration (for an example see `configs/dataset_description/20220607_151_dalles_proto.yaml`).
Customization: If you use a different classification (e.g. additional classes), you will need to create a `dataset_description` configuration (similar to `configs/dataset_description/20220607_151_dalles_proto.yaml`).

Additionnaly, you can control cloud sampling parameters via two configurations:
- `configs/datamodule/transforms/preparations/points_budget.yaml`: (defaut) allows variable cloud size within lower and higher boundaries.
Expand All @@ -17,15 +17,17 @@ Additionnaly, you can control cloud sampling parameters via two configurations:

## Preparing the dataset

Input point clouds need to be splitted in subtiles that can be digested by segmentation models. We found that a receptive field of 50m*50m was a good balance between context and memory intensity. For faster training, this split can be done once, to avoid loading large file in memory multiple times.

To perform a training, you will need to specify these parameters of the datamodule config group:
- `data_dir`: path to a directory in which a set of LAS files are stored (can be nested in subdirectories).
To perform a training, you will need to specify these parameters in the datamodule config group:
- `data_dir`: path to a directory in which a set of LAS files are stored. Clouds must be nested in subdirectories named according to their spli: train, val, or test.
- `split_csv_path`: path to a CSV file with schema `basename,split`, specifying a train/val/test spit for your data.

These will be composed into a single file dataset for which you can specify a path via the `datamodule.hdf5_file_path` parameter. This happens on the fly, therefore a first training might take some time, but this should only happens once.
Under the hood, the path of each LAS file will be reconstructed like this: '{data_dir}/{split}/{basename}'.

Large input point clouds need to be divided in smaller clouds that can be digested by segmentation models. We found that a receptive field of 50m x 50m was a good balance between context and memory intensity. The division is performed once, to avoid loading large file in memory multiple times during training.

After division, the smaller clouds are preprocessed (i.e. selection of specific LAS dimensions, on-the-fly creation of dimensions) and regrouped into a single HDF5 file whose path is specified via the `datamodule.hdf5_file_path` parameter.

Once this is done, you do not need sources anymore, and simply specifying the path to the HDF5 dataset is enough.
The HDF5 dataset is created at training time. It should only happens once. Once this is done, you do not need sources anymore, and simply specifying the path to the HDF5 dataset is enough (there is no need for data_dir or split_csv_path parameters anymore).

It's also possible to create the hdf5 file without training any model: just fill the `datamodule.hdf5_file_path` parameter as before to specify the file path, but use `task=create_hdf5` instead of `task=fit`.

Expand Down
40 changes: 15 additions & 25 deletions myria3d/pctl/dataset/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import glob
import json
import math
from pathlib import Path
import subprocess as sp
from numbers import Number
from typing import Dict, List, Literal, Union
Expand All @@ -10,7 +11,6 @@
import pdal
from scipy.spatial import cKDTree
from shapely.geometry import Point
from tqdm import tqdm

SPLIT_TYPE = Union[Literal["train"], Literal["val"], Literal["test"]]
SHAPE_TYPE = Union[Literal["disk"], Literal["square"]]
Expand All @@ -31,9 +31,7 @@ def find_file_in_dir(data_dir: str, basename: str) -> str:
return files[0]


def get_mosaic_of_centers(
tile_width: Number, subtile_width: Number, subtile_overlap: Number = 0
):
def get_mosaic_of_centers(tile_width: Number, subtile_width: Number, subtile_overlap: Number = 0):
if subtile_overlap < 0:
raise ValueError("datamodule.subtile_overlap must be positive.")

Expand Down Expand Up @@ -63,9 +61,7 @@ def pdal_read_las_array(las_path: str):
def pdal_read_las_array_as_float32(las_path: str):
"""Read LAS as a a named array, casted to floats."""
arr = pdal_read_las_array(las_path)
all_floats = np.dtype(
{"names": arr.dtype.names, "formats": ["f4"] * len(arr.dtype.names)}
)
all_floats = np.dtype({"names": arr.dtype.names, "formats": ["f4"] * len(arr.dtype.names)})
return arr.astype(all_floats)


Expand Down Expand Up @@ -101,6 +97,7 @@ def get_pdal_info_metadata(las_path: str) -> Dict:

return json_info["metadata"]


# hdf5, iterable


Expand All @@ -125,32 +122,26 @@ def split_cloud_into_samples(
"""
points = pdal_read_las_array_as_float32(las_path)
pos = np.asarray(
[points["X"], points["Y"], points["Z"]], dtype=np.float32
).transpose()
pos = np.asarray([points["X"], points["Y"], points["Z"]], dtype=np.float32).transpose()
kd_tree = cKDTree(pos[:, :2] - pos[:, :2].min(axis=0))
XYs = get_mosaic_of_centers(
tile_width, subtile_width, subtile_overlap=subtile_overlap
)
for center in tqdm(XYs, desc="Centers"):
XYs = get_mosaic_of_centers(tile_width, subtile_width, subtile_overlap=subtile_overlap)
for center in XYs:
radius = subtile_width // 2 # Square receptive field.
minkowski_p = np.inf
if shape == "disk":
# Disk receptive field.
# Adapt radius to have complete coverage of the data, with a slight overlap between samples.
minkowski_p = 2
radius = radius * math.sqrt(2)
sample_idx = np.array(
kd_tree.query_ball_point(center, r=radius, p=minkowski_p)
)
sample_idx = np.array(kd_tree.query_ball_point(center, r=radius, p=minkowski_p))
if not len(sample_idx):
# no points in this receptive fields
continue
sample_points = points[sample_idx]
yield sample_idx, sample_points


def pre_filter_below_n_points(data, min_num_nodes=50):
def pre_filter_below_n_points(data, min_num_nodes=1):
return data.pos.shape[0] < min_num_nodes


Expand All @@ -171,16 +162,15 @@ def make_circle_wkt(center, subtile_width):
return wkt


def get_las_paths_by_split_dict(data_dir: str, split_csv_path: str) -> LAS_PATHS_BY_SPLIT_DICT_TYPE:
def get_las_paths_by_split_dict(
data_dir: str, split_csv_path: str
) -> LAS_PATHS_BY_SPLIT_DICT_TYPE:
las_paths_by_split_dict: LAS_PATHS_BY_SPLIT_DICT_TYPE = {}
split_df = pd.read_csv(split_csv_path)
for phase in ["train", "val", "test"]:
basenames = split_df[
split_df.split == phase
].basename.tolist()
las_paths_by_split_dict[phase] = [
find_file_in_dir(data_dir, b) for b in basenames
]
basenames = split_df[split_df.split == phase].basename.tolist()
# Reminder: an explicit data structure with ./val, ./train, ./test subfolder is required.
las_paths_by_split_dict[phase] = [str(Path(data_dir) / phase / b) for b in basenames]

if not las_paths_by_split_dict:
raise FileNotFoundError(
Expand Down
3 changes: 2 additions & 1 deletion myria3d/pctl/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def standardize_channel(self, channel_data: torch.Tensor, clamp_sigma: int = 3):
"""Sample-wise standardization y* = (y-y_mean)/y_std. clamping to ignore large values."""
mean = channel_data.mean()
std = channel_data.std() + 10**-6
if torch.isnan(std):
std = 1.0
standard = (channel_data - mean) / std
clamp = clamp_sigma * std
clamped = torch.clamp(input=standard, min=-clamp, max=clamp)
Expand Down Expand Up @@ -177,7 +179,6 @@ def __init__(
classification_preprocessing_dict: Dict[int, int],
classification_dict: Dict[int, str],
):

self._set_preprocessing_mapper(classification_preprocessing_dict)
self._set_mapper(classification_dict)

Expand Down
2 changes: 1 addition & 1 deletion package_metadata.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__: "3.3.3"
__version__: "3.4.8"
__name__: "myria3d"
__url__: "https://github.com/IGNF/myria3d"
__description__: "Deep Learning for the Semantic Segmentation of Aerial Lidar Point Clouds"
Expand Down
Loading

0 comments on commit b4d2ef8

Please sign in to comment.