Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/super_gradients/common/object_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ class Transforms:
DetectionNormalize = "DetectionNormalize"
DetectionPadIfNeeded = "DetectionPadIfNeeded"
DetectionLongestMaxSize = "DetectionLongestMaxSize"
# Optical flow transforms
OpticalFlowColorJitter = "OpticalFlowColorJitter"
OpticalFlowOcclusion = "OpticalFlowOcclusion"
OpticalFlowRandomRescale = "OpticalFlowRandomRescale"
OpticalFlowRandomFlip = "OpticalFlowRandomFlip"
OpticalFlowCrop = "OpticalFlowCrop"
OpticalFlowToTensor = "OpticalFlowToTensor"
#
RandomResizedCropAndInterpolation = "RandomResizedCropAndInterpolation"
RandAugmentTransform = "RandAugmentTransform"
Expand Down Expand Up @@ -438,6 +445,7 @@ class Datasets:
COCO_KEY_POINTS_DATASET = "COCOKeypointsDataset"
COCO_POSE_ESTIMATION_DATASET = "COCOPoseEstimationDataset"
NYUV2_DEPTH_ESTIMATION_DATASET = "NYUv2DepthEstimationDataset"
KITTI_OPTICAL_FLOW_DATASET = "KITTIOpticalFlowDataset"


class Processings:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from super_gradients.training.datasets.optical_flow_datasets.abstract_optical_flow_dataset import AbstractOpticalFlowDataset
from super_gradients.training.datasets.optical_flow_datasets.kitti_dataset import KITTIOpticalFlowDataset

__all__ = ["AbstractOpticalFlowDataset", "KITTIOpticalFlowDataset"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import abc
from typing import List, Tuple

import random

import numpy as np
from data_gradients.common.decorators import resolve_param
from matplotlib import pyplot as plt
from torch.utils.data.dataloader import Dataset

from super_gradients.common.factories.list_factory import ListFactory
from super_gradients.common.factories.transforms_factory import TransformsFactory
from super_gradients.training.samples import OpticalFlowSample
from super_gradients.training.transforms.optical_flow import AbstractOpticalFlowTransform
from super_gradients.training.utils.visualization.optical_flow import FlowVisualization


class AbstractOpticalFlowDataset(Dataset):
"""
Abstract class for datasets for optical flow task.

Attempting to follow principles provided in pose_etimation_dataset.
"""

@resolve_param("transforms", ListFactory(TransformsFactory()))
def __init__(self, transforms: List[AbstractOpticalFlowTransform] = None):
super().__init__()
self.transforms = transforms or []

@abc.abstractmethod
def load_sample(self, index: int) -> OpticalFlowSample:
"""
Load an optical flow sample from the dataset.

:param index: Index of the sample to load.
:return: Instance of OpticalFlowSample.

"""
raise NotImplementedError()

def load_random_sample(self) -> OpticalFlowSample:
"""
Return a random sample from the dataset

:return: Instance of OpticalFlowSample
"""
num_samples = len(self)
random_index = random.randrange(0, num_samples)
return self.load_sample(random_index)

def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Get a transformed optical flow sample from the dataset.

:param index: Index of the sample to retrieve.
:return: Tuple containing the transformed images and flow map as np.ndarrays.

After applying the transforms pipeline, the image is expected to be in 2HWC format, and the flow map should be
a 3D array (e.g., 2 x Height x Width).

Before returning the images and flow map, the image's channels are moved to 2CHW format and the flow_map's channels are moved to CHW format.
"""
sample = self.load_sample(index)
for transform in self.transforms:
sample = transform(sample)
return np.transpose(sample.images, (0, 3, 1, 2)).astype(np.float32), (
np.transpose(sample.flow_map, (2, 0, 1)).astype(np.float32),
sample.valid.astype(np.float32),
)

def plot(
self,
max_samples_per_plot: int = 8,
n_plots: int = 1,
plot_transformed_data: bool = True,
):
"""
Combine samples of images with flow maps into plots and display the result.

:param max_samples_per_plot: Maximum number of samples (image with depth map) to be displayed per plot.
:param n_plots: Number of plots to display.
:param plot_transformed_data: If True, the plot will be over samples after applying transforms (i.e., on __getitem__).
If False, the plot will be over the raw samples (i.e., on load_sample).

:return: None
"""
plot_counter = 0

for plot_i in range(n_plots):
fig, axes = plt.subplots(3, max_samples_per_plot, figsize=(20, 7))
for sample_i in range(max_samples_per_plot):
index = sample_i + plot_i * max_samples_per_plot
if plot_transformed_data:
images, (flow_map, valid) = self[index]

# Transpose to HWC format for visualization
images = images.transpose(0, 2, 3, 1)
flow_map = flow_map.squeeze() # Remove dummy dimension
else:
sample = self.load_sample(index)
images, flow_map, _ = sample.images, sample.flow_map.sample.valid

# Plot the image
axes[0, sample_i].imshow(images[0].astype(np.uint8))
axes[0, sample_i].axis("off")
axes[0, sample_i].set_title(f"Sample {index} image1")

axes[1, sample_i].imshow(images[1].astype(np.uint8))
axes[1, sample_i].axis("off")
axes[1, sample_i].set_title(f"Sample {index} image2")

# Plot the depth map side by side with the selected color scheme
flow_map = FlowVisualization.process_flow_map_for_visualization(flow_map)
axes[2, sample_i].imshow(flow_map)
axes[2, sample_i].axis("off")
axes[2, sample_i].set_title(f"Flow Map {index}")

plt.show()
plt.close()

plot_counter += 1
if plot_counter == n_plots:
return
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import warnings

import numpy as np

from super_gradients.common.object_names import Datasets
from super_gradients.common.registry import register_dataset
from super_gradients.training.datasets.optical_flow_datasets import kitti_utils
from super_gradients.training.datasets.optical_flow_datasets.abstract_optical_flow_dataset import AbstractOpticalFlowDataset

from super_gradients.training.samples import OpticalFlowSample
from glob import glob
import os


@register_dataset(Datasets.KITTI_OPTICAL_FLOW_DATASET)
class KITTIOpticalFlowDataset(AbstractOpticalFlowDataset):
"""
Dataset class for KITTI 2015 dataset for optical flow.

:param root: Root directory containing the dataset.
:param transforms: Transforms to be applied to the samples.

To use the KITTIOpticalFlowDataset class, ensure that your dataset directory is organized as follows:

- Root directory (specified as 'root' when initializing the dataset)
- training
- image_2
- 000000_10.png
- 000000_11.png
- 000001_10.png
- 000001_11.png
- ...
- flow_occ
- 000000_10.png
- 000001_10.png
- ...

Data can be obtained at https://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow
...
"""

def __init__(self, root: str, transforms=None):
"""
Initialize KITTIDataset.

:param root: Root directory containing the dataset.
:param df_path: Path to the CSV file containing image and depth map file paths.
:param transforms: Transforms to be applied to the samples.
"""
super(KITTIOpticalFlowDataset, self).__init__(transforms=transforms)

images_list = []

data_root = os.path.join(root, "training")

images1 = sorted(glob(os.path.join(data_root, "image_2/*_10.png")))
images2 = sorted(glob(os.path.join(data_root, "image_2/*_11.png")))

for img1, img2 in zip(images1, images2):
images_list += [[img1, img2]]

flow_list = sorted(glob(os.path.join(data_root, "flow_occ/*_10.png")))

self.files_list = [(elem1[0], elem1[1], elem2) for elem1, elem2 in zip(images_list, flow_list)]

self._check_paths_exist()

def load_sample(self, index: int) -> OpticalFlowSample:
"""
Load an optical flow estimation sample at the specified index.

:param index: Index of the sample.

:return: Loaded optical flow estimation sample.
"""
flow_map, valid = kitti_utils.read_flow_kitti(self.files_list[index][2])

image1 = kitti_utils.read_gen(self.files_list[index][0])
image2 = kitti_utils.read_gen(self.files_list[index][1])

flow_map = np.array(flow_map).astype(np.float32)
image1 = np.array(image1).astype(np.uint8)
image2 = np.array(image2).astype(np.uint8)

# grayscale images
if len(image1.shape) == 2:
image1 = np.tile(image1[..., None], (1, 1, 3))
image2 = np.tile(image2[..., None], (1, 1, 3))
else:
image1 = image1[..., :3]
image2 = image2[..., :3]

images = np.stack([image1, image2])

if valid is not None:
valid = valid
else:
valid = (np.abs(flow_map[:, :, 0]) < 1000) & (np.abs(flow_map[:, :, 1]) < 1000)

return OpticalFlowSample(images=images, flow_map=flow_map, valid=valid)

def __len__(self):
"""
Get the number of samples in the dataset.

:return: Number of samples in the dataset.
"""
return len(self.files_list)

def _check_paths_exist(self):
"""
Check if the paths in self.train_list and self.val_list exist. Remove lines with missing paths and print information about removed paths.
Raise an error if all lines are removed.
"""
valid_paths = []

for idx in range(len(self.files_list)):
paths_exist = all(os.path.exists(path) for path in self.files_list[idx])
if paths_exist:
valid_paths.append(self.files_list[idx])
else:
warnings.warn(f"Warning: Removed the following line as one or more paths do not exist: {self.files_list[idx]}")

if not valid_paths:
raise FileNotFoundError("All lines in the dataset have been removed as some paths do not exist. " "Please check the paths and dataset structure.")

self.files_list = valid_paths
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import numpy as np
from PIL import Image
import os
import re

import cv2


def readFlow(fn):
"""Read .flo file in Middlebury format"""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy

# WARNING: this will work on little-endian architectures (eg Intel x86) only!
# print 'fn = %s'%(fn)
with open(fn, "rb") as f:
magic = np.fromfile(f, np.float32, count=1)
if 202021.25 != magic:
print("Magic number incorrect. Invalid .flo file")
return None
else:
w = np.fromfile(f, np.int32, count=1)
h = np.fromfile(f, np.int32, count=1)
# print 'Reading %d x %d flo file\n' % (w, h)
data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
# Reshape data into 3D array (columns, rows, bands)
# The reshape here is for visualization, the original code is (w,h,2)
return np.resize(data, (int(h), int(w), 2))


def readPFM(file):
file = open(file, "rb")

color = None
width = None
height = None
scale = None
endian = None

header = file.readline().rstrip()
if header == b"PF":
color = True
elif header == b"Pf":
color = False
else:
raise Exception("Not a PFM file.")

dim_match = re.match(rb"^(\d+)\s(\d+)\s$", file.readline())
if dim_match:
width, height = map(int, dim_match.groups())
else:
raise Exception("Malformed PFM header.")

scale = float(file.readline().rstrip())
if scale < 0: # little-endian
endian = "<"
scale = -scale
else:
endian = ">" # big-endian

data = np.fromfile(file, endian + "f")
shape = (height, width, 3) if color else (height, width)

data = np.reshape(data, shape)
data = np.flipud(data)
return data


def read_gen(file_name, pil=False):
ext = os.path.splitext(file_name)[-1]
if ext == ".png" or ext == ".jpeg" or ext == ".ppm" or ext == ".jpg":
return Image.open(file_name)
elif ext == ".bin" or ext == ".raw":
return np.load(file_name)
elif ext == ".flo":
return readFlow(file_name).astype(np.float32)
elif ext == ".pfm":
flow = readPFM(file_name).astype(np.float32)
if len(flow.shape) == 2:
return flow
else:
return flow[:, :, :-1]
return []


def read_flow_kitti(filename):
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
flow = flow[:, :, ::-1].astype(np.float32)
flow, valid = flow[:, :, :2], flow[:, :, 2]
flow = (flow - 2**15) / 64.0
return flow, valid
3 changes: 2 additions & 1 deletion src/super_gradients/training/samples/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .pose_estimation_sample import PoseEstimationSample
from .detection_sample import DetectionSample
from .segmentation_sample import SegmentationSample
from .optical_flow_sample import OpticalFlowSample

__all__ = ["PoseEstimationSample", "DetectionSample", "SegmentationSample", "DepthEstimationSample"]
__all__ = ["PoseEstimationSample", "DetectionSample", "SegmentationSample", "DepthEstimationSample", "OpticalFlowSample"]
Loading