-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
uuvdataset: Add basic code for dataset and running
- Loading branch information
Showing
6 changed files
with
195 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
###################################################################### | ||
# | ||
# Copyright (c) 2024 ETHZ Autonomous Systems Lab. All rights reserved. | ||
# | ||
###################################################################### | ||
|
||
import argparse | ||
import logging | ||
import pickle | ||
import radarmeetsvision as rmv | ||
|
||
from pathlib import Path | ||
from results_table_template import generate_tables | ||
from create_scatter_plot import create_scatter_plot | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Purely evalute a network') | ||
parser.add_argument('--dataset', type=str, required=True, help='Path to the dataset directory') | ||
parser.add_argument('--output', type=str, required=True, help='Path to the output directory') | ||
parser.add_argument('--network', type=str, help='Path to the network file') | ||
args = parser.parse_args() | ||
rmv.setup_global_logger() | ||
|
||
interface = rmv.Interface() | ||
interface.set_encoder('vitb') | ||
depth_min = 0.19983673095703125 | ||
depth_max = 120.49285888671875 | ||
interface.set_depth_range((depth_min, depth_max)) | ||
interface.set_output_channels(2) | ||
interface.set_use_depth_prior(True) | ||
interface.load_model(pretrained_from=args.network) | ||
|
||
interface.set_size(720, 1280) | ||
interface.set_batch_size(1) | ||
interface.set_criterion() | ||
|
||
loader = interface.get_single_dataset_loader(args.dataset) | ||
interface.validate_epoch(0, loader, save_outputs=True) | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import argparse | ||
import cv2 | ||
import laspy | ||
import logging | ||
import matplotlib.pyplot as plt | ||
import multiprocessing as mp | ||
import numpy as np | ||
import pickle | ||
import re | ||
import rosbag | ||
import scipy.ndimage as ndi | ||
import sensor_msgs.point_cloud2 as pc2 | ||
|
||
from pathlib import Path | ||
from cv_bridge import CvBridge | ||
from scipy.spatial.transform import Rotation | ||
from scipy.spatial import cKDTree | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger("dataset") | ||
|
||
class UUVDataset: | ||
def __init__(self, input_dir): | ||
for file_path in input_dir.iterdir(): | ||
if '.bag' in str(file_path): | ||
self.bag_file = file_path | ||
|
||
def generate(self): | ||
output_dir = Path(self.bag_file).parent / 'output' | ||
output_dir.mkdir(exist_ok=True) | ||
rgb_dir = output_dir / 'rgb' | ||
rgb_dir.mkdir(exist_ok=True) | ||
depth_dir = output_dir / 'depth' | ||
depth_dir.mkdir(exist_ok=True) | ||
depth_prior_dir = output_dir / 'depth_prior' | ||
depth_prior_dir.mkdir(exist_ok=True) | ||
|
||
topics = ['/ted/image', '/navigation/plane_approximation', '/sensor/dvl_position'] | ||
image_count = 0 | ||
depth_priors = None | ||
bridge = CvBridge() | ||
radius_prior_pixel = 5 | ||
|
||
with rosbag.Bag(self.bag_file, 'r') as bag: | ||
for i, (topic, msg, t) in enumerate(bag.read_messages(topics=topics)): | ||
if topic == '/navigation/plane_approximation': | ||
depth_priors = [msg.NetDistance] | ||
|
||
elif topic == '/ted/image' and depth_priors is not None: | ||
# RGB image | ||
img = bridge.imgmsg_to_cv2(msg, desired_encoding='bgr8') | ||
flipped_image = cv2.flip(img, 0) | ||
height, width, _ = flipped_image.shape | ||
width = width //2 | ||
right_image = flipped_image[:, width:, :] | ||
|
||
img_file = rgb_dir / f'{image_count:05d}_rgb.jpg' | ||
cv2.imwrite(str(img_file), right_image) | ||
|
||
# Depth prior | ||
depth_prior = np.zeros((width, height), dtype=np.float32) | ||
circular_mask = self.create_circular_mask(2 * width, 2 * height, radius=radius_prior_pixel) | ||
x, y = width // 2, height // 2 | ||
translated_mask = circular_mask[int(width - x):int(2 * width - x), int(height - y):int(2 * height - y)] | ||
depth_prior += depth_priors[0] * translated_mask | ||
depth_prior = depth_prior.T | ||
|
||
depth_prior_file = output_dir / 'depth_prior' / f'{image_count:05d}_ra.npy' | ||
np.save(depth_prior_file, depth_prior) | ||
plt.imsave(output_dir / 'depth_prior' / f'{image_count:05d}_ra.jpg', depth_prior, vmin=0,vmax=15) | ||
|
||
image_count += 1 | ||
|
||
def create_circular_mask(self, h, w, center=None, radius=None): | ||
# From: | ||
# https://stackoverflow.com/questions/44865023/how-can-i-create-a-circular-mask-for-a-numpy-array | ||
if center is None: # use the middle of the image | ||
center = (int(w / 2), int(h / 2)) | ||
if radius is None: # use the smallest distance between the center and image walls | ||
radius = min(center[0], center[1], w - center[0], h - center[1]) | ||
|
||
Y, X = np.ogrid[:h, :w] | ||
dist_from_center = np.sqrt((X - center[0])**2 + (Y - center[1])**2) | ||
|
||
mask = dist_from_center <= radius | ||
return mask | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Process the camera parameters, offset, and point cloud files.') | ||
parser.add_argument('input_dir', type=Path, help='Path to folder containing all required files') | ||
|
||
args = parser.parse_args() | ||
dataset = UUVDataset(args.input_dir) | ||
dataset.generate() |