Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added all modified files #313

Open
wants to merge 4 commits into
base: mask_rcnn
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
checkpoint
*.coverage
*.npy
.DS_Store
*.p
*.zip
*.iml
*.jpeg
*.tgz

!.github/manifest.xml

Expand Down
85 changes: 85 additions & 0 deletions examples/mask_rcnn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Mask R-CNN 5 stage Architecture:
1. **Backbone:**
A backbone is the main feature extractor of Mask R-CNN. Common choices of this
part are residual networks (ResNets) with or without FPN. For simplicity,
ResNet without FPN as a backbone is taken. When a raw image is fed into a
ResNet backbone, data goes through multiple residual bottleneck blocks,
and turns into a feature map. Feature map from the final convolutional layer of
the backbone contains abstract informations of an image, e.g., different object
instances, their classes and spatial properties. It is then fed to the RPN.

2. **Region Proposal Network:**
RPN stands for Region Proposal Network. Its function is scanning the feature
map and proposing regions that may have objects in them (Region of Interest
or RoI). Concretely, a convolutional layer processes the feature map,
outputs a c-channel tensor whose each spacial vector (also have c channels)
is associated with an anchor center. A set of anchor boxes with different
scales and aspect ratios are generated given one anchor center. These anchor
boxes are different areas that evenly distributed over the whole image and
cover it completely. Then two sibling 1 by 1 convolutional layers process the
c-channel tensor. One is a binary classifier. It predicts whether each anchor
box has an object. It maps each c-channel vector to a k-channel vector
(represents k anchor boxes with different scales and aspect ratios sharing
one anchor center). The other is a object bounding-box regressor. It predicts
the offsets between the true object bounding-box and the anchor box. It maps
each c-channel vector to a 4k-channel vector. For those overlapped bounding-
boxes that may suggest the same object, we select ones with the highest
objectness score, and drop the others. It's the Non-max suppression process.
Thus a bunch of proposed RoIs is obtained.

3. **ROIALign:**
RoIAlign or Region of Interest alignment extracts feature vectors from a
feature map based on RoI proposed by RPN, and turn them into a fix-sized tensor
for further processes. This operation can be illustrated by the above figure.
RoI with their corresponding areas in the feature map by scaling is aligned.
These regions come in different locations, scales and aspect radios. To get
feature tensors of uniform shape, we sample over relevant aligned areas of
the feature map. The white-bordered grid represents the feature map. The
black-bordered grids represent RoIs. We divide each RoI into a fixed number of
bins. In each bin, there are 4 dots representing sample locations.
Feature vectors are sampled on the feature map grid around each dot and compute
their bilinear interpolation as the dot vector. Then we pool dot vectors within
one bin to get a smaller fix-sized feature map for each RoI. Each RoI's feature
map is put into a set of residual bottleneck blocks to extract features
further. The results represent every RoI's finer feature map and will be
processed by two following parallel branches: object detection branch and mask
generation branch.

4. **Object Detection branch:**
After we get individual RoI feature map, we can predict its object category
and a finer instance bounding-box. This branch is a fully-connected layer
that maps feature vectors to the final n classes and 4n instance bounding-box
coordinates.
5. **Mask generation branch:**
On the mask generation branch, we feed RoI feature map to a transposed
convolutional layer and a convolutional layer successively. This branch is a
fully convolutional network. One binary segmentation mask is generated for
one class. Then we pick the output mask according to the class prediction in
object detection branch. In this way, per-pixel's mask prediction can avoid
competition between different classes.

## **Losses Used:**
1. **rpn_class_loss** : How well the Region Proposal Network separates
background with objects.
2. **rpn_bbox_loss** : How well the Region Proposal Network localise objects.
3. **mrcnn_bbox_loss** : How well the Mask RCNN localise objects.
4. **mrcnn_class_loss** : How well the Mask RCNN recognise each class of
object.
5. **mrcnn_mask_loss** : How well the Mask RCNN segment objects.
## **How to use?**

Run the shapes_train.py file for training on the custom Shapes Dataset and
use shapes_demo.py to test the results.

**NOTE: Specifiy the path of the saved shapes weights after training while
testing.**

Use the coco_demo.py file to test on television test image using coco
pretrained weights.

## **TODO List:**

1. Modify the train file to train for COCO dataset.
2. Refractor the utils.py file and use paz library for display function.
3. Recheck the x and y coordinates of bounding box functions in
mask_rcnn.backened to match the paz format. And retrain again.
52 changes: 25 additions & 27 deletions examples/mask_rcnn/backend/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,32 +94,6 @@ def generate_anchors(scales, ratios, shape, feature_stride, anchor_stride):
return boxes


def extract_boxes_from_masks(masks):
"""Compute bounding boxes from masks.

# Arguments:
mask: [height, width, num_instances]. Mask pixels are either 1 or 0.

# Returns:
box array [no. of instances, (x1, y1, x2, y2)].
"""
boxes = np.zeros([masks.shape[-1], 4], dtype=np.int32)

for instance in range(masks.shape[-1]):
mask = masks[:, :, instance]
horizontal_indicies = np.where(np.any(mask, axis=0))[0]
vertical_indicies = np.where(np.any(mask, axis=1))[0]
if horizontal_indicies.shape[0]:
x1, x2 = horizontal_indicies[[0, -1]]
y1, y2 = vertical_indicies[[0, -1]]
x2 += 1
y2 += 1
else:
x1, x2, y1, y2 = 0, 0, 0, 0
boxes[instance] = np.array([y1, x1, y2, x2])
return boxes


def compute_RPN_bounding_box(anchors, rpn_match, groundtruth_boxes,
anchor_iou_argmax,
std_dev=np.array([0.1, 0.1, 0.2, 0.2])):
Expand Down Expand Up @@ -189,7 +163,7 @@ def compute_anchor_boxes_overlaps(anchors, groundtruth_class_ids,
non_crowd_ix = np.where(groundtruth_class_ids > 0)[0]
crowd_boxes = groundtruth_boxes[crowd_instances]

groundtruth_boxes = groundtruth_boxes[non_crowd_instances]
groundtruth_boxes = groundtruth_boxes[non_crowd_ix]
crowd_overlaps = compute_ious(anchors, crowd_boxes)
crowd_iou_max = np.amax(crowd_overlaps, axis=1)
no_crowd_bool = (crowd_iou_max < crowd_threshold)
Expand Down Expand Up @@ -241,3 +215,27 @@ def compute_RPN_match(anchors, overlaps, no_crowd_bool, anchor_size=256):
ids = np.random.choice(ids, extra, replace=False)
rpn_match[ids] = 0
return rpn_match, anchor_iou_argmax


def concatenate_RPN_values(RPN_bounding_box, RPN_match, anchors_shape,
RPN_box_shape=256):
"""
Concatenate RPN match and RPN bounding box values to handle RPN losses
directly to the model.

# Arguments:
RPN_bounding_box: [N, (x1, y1, x2, y2)]
RPN_match: [anchor_shape, num_groundtruth_boxes]
anchors_shape: [N, 4]

# Return:
RPN_values: [N]
"""
zeros_array = np.zeros((anchors_shape, 3))
RPN_bounding_box_padded = np.zeros((RPN_box_shape, 4))
RPN_match = np.reshape(RPN_match, (-1, 1))
RPN_match_padded = np.concatenate((RPN_match, zeros_array), axis=1)
RPN_bounding_box_padded[:len(RPN_bounding_box)] = RPN_bounding_box
RPN_values = np.concatenate((RPN_bounding_box_padded,
RPN_match_padded))
return RPN_values
8 changes: 5 additions & 3 deletions examples/mask_rcnn/backend/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ def crop_resize_masks(boxes, mask, small_mask_shape):
"""
smaller_masks = np.zeros(small_mask_shape + (mask.shape[-1],), dtype=bool)
for instance in range(mask.shape[-1]):
small_mask = mask[:, :, instance].astype(bool)
small_mask = mask[:, :, instance]
y1, x1, y2, x2 = boxes[instance, :4]
small_mask = small_mask[y1:y2, x1:x2]
small_mask = resize_image(small_mask, small_mask_shape)

small_mask = resize_image(np.array(small_mask), small_mask_shape)
smaller_masks[:, :, instance] = np.around(small_mask).astype(bool)
return smaller_masks

Expand All @@ -67,10 +68,11 @@ def resize_to_original_size(mask, box, image_shape, threshold=0.5):
# Returns:
A binary mask with the same size as the original image.
"""
box = [int(x) for x in box]
y_min, x_min, y_max, x_max = box
mask = resize_image(mask, (int(x_max - x_min), int(y_max - y_min)))
mask = np.where(mask >= threshold, 1, 0).astype(np.bool)

mask = np.where(mask >= threshold, 1, 0).astype(np.bool)
full_mask = np.zeros(image_shape[:2], dtype=np.bool)
full_mask[int(y_min):int(y_max), int(x_min):int(x_max)] = mask
return full_mask
57 changes: 57 additions & 0 deletions examples/mask_rcnn/coco_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.utils import get_file
import cv2

from paz.abstract import SequentialProcessor
from paz.backend.image.opencv_image import load_image

from mask_rcnn.model.model import MaskRCNN
from mask_rcnn.datasets.shapes import Shapes

from mask_rcnn.pipelines.detection import ResizeImages, NormalizeImages
from mask_rcnn.pipelines.detection import Detect, PostprocessInputs

from mask_rcnn.inference import test
from mask_rcnn.utils import display_instances

image_min_dim = 800
image_max_dim = 1024
image_scale = 0
anchor_ratios = (32, 64, 128, 256, 512)
images_per_gpu = 1
num_classes = 81

url = 'https://github.com/oarriaga/altamira-data/releases/tag/v0.18/'

weights_local_path = os.path.join(os.getcwd() + '/mask_rcnn_coco.h5')
image_local_path = os.path.join(os.getcwd() + '/television.jpeg')

weights_path = get_file(weights_local_path, url + '/mask_rcnn_coco.h5')
image_path = get_file(image_local_path, url + '/television.jpeg')

image = load_image(image_path)

class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
'kite', 'baseball bat', 'baseball glove', 'skateboard',
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
'teddy bear', 'hair drier', 'toothbrush']

results = test(image, weights_path, 128, num_classes, 1, images_per_gpu,
anchor_ratios, [1024, 1024], 1)
r = results[0]
print(r)
display_instances(image, r['rois'], r['masks'], r['class_ids'], class_names,
r['scores'])
13 changes: 7 additions & 6 deletions examples/mask_rcnn/datasets/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ class Shapes(Loader):
containing
"""
def __init__(self, num_samples, image_size, split='train',
class_names='all', iou_thresh=0.3, max_num_shapes=4):
class_names='all', IoU_thresh=0.3, max_num_shapes=4):
if class_names == 'all':
class_names = ['BG', 'square', 'circle', 'triangle']
self.name_to_arg = dict(zip(class_names, range(len(class_names))))
self.arg_to_name = dict(zip(range(len(class_names)), class_names))
self.num_samples, self.image_size = num_samples, image_size
self.labels = ['input_image', 'input_gt_class_ids', 'input_gt_boxes',
'input_gt_masks']
self.iou_thresh = iou_thresh
self.IoU_thresh = IoU_thresh
self.max_num_shapes = max_num_shapes
super(Shapes, self).__init__(None, split, class_names, 'Shapes')

Expand All @@ -40,7 +40,7 @@ def load_data(self):
def load_sample(self):
shapes = self._sample_shapes(self.max_num_shapes, *self.image_size)
boxes = self._compute_bounding_boxes(shapes)
shapes, boxes = self._filter_shapes(boxes, shapes, self.iou_thresh)
shapes, boxes = self._filter_shapes(boxes, shapes, self.IoU_thresh)
image = self._draw_shapes(shapes)
masks = self._draw_masks(shapes)

Expand Down Expand Up @@ -83,9 +83,9 @@ def _compute_bounding_boxes(self, shapes):
boxes.append(box)
return np.asarray(boxes)

def _filter_shapes(self, boxes, shapes, iou_thresh):
def _filter_shapes(self, boxes, shapes, IoU_thresh):
scores = np.ones(len(boxes))
args, num_boxes = apply_non_max_suppression(boxes, scores, iou_thresh)
args, num_boxes = apply_non_max_suppression(boxes, scores, IoU_thresh)
box_args = args[:num_boxes]
selected_shapes = []
for box_arg in box_args:
Expand Down Expand Up @@ -133,7 +133,8 @@ def extract_boxes_from_masks(self, masks):
"""Compute bounding boxes from masks.

# Arguments:
mask: [height, width, num_instances]. Mask pixels are either 1 or 0.
mask: [height, width, num_instances]. Mask pixels are either 1 or
0.

# Returns:
box array [no. of instances, (x1, y1, x2, y2)].
Expand Down
46 changes: 21 additions & 25 deletions examples/mask_rcnn/inference.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,33 @@
from paz.abstract import SequentialProcessor

from mask_rcnn.model.model import MaskRCNN
from mask_rcnn.evaluation.inference_graph import InferenceGraph
from mask_rcnn.model.model import MaskRCNN, norm_all_boxes
from mask_rcnn.pipelines.detection import ResizeImages, NormalizeImages
from mask_rcnn.pipelines.detection import Detect, PostprocessInputs

from mask_rcnn.utils import norm_boxes_graph

image_min_dim = 128
image_max_dim = 128
image_scale = 0
image_shape = [128, 128, 3]
anchor_ratios = (8, 16, 32, 64, 128)
images_per_gpu = 1


def test(images, weights_path):
config = TestConfig()
resize = SequentialProcessor([ResizeImages(image_min_dim, image_scale, image_max_dim)])
molded_images, windows = resize(images)
def test(images, weights_path, ROIs_per_image, num_classes, batch_size,
images_per_gpu, anchor_ratios, image_shape, min_image_scale):
resize = SequentialProcessor([ResizeImages(image_shape[0], min_image_scale,
image_shape[1])])
molded_images, windows = resize([images])
image_shape = molded_images[0].shape
window = norm_boxes_graph(windows[0], image_shape[:2])
window = norm_all_boxes(windows[0], image_shape[:2])

base_model = MaskRCNN(model_dir='../../mask_rcnn', image_shape=image_shape, backbone="resnet101",
batch_size=1, images_per_gpu=1, rpn_anchor_scales=(8, 16, 32, 64, 128),
train_rois_per_image=32, num_classes=4, window=window)
inference_model = base_model.build_inference_model()
base_model = MaskRCNN(model_dir='../../mask_rcnn',
image_shape=image_shape,
backbone="resnet101",
batch_size=batch_size, images_per_gpu=images_per_gpu,
RPN_anchor_scales=anchor_ratios,
train_ROIs_per_image=ROIs_per_image,
num_classes=num_classes,
window=window)

base_model.keras_model = inference_model
base_model.build_model(train=False)
base_model.keras_model.load_weights(weights_path, by_name=True)
preprocess = SequentialProcessor([ResizeImages(image_min_dim, image_scale, image_max_dim),
NormalizeImages()])
preprocess = SequentialProcessor([ResizeImages(), NormalizeImages()])
postprocess = SequentialProcessor([PostprocessInputs()])
detect = Detect(base_model, anchor_ratios, images_per_gpu, preprocess, postprocess)
results = detect(images)

detect = Detect(base_model, anchor_ratios, images_per_gpu, preprocess,
postprocess)
results = detect([images])
return results
Loading