diff --git a/depthy/misc/plots.py b/depthy/misc/plots.py index fe76854..18a61a6 100644 --- a/depthy/misc/plots.py +++ b/depthy/misc/plots.py @@ -35,25 +35,29 @@ def plot_point_cloud(disp_arr: np.ndarray, raise IndexError('Downscale factor is %s and out-of-range.' % down_scale) # rgb image presence/absence handling - if rgb_img is None or disp_arr.shape[:2] != rgb_img.shape[:2]: - rgb = np.ones(disp_arr.shape)[::down_scale, ::down_scale, ...] + if rgb_img is None or disp_arr.shape[:2] != rgb_img.shape[:2] or len(rgb_img.shape) != 3: + rgb = np.zeros(disp_arr.shape+(3,))[::down_scale, ::down_scale, ...] if rgb_img is not None: warnings.warn('Depth map and RGB image dimension mismatch.') else: # flip x-axis and downscale rgb image rgb = rgb_img[:, ::-1, ...][::down_scale, ::down_scale, ...] + # normalize rgb values to 0-1 range + rgb = Normalizer(rgb).type_norm(new_min=max(0, rgb.min()), new_max=1) # flip x-axis and downscale depth map zz = disp_arr[:, ::-1][::down_scale, ::down_scale, ...] xx, yy = np.meshgrid(np.arange(zz.shape[1]), np.arange(zz.shape[0])) - # normalize rgb values to 0-1 range - rgb = Normalizer(rgb).type_norm(new_min=0, new_max=1) + # sort according to depth value to avoid occlusion order problem + order = np.argsort(zz.ravel()) # plot depth data fig, ax = (plt.figure(), plt.axes(projection='3d')) if ax is None else (None, ax) ax.set_axis_on() if show_axes else ax.set_axis_off() - ax.scatter(xx, yy, zz, c=rgb.reshape(-1, rgb.shape[-1]), s=s) + ax.scatter(xx.ravel()[order], yy.ravel()[order], zz.ravel()[order], c=rgb.reshape(-1, rgb.shape[2])[order], s=s) ax.view_init(view_angles[0], view_angles[1]) + ax.set_ylim(0, zz.shape[0]) + ax.set_xlim(0, zz.shape[1]) return ax diff --git a/tests/test_plt.py b/tests/test_plt.py index 8a39d4d..fdf0276 100644 --- a/tests/test_plt.py +++ b/tests/test_plt.py @@ -1,8 +1,9 @@ import unittest import numpy as np import matplotlib.pyplot as plt +import os -from depthy.misc import plot_point_cloud +from depthy.misc import plot_point_cloud, load_pfm class PlotDepthTestCase(unittest.TestCase): @@ -11,6 +12,14 @@ def setUp(self): self.plot_opt = False + def test_real_data(self): + + self.depth_example, _ = load_pfm(os.path.join('..', 'docs', 'img', 'pens.pfm')) + + # test case with real data + plot_point_cloud(disp_arr=self.depth_example, rgb_img=None, down_scale=4, show_axes=True) + plt.show() + def test_point_cloud(self): # test invalid downscale parameters