diff --git a/depthy/misc/plots.py b/depthy/misc/plots.py index ab7f43a..6b134ba 100644 --- a/depthy/misc/plots.py +++ b/depthy/misc/plots.py @@ -1,5 +1,6 @@ import numpy as np import warnings +from depthy.misc import Normalizer try: import matplotlib.pyplot as plt @@ -42,9 +43,12 @@ def plot_point_cloud(disp_arr: np.ndarray, zz = disp_arr[:, ::-1][::down_scale, ::down_scale, ...] xx, yy = np.meshgrid(np.arange(zz.shape[1]), np.arange(zz.shape[0])) + # normalize rgb values to 0-1 range + rgb = Normalizer(rgb).type_norm(new_min=0, new_max=1) + # plot depth data fig, ax = (plt.figure(), plt.axes(projection='3d')) if ax is None else (None, ax) - ax.scatter(xx, yy, zz, c=rgb.reshape(-1, rgb.shape[-1]) / rgb.max(), s=0.5) + ax.scatter(xx, yy, zz, c=rgb.reshape(-1, rgb.shape[-1]), s=0.5) ax.view_init(view_angles[0], view_angles[1]) return ax diff --git a/tests/test_plt.py b/tests/test_plt.py index bcab98b..8a39d4d 100644 --- a/tests/test_plt.py +++ b/tests/test_plt.py @@ -29,6 +29,9 @@ def test_point_cloud(self): # test valid case plot_point_cloud(disp_arr=np.ones([6, 6]), rgb_img=np.ones([6, 6, 3]), down_scale=2) + # test invalid rgb values + plot_point_cloud(disp_arr=np.ones([6, 6]), rgb_img=np.ones([6, 6, 3])*-1, down_scale=2) + # test Axes3D argument fig, ax = plt.figure(), plt.axes(projection='3d') ax_type = type(ax)