Skip to content

Commit

Permalink
uuvdataset: Add basic code for dataset and running
Browse files Browse the repository at this point in the history
  • Loading branch information
marcojob committed Nov 17, 2024
1 parent 5211a36 commit a76252a
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 30 deletions.
48 changes: 33 additions & 15 deletions radarmeetsvision/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def load_model(self, pretrained_from=None):
self.model = get_model(pretrained_from, self.use_depth_prior, self.encoder, self.max_depth, output_channels=self.output_channels)
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
self.model.to(self.device)
else:
logger.error(f"One or multiple undefined: {self.encoder}, {self.max_depth}, {self.output_channels}, {self.use_depth_prior}")

return self.model

Expand Down Expand Up @@ -138,6 +140,9 @@ def get_dataset_loader(self, task, datasets_dir, dataset_list, index_list=None):

return loader, dataset

def get_single_dataset_loader(self, dataset_dir):
dataset = BlearnDataset(dataset_dir, 'all', self.size, 0, -1)
return DataLoader(dataset, batch_size=self.batch_size, pin_memory=True, drop_last=True)

def update_best_result(self, results, nsamples):
if nsamples:
Expand All @@ -155,18 +160,27 @@ def update_best_result(self, results, nsamples):

def prepare_sample(self, sample, random_flip=False):
image = sample['image'].to(self.device)
depth_target = sample['depth'].to(self.device)
depth_prior = sample['depth_prior'].to(self.device)
valid_mask = sample['valid_mask'].to(self.device)

mask = (valid_mask == 1) & (depth_target >= self.min_depth) & (depth_target <= self.max_depth)

depth_prior = sample['depth_prior'].to(self.device)
if self.use_depth_prior:
depth_prior = depth_prior.unsqueeze(1)
image = torch.cat((image, depth_prior), axis=1)

if random_flip:
image, depth_target, valid_mask = randomly_flip(image, depth_target, mask)
if 'depth' in sample.keys():
depth_target = sample['depth'].to(self.device)
mask = (valid_mask == 1) & (depth_target >= self.min_depth) & (depth_target <= self.max_depth)

if random_flip:
image, depth_target, valid_mask = randomly_flip(image, depth_target, mask)

else:
depth_target, mask = None, None

if 'valid_mask' in sample.keys():
valid_mask = sample['valid_mask'].to(self.device)

else:
valid_mask = None

return image, depth_prior, depth_target, mask

Expand Down Expand Up @@ -200,28 +214,32 @@ def train_epoch(self, epoch, train_loader):
logger.info('Iter: {}/{}, LR: {:.7f}, Loss: {:.3f}'.format(i, len(train_loader), self.optimizer.param_groups[0]['lr'], total_loss/(i + 1.0)))


def validate_epoch(self, epoch, val_loader):
def validate_epoch(self, epoch, val_loader, save_outputs=False):
self.model.eval()
print("RUNNING EVAL")

self.results, self.results_per_sample, nsamples = get_empty_results(self.device)
for i, sample in enumerate(val_loader):
image, _, depth_target, mask = self.prepare_sample(sample, random_flip=False)

# TODO: Maybe not hardcode 10 here?
if mask.sum() > 10:
if mask is None or mask.sum() > 10:
with torch.no_grad():
prediction = self.model(image)
prediction = interpolate_shape(prediction, depth_target)
depth_prediction = get_depth_from_prediction(prediction, image)

current_results = eval_depth(depth_prediction[mask], depth_target[mask])
if current_results is not None:
for k in self.results.keys():
self.results[k] += current_results[k]
self.results_per_sample[k].append(current_results[k])
nsamples += 1
if mask is not None:
current_results = eval_depth(depth_prediction[mask], depth_target[mask])
if current_results is not None:
for k in self.results.keys():
self.results[k] += current_results[k]
self.results_per_sample[k].append(current_results[k])
nsamples += 1

if i % 10 == 0:
plt.imshow(depth_prediction)
plt.show()
abs_rel = (self.results["abs_rel"]/nsamples).item()
logger.info(f'Iter: {i}/{len(val_loader)}, Absrel: {abs_rel:.3f}')

Expand Down
19 changes: 10 additions & 9 deletions radarmeetsvision/metric_depth_network/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,16 @@ def get_confidence_from_prediction(prediction):
return prediction[:, 1:, :, :]

def interpolate_shape(prediction, target, mode='bilinear'):
if len(target.shape) > 2:
target = target.squeeze()
if target is not None:
if len(target.shape) > 2:
target = target.squeeze()

interp_shape = (target.shape[0], target.shape[1])
interp_shape = (target.shape[0], target.shape[1])

if mode == 'nearest':
if len(prediction.shape) < 4:
prediction = prediction.unsqueeze(0)
prediction = F.interpolate(prediction, interp_shape, mode=mode)
else:
prediction = F.interpolate(prediction, interp_shape, mode=mode, align_corners=True)
if mode == 'nearest':
if len(prediction.shape) < 4:
prediction = prediction.unsqueeze(0)
prediction = F.interpolate(prediction, interp_shape, mode=mode)
else:
prediction = F.interpolate(prediction, interp_shape, mode=mode, align_corners=True)
return prediction
16 changes: 11 additions & 5 deletions radarmeetsvision/metric_depth_network/dataset/blearndataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def get_filelist(self, index_min=0, index_max=-1):
filelist = all_indexes

else:
logger.error(f'Mode not supported: {self.mode}')
filelist = None

# Log the number of files selected
Expand All @@ -120,15 +121,19 @@ def __getitem__(self, item):
depth = self.get_depth(index)
depth_prior = self.get_depth_prior(index, image_cv.copy(), depth)

sample = self.transform({'image': image, 'depth': depth, 'depth_prior': depth_prior})
if depth is not None:
sample = self.transform({'image': image, 'depth': depth, 'depth_prior': depth_prior})
sample['depth'] = torch.from_numpy(sample['depth'])
sample['depth'] = torch.nan_to_num(sample['depth'], nan=0.0)
sample['valid_mask'] = ((sample['depth'] > self.depth_min) & (sample['depth'] <= self.depth_max))
else:
sample = self.transform({'image': image, 'depth_prior': depth_prior})

sample['image'] = torch.from_numpy(sample['image'])
sample['depth'] = torch.from_numpy(sample['depth'])
sample['depth'] = torch.nan_to_num(sample['depth'], nan=0.0)

sample['depth_prior'] = torch.from_numpy(sample['depth_prior'])
sample['depth_prior'] = torch.nan_to_num(sample['depth_prior'], nan=0.0)

sample['valid_mask'] = ((sample['depth'] > self.depth_min) & (sample['depth'] <= self.depth_max))
sample['image_path'] = str(img_path)

return sample
Expand Down Expand Up @@ -230,7 +235,8 @@ def get_depth_range(self):
def get_depth_prior(self, index, img_copy, depth):
if self.depth_prior_dir.is_dir():
depth_prior = np.load(str(self.depth_prior_dir / self.depth_prior_template.format(index))).astype(np.float32)
if depth_prior.max() <= 1.0:
# TODO: Better detecting if dataset is normalized or not
if depth_prior.max() <= 0.01:
depth_prior_valid_mask = (depth_prior > 0.0) & (depth_prior <= 1.0)
depth_prior[depth_prior_valid_mask] *= self.depth_range + self.depth_min

Expand Down
2 changes: 1 addition & 1 deletion radarmeetsvision/metric_depth_network/dataset/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def __call__(self, sample):
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])

if "depth" in sample:
if "depth" in sample and sample["depth"] is not None:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)

Expand Down
45 changes: 45 additions & 0 deletions scripts/evaluation/evaluate_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
######################################################################
#
# Copyright (c) 2024 ETHZ Autonomous Systems Lab. All rights reserved.
#
######################################################################

import argparse
import logging
import pickle
import radarmeetsvision as rmv

from pathlib import Path
from results_table_template import generate_tables
from create_scatter_plot import create_scatter_plot

logger = logging.getLogger(__name__)

def main():
parser = argparse.ArgumentParser(description='Purely evalute a network')
parser.add_argument('--dataset', type=str, required=True, help='Path to the dataset directory')
parser.add_argument('--output', type=str, required=True, help='Path to the output directory')
parser.add_argument('--network', type=str, help='Path to the network file')
args = parser.parse_args()
rmv.setup_global_logger()

interface = rmv.Interface()
interface.set_encoder('vitb')
depth_min = 0.19983673095703125
depth_max = 120.49285888671875
interface.set_depth_range((depth_min, depth_max))
interface.set_output_channels(2)
interface.set_use_depth_prior(True)
interface.load_model(pretrained_from=args.network)

interface.set_size(720, 1280)
interface.set_batch_size(1)
interface.set_criterion()

loader = interface.get_single_dataset_loader(args.dataset)
interface.validate_epoch(0, loader, save_outputs=True)



if __name__ == "__main__":
main()
95 changes: 95 additions & 0 deletions scripts/validation_dataset/create_uuv_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import argparse
import cv2
import laspy
import logging
import matplotlib.pyplot as plt
import multiprocessing as mp
import numpy as np
import pickle
import re
import rosbag
import scipy.ndimage as ndi
import sensor_msgs.point_cloud2 as pc2

from pathlib import Path
from cv_bridge import CvBridge
from scipy.spatial.transform import Rotation
from scipy.spatial import cKDTree

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("dataset")

class UUVDataset:
def __init__(self, input_dir):
for file_path in input_dir.iterdir():
if '.bag' in str(file_path):
self.bag_file = file_path

def generate(self):
output_dir = Path(self.bag_file).parent / 'output'
output_dir.mkdir(exist_ok=True)
rgb_dir = output_dir / 'rgb'
rgb_dir.mkdir(exist_ok=True)
depth_dir = output_dir / 'depth'
depth_dir.mkdir(exist_ok=True)
depth_prior_dir = output_dir / 'depth_prior'
depth_prior_dir.mkdir(exist_ok=True)

topics = ['/ted/image', '/navigation/plane_approximation', '/sensor/dvl_position']
image_count = 0
depth_priors = None
bridge = CvBridge()
radius_prior_pixel = 5

with rosbag.Bag(self.bag_file, 'r') as bag:
for i, (topic, msg, t) in enumerate(bag.read_messages(topics=topics)):
if topic == '/navigation/plane_approximation':
depth_priors = [msg.NetDistance]

elif topic == '/ted/image' and depth_priors is not None:
# RGB image
img = bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8')
flipped_image = cv2.flip(img, 0)
height, width, _ = flipped_image.shape
width = width //2
right_image = flipped_image[:, width:, :]

img_file = rgb_dir / f'{image_count:05d}_rgb.jpg'
cv2.imwrite(str(img_file), right_image)

# Depth prior
depth_prior = np.zeros((width, height), dtype=np.float32)
circular_mask = self.create_circular_mask(2 * width, 2 * height, radius=radius_prior_pixel)
x, y = width // 2, height // 2
translated_mask = circular_mask[int(width - x):int(2 * width - x), int(height - y):int(2 * height - y)]
depth_prior += depth_priors[0] * translated_mask
depth_prior = depth_prior.T

depth_prior_file = output_dir / 'depth_prior' / f'{image_count:05d}_ra.npy'
np.save(depth_prior_file, depth_prior)
plt.imsave(output_dir / 'depth_prior' / f'{image_count:05d}_ra.jpg', depth_prior, vmin=0,vmax=15)

image_count += 1

def create_circular_mask(self, h, w, center=None, radius=None):
# From:
# https://stackoverflow.com/questions/44865023/how-can-i-create-a-circular-mask-for-a-numpy-array
if center is None: # use the middle of the image
center = (int(w / 2), int(h / 2))
if radius is None: # use the smallest distance between the center and image walls
radius = min(center[0], center[1], w - center[0], h - center[1])

Y, X = np.ogrid[:h, :w]
dist_from_center = np.sqrt((X - center[0])**2 + (Y - center[1])**2)

mask = dist_from_center <= radius
return mask


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Process the camera parameters, offset, and point cloud files.')
parser.add_argument('input_dir', type=Path, help='Path to folder containing all required files')

args = parser.parse_args()
dataset = UUVDataset(args.input_dir)
dataset.generate()

0 comments on commit a76252a

Please sign in to comment.