-
Notifications
You must be signed in to change notification settings - Fork 1
/
evaluate_patch.py
81 lines (65 loc) · 2.72 KB
/
evaluate_patch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os
import torch
import argparse
from utils.config_loader import load_yaml
from utils.data_loader import load_hf_dataset, load_manhole_set
from adv_manhole.models import load_models, ModelType
from adv_manhole.texture_mapping.depth_mapping import DepthTextureMapping
from adv_manhole.attack.framework import AdvManholeFramework
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", help='Specify Config Path', default='configs/generate_patch.yml')
def main():
args = parser.parse_args()
# Load Configs file
cfg = load_yaml(args.config_path)
# Set cuda device
if cfg["device"]["gpu"] == 'cpu':
device = torch.device("cpu")
else:
device = torch.device("cuda")
torch.cuda.set_device(cfg["device"]["gpu"])
if os.path.exists(cfg["log"]["log_main_dir"]) is False:
os.makedirs(cfg["log"]["log_main_dir"])
# Load dataset
batch_size=cfg['dataset']['batch_size']
dataset, filtered_dataset = load_hf_dataset(
dataset_name=cfg['dataset']['name'],
batch_size=batch_size,
cache_dir=cfg['dataset']['cache_dir'],
filter_set='eval',
selected_columns=["rgb", "raw_depth", "camera_config", "semantic"]
)
# Load manhole candidate
manhole_set = load_manhole_set(
manhole_set_path=cfg['manhole_set']['manhole_candidate_path'],
image_size=cfg['manhole_set']['image_size'],
adversarial_sample_images=cfg['manhole_set']['adversarial_sample_images'],
is_eval=True
)
# Load MonoDepth2 model
mde_model = load_models(ModelType.MDE, cfg['model']['mde_model'])
# Load DDRNet model
ss_model = load_models(ModelType.SS, cfg['model']['ss_model'])
# Define depth planar mapping
depth_planar_mapping = DepthTextureMapping(
random_scale=(0.0, 0.01),
with_circle_mask=True,
device=device
)
train_total_batch = len(filtered_dataset["train"]) // batch_size + 1 if len1(filtered_dataset["train"]) % batch_size != 0 else 0
val_total_batch = len(filtered_dataset["validation"]) // batch_size + 1 if len(filtered_dataset["validation"]) % batch_size != 0 else 0
test_total_batch = len(filtered_dataset["test"]) // batch_size + 1 if len(filtered_dataset["test"]) % batch_size != 0 else 0
adv_manhole_instance = AdvManholeFramework(
mde_model=mde_model,
ss_model=ss_model,
depth_planar_mapping=depth_planar_mapping,
device=device
)
adv_manhole_instance.evaluate(
dataset=dataset,
total_batch=[train_total_batch, val_total_batch, test_total_batch],
manhole_set=manhole_set,
log_name=cfg['log']['log_name']
)
if __name__ == "__main__":
main()