Skip to content

Commit

Permalink
Use override for dynamic shader construction and consolidate command …
Browse files Browse the repository at this point in the history
…encoding
  • Loading branch information
Wayne Wu authored and Wayne Wu committed May 8, 2023
1 parent 4ae9079 commit 0ef9554
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 131 deletions.
243 changes: 115 additions & 128 deletions src/sample/crowd/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import renderWGSL from '../../shaders/background.render.wgsl';
import crowdWGSL from '../../shaders/crowd.render.wgsl';
import explicitIntegrationWGSL from '../../shaders/explicitIntegration.compute.wgsl';
import assignCellsWGSL from '../../shaders/assignCells.compute.wgsl';
import bitonicSortWGSL from '../../shaders/bitonicSort.compute.wgsl';
import buildHashGrid from '../../shaders/buildHashGrid.compute.wgsl';
import contactSolveWGSL from '../../shaders/contactSolve.compute.wgsl';
import constraintSolveWGSL from '../../shaders/constraintSolve.compute.wgsl';
Expand All @@ -31,46 +32,6 @@ function resetCameraFunc(x: number = 50, y: number = 50, z: number = 50) {
camera.updateProjectionMatrix();
}

function getSortStepWGSL(numAgents : number, k : number, j : number, ){
// bitonic sort requires a device-wide join after every "step" to avoid
// race conditions. The least gross way I can think to do that is to create a new pipeline
// for each step.
const baseWGSL = `
@binding(1) @group(0) var<storage, read_write> agentData : Agents;
fn swap(idx1 : u32, idx2 : u32) {
var tmp = agentData.agents[idx1];
agentData.agents[idx1] = agentData.agents[idx2];
agentData.agents[idx2] = tmp;
}
fn agentlt(idx1 : u32, idx2 : u32) -> bool {
return agentData.agents[idx1].cell < agentData.agents[idx2].cell;
}
fn agentgt(idx1 : u32, idx2 : u32) -> bool {
return agentData.agents[idx1].cell > agentData.agents[idx2].cell;
}
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) GlobalInvocationID : vec3<u32>) {
var idx = GlobalInvocationID.x ;
var j : u32 = ${j}u;
var k : u32 = ${k}u;
var l = idx ^ j;
if (l > idx){
if ( ((idx & k) == 0u && agentgt(idx,l)) || ((idx & k) != 0u && agentlt(idx, l))){
swap(idx, l);
}
}
}`;

// minify the wgsl
return baseWGSL.replace('/\s+/g', ' ').trim();
}


function fillSortPipelineList(device,
numAgents : number,
computePipelinesSort,
Expand All @@ -79,21 +40,28 @@ function fillSortPipelineList(device,
// be sure the list is empty before pushing new pipelines
computePipelinesSort.length = 0;

var pipelineLayout = device.createPipelineLayout({
bindGroupLayouts: [compBuffManager.bindGroupLayout]
});
var shaderModule = device.createShaderModule({
code: headerWGSL + bitonicSortWGSL,
});

// set up sort pipelines
// adapted from Wikipedia's non-recursive example of bitonic sort:
// https://en.wikipedia.org/wiki/Bitonic_sorter
for (let k = 2; k <= numAgents; k *= 2){ // k is doubled every iteration
for (let j = k/2; j > 0; j = Math.floor(j/2)){ // j is halved at every iteration, with truncation of fractional parts
computePipelinesSort.push(
device.createComputePipeline({
layout: device.createPipelineLayout({
bindGroupLayouts: [compBuffManager.bindGroupLayout]
}),
layout: pipelineLayout,
compute: {
module: device.createShaderModule({
code: headerWGSL + getSortStepWGSL(numAgents, k, j),
}),
module: shaderModule,
entryPoint: 'main',
constants: {
1100: j,
1200: k,
}
},
})
);
Expand Down Expand Up @@ -294,13 +262,16 @@ const init: SampleInit = async ({ canvasRef, gui, stats }) => {
var computePipelinesSort = [];
var computePipelinesPostSort = [];


var pipelineLayout = device.createPipelineLayout({
bindGroupLayouts: [compBuffManager.bindGroupLayout]
});

// set up pre-sort pipelines
for(let i = 0; i < computeShadersPreSort.length; i++){
computePipelinesPreSort.push(
device.createComputePipeline({
layout: device.createPipelineLayout({
bindGroupLayouts: [compBuffManager.bindGroupLayout]
}),
layout: pipelineLayout,
compute: {
module: device.createShaderModule({
code: computeShadersPreSort[i],
Expand All @@ -318,21 +289,53 @@ const init: SampleInit = async ({ canvasRef, gui, stats }) => {
compBuffManager);

// set up post sort pipelines
for(let i = 0; i < computeShadersPostSort.length; i++){
let i = 0;
for(;i < 2;){
computePipelinesPostSort.push(
device.createComputePipeline({
layout: device.createPipelineLayout({
bindGroupLayouts: [compBuffManager.bindGroupLayout]
}),
layout: pipelineLayout,
compute: {
module: device.createShaderModule({
code: computeShadersPostSort[i],
code: computeShadersPostSort[i++],
}),
entryPoint: 'main',
},
})
);
}

var shaderModule = device.createShaderModule({
code: computeShadersPostSort[i++],
});
for(let j = 0; j < 6; j++){
computePipelinesPostSort.push(
device.createComputePipeline({
layout: pipelineLayout,
compute: {
module: shaderModule,
entryPoint: 'main',
constants: {
1000 : j + 1,
}
},
})
);
}

for(;i < computeShadersPostSort.length;){
computePipelinesPostSort.push(
device.createComputePipeline({
layout: pipelineLayout,
compute: {
module: device.createShaderModule({
code: computeShadersPostSort[i++],
}),
entryPoint: 'main',
},
})
);
}

}

function setTestScene(camPos: vec3, displayAgentSlider: boolean, numAgents: number,
Expand Down Expand Up @@ -360,7 +363,6 @@ const init: SampleInit = async ({ canvasRef, gui, stats }) => {
}

// get compute bind group
var computeBindGroup;
var computeBindGroup1 = compBuffManager.getBindGroup(false);
var computeBindGroup2 = compBuffManager.getBindGroup(true);

Expand Down Expand Up @@ -469,101 +471,85 @@ const init: SampleInit = async ({ canvasRef, gui, stats }) => {
resetSim = false;
}

var command = device.createCommandEncoder();

if(simulationParams.simulate) {

computeBindGroup = computeBindGroup1;
var computeCommand = device.createCommandEncoder();
// write the parameters to the Uniform buffer for our compute shaders
compBuffManager.writeSimParams(simulationParams);

// execute each compute shader in the order they were pushed onto
// the computePipelines array
var passEncoder = computeCommand.beginComputePass();
//// ----- Compute Pass Before Sort -----
for (let i = 0; i < computePipelinesPreSort.length; i++){
passEncoder.setPipeline(computePipelinesPreSort[i]);
passEncoder.setBindGroup(0, computeBindGroup);
// kick off the compute shader
passEncoder.dispatchWorkgroups(Math.ceil(compBuffManager.numAgents / 64));
}
const computeWorkgroupCount = Math.ceil(compBuffManager.numAgents/64);
const sortWorkgroupCount = Math.ceil(compBuffManager.numAgents/256);

// ----- Compute Pass Sort -----
for (let i = 0; i < computePipelinesSort.length; i++){
passEncoder.setPipeline(computePipelinesSort[i]);
passEncoder.setBindGroup(0, computeBindGroup);
// kick off the compute shader
passEncoder.dispatchWorkgroups(Math.ceil(compBuffManager.numAgents / 256));
}

// ----- Compute Pass Post Sort 1 -----
const constraintShaderIdx = 2;
for (let i = 0; i < constraintShaderIdx; i++){
passEncoder.setPipeline(computePipelinesPostSort[i]);
passEncoder.setBindGroup(0, computeBindGroup);
// kick off the compute shader
passEncoder.dispatchWorkgroups(Math.ceil(compBuffManager.numAgents / 64));
}
passEncoder.end();

device.queue.submit([computeCommand.finish()]);


// Stability/Contact solve will write to different buffer
computeBindGroup = computeBindGroup2;

// ----- Compute Pass Constraint Solve -----
// Since WebGPU does not support push constants, need to add write buffer to queue
// SEE: https://github.com/gpuweb/gpuweb/issues/762#issuecomment-625622428
for (let i = 0; i < 6; i++) {
computeCommand = device.createCommandEncoder();
compBuffManager.setIteration(i+1); // Set iteration number for compute shader
passEncoder = computeCommand.beginComputePass();
passEncoder.setPipeline(computePipelinesPostSort[constraintShaderIdx]);
// write the parameters to the Uniform buffer for our compute shaders
compBuffManager.writeSimParams(simulationParams);

// execute each compute shader in the order they were pushed onto
// the computePipelines array
var passEncoder = command.beginComputePass();
passEncoder.setBindGroup(0, computeBindGroup1);

//// ----- Compute Pass Before Sort -----
for (let i = 0; i < computePipelinesPreSort.length; i++){
passEncoder.setPipeline(computePipelinesPreSort[i]);
passEncoder.dispatchWorkgroups(computeWorkgroupCount);
}

// ----- Compute Pass Sort -----
for (let i = 0; i < computePipelinesSort.length; i++){
passEncoder.setPipeline(computePipelinesSort[i]);
passEncoder.dispatchWorkgroups(sortWorkgroupCount);
}

// ----- Compute Pass Post Sort 1 -----
let i = 0;
for (;i < 2 /* constraint shader index */; i++){
passEncoder.setPipeline(computePipelinesPostSort[i]);
passEncoder.dispatchWorkgroups(computeWorkgroupCount);
}

// Stability/Contact solve will write to different buffer
var computeBindGroup = computeBindGroup2;

// ----- Compute Pass Constraint Solve -----
for (; i < 6; i++) {
passEncoder.setPipeline(computePipelinesPostSort[i]);
passEncoder.setBindGroup(0, computeBindGroup);
passEncoder.dispatchWorkgroups(computeWorkgroupCount);

// ping-pong buffers
if(computeBindGroup == computeBindGroup1)
computeBindGroup = computeBindGroup2;
else if(computeBindGroup == computeBindGroup2)
computeBindGroup = computeBindGroup1;
}

passEncoder.setBindGroup(0, computeBindGroup);
passEncoder.dispatchWorkgroups(Math.ceil(compBuffManager.numAgents / 64));

// ----- Compute Pass Post Sort 2 -----
for (;i < computePipelinesPostSort.length; i++){
passEncoder.setPipeline(computePipelinesPostSort[i]);
passEncoder.dispatchWorkgroups(computeWorkgroupCount);
}

passEncoder.end();
device.queue.submit([computeCommand.finish()]);

// ping-pong buffers
if(computeBindGroup == computeBindGroup1)
computeBindGroup = computeBindGroup2;
else if(computeBindGroup == computeBindGroup2)
computeBindGroup = computeBindGroup1;
}

// ----- Compute Pass Post Sort 2 -----
computeCommand = device.createCommandEncoder();
passEncoder = computeCommand.beginComputePass();
for (let i = constraintShaderIdx+1; i < computePipelinesPostSort.length; i++){
passEncoder.setPipeline(computePipelinesPostSort[i]);
passEncoder.setBindGroup(0, computeBindGroup);
// kick off the compute shader
passEncoder.dispatchWorkgroups(Math.ceil(compBuffManager.numAgents / 64));
}
passEncoder.end();

device.queue.submit([computeCommand.finish()])
}
}

// ------------------ Render Calls ------------------------- //
if (bufManagerExists) {

renderBuffManager.updateSceneUBO(camera, guiParams.gridOn, time, sceneParams.shadowOn);

const renderCommand = device.createCommandEncoder();

const agentsBuffer : GPUBuffer = computeBindGroup == computeBindGroup2 ? compBuffManager.agents1Buffer : compBuffManager.agents2Buffer;

if(sceneParams.shadowOn)
renderBuffManager.drawCrowdShadow(device, renderCommand, agentsBuffer, compBuffManager.numAgents);
renderBuffManager.drawCrowdShadow(device, command, agentsBuffer, compBuffManager.numAgents);

// const transformationMatrix = getTransformationMatrix();
renderBuffManager.renderPassDescriptor.colorAttachments[0].view = context
.getCurrentTexture()
.createView();

const renderPass = renderCommand.beginRenderPass(renderBuffManager.renderPassDescriptor);
const renderPass = command.beginRenderPass(renderBuffManager.renderPassDescriptor);

// ----------------------- Draw ------------------------- //
renderBuffManager.drawPlatform(device, renderPass, platformWidth);
Expand All @@ -578,9 +564,10 @@ const init: SampleInit = async ({ canvasRef, gui, stats }) => {
}

renderPass.end();
device.queue.submit([renderCommand.finish()]);
}

device.queue.submit([command.finish()]);

requestAnimationFrame(frame);
stats.end();
}
Expand Down
32 changes: 32 additions & 0 deletions src/shaders/bitonicSort.compute.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@

@id(1100) override j : u32;
@id(1200) override k : u32;

@binding(1) @group(0) var<storage, read_write> agentData : Agents;

fn swap(idx1 : u32, idx2 : u32) {
var tmp = agentData.agents[idx1];
agentData.agents[idx1] = agentData.agents[idx2];
agentData.agents[idx2] = tmp;
}

fn agentlt(idx1 : u32, idx2 : u32) -> bool {
return agentData.agents[idx1].cell < agentData.agents[idx2].cell;
}

fn agentgt(idx1 : u32, idx2 : u32) -> bool {
return agentData.agents[idx1].cell > agentData.agents[idx2].cell;
}

@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) GlobalInvocationID : vec3<u32>) {
var idx = GlobalInvocationID.x ;

var l = idx ^ j;
if (l > idx){
if ( ((idx & k) == 0u && agentgt(idx,l)) || ((idx & k) != 0u && agentlt(idx, l))){
swap(idx, l);
}
}
}

Loading

0 comments on commit 0ef9554

Please sign in to comment.