Skip to content
This repository has been archived by the owner on Jan 3, 2024. It is now read-only.

Commit

Permalink
Add fragmented code
Browse files Browse the repository at this point in the history
  • Loading branch information
hmellor committed Jun 22, 2019
1 parent f45365d commit 140fe1e
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 0 deletions.
Empty file added pytorch-superpixels/__init__.py
Empty file.
104 changes: 104 additions & 0 deletions pytorch-superpixels/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
'''For superpixel validation'''


def mask_accuracy(target, mask):
target_s = torch.zeros_like(target)
superpixels = mask.unique().numel()
for superpixel in range(superpixels):
# Define mask for cluster idx
segment_mask = mask == superpixel
# Take slices to select image, apply mask, mode for majority class
target_s[segment_mask] = target[segment_mask].view(-1).mode()[0]
accuracy = torch.mean((target == target_s).float())
return accuracy


def dataset_accuracy(superpixels):
# Generate image list
if superpixels is not None:
image_list = get_image_list('trainval_super')
else:
image_list = get_image_list()

mask_acc = 0
mask_dir = "SegmentationClass/{}_sp".format(superpixels)
target_dir = "SegmentationClass/pre_encoded"
for image_number in tqdm(image_list):
mask_path = join(root, mask_dir, image_number + ".pt")
target_path = join(root, target_dir, image_number + ".png")
mask = torch.load(mask_path)
target = io.imread(target_path)
target = torch.from_numpy(target)
mask_acc += mask_accuracy(target, mask)
dataset_acc = mask_acc / len(image_list)
return dataset_acc


def find_smallest_object():
# Generate image list
image_list = get_image_list()
smallest_object = 1e6
for image_number in tqdm(image_list):
target_name = image_number + ".png"
target_path = join(root, "SegmentationClass/pre_encoded", target_name)
target = io.imread(target_path)
target = torch.from_numpy(target)
object_size = torch.ne(target, 0).sum()
if object_size < smallest_object:
smallest_object = object_size
print(smallest_object, image_number)
return smallest_object


def find_usable_images(split, superpixels):
# Generate image list
image_list = get_image_list(split)
usable = []
target_dir = join(
root,
"SegmentationClass/pre_encoded_{}_sp".format(superpixels)
)
for image_number in image_list:
target_name = image_number + ".pt"
target_path = join(target_dir, target_name)
target = torch.load(target_path)
if target.nonzero().numel() > 0:
usable.append(image_number)
return usable


def fix_broken_images(superpixels):
for split in ["train", "val", "trainval"]:
usable = find_usable_images(split=split, superpixels=superpixels)
super_path = join(root, "ImageSets/Segmentation", split + "_super.txt")
if exists(super_path):
remove(super_path)
with open(super_path, "w+") as file:
for image_number in usable:
file.write(image_number + "\n")


def find_size_variance(superpixels):
# Generate image list
if superpixels is not None:
image_list = get_image_list('trainval_super')
else:
image_list = get_image_list()
mask_dir = "SegmentationClass/{}_sp".format(superpixels)
dataset_variance = 0
for image_number in tqdm(image_list):
mask_path = join(root, mask_dir, image_number + ".pt")
mask = torch.load(mask_path)
# Initialise number of superpixels tensors
Q = mask.unique().numel()
size = torch.zeros(Q)
counter = torch.ones_like(mask)
# Calculate the size of each superpixel
size.put_(mask, counter.float(), True)
# Calculate the mean and standard deviation of the sizes
std = size.std()
mean = size.mean()
# Add to the variance of the total datasets
dataset_variance += std / mean
dataset_variance /= len(image_list)
return dataset_variance
98 changes: 98 additions & 0 deletions pytorch-superpixels/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
'''For pre-processing'''


def create_masks(numSegments=100, limOverseg=None):
# Generate image list
image_list = get_image_list()
for image_number in tqdm(image_list):
# Load image/target pair
image_name = image_number + ".jpg"
target_name = image_number + ".png"
image_path = join(root, "JPEGImages", image_name)
target_path = join(root, "SegmentationClass/pre_encoded", target_name)
image = img_as_float(io.imread(image_path))
target = io.imread(target_path)
target = torch.from_numpy(target)
# Create mask for image/target pair
mask, target_s = create_mask(
image=image,
target=target,
numSegments=numSegments,
limOverseg=limOverseg
)

# Save for later
image_save_dir = join(
root,
"SegmentationClass/{}_sp".format(numSegments)
)
target_s_save_dir = join(
root,
"SegmentationClass/pre_encoded_{}_sp".format(numSegments)
)
if not exists(image_save_dir):
mkdir(image_save_dir)
if not exists(target_s_save_dir):
mkdir(target_s_save_dir)
save_name = image_number + ".pt"
image_save_path = join(image_save_dir, save_name)
target_s_save_path = join(target_s_save_dir, save_name)
torch.save(mask, image_save_path)
torch.save(target_s, target_s_save_path)


def create_mask(image, target, numSegments, limOverseg):
# Perform SLIC segmentation
mask = slic(image, n_segments=numSegments, slic_zero=True)
mask = torch.from_numpy(mask)

if limOverseg is not None:
# Oversegmentation step
superpixels = mask.unique().numel()
overseg = superpixels
for superpixel in range(superpixels):
overseg -= 1
# Define mask for superpixel
segment_mask = mask == superpixel
# Classes in this superpixel
classes = target[segment_mask].unique(sorted=True)
# Check if superpixel is on target boundary
on_boundary = classes.numel() > 1
# If current superpixel is on a gt boundary
if on_boundary:
# Find how many of each class is in superpixel
class_hist = torch.bincount(target[segment_mask])
# Remove zero elements
class_hist = class_hist[class_hist.nonzero()].float()
# Find minority class in superpixel
min_class = min(class_hist)
# Is the minority class large enough for oversegmentation
above_threshold = min_class > class_hist.sum() * limOverseg
if above_threshold:
# Leaving one class in supperpixel be
for c in classes[1:]:
# Adding to the oversegmentation offset
overseg += 1
# Add offset to class c in the mask
mask[segment_mask] += (target[segment_mask]
== c).long() * overseg

# (Re)define how many superpixels there are and create target_s
superpixels = mask.unique().numel()
target_s = torch.zeros(superpixels, dtype=torch.long)
for superpixel in range(superpixels):
# Define mask for superpixel
segment_mask = mask == superpixel
# Apply mask, the mode for majority class
target_s[superpixel] = target[segment_mask].view(-1).mode()[0]
return mask, target_s


def get_image_list(split=None):
if split is None:
image_list_path = join(root, "ImageSets/Segmentation/trainval.txt")
else:
image_list_path = join(root, "ImageSets/Segmentation/", split + ".txt")
image_list = tuple(open(image_list_path, "r"))
image_list = [id_.rstrip() for id_ in image_list]
return image_list
74 changes: 74 additions & 0 deletions pytorch-superpixels/runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
from os.path import join
from os.path import exists
from os.path import dirname
from os.path import abspath
from os import mkdir
from os import remove
from os import listdir
from tqdm import tqdm
from skimage import io
from skimage.util import img_as_float
from skimage.segmentation import slic

# Define absolute path for accessing dataset files
package_dir = dirname(abspath(__file__))
dataset_dir = "../../datasets/VOCdevkit/VOC2011"
root = join(package_dir, dataset_dir)
'''For use during runtime'''


def convert_to_superpixels(input, target, mask):
# Extract size data from input and target
images, c, h, w = input.size()
if images > 1:
raise RuntimeError("Not implemented for batch sizes greater than 1")
# Initialise vairables to use
Q = mask.unique().numel()
output = torch.zeros((Q, c), device=input.device)
size = torch.zeros(Q, device=input.device)
counter = torch.ones(mask.size(), device=input.device)
# Calculate the size of each superpixel
size.put_(mask, counter, True)
# Calculate the mean value of each superpixel
input = input.view(c, -1)
mask = mask.view(1, -1).repeat(c, 1)
arange = torch.arange(start=1, end=c, device=input.device)
mask[arange, :] += Q * arange.view(-1, 1)
output = output.put_(mask, input, True).view(c, Q).t()
output = (output.t() / size).t()
return output, target.view(-1), size


def convert_to_pixels(input, output, mask):
n, c, h, w = output.size()
for k in range(c):
output[0, k, :, :] = torch.gather(
input[:, k], 0, mask.view(-1)).view(h, w)
return output


def to_super_to_pixels(input, mask):
target = torch.tensor([])
input_s, _, _ = convert_to_superpixels(input, target, mask)
output = convert_to_pixels(input_s, input, mask)
return output


def setup_superpixels(superpixels):
image_save_dir = join(
root,
"SegmentationClass/{}_sp".format(superpixels)
)
target_s_save_dir = join(
root,
"SegmentationClass/pre_encoded_{}_sp".format(superpixels)
)
dirs = [image_save_dir, target_s_save_dir]
dataset_len = len(get_image_list())
if not any(exists(x) and len(listdir(x)) == dataset_len for x in dirs):
print("Superpixel dataset of scale {} superpixels either doesn't exist or is incomplete".format(superpixels))
print("Generating superpixel dataset now...")
create_masks(superpixels)

fix_broken_images(superpixels)

0 comments on commit 140fe1e

Please sign in to comment.