Skip to content

Commit

Permalink
evaluation: Move over files
Browse files Browse the repository at this point in the history
  • Loading branch information
marcojob committed Oct 1, 2024
1 parent d4d168e commit dfe240f
Show file tree
Hide file tree
Showing 18 changed files with 605 additions and 1 deletion.
4 changes: 4 additions & 0 deletions .devcontainer/devcontainer-all-packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ main() {
build-essential
ca-certificates
ccache
cm-super
curl
gawk
gnupg
Expand All @@ -26,6 +27,9 @@ main() {
software-properties-common
ssh
sudo
texlive
texlive-fonts-recommended
texlive-latex-extra
udev
unzip
usbutils
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:

strategy:
matrix:
ci_script: [pr_unittest]
ci_script: [pr_unittest, pr_evaluate_networks, pr_train_networks]

steps:
- name: Checkout
Expand Down
21 changes: 21 additions & 0 deletions ci/pr_evaluate_networks.bash
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/bin/bash

# Install radarmeetsvision
pip install -e .

RESULTS_DIR=tests/resources/results
if [ -d $RESULTS_DIR ]; then
rm -r $RESULTS_DIR
fi
rm -rf tests/resources/*.pkl

# Run the evaluation script
python3 scripts/evaluation/evaluate_networks.py --dataset tests/resources --config tests/resources/test_evaluation.json --output tests/resources --network tests/resources

TEX_FILE="tests/resources/results/results_table0.tex"
if [[ -f "$TEX_FILE" && -s "$TEX_FILE" ]]; then
echo "Evaluation script successful, .tex table exists and is not empty."
else
echo "Evaluation script failed, .tex table does not exist or is empty."
exit 1
fi
56 changes: 56 additions & 0 deletions ci/pr_train_networks.bash
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/bin/bash

user=$USER
checkpoints=tests/resources
save_path=tests/resources
datasets=tests/resources

config_path=tests/resources
config_radar=$config_path/test_train_radar.json
config_metric=$config_path/test_train_metric.json
config_relative=$config_path/test_train_relative.json

# Install radarmeetsvision
pip install -e rmv-core

# RADAR TRAINING (depth prior + 2 output channels)
python3 scripts/train.py \
--checkpoints $checkpoints \
--config $config_radar \
--datasets $datasets \
--results ""

if [ $? -eq 0 ]; then
echo "Radar training script successful"
else
echo "Training script failed"
exit 1
fi

# RGB TRAINING (no depth prior + 1 output channel)
python3 scripts/train.py \
--checkpoints $checkpoints \
--config $config_metric \
--datasets $datasets \
--results ""

if [ $? -eq 0 ]; then
echo "RGB training script successful"
else
echo "Training script failed"
exit 1
fi

# Relative RGB TRAINING (no depth prior + 1 output channel)
python3 scripts/train.py \
--checkpoints $checkpoints \
--config $config_relative \
--datasets $datasets \
--results ""

if [ $? -eq 0 ]; then
echo "Relative RGB training script successful"
else
echo "Training script failed"
exit 1
fi
77 changes: 77 additions & 0 deletions scripts/evaluation/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
{
"scenarios": {
"Industrial Hall": "maschinenhalle0",
"Agricultural Field": "outdoor0",
"Rhône Glacier": "rhone_flight"
},
"networks": {
"Metric Depth \\cite{depthanythingv2}-S": "rgb_s_bs8_e9.pth",
"Metric Depth \\cite{depthanythingv2}-B": "rgb_b_bs4_e8.pth",
"Scaled Relative Depth \\cite{depthanythingv2}-S": "relrgb_s_bs8_e9.pth",
"Scaled Relative Depth \\cite{depthanythingv2}-B": "relrgb_b_bs4_e9.pth",
"Ours-S": "radar_s_bs8_e19.pth",
"Ours-B": "radar_b_bs4_e21.pth"
},
"Metric Depth \\cite{depthanythingv2}-S": {
"use_depth_prior": false,
"output_channels": 1,
"relative_depth": 0,
"depth_min": 0.19983673095703125,
"depth_max": 120.49285888671875,
"encoder": "vits",
"marker": "X",
"plot_prediction": 0
},
"Metric Depth \\cite{depthanythingv2}-B": {
"use_depth_prior": false,
"output_channels": 1,
"relative_depth": 0,
"depth_min": 0.19983673095703125,
"depth_max": 120.49285888671875,
"encoder": "vitb",
"marker": "X",
"plot_prediction": 0
},
"Scaled Relative Depth \\cite{depthanythingv2}-S": {
"use_depth_prior": false,
"output_channels": 1,
"relative_depth": 1,
"depth_min": 0.0,
"depth_max": 1.0,
"encoder": "vits",
"marker": "D",
"plot_prediction": 0
},
"Scaled Relative Depth \\cite{depthanythingv2}-B": {
"use_depth_prior": false,
"output_channels": 1,
"relative_depth": 1,
"depth_min": 0.0,
"depth_max": 1.0,
"encoder": "vitb",
"marker": "D",
"plot_prediction": 0
},
"Ours-S": {
"use_depth_prior": true,
"output_channels": 2,
"relative_depth": 0,
"depth_min": 0.19983673095703125,
"depth_max": 120.49285888671875,
"encoder": "vits",
"marker": "o",
"plot_prediction": 0
},
"Ours-B": {
"use_depth_prior": true,
"output_channels": 2,
"relative_depth": 0,
"depth_min": 0.19983673095703125,
"depth_max": 120.49285888671875,
"encoder": "vitb",
"marker": "o",
"plot_prediction": 1
},
"height": 480,
"width": 640
}
92 changes: 92 additions & 0 deletions scripts/evaluation/create_scatter_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

from matplotlib.gridspec import GridSpec
from pathlib import Path
from PIL import Image

mpl.rcParams['font.size'] = 8
mpl.rcParams['figure.figsize'] = [9, 2.7]
mpl.rcParams['lines.linewidth'] = 0.6
mpl.rcParams['grid.linewidth'] = 0.5
mpl.rcParams['axes.linewidth'] = 0.6
mpl.rcParams['axes.labelsize'] = 8
mpl.rcParams['axes.titlesize'] = 8
mpl.rcParams['legend.fontsize'] = 8
mpl.rcParams['xtick.labelsize'] = 8
mpl.rcParams['ytick.labelsize'] = 8
mpl.rcParams['text.usetex'] = True
mpl.rcParams['font.family'] = 'serif'

def create_scatter_plot(results_per_scenario_sample, config, output_dir, subsample=10):
scenario_color = {
'Industrial Hall': 'k',
'Agricultural Field': 'green',
'Rhône Glacier': 'navy'
}
label_dict = {
"Metric Depth \\cite{depthanythingv2}-B": "Metric Depth-B",
"Ours-B": "Ours-B"
}
samples = [['00000_rgb.jpg', '00000_dp.jpg'],
['00050_rgb.jpg', '00050_dp.jpg'],
['00250_rgb.jpg', '00250_dp.jpg']]

sample_dir = Path('scripts/evaluation/samples')
img_out = None
for i, sample in enumerate(samples):
rgb_file = sample_dir / sample[0]
dp_file = sample_dir / sample[1]

rgb = np.array(Image.open(rgb_file))
dp = np.array(Image.open(dp_file))
img = np.concatenate((rgb, dp), axis=1)
if img_out is None:
img_out = img
else:
img_out = np.concatenate((img_out, img), axis=0)
border=5
img_out[:border, :] = (0.0, 0.0, 0.0)
img_out[-border:, :] = (0.0, 0.0, 0.0)
img_out[:, :border] = (0.0, 0.0, 0.0)
img_out[:, -border:] = (0.0, 0.0, 0.0)


fig = plt.figure(figsize=(9, 2.7)) # Adjust the figure size if needed
gs = GridSpec(1, 2, width_ratios=[1, 2]) # Set the width ratios to 1:2

ax0 = fig.add_subplot(gs[0])
ax0.imshow(img_out)
ax0.axis('off')

ax1 = fig.add_subplot(gs[1])
for scenario_key in config['scenarios'].keys():
for i, network_key in enumerate(config['networks'].keys()):
if '-B' in network_key and not 'Scaled' in network_key:
average_depths = results_per_scenario_sample[scenario_key][network_key]['average_depth']
abs_rel_values = results_per_scenario_sample[scenario_key][network_key]['abs_rel']
average_depths_subsampled = average_depths[::subsample]
abs_rel_values_subsampled = abs_rel_values[::subsample]
label = label_dict[network_key] + ' ' + (scenario_key if 'Metric' in network_key else '')
ax1.scatter(average_depths_subsampled, abs_rel_values_subsampled, label=label, marker=config[network_key]['marker'], c=scenario_color[scenario_key], s=25, alpha=0.5)


# Set axis labels, title, and legend
plt.xlabel('Average Scene Depth [m]')
plt.ylabel('Absolute Relative Error [ ]')
plt.legend(loc='upper right')
ax1.grid()
plt.tight_layout()

# Save the plot
output_file = Path(output_dir) / f'results_overview.png'
plt.savefig(str(output_file), transparent=True, bbox_inches='tight', dpi=400)
plt.close()

# Post-process the saved image to crop any unnecessary white space
img = Image.open(str(output_file))
img = img.convert("RGBA")
bbox = img.getbbox()
cropped_img = img.crop(bbox)
cropped_img.save(str(output_file))
121 changes: 121 additions & 0 deletions scripts/evaluation/evaluate_networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
######################################################################
#
# Copyright (c) 2024 ETHZ Autonomous Systems Lab. All rights reserved.
#
######################################################################

import argparse
import logging
import pickle
import radarmeetsvision as rmv

from pathlib import Path
from results_table_template import generate_tables
from create_scatter_plot import create_scatter_plot

logger = logging.getLogger(__name__)


class Evaluation:
def __init__(self, config, scenario_key, network_key, args):
self.results_per_sample = {}
self.results_dict = {}
self.interface = rmv.Interface()
self.networks_dir = Path(args.network)
self.datasets_dir = args.dataset
self.results, self.results_per_sample = None, None
self.setup_interface(config, scenario_key, network_key)
self.run(network_key)

def run(self, network_key):
self.interface.validate_epoch(0, self.loader)
self.results, self.results_per_sample = self.interface.get_results()
self.results['method'] = network_key

def setup_interface(self, config, scenario_key, network_key):
self.interface.set_epochs(1)

network_config = config[network_key]
self.interface.set_encoder(network_config['encoder'])

depth_min, depth_max = network_config['depth_min'], network_config['depth_max']
self.interface.set_depth_range((depth_min, depth_max))
self.interface.set_output_channels(network_config['output_channels'])
self.interface.set_use_depth_prior(network_config['use_depth_prior'])

network_file = config['networks'][network_key]
if network_file is not None:
network_file = self.networks_dir / network_file
self.interface.load_model(pretrained_from=network_file)

self.interface.set_size(config['height'], config['width'])
self.interface.set_batch_size(1)
self.interface.set_criterion()

dataset_list = [config['scenarios'][scenario_key]]
self.loader, _ = self.interface.get_dataset_loader('val_all', self.datasets_dir, dataset_list)

def get_results_per_sample(self):
return self.results_per_sample

def get_results(self):
return self.results


def main():
parser = argparse.ArgumentParser(description='Create evaluation results for paper')
parser.add_argument('--config', type=str, default='scripts/evaluation/config.json', help='Path to the JSON config file.')
parser.add_argument('--dataset', type=str, required=True, help='Path to the dataset directory')
parser.add_argument('--output', type=str, required=True, help='Path to the output directory')
parser.add_argument('--network', type=str, help='Path to the network directory')
args = parser.parse_args()

rmv.setup_global_logger()

config = rmv.load_config(args.config)

results_per_scenario = {}
results_per_scenario_sample = {}
results_per_scenario_file = Path(args.output) / "results_per_scenario.pkl"
results_per_scenario_sample_file = Path(args.output) / "results_per_scenario_sample.pkl"
if not results_per_scenario_file.is_file() or not results_per_scenario_sample_file.is_file():
for scenario_key in config['scenarios'].keys():
results_per_scenario[scenario_key] = []
results_per_sample = {}

for network_key in config['networks'].keys():
logger.info(f'Evaluation: {scenario_key} {network_key}')
eval_obj = Evaluation(config, scenario_key, network_key, args)
results_per_sample[network_key] = eval_obj.get_results_per_sample()

results_dict = eval_obj.get_results()
if results_dict:
# Used to generate main results table
results_per_scenario[scenario_key].append(results_dict.copy())
logger.info(f'{scenario_key} {network_key} {results_dict["abs_rel"]:.3f}')

# Used to generate visual grid
results_per_scenario_sample[scenario_key] = results_per_sample

with results_per_scenario_file.open('wb') as f:
pickle.dump(results_per_scenario, f)

with results_per_scenario_sample_file.open('wb') as f:
pickle.dump(results_per_scenario_sample, f)

else:
with results_per_scenario_file.open('rb') as f:
results_per_scenario = pickle.load(f)

with results_per_scenario_sample_file.open('rb') as f:
results_per_scenario_sample = pickle.load(f)

# Generate the main results table
if results_per_scenario:
generate_tables(args.output, results_per_scenario)

create_scatter_plot(results_per_scenario_sample, config, args.output)


if __name__ == "__main__":
main()
Loading

0 comments on commit dfe240f

Please sign in to comment.