diff --git a/.devcontainer/devcontainer-all-packages.sh b/.devcontainer/devcontainer-all-packages.sh index 8634bf3..b1e2960 100755 --- a/.devcontainer/devcontainer-all-packages.sh +++ b/.devcontainer/devcontainer-all-packages.sh @@ -10,6 +10,7 @@ main() { build-essential ca-certificates ccache + cm-super curl gawk gnupg diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ef92958..22fe9d0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,7 +12,7 @@ jobs: strategy: matrix: - ci_script: [pr_unittest] + ci_script: [pr_unittest, pr_evaluate_networks, pr_train_networks] steps: - name: Checkout diff --git a/ci/pr_evaluate_networks.bash b/ci/pr_evaluate_networks.bash new file mode 100755 index 0000000..b87d740 --- /dev/null +++ b/ci/pr_evaluate_networks.bash @@ -0,0 +1,21 @@ +#!/bin/bash + +# Install radarmeetsvision +pip install -e . + +RESULTS_DIR=tests/resources/results +if [ -d $RESULTS_DIR ]; then + rm -r $RESULTS_DIR +fi +rm -rf tests/resources/*.pkl + +# Run the evaluation script +python3 scripts/evaluation/evaluate_networks.py --dataset tests/resources --config tests/resources/test_evaluation.json --output tests/resources --network tests/resources + +TEX_FILE="tests/resources/results/results_table0.tex" +if [[ -f "$TEX_FILE" && -s "$TEX_FILE" ]]; then + echo "Evaluation script successful, .tex table exists and is not empty." +else + echo "Evaluation script failed, .tex table does not exist or is empty." + exit 1 +fi diff --git a/ci/pr_train_networks.bash b/ci/pr_train_networks.bash new file mode 100755 index 0000000..d93b07e --- /dev/null +++ b/ci/pr_train_networks.bash @@ -0,0 +1,55 @@ +#!/bin/bash + +checkpoints=tests/resources +save_path=tests/resources +datasets=tests/resources + +config_path=tests/resources +config_radar=$config_path/test_train_radar.json +config_metric=$config_path/test_train_metric.json +config_relative=$config_path/test_train_relative.json + +# Install radarmeetsvision +pip install -e . + +# RADAR TRAINING (depth prior + 2 output channels) +python3 scripts/train.py \ +--checkpoints $checkpoints \ +--config $config_radar \ +--datasets $datasets \ +--results "" + +if [ $? -eq 0 ]; then + echo "Radar training script successful" +else + echo "Training script failed" + exit 1 +fi + +# RGB TRAINING (no depth prior + 1 output channel) +python3 scripts/train.py \ +--checkpoints $checkpoints \ +--config $config_metric \ +--datasets $datasets \ +--results "" + +if [ $? -eq 0 ]; then + echo "RGB training script successful" +else + echo "Training script failed" + exit 1 +fi + +# Relative RGB TRAINING (no depth prior + 1 output channel) +python3 scripts/train.py \ +--checkpoints $checkpoints \ +--config $config_relative \ +--datasets $datasets \ +--results "" + +if [ $? -eq 0 ]; then + echo "Relative RGB training script successful" +else + echo "Training script failed" + exit 1 +fi diff --git a/ci/pr_unittest.bash b/ci/pr_unittest.bash index 63a60d3..731f6c8 100755 --- a/ci/pr_unittest.bash +++ b/ci/pr_unittest.bash @@ -1,6 +1,6 @@ #!/bin/bash -if pip install . +if pip install -e . then echo "Installation successful" else diff --git a/scripts/evaluation/config.json b/scripts/evaluation/config.json new file mode 100644 index 0000000..cc25b37 --- /dev/null +++ b/scripts/evaluation/config.json @@ -0,0 +1,77 @@ +{ + "scenarios": { + "Industrial Hall": "maschinenhalle0", + "Agricultural Field": "outdoor0", + "Rhône Glacier": "rhone_flight" + }, + "networks": { + "Metric Depth \\cite{depthanythingv2}-S": "rgb_s_bs8_e9.pth", + "Metric Depth \\cite{depthanythingv2}-B": "rgb_b_bs4_e8.pth", + "Scaled Relative Depth \\cite{depthanythingv2}-S": "relrgb_s_bs8_e9.pth", + "Scaled Relative Depth \\cite{depthanythingv2}-B": "relrgb_b_bs4_e9.pth", + "Ours-S": "radar_s_bs8_e19.pth", + "Ours-B": "radar_b_bs4_e21.pth" + }, + "Metric Depth \\cite{depthanythingv2}-S": { + "use_depth_prior": false, + "output_channels": 1, + "relative_depth": 0, + "depth_min": 0.19983673095703125, + "depth_max": 120.49285888671875, + "encoder": "vits", + "marker": "X", + "plot_prediction": 0 + }, + "Metric Depth \\cite{depthanythingv2}-B": { + "use_depth_prior": false, + "output_channels": 1, + "relative_depth": 0, + "depth_min": 0.19983673095703125, + "depth_max": 120.49285888671875, + "encoder": "vitb", + "marker": "X", + "plot_prediction": 0 + }, + "Scaled Relative Depth \\cite{depthanythingv2}-S": { + "use_depth_prior": false, + "output_channels": 1, + "relative_depth": 1, + "depth_min": 0.0, + "depth_max": 1.0, + "encoder": "vits", + "marker": "D", + "plot_prediction": 0 + }, + "Scaled Relative Depth \\cite{depthanythingv2}-B": { + "use_depth_prior": false, + "output_channels": 1, + "relative_depth": 1, + "depth_min": 0.0, + "depth_max": 1.0, + "encoder": "vitb", + "marker": "D", + "plot_prediction": 0 + }, + "Ours-S": { + "use_depth_prior": true, + "output_channels": 2, + "relative_depth": 0, + "depth_min": 0.19983673095703125, + "depth_max": 120.49285888671875, + "encoder": "vits", + "marker": "o", + "plot_prediction": 0 + }, + "Ours-B": { + "use_depth_prior": true, + "output_channels": 2, + "relative_depth": 0, + "depth_min": 0.19983673095703125, + "depth_max": 120.49285888671875, + "encoder": "vitb", + "marker": "o", + "plot_prediction": 1 + }, + "height": 480, + "width": 640 +} diff --git a/scripts/evaluation/create_scatter_plot.py b/scripts/evaluation/create_scatter_plot.py new file mode 100644 index 0000000..c4319f2 --- /dev/null +++ b/scripts/evaluation/create_scatter_plot.py @@ -0,0 +1,98 @@ +###################################################################### +# +# Copyright (c) 2024 ETHZ Autonomous Systems Lab. All rights reserved. +# +###################################################################### + +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np + +from matplotlib.gridspec import GridSpec +from pathlib import Path +from PIL import Image + +mpl.rcParams['font.size'] = 8 +mpl.rcParams['figure.figsize'] = [9, 2.7] +mpl.rcParams['lines.linewidth'] = 0.6 +mpl.rcParams['grid.linewidth'] = 0.5 +mpl.rcParams['axes.linewidth'] = 0.6 +mpl.rcParams['axes.labelsize'] = 8 +mpl.rcParams['axes.titlesize'] = 8 +mpl.rcParams['legend.fontsize'] = 8 +mpl.rcParams['xtick.labelsize'] = 8 +mpl.rcParams['ytick.labelsize'] = 8 +mpl.rcParams['text.usetex'] = True +mpl.rcParams['font.family'] = 'serif' + +def create_scatter_plot(results_per_scenario_sample, config, output_dir, subsample=10): + scenario_color = { + 'Industrial Hall': 'k', + 'Agricultural Field': 'green', + 'Rhône Glacier': 'navy' + } + label_dict = { + "Metric Depth \\cite{depthanythingv2}-B": "Metric Depth-B", + "Ours-B": "Ours-B" + } + samples = [['00000_rgb.jpg', '00000_dp.jpg'], + ['00050_rgb.jpg', '00050_dp.jpg'], + ['00250_rgb.jpg', '00250_dp.jpg']] + + sample_dir = Path('scripts/evaluation/samples') + img_out = None + for i, sample in enumerate(samples): + rgb_file = sample_dir / sample[0] + dp_file = sample_dir / sample[1] + + rgb = np.array(Image.open(rgb_file)) + dp = np.array(Image.open(dp_file)) + img = np.concatenate((rgb, dp), axis=1) + if img_out is None: + img_out = img + else: + img_out = np.concatenate((img_out, img), axis=0) + border=5 + img_out[:border, :] = (0.0, 0.0, 0.0) + img_out[-border:, :] = (0.0, 0.0, 0.0) + img_out[:, :border] = (0.0, 0.0, 0.0) + img_out[:, -border:] = (0.0, 0.0, 0.0) + + + fig = plt.figure(figsize=(9, 2.7)) # Adjust the figure size if needed + gs = GridSpec(1, 2, width_ratios=[1, 2]) # Set the width ratios to 1:2 + + ax0 = fig.add_subplot(gs[0]) + ax0.imshow(img_out) + ax0.axis('off') + + ax1 = fig.add_subplot(gs[1]) + for scenario_key in config['scenarios'].keys(): + for i, network_key in enumerate(config['networks'].keys()): + if '-B' in network_key and not 'Scaled' in network_key: + average_depths = results_per_scenario_sample[scenario_key][network_key]['average_depth'] + abs_rel_values = results_per_scenario_sample[scenario_key][network_key]['abs_rel'] + average_depths_subsampled = average_depths[::subsample] + abs_rel_values_subsampled = abs_rel_values[::subsample] + label = label_dict[network_key] + ' ' + (scenario_key if 'Metric' in network_key else '') + ax1.scatter(average_depths_subsampled, abs_rel_values_subsampled, label=label, marker=config[network_key]['marker'], c=scenario_color[scenario_key], s=25, alpha=0.5) + + + # Set axis labels, title, and legend + plt.xlabel('Average Scene Depth [m]') + plt.ylabel('Absolute Relative Error [ ]') + plt.legend(loc='upper right') + ax1.grid() + plt.tight_layout() + + # Save the plot + output_file = Path(output_dir) / f'results_overview.png' + plt.savefig(str(output_file), transparent=True, bbox_inches='tight', dpi=400) + plt.close() + + # Post-process the saved image to crop any unnecessary white space + img = Image.open(str(output_file)) + img = img.convert("RGBA") + bbox = img.getbbox() + cropped_img = img.crop(bbox) + cropped_img.save(str(output_file)) diff --git a/scripts/evaluation/evaluate_networks.py b/scripts/evaluation/evaluate_networks.py new file mode 100644 index 0000000..1049fca --- /dev/null +++ b/scripts/evaluation/evaluate_networks.py @@ -0,0 +1,121 @@ +###################################################################### +# +# Copyright (c) 2024 ETHZ Autonomous Systems Lab. All rights reserved. +# +###################################################################### + +import argparse +import logging +import pickle +import radarmeetsvision as rmv + +from pathlib import Path +from results_table_template import generate_tables +from create_scatter_plot import create_scatter_plot + +logger = logging.getLogger(__name__) + + +class Evaluation: + def __init__(self, config, scenario_key, network_key, args): + self.results_per_sample = {} + self.results_dict = {} + self.interface = rmv.Interface() + self.networks_dir = Path(args.network) + self.datasets_dir = args.dataset + self.results, self.results_per_sample = None, None + self.setup_interface(config, scenario_key, network_key) + self.run(network_key) + + def run(self, network_key): + self.interface.validate_epoch(0, self.loader) + self.results, self.results_per_sample = self.interface.get_results() + self.results['method'] = network_key + + def setup_interface(self, config, scenario_key, network_key): + self.interface.set_epochs(1) + + network_config = config[network_key] + self.interface.set_encoder(network_config['encoder']) + + depth_min, depth_max = network_config['depth_min'], network_config['depth_max'] + self.interface.set_depth_range((depth_min, depth_max)) + self.interface.set_output_channels(network_config['output_channels']) + self.interface.set_use_depth_prior(network_config['use_depth_prior']) + + network_file = config['networks'][network_key] + if network_file is not None: + network_file = self.networks_dir / network_file + self.interface.load_model(pretrained_from=network_file) + + self.interface.set_size(config['height'], config['width']) + self.interface.set_batch_size(1) + self.interface.set_criterion() + + dataset_list = [config['scenarios'][scenario_key]] + self.loader, _ = self.interface.get_dataset_loader('val_all', self.datasets_dir, dataset_list) + + def get_results_per_sample(self): + return self.results_per_sample + + def get_results(self): + return self.results + + +def main(): + parser = argparse.ArgumentParser(description='Create evaluation results for paper') + parser.add_argument('--config', type=str, default='scripts/evaluation/config.json', help='Path to the JSON config file.') + parser.add_argument('--dataset', type=str, required=True, help='Path to the dataset directory') + parser.add_argument('--output', type=str, required=True, help='Path to the output directory') + parser.add_argument('--network', type=str, help='Path to the network directory') + args = parser.parse_args() + + rmv.setup_global_logger() + + config = rmv.load_config(args.config) + + results_per_scenario = {} + results_per_scenario_sample = {} + results_per_scenario_file = Path(args.output) / "results_per_scenario.pkl" + results_per_scenario_sample_file = Path(args.output) / "results_per_scenario_sample.pkl" + if not results_per_scenario_file.is_file() or not results_per_scenario_sample_file.is_file(): + for scenario_key in config['scenarios'].keys(): + results_per_scenario[scenario_key] = [] + results_per_sample = {} + + for network_key in config['networks'].keys(): + logger.info(f'Evaluation: {scenario_key} {network_key}') + eval_obj = Evaluation(config, scenario_key, network_key, args) + results_per_sample[network_key] = eval_obj.get_results_per_sample() + + results_dict = eval_obj.get_results() + if results_dict: + # Used to generate main results table + results_per_scenario[scenario_key].append(results_dict.copy()) + logger.info(f'{scenario_key} {network_key} {results_dict["abs_rel"]:.3f}') + + # Used to generate visual grid + results_per_scenario_sample[scenario_key] = results_per_sample + + with results_per_scenario_file.open('wb') as f: + pickle.dump(results_per_scenario, f) + + with results_per_scenario_sample_file.open('wb') as f: + pickle.dump(results_per_scenario_sample, f) + + else: + with results_per_scenario_file.open('rb') as f: + results_per_scenario = pickle.load(f) + + with results_per_scenario_sample_file.open('rb') as f: + results_per_scenario_sample = pickle.load(f) + + # Generate the main results table + if results_per_scenario: + generate_tables(args.output, results_per_scenario) + + create_scatter_plot(results_per_scenario_sample, config, args.output) + + +if __name__ == "__main__": + main() diff --git a/scripts/evaluation/results_table_template.py b/scripts/evaluation/results_table_template.py new file mode 100644 index 0000000..05cfc62 --- /dev/null +++ b/scripts/evaluation/results_table_template.py @@ -0,0 +1,136 @@ +###################################################################### +# +# Copyright (c) 2024 ETHZ Autonomous Systems Lab. All rights reserved. +# +###################################################################### + +from datetime import datetime +from pathlib import Path + +def generate_table_top(scenarios): + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + header = f"% DO NOT MODIFY, AUTOGENERATED AT {current_time}\n" + header += "\\renewcommand{\\thetable}{\\arabic{table}}\n\\captionsetup[table]{labelformat=simple, labelsep=colon, name=Tab.}\n" + header += "\\begin{table*}[t]\n\\centering\n\\sisetup{detect-all}\n\\NewDocumentCommand{\\B}{}{\\fontseries{b}\\selectfont}\n" + + # Define the custom underline command + header += ( + "\\def\\Decimal{.000}% Corresponds to the \".2\" of \"table-format\"\n" + "\\def\\Uline#1{\\Ulinehelp#1 }\n" + "\\def\\Ulinehelp#1.#2 {%\n" + " #1.#2\\setbox0=\\hbox{#1\\Decimal}\\hspace{-\\wd0}{\\if\\relax#2\\relax%\n" + " \\uline{\\phantom{#1.0}}\\else\\uline{\\phantom{#1.#2}}\\fi}%\n}\n" + ) + + column_format = "\n@{}\nl\n" + "".join([f"S[table-format=2.3]\nS[table-format=2.3]\nS[table-format=2.3]\n" for _ in scenarios]) + "@{}\n" + header += f"\\begin{{tabular}}{{{column_format}}}\n\\toprule\n" + + scenario_headers = " & " + " & ".join([f"\\multicolumn{{3}}{{c}}{{{scenario}}}" for scenario in scenarios]) + " \\\\\n" + metric_headers = " ".join([f"\\cmidrule(lr){{{3*i+2}-{3*i+4}}}" for i in range(len(scenarios))]) + "\n" + metrics = "Models & " + " & ".join(["{AbsRel $(\\downarrow)$} & {$\\delta_1 (\\uparrow)$} & {RMSE $(\\downarrow)$}" for _ in scenarios]) + " \\\\\n\\midrule\n" + + return header + scenario_headers + metric_headers + metrics + +def generate_table_bottom(): + return "\\bottomrule\n\\end{tabular}\n\\label{tab:results}\\caption{Comparison of metric and scaled relative Depth Anything V2 \\cite{depthanythingv2} with our approach. The suffix -S and -B indicate the pre-trained network size, small and base respectively. The best values are in bold, and the second-best values are underlined.}\n\\end{table*}" + +def get_best_values(result_dict, scenarios): + num_scenarios = len(scenarios) + + # Initialize best and second-best values for `AbsRel`, `d1`, and `RMSE` + best_abs_rel = [float('inf')] * num_scenarios + second_best_abs_rel = [float('inf')] * num_scenarios + best_d1 = [-float('inf')] * num_scenarios + second_best_d1 = [-float('inf')] * num_scenarios + best_rmse = [float('inf')] * num_scenarios + second_best_rmse = [float('inf')] * num_scenarios + + # Iterate through each scenario + for j, scenario in enumerate(scenarios): + for result in result_dict[scenario]: + abs_rel = result['abs_rel'] + d1 = result['d1'] + rmse = result['rmse'] + + # Update best and second-best for AbsRel + if abs_rel < best_abs_rel[j]: + second_best_abs_rel[j] = best_abs_rel[j] + best_abs_rel[j] = abs_rel + elif abs_rel < second_best_abs_rel[j]: + second_best_abs_rel[j] = abs_rel + + # Update best and second-best for d1 + if d1 > best_d1[j]: + second_best_d1[j] = best_d1[j] + best_d1[j] = d1 + elif d1 > second_best_d1[j]: + second_best_d1[j] = d1 + + # Update best and second-best for RMSE + if rmse < best_rmse[j]: + second_best_rmse[j] = best_rmse[j] + best_rmse[j] = rmse + elif rmse < second_best_rmse[j]: + second_best_rmse[j] = rmse + + return best_abs_rel, second_best_abs_rel, best_d1, second_best_d1, best_rmse, second_best_rmse + +def get_model_lines(result_dict, scenarios): + results_table_template_models_line = "{} " + " ".join(["& {} & {} & {}" for _ in range(len(scenarios))]) + " \\\\\n" + + best_abs_rel, second_best_abs_rel, best_d1, second_best_d1, best_rmse, second_best_rmse = get_best_values(result_dict, scenarios) + + model_lines = [] + methods = [result['method'] for result in result_dict[scenarios[0]]] + + for i, method in enumerate(methods): + metric_list = [method] + for j, scenario in enumerate(scenarios): + abs_rel = result_dict[scenario][i]['abs_rel'] + d1 = result_dict[scenario][i]['d1'] + rmse = result_dict[scenario][i]['rmse'] + + # Mark best and second-best values with \B and \Uline respectively + if abs_rel == best_abs_rel[j]: + abs_rel = "\\B " + f"{abs_rel:.3f}" + elif abs_rel == second_best_abs_rel[j]: + abs_rel = "\\Uline " + f"{abs_rel:.3f}" + else: + abs_rel = f"{abs_rel:.3f}" + + if d1 == best_d1[j]: + d1 = "\\B " + f"{d1:.3f}" + elif d1 == second_best_d1[j]: + d1 = "\\Uline " + f"{d1:.3f}" + else: + d1 = f"{d1:.3f}" + + if rmse == best_rmse[j]: + rmse = "\\B " + f"{rmse:.3f}" + elif rmse == second_best_rmse[j]: + rmse = "\\Uline " + f"{rmse:.3f}" + else: + rmse = f"{rmse:.3f}" + + metric_list.append(abs_rel) + metric_list.append(d1) + metric_list.append(rmse) + + line = results_table_template_models_line.format(*metric_list) + model_lines.append(line) + + return model_lines + +def generate_tables(output_dir, result_dict): + scenarios = list(result_dict.keys()) + + results_table = generate_table_top(scenarios) + results_table += "".join(get_model_lines(result_dict, scenarios)) + results_table += generate_table_bottom() + + results_dir = Path(output_dir) / 'results' + results_dir.mkdir(exist_ok=True) + filename = results_dir / 'results_table0.tex' + with filename.open('w') as f: + f.write(results_table) diff --git a/scripts/evaluation/samples/00000_dp.jpg b/scripts/evaluation/samples/00000_dp.jpg new file mode 100644 index 0000000..aa63dba Binary files /dev/null and b/scripts/evaluation/samples/00000_dp.jpg differ diff --git a/scripts/evaluation/samples/00000_rgb.jpg b/scripts/evaluation/samples/00000_rgb.jpg new file mode 100644 index 0000000..ce7f7b7 Binary files /dev/null and b/scripts/evaluation/samples/00000_rgb.jpg differ diff --git a/scripts/evaluation/samples/00050_dp.jpg b/scripts/evaluation/samples/00050_dp.jpg new file mode 100644 index 0000000..4c2b4dc Binary files /dev/null and b/scripts/evaluation/samples/00050_dp.jpg differ diff --git a/scripts/evaluation/samples/00050_rgb.jpg b/scripts/evaluation/samples/00050_rgb.jpg new file mode 100644 index 0000000..b50703a Binary files /dev/null and b/scripts/evaluation/samples/00050_rgb.jpg differ diff --git a/scripts/evaluation/samples/00250_dp.jpg b/scripts/evaluation/samples/00250_dp.jpg new file mode 100644 index 0000000..b960383 Binary files /dev/null and b/scripts/evaluation/samples/00250_dp.jpg differ diff --git a/scripts/evaluation/samples/00250_rgb.jpg b/scripts/evaluation/samples/00250_rgb.jpg new file mode 100644 index 0000000..1f09720 Binary files /dev/null and b/scripts/evaluation/samples/00250_rgb.jpg differ diff --git a/tests/resources/test_evaluation.json b/tests/resources/test_evaluation.json new file mode 100644 index 0000000..0251721 --- /dev/null +++ b/tests/resources/test_evaluation.json @@ -0,0 +1,40 @@ +{ + "scenarios": { + "Training": "tiny_dataset", + "Industrial Hall": "tiny_dataset_validation" + }, + "networks": { + "RGB": null, + "Naive": null, + "Ours-S": null + }, + "RGB": { + "use_depth_prior": false, + "output_channels": 1, + "relative_depth": 0, + "depth_min": 0.19983673095703125, + "depth_max": 120.49285888671875, + "encoder": "vits", + "plot_prediction": 0 + }, + "Ours-S": { + "use_depth_prior": true, + "output_channels": 2, + "relative_depth": 0, + "depth_min": 0.19983673095703125, + "depth_max": 120.49285888671875, + "encoder": "vits", + "plot_prediction": 1 + }, + "Naive": { + "use_depth_prior": false, + "output_channels": 1, + "relative_depth": 1, + "depth_min": 0.0, + "depth_max": 1.0, + "encoder": "vits", + "plot_prediction": 0 + }, + "height": 480, + "width": 640 +} diff --git a/tests/resources/test_train_metric.json b/tests/resources/test_train_metric.json new file mode 100644 index 0000000..4cbb8f7 --- /dev/null +++ b/tests/resources/test_train_metric.json @@ -0,0 +1,21 @@ +{ + "epochs": 1, + "encoder": "vits", + "depth_min": 0.19983673095703125, + "depth_max": 120.49285888671875, + "output_channels": 1, + "use_depth_prior": 0, + "pretrained_from": null, + "height": 518, + "width": 518, + "task": { + "train_all": { + "dir": "", + "datasets": ["tiny_dataset"] + }, + "val_all": { + "dir": "", + "datasets": ["tiny_dataset_validation"] + } + } +} diff --git a/tests/resources/test_train_radar.json b/tests/resources/test_train_radar.json new file mode 100644 index 0000000..94d49d4 --- /dev/null +++ b/tests/resources/test_train_radar.json @@ -0,0 +1,21 @@ +{ + "epochs": 1, + "encoder": "vits", + "depth_min": 0.19983673095703125, + "depth_max": 120.49285888671875, + "output_channels": 2, + "use_depth_prior": 1, + "pretrained_from": null, + "height": 518, + "width": 518, + "task": { + "train_all": { + "dir": "", + "datasets": ["tiny_dataset"] + }, + "val_all": { + "dir": "", + "datasets": ["tiny_dataset_validation"] + } + } +} diff --git a/tests/resources/test_train_relative.json b/tests/resources/test_train_relative.json new file mode 100644 index 0000000..856ecdd --- /dev/null +++ b/tests/resources/test_train_relative.json @@ -0,0 +1,21 @@ +{ + "epochs": 1, + "encoder": "vits", + "depth_min": 0.0, + "depth_max": 1.0, + "output_channels": 1, + "use_depth_prior": 0, + "pretrained_from": null, + "height": 518, + "width": 518, + "task": { + "train_all": { + "dir": "", + "datasets": ["tiny_dataset"] + }, + "val_all": { + "dir": "", + "datasets": ["tiny_dataset_validation"] + } + } +}