Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API for eval, prep_display #456

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pycocotools

from data import cfg, set_cfg, set_dataset
from data.config import Config

import numpy as np
import torch
Expand Down Expand Up @@ -132,16 +133,27 @@ def parse_args(argv=None):
coco_cats_inv = {}
color_cache = defaultdict(lambda: {})

def prep_display(dets_out, img, h, w, undo_transform=True, class_color=False, mask_alpha=0.45, fps_str=''):
def prep_display(dets_out, img, h=None, w=None, undo_transform=True, class_color=False, mask_alpha=0.45, fps_str='', override_args:Config=None):
"""
Note: If undo_transform=False then im_h and im_w are allowed to be None.
process image by the network and display all the detections, features as requested by args (command line).

Note: If undo_transform=False then h,w are allowed to be None.

@param override_args - (Config, default None) arguments, overrides args parsed from command-line. Useful when calling as API, where we
don't have cmd args.
"""
if undo_transform:
assert w is not None and h is not None, "with undo_transform=True, w,h params must be specified!"
img_numpy = undo_image_transformation(img, w, h)
img_gpu = torch.Tensor(img_numpy).cuda()
else:
img_gpu = img / 255.0
h, w, _ = img.shape

global args
if override_args is not None:
#override the command line args by the given arguments (type Config)
args = override_args

with timer.env('Postprocess'):
save = cfg.rescore_bbox
Expand Down Expand Up @@ -593,22 +605,33 @@ def badhash(x):
return x

def evalimage(net:Yolact, path:str, save_path:str=None):
frame = torch.from_numpy(cv2.imread(path)).cuda().float()
batch = FastBaseTransform()(frame.unsqueeze(0))
preds = net(batch)
"""
Evaluate a single image given:
@argument net - Yolact object, the network
@argument path - (string) image path
@argument save_path - (string, default None) where to output the labeled image.
@argument args.display - (uses the global congig --display) display the image

img_numpy = prep_display(preds, frame, None, None, undo_transform=False)

if save_path is None:
img_numpy = img_numpy[:, :, (2, 1, 0)]
@return the labeled image as numpy array
"""
with torch.no_grad():
frame = torch.from_numpy(cv2.imread(path)).cuda().float()
batch = FastBaseTransform()(frame.unsqueeze(0))
preds = net(batch)

if save_path is None:
plt.imshow(img_numpy)
img_numpy = prep_display(preds, frame, undo_transform=False, override_args=args)

if args.display:
plt.imshow(cv2.cvtColor(img_numpy, cv2.COLOR_BGR2RGB)) #matplotlib's imshow() needs image converted from BGR(cv2) to RGB(pyplot)
plt.title(path)
plt.show()
else:

if save_path is not None:
cv2.imwrite(save_path, img_numpy)

return img_numpy


def evalimages(net:Yolact, input_folder:str, output_folder:str):
if not os.path.exists(output_folder):
os.mkdir(output_folder)
Expand Down Expand Up @@ -709,7 +732,7 @@ def eval_network(inp):
def prep_frame(inp, fps_str):
with torch.no_grad():
frame, preds = inp
return prep_display(preds, frame, None, None, undo_transform=False, class_color=True, fps_str=fps_str)
return prep_display(preds, frame, undo_transform=False, class_color=True, fps_str=fps_str, override_args=args)

frame_buffer = Queue()
video_fps = 0
Expand Down Expand Up @@ -949,7 +972,7 @@ def evaluate(net:Yolact, dataset, train_mode=False):
preds = net(batch)
# Perform the meat of the operation here depending on our mode.
if args.display:
img_numpy = prep_display(preds, img, h, w)
img_numpy = prep_display(preds, img, h, w, override_args=args)
elif args.benchmark:
prep_benchmark(preds, h, w)
else:
Expand Down