-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
18 changed files
with
605 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#!/bin/bash | ||
|
||
# Install radarmeetsvision | ||
pip install -e . | ||
|
||
RESULTS_DIR=tests/resources/results | ||
if [ -d $RESULTS_DIR ]; then | ||
rm -r $RESULTS_DIR | ||
fi | ||
rm -rf tests/resources/*.pkl | ||
|
||
# Run the evaluation script | ||
python3 scripts/evaluation/evaluate_networks.py --dataset tests/resources --config tests/resources/test_evaluation.json --output tests/resources --network tests/resources | ||
|
||
TEX_FILE="tests/resources/results/results_table0.tex" | ||
if [[ -f "$TEX_FILE" && -s "$TEX_FILE" ]]; then | ||
echo "Evaluation script successful, .tex table exists and is not empty." | ||
else | ||
echo "Evaluation script failed, .tex table does not exist or is empty." | ||
exit 1 | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#!/bin/bash | ||
|
||
user=$USER | ||
checkpoints=tests/resources | ||
save_path=tests/resources | ||
datasets=tests/resources | ||
|
||
config_path=tests/resources | ||
config_radar=$config_path/test_train_radar.json | ||
config_metric=$config_path/test_train_metric.json | ||
config_relative=$config_path/test_train_relative.json | ||
|
||
# Install radarmeetsvision | ||
pip install -e rmv-core | ||
|
||
# RADAR TRAINING (depth prior + 2 output channels) | ||
python3 scripts/train.py \ | ||
--checkpoints $checkpoints \ | ||
--config $config_radar \ | ||
--datasets $datasets \ | ||
--results "" | ||
|
||
if [ $? -eq 0 ]; then | ||
echo "Radar training script successful" | ||
else | ||
echo "Training script failed" | ||
exit 1 | ||
fi | ||
|
||
# RGB TRAINING (no depth prior + 1 output channel) | ||
python3 scripts/train.py \ | ||
--checkpoints $checkpoints \ | ||
--config $config_metric \ | ||
--datasets $datasets \ | ||
--results "" | ||
|
||
if [ $? -eq 0 ]; then | ||
echo "RGB training script successful" | ||
else | ||
echo "Training script failed" | ||
exit 1 | ||
fi | ||
|
||
# Relative RGB TRAINING (no depth prior + 1 output channel) | ||
python3 scripts/train.py \ | ||
--checkpoints $checkpoints \ | ||
--config $config_relative \ | ||
--datasets $datasets \ | ||
--results "" | ||
|
||
if [ $? -eq 0 ]; then | ||
echo "Relative RGB training script successful" | ||
else | ||
echo "Training script failed" | ||
exit 1 | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
{ | ||
"scenarios": { | ||
"Industrial Hall": "maschinenhalle0", | ||
"Agricultural Field": "outdoor0", | ||
"Rhône Glacier": "rhone_flight" | ||
}, | ||
"networks": { | ||
"Metric Depth \\cite{depthanythingv2}-S": "rgb_s_bs8_e9.pth", | ||
"Metric Depth \\cite{depthanythingv2}-B": "rgb_b_bs4_e8.pth", | ||
"Scaled Relative Depth \\cite{depthanythingv2}-S": "relrgb_s_bs8_e9.pth", | ||
"Scaled Relative Depth \\cite{depthanythingv2}-B": "relrgb_b_bs4_e9.pth", | ||
"Ours-S": "radar_s_bs8_e19.pth", | ||
"Ours-B": "radar_b_bs4_e21.pth" | ||
}, | ||
"Metric Depth \\cite{depthanythingv2}-S": { | ||
"use_depth_prior": false, | ||
"output_channels": 1, | ||
"relative_depth": 0, | ||
"depth_min": 0.19983673095703125, | ||
"depth_max": 120.49285888671875, | ||
"encoder": "vits", | ||
"marker": "X", | ||
"plot_prediction": 0 | ||
}, | ||
"Metric Depth \\cite{depthanythingv2}-B": { | ||
"use_depth_prior": false, | ||
"output_channels": 1, | ||
"relative_depth": 0, | ||
"depth_min": 0.19983673095703125, | ||
"depth_max": 120.49285888671875, | ||
"encoder": "vitb", | ||
"marker": "X", | ||
"plot_prediction": 0 | ||
}, | ||
"Scaled Relative Depth \\cite{depthanythingv2}-S": { | ||
"use_depth_prior": false, | ||
"output_channels": 1, | ||
"relative_depth": 1, | ||
"depth_min": 0.0, | ||
"depth_max": 1.0, | ||
"encoder": "vits", | ||
"marker": "D", | ||
"plot_prediction": 0 | ||
}, | ||
"Scaled Relative Depth \\cite{depthanythingv2}-B": { | ||
"use_depth_prior": false, | ||
"output_channels": 1, | ||
"relative_depth": 1, | ||
"depth_min": 0.0, | ||
"depth_max": 1.0, | ||
"encoder": "vitb", | ||
"marker": "D", | ||
"plot_prediction": 0 | ||
}, | ||
"Ours-S": { | ||
"use_depth_prior": true, | ||
"output_channels": 2, | ||
"relative_depth": 0, | ||
"depth_min": 0.19983673095703125, | ||
"depth_max": 120.49285888671875, | ||
"encoder": "vits", | ||
"marker": "o", | ||
"plot_prediction": 0 | ||
}, | ||
"Ours-B": { | ||
"use_depth_prior": true, | ||
"output_channels": 2, | ||
"relative_depth": 0, | ||
"depth_min": 0.19983673095703125, | ||
"depth_max": 120.49285888671875, | ||
"encoder": "vitb", | ||
"marker": "o", | ||
"plot_prediction": 1 | ||
}, | ||
"height": 480, | ||
"width": 640 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import matplotlib as mpl | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from matplotlib.gridspec import GridSpec | ||
from pathlib import Path | ||
from PIL import Image | ||
|
||
mpl.rcParams['font.size'] = 8 | ||
mpl.rcParams['figure.figsize'] = [9, 2.7] | ||
mpl.rcParams['lines.linewidth'] = 0.6 | ||
mpl.rcParams['grid.linewidth'] = 0.5 | ||
mpl.rcParams['axes.linewidth'] = 0.6 | ||
mpl.rcParams['axes.labelsize'] = 8 | ||
mpl.rcParams['axes.titlesize'] = 8 | ||
mpl.rcParams['legend.fontsize'] = 8 | ||
mpl.rcParams['xtick.labelsize'] = 8 | ||
mpl.rcParams['ytick.labelsize'] = 8 | ||
mpl.rcParams['text.usetex'] = True | ||
mpl.rcParams['font.family'] = 'serif' | ||
|
||
def create_scatter_plot(results_per_scenario_sample, config, output_dir, subsample=10): | ||
scenario_color = { | ||
'Industrial Hall': 'k', | ||
'Agricultural Field': 'green', | ||
'Rhône Glacier': 'navy' | ||
} | ||
label_dict = { | ||
"Metric Depth \\cite{depthanythingv2}-B": "Metric Depth-B", | ||
"Ours-B": "Ours-B" | ||
} | ||
samples = [['00000_rgb.jpg', '00000_dp.jpg'], | ||
['00050_rgb.jpg', '00050_dp.jpg'], | ||
['00250_rgb.jpg', '00250_dp.jpg']] | ||
|
||
sample_dir = Path('scripts/evaluation/samples') | ||
img_out = None | ||
for i, sample in enumerate(samples): | ||
rgb_file = sample_dir / sample[0] | ||
dp_file = sample_dir / sample[1] | ||
|
||
rgb = np.array(Image.open(rgb_file)) | ||
dp = np.array(Image.open(dp_file)) | ||
img = np.concatenate((rgb, dp), axis=1) | ||
if img_out is None: | ||
img_out = img | ||
else: | ||
img_out = np.concatenate((img_out, img), axis=0) | ||
border=5 | ||
img_out[:border, :] = (0.0, 0.0, 0.0) | ||
img_out[-border:, :] = (0.0, 0.0, 0.0) | ||
img_out[:, :border] = (0.0, 0.0, 0.0) | ||
img_out[:, -border:] = (0.0, 0.0, 0.0) | ||
|
||
|
||
fig = plt.figure(figsize=(9, 2.7)) # Adjust the figure size if needed | ||
gs = GridSpec(1, 2, width_ratios=[1, 2]) # Set the width ratios to 1:2 | ||
|
||
ax0 = fig.add_subplot(gs[0]) | ||
ax0.imshow(img_out) | ||
ax0.axis('off') | ||
|
||
ax1 = fig.add_subplot(gs[1]) | ||
for scenario_key in config['scenarios'].keys(): | ||
for i, network_key in enumerate(config['networks'].keys()): | ||
if '-B' in network_key and not 'Scaled' in network_key: | ||
average_depths = results_per_scenario_sample[scenario_key][network_key]['average_depth'] | ||
abs_rel_values = results_per_scenario_sample[scenario_key][network_key]['abs_rel'] | ||
average_depths_subsampled = average_depths[::subsample] | ||
abs_rel_values_subsampled = abs_rel_values[::subsample] | ||
label = label_dict[network_key] + ' ' + (scenario_key if 'Metric' in network_key else '') | ||
ax1.scatter(average_depths_subsampled, abs_rel_values_subsampled, label=label, marker=config[network_key]['marker'], c=scenario_color[scenario_key], s=25, alpha=0.5) | ||
|
||
|
||
# Set axis labels, title, and legend | ||
plt.xlabel('Average Scene Depth [m]') | ||
plt.ylabel('Absolute Relative Error [ ]') | ||
plt.legend(loc='upper right') | ||
ax1.grid() | ||
plt.tight_layout() | ||
|
||
# Save the plot | ||
output_file = Path(output_dir) / f'results_overview.png' | ||
plt.savefig(str(output_file), transparent=True, bbox_inches='tight', dpi=400) | ||
plt.close() | ||
|
||
# Post-process the saved image to crop any unnecessary white space | ||
img = Image.open(str(output_file)) | ||
img = img.convert("RGBA") | ||
bbox = img.getbbox() | ||
cropped_img = img.crop(bbox) | ||
cropped_img.save(str(output_file)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
###################################################################### | ||
# | ||
# Copyright (c) 2024 ETHZ Autonomous Systems Lab. All rights reserved. | ||
# | ||
###################################################################### | ||
|
||
import argparse | ||
import logging | ||
import pickle | ||
import radarmeetsvision as rmv | ||
|
||
from pathlib import Path | ||
from results_table_template import generate_tables | ||
from create_scatter_plot import create_scatter_plot | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Evaluation: | ||
def __init__(self, config, scenario_key, network_key, args): | ||
self.results_per_sample = {} | ||
self.results_dict = {} | ||
self.interface = rmv.Interface() | ||
self.networks_dir = Path(args.network) | ||
self.datasets_dir = args.dataset | ||
self.results, self.results_per_sample = None, None | ||
self.setup_interface(config, scenario_key, network_key) | ||
self.run(network_key) | ||
|
||
def run(self, network_key): | ||
self.interface.validate_epoch(0, self.loader) | ||
self.results, self.results_per_sample = self.interface.get_results() | ||
self.results['method'] = network_key | ||
|
||
def setup_interface(self, config, scenario_key, network_key): | ||
self.interface.set_epochs(1) | ||
|
||
network_config = config[network_key] | ||
self.interface.set_encoder(network_config['encoder']) | ||
|
||
depth_min, depth_max = network_config['depth_min'], network_config['depth_max'] | ||
self.interface.set_depth_range((depth_min, depth_max)) | ||
self.interface.set_output_channels(network_config['output_channels']) | ||
self.interface.set_use_depth_prior(network_config['use_depth_prior']) | ||
|
||
network_file = config['networks'][network_key] | ||
if network_file is not None: | ||
network_file = self.networks_dir / network_file | ||
self.interface.load_model(pretrained_from=network_file) | ||
|
||
self.interface.set_size(config['height'], config['width']) | ||
self.interface.set_batch_size(1) | ||
self.interface.set_criterion() | ||
|
||
dataset_list = [config['scenarios'][scenario_key]] | ||
self.loader, _ = self.interface.get_dataset_loader('val_all', self.datasets_dir, dataset_list) | ||
|
||
def get_results_per_sample(self): | ||
return self.results_per_sample | ||
|
||
def get_results(self): | ||
return self.results | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Create evaluation results for paper') | ||
parser.add_argument('--config', type=str, default='scripts/evaluation/config.json', help='Path to the JSON config file.') | ||
parser.add_argument('--dataset', type=str, required=True, help='Path to the dataset directory') | ||
parser.add_argument('--output', type=str, required=True, help='Path to the output directory') | ||
parser.add_argument('--network', type=str, help='Path to the network directory') | ||
args = parser.parse_args() | ||
|
||
rmv.setup_global_logger() | ||
|
||
config = rmv.load_config(args.config) | ||
|
||
results_per_scenario = {} | ||
results_per_scenario_sample = {} | ||
results_per_scenario_file = Path(args.output) / "results_per_scenario.pkl" | ||
results_per_scenario_sample_file = Path(args.output) / "results_per_scenario_sample.pkl" | ||
if not results_per_scenario_file.is_file() or not results_per_scenario_sample_file.is_file(): | ||
for scenario_key in config['scenarios'].keys(): | ||
results_per_scenario[scenario_key] = [] | ||
results_per_sample = {} | ||
|
||
for network_key in config['networks'].keys(): | ||
logger.info(f'Evaluation: {scenario_key} {network_key}') | ||
eval_obj = Evaluation(config, scenario_key, network_key, args) | ||
results_per_sample[network_key] = eval_obj.get_results_per_sample() | ||
|
||
results_dict = eval_obj.get_results() | ||
if results_dict: | ||
# Used to generate main results table | ||
results_per_scenario[scenario_key].append(results_dict.copy()) | ||
logger.info(f'{scenario_key} {network_key} {results_dict["abs_rel"]:.3f}') | ||
|
||
# Used to generate visual grid | ||
results_per_scenario_sample[scenario_key] = results_per_sample | ||
|
||
with results_per_scenario_file.open('wb') as f: | ||
pickle.dump(results_per_scenario, f) | ||
|
||
with results_per_scenario_sample_file.open('wb') as f: | ||
pickle.dump(results_per_scenario_sample, f) | ||
|
||
else: | ||
with results_per_scenario_file.open('rb') as f: | ||
results_per_scenario = pickle.load(f) | ||
|
||
with results_per_scenario_sample_file.open('rb') as f: | ||
results_per_scenario_sample = pickle.load(f) | ||
|
||
# Generate the main results table | ||
if results_per_scenario: | ||
generate_tables(args.output, results_per_scenario) | ||
|
||
create_scatter_plot(results_per_scenario_sample, config, args.output) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.