-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation.py
83 lines (63 loc) · 3.29 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
import argparse
import os
from torch.utils.data import DataLoader
from utils import RunningAverage, RunningAverageDict, compute_scale_and_shift, compute_errors, compute_metrics, ImageDataset
def main():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
parser = argparse.ArgumentParser(description='Evaluate depth predictions')
parser.add_argument('--path_pred', type=str, required=True, help='Path to prediction images')
parser.add_argument('--gt_path', type=str, required=True, help='Path to ground truth images')
parser.add_argument('--sport', type=str, required=True, help='Specify the sport you want to evaluate your predictions on (basket or foot)')
args = parser.parse_args()
exclude_file = "basket_game_17_video_2_color_45.png"
if args.sport is None:
pred_files = [os.path.join(args.path_pred, f) for f in os.listdir(args.path_pred)
if os.path.isfile(os.path.join(args.path_pred, f)) and f != exclude_file]
gt_files = [os.path.join(args.gt_path, f) for f in os.listdir(args.gt_path)
if os.path.isfile(os.path.join(args.gt_path, f))]
else:
pred_files = [os.path.join(args.path_pred, f) for f in os.listdir(args.path_pred)
if os.path.isfile(os.path.join(args.path_pred, f)) and args.sport in f and f != exclude_file]
gt_files = [os.path.join(args.gt_path, f) for f in os.listdir(args.gt_path)
if os.path.isfile(os.path.join(args.gt_path, f)) and args.sport in f]
pred_files.sort()
gt_files.sort()
print(f"Number of prediction files: {len(pred_files)}")
print(f"Number of ground truth files: {len(gt_files)}")
pred_dataset = ImageDataset(pred_files)
gt_dataset = ImageDataset(gt_files)
pred_loader = DataLoader(pred_dataset, batch_size=1, num_workers=2, pin_memory=True)
gt_loader = DataLoader(gt_dataset, batch_size=1, num_workers=2, pin_memory=True)
print("Loading done.")
mask = torch.ones(1080, 1920, device=device)
mask_score = False
if args.sport == "basket":
mask[870:1016, 1570:1829] = 0
with open('test_score.txt', 'r') as f:
file_contents = f.read().splitlines()
mask.to(device).squeeze()
print("Starting evaluation...")
metrics = RunningAverageDict()
i=0
with torch.no_grad():
for preds, gts in zip(pred_loader, gt_loader):
mask_score = False
preds, gts = preds.to(device).squeeze(), gts.to(device).squeeze()
# Special case for some of the soccer test files that contain a score banner
gt_file = gt_files[i]
if gt_file in file_contents:
mask[70:122, 95:612] = 0
mask_score = True
i+=1
gts = gts / 65535.0
if torch.all(gts == 1) or torch.all(gts == 0):
continue
preds = preds / 65535.0
scale, shift = compute_scale_and_shift(preds, gts, mask)
scaled_predictions = scale.view(-1, 1, 1) * preds + shift.view(-1, 1, 1)
metrics.update(compute_metrics(gts, scaled_predictions[0], mask_score, args.sport))
print("Evaluation completed.")
print("\n".join(f"{k}: {v}" for k, v in metrics.get_value().items()))
if __name__ == "__main__":
main()