Skip to content

Commit

Permalink
0.1.4
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Apr 17, 2023
1 parent f336a25 commit 076963e
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 50 deletions.
2 changes: 1 addition & 1 deletion kiui/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import torch
import numpy as np

import cv2
import json
import varname
from PIL import Image

from rich.console import Console
from rich.text import Text

# inspect array like object x and report stats
def lo(*xs, verbose=0):
Expand Down
119 changes: 79 additions & 40 deletions kiui/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import matplotlib.cm as cm
import matplotlib.pyplot as plt

import trimesh
from .utils import lo

def map_color(value, cmap_name='viridis', vmin=None, vmax=None):
Expand All @@ -18,31 +17,65 @@ def map_color(value, cmap_name='viridis', vmin=None, vmax=None):
rgb = cmap(value)[:, :3] # will return rgba, we take only first 3 so we get rgb
return rgb

# visualize some 2D matrix, different from plot_image, this will keep the original range and plot channel-by-channel
def plot_matrix(*xs):
# x: [B, C, H, W], [C, H, W], or [H, W] torch.Tensor
# [B, H, W, C], [H, W, C], or [H, W] numpy.ndarray

def _plot_matrix(matrix):

if isinstance(matrix, torch.Tensor):
if len(matrix.shape) == 3:
matrix = matrix.permute(1,2,0).squeeze()
matrix = matrix.detach().cpu().numpy()

lo(matrix)

if len(matrix.shape) == 3:
# per channel
for i in range(matrix.shape[-1]):
plt.matshow(matrix[..., i])
plt.show()
else:
plt.matshow(matrix)
plt.show()

def plot_image(*xs, renormalize=False):
# x: [3, H, W] or [1, H, W] or [H, W] torch.Tensor
# [H, W, 3] or [H, W] numpy.ndarray

def _plot_image(x):

if isinstance(x, torch.Tensor):
if len(x.shape) == 3:
x = x.permute(1,2,0).squeeze()
x = x.detach().cpu().numpy()

lo(x)

x = x.astype(np.float32)
for x in xs:
if len(x.shape) == 4:
for i in range(x.shape[0]):
_plot_matrix(x[i])
else:
_plot_matrix(x)

# sequentially plot provided images
def plot_image(*xs, normalize=False):
# x: [B, 3, H, W], [3, H, W], [1, H, W] or [H, W] torch.Tensor
# [B, H, W, 3], [H, W, 3], [H, W, 1] or [H, W] numpy.ndarray

def _plot_image(image):

if isinstance(image, torch.Tensor):
if len(image.shape) == 3:
image = image.permute(1,2,0).squeeze()
image = image.detach().cpu().numpy()

lo(image)

image = image.astype(np.float32)

# renormalize
if renormalize:
x = (x - x.min(axis=0, keepdims=True)) / (x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8)
# normalize
if normalize:
image = (image - image.min(axis=0, keepdims=True)) / (image.max(axis=0, keepdims=True) - image.min(axis=0, keepdims=True) + 1e-8)

plt.imshow(x)
plt.imshow(image)
plt.show()

for x in xs:
_plot_image(x)
if len(x.shape) == 4:
for i in range(x.shape[0]):
_plot_image(x[i])
else:
_plot_image(x)


def plot_pointcloud(pc, color=None):
Expand All @@ -51,28 +84,33 @@ def plot_pointcloud(pc, color=None):

lo(pc)

# import open3d as o3d
# pcd = o3d.geometry.PointCloud()
# pcd.points = o3d.utility.Vector3dVector(pc)
# if color is not None:
# pcd.colors = o3d.utility.Vector3dVector(color)
# o3d.visualization.draw_geometries([pcd])

pc = trimesh.PointCloud(pc, color)
# axis
axes = trimesh.creation.axis(axis_length=4)
# sphere
box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline()
box.colors = np.array([[128, 128, 128]] * len(box.entities))

trimesh.Scene([pc, axes, box]).show()
if color is None or color.shape[-1] == 3:
# use o3d as it's better to control
import open3d as o3d
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(pc)
if color is not None:
pcd.colors = o3d.utility.Vector3dVector(color)
o3d.visualization.draw_geometries([pcd])

else:
import trimesh
pc = trimesh.PointCloud(pc, color)
# axis
axes = trimesh.creation.axis(axis_length=4)
# sphere
box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline()
box.colors = np.array([[128, 128, 128]] * len(box.entities))

trimesh.Scene([pc, axes, box]).show()


def plot_poses(poses, size=0.05, bound=1, points=None):
def plot_poses(poses, size=0.05, bound=1, points=None, mesh=None):
# poses: [B, 4, 4]

lo(poses)

import trimesh
axes = trimesh.creation.axis(axis_length=4)
box = trimesh.primitives.Box(extents=[2*bound]*3).as_outline()
box.colors = np.array([[128, 128, 128]] * len(box.entities))
Expand Down Expand Up @@ -100,15 +138,16 @@ def plot_poses(poses, size=0.05, bound=1, points=None):
objects.append(segs)

if points is not None:
print('[visualize points]', points.shape, points.dtype, points.min(0), points.max(0))

lo(points)

colors = np.zeros((points.shape[0], 4), dtype=np.uint8)
colors[:, 2] = 255 # blue
colors[:, 3] = 30 # transparent
objects.append(trimesh.PointCloud(points, colors))

# tmp: verify mesh matches the points
# mesh = trimesh.load('trial_garden_colmap/mesh_stage0/mesh.ply')
# objects.append(mesh)
if mesh is not None:
objects.append(mesh)

scene = trimesh.Scene(objects)
scene.set_camera(distance=bound, center=[0, 0, 0])
Expand Down
16 changes: 7 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
if __name__ == '__main__':
setup(
name="kiui",
version='0.1.3',
version='0.1.4',
description="self-use toolkits",
long_description=open('README.md', encoding='utf-8').read(),
long_description_content_type='text/markdown',
Expand All @@ -19,13 +19,11 @@
install_requires=[
'varname',
'rich',
'trimesh',
'numpy',
'torch',
'tqdm',
'matplotlib',
'opencv-python',
],
extras_require={
'full': [
'tqdm',
'numpy',
'matplotlib',
'opencv-python',
],
},
)

0 comments on commit 076963e

Please sign in to comment.