Skip to content

Commit

Permalink
feat: rasterize normal option (#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
mosure authored Jun 10, 2024
1 parent 176f048 commit 063d9f0
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 25 deletions.
23 changes: 21 additions & 2 deletions src/gaussian/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,35 @@ pub enum GaussianCloudDrawMode {
HighlightSelected,
}


#[derive(
Clone,
Copy,
Debug,
Default,
Eq,
Hash,
PartialEq,
Reflect,
)]
pub enum GaussianCloudRasterize {
#[default]
Color,
Depth,
Normal,
}


#[derive(Component, Reflect, Clone)]
#[reflect(Component)]
pub struct GaussianCloudSettings {
pub aabb: bool,
pub global_scale: f32,
pub transform: Transform,
pub visualize_bounding_box: bool,
pub visualize_depth: bool,
pub sort_mode: SortMode,
pub draw_mode: GaussianCloudDrawMode,
pub rasterize_mode: GaussianCloudRasterize,
}

impl Default for GaussianCloudSettings {
Expand All @@ -39,9 +58,9 @@ impl Default for GaussianCloudSettings {
global_scale: 1.0,
transform: Transform::IDENTITY,
visualize_bounding_box: false,
visualize_depth: false,
sort_mode: SortMode::default(),
draw_mode: GaussianCloudDrawMode::default(),
rasterize_mode: GaussianCloudRasterize::default(),
}
}
}
74 changes: 57 additions & 17 deletions src/render/gaussian.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -130,27 +130,15 @@ struct GaussianVertexOutput {
#endif


// https://github.com/cvlab-epfl/gaussian-splatting-web/blob/905b3c0fb8961e42c79ef97e64609e82383ca1c2/src/shaders.ts#L185
// TODO: precompute
fn compute_cov3d(scale: vec3<f32>, rotation: vec4<f32>) -> array<f32, 6> {
let S = mat3x3<f32>(
scale.x * gaussian_uniforms.global_scale, 0.0, 0.0,
0.0, scale.y * gaussian_uniforms.global_scale, 0.0,
0.0, 0.0, scale.z * gaussian_uniforms.global_scale,
);

fn get_rotation_matrix(
rotation: vec4<f32>,
) -> mat3x3<f32> {
let r = rotation.x;
let x = rotation.y;
let y = rotation.z;
let z = rotation.w;

let T = mat3x3<f32>(
gaussian_uniforms.transform[0].xyz,
gaussian_uniforms.transform[1].xyz,
gaussian_uniforms.transform[2].xyz,
);

let R = mat3x3<f32>(
return mat3x3<f32>(
1.0 - 2.0 * (y * y + z * z),
2.0 * (x * y - r * z),
2.0 * (x * z + r * y),
Expand All @@ -163,6 +151,31 @@ fn compute_cov3d(scale: vec3<f32>, rotation: vec4<f32>) -> array<f32, 6> {
2.0 * (y * z + r * x),
1.0 - 2.0 * (x * x + y * y),
);
}

fn get_scale_matrix(
scale: vec3<f32>,
) -> mat3x3<f32> {
return mat3x3<f32>(
scale.x * gaussian_uniforms.global_scale, 0.0, 0.0,
0.0, scale.y * gaussian_uniforms.global_scale, 0.0,
0.0, 0.0, scale.z * gaussian_uniforms.global_scale,
);
}


// https://github.com/cvlab-epfl/gaussian-splatting-web/blob/905b3c0fb8961e42c79ef97e64609e82383ca1c2/src/shaders.ts#L185
// TODO: precompute
fn compute_cov3d(scale: vec3<f32>, rotation: vec4<f32>) -> array<f32, 6> {
let S = get_scale_matrix(scale);

let T = mat3x3<f32>(
gaussian_uniforms.transform[0].xyz,
gaussian_uniforms.transform[1].xyz,
gaussian_uniforms.transform[2].xyz,
);

let R = get_rotation_matrix(rotation);

let M = S * R;
let Sigma = transpose(M) * M;
Expand Down Expand Up @@ -356,7 +369,7 @@ fn vs_points(

var rgb = vec3<f32>(0.0);

#ifdef VISUALIZE_DEPTH
#ifdef RASTERIZE_DEPTH
let first_position = vec4<f32>(get_position(get_entry(1u).value), 1.0);
let last_position = vec4<f32>(get_position(get_entry(gaussian_uniforms.count - 1u).value), 1.0);

Expand All @@ -374,6 +387,33 @@ fn vs_points(
min_distance,
max_distance,
);
#else ifdef RASTERIZE_NORMAL
let T = mat3x3<f32>(
gaussian_uniforms.transform[0].xyz,
gaussian_uniforms.transform[1].xyz,
gaussian_uniforms.transform[2].xyz,
);

let R = get_rotation_matrix(get_rotation(splat_index));
let S = get_scale_matrix(get_scale(splat_index));

let M = S * R;
let Sigma = transpose(M) * M;

let N = T * Sigma * transpose(T);
let normal = vec3<f32>(
N[0][0],
N[0][1],
N[1][1],
);

let t = normalize(normal);

rgb = vec3<f32>(
0.5 * (t.x + 1.0),
0.5 * (t.y + 1.0),
0.5 * (t.z + 1.0)
);
#else
rgb = get_color(splat_index, ray_direction);
#endif
Expand Down
15 changes: 9 additions & 6 deletions src/render/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ use crate::{
cloud::GaussianCloud,
settings::{
GaussianCloudDrawMode,
GaussianCloudRasterize,
GaussianCloudSettings,
},
},
Expand Down Expand Up @@ -305,8 +306,8 @@ fn queue_gaussians(
let key = GaussianCloudPipelineKey {
aabb: settings.aabb,
visualize_bounding_box: settings.visualize_bounding_box,
visualize_depth: settings.visualize_depth,
draw_mode: settings.draw_mode,
rasterize_mode: settings.rasterize_mode,
sample_count: msaa.samples(),
};

Expand Down Expand Up @@ -521,10 +522,6 @@ pub fn shader_defs(
#[cfg(feature = "morph_particles")]
shader_defs.push("READ_WRITE_POINTS".into());

if key.visualize_depth {
shader_defs.push("VISUALIZE_DEPTH".into());
}

#[cfg(feature = "packed")]
shader_defs.push("PACKED".into());

Expand Down Expand Up @@ -561,6 +558,12 @@ pub fn shader_defs(
#[cfg(feature = "webgl2")]
shader_defs.push("WEBGL2".into());

match key.rasterize_mode {
GaussianCloudRasterize::Color => {},
GaussianCloudRasterize::Depth => shader_defs.push("RASTERIZE_DEPTH".into()),
GaussianCloudRasterize::Normal => shader_defs.push("RASTERIZE_NORMAL".into()),
}

match key.draw_mode {
GaussianCloudDrawMode::All => {},
GaussianCloudDrawMode::Selected => shader_defs.push("DRAW_SELECTED".into()),
Expand All @@ -574,8 +577,8 @@ pub fn shader_defs(
pub struct GaussianCloudPipelineKey {
pub aabb: bool,
pub visualize_bounding_box: bool,
pub visualize_depth: bool,
pub draw_mode: GaussianCloudDrawMode,
pub rasterize_mode: GaussianCloudRasterize,
pub sample_count: u32,
}

Expand Down

0 comments on commit 063d9f0

Please sign in to comment.