Skip to content

Commit

Permalink
[js/webgpu] Donot record with computePassEncoder when capturing
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Nov 19, 2024
1 parent 497b06f commit e8b80b9
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 23 deletions.
25 changes: 17 additions & 8 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ import {
} from './webgpu/types';

interface CommandInfo {
readonly kernelId: number;
readonly computePipeline: GPUComputePipeline;
readonly bindGroup: GPUBindGroup;
readonly dispatchGroup: [number, number, number];
readonly kernelId?: number;
readonly computePipeline?: GPUComputePipeline;
readonly bindGroup?: GPUBindGroup;
readonly dispatchGroup?: [number, number, number];
readonly source?: GPUBuffer;
readonly dest?: GPUBuffer;
readonly size?: number;
}

interface KernelInfo {
Expand Down Expand Up @@ -909,10 +912,16 @@ export class WebGpuBackend {
for (let i = 0; i < length; i++) {
const computePassEncoder = this.getComputePassEncoder();
const command = sessionCommandList![i];
this.writeTimestamp(this.pendingDispatchNumber * 2);
computePassEncoder.setPipeline(command.computePipeline);
computePassEncoder.setBindGroup(0, command.bindGroup);
computePassEncoder.dispatchWorkgroups(...command.dispatchGroup);
if (command.bindGroup) {
this.writeTimestamp(this.pendingDispatchNumber * 2);
computePassEncoder.setPipeline(command.computePipeline!);
computePassEncoder.setBindGroup(0, command.bindGroup);
computePassEncoder.dispatchWorkgroups(...command.dispatchGroup!);
} else {
this.writeTimestamp(this.pendingDispatchNumber * 2);
const commandEncoder = this.getCommandEncoder();
commandEncoder.copyBufferToBuffer(command.source!, 0, command.dest!, 0, command.size!);
}
this.writeTimestamp(this.pendingDispatchNumber * 2 + 1);
this.pendingDispatchNumber++;
if (this.queryType !== 'none') {
Expand Down
43 changes: 33 additions & 10 deletions js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -274,16 +274,39 @@ class GpuDataManagerImpl implements GpuDataManager {

const size = calcNormalizedBufferSize(sourceGpuDataCache.originalSize);

// GPU copy
const commandEncoder = this.backend.getCommandEncoder();
this.backend.endComputePass();
commandEncoder.copyBufferToBuffer(
sourceGpuDataCache.gpuData.buffer,
0,
destinationGpuDataCache.gpuData.buffer,
0,
size,
);
if (this.backend.sessionStatus === 'capturing') {
const commandInfo = {
source: sourceGpuDataCache.gpuData.buffer,
dest: destinationGpuDataCache.gpuData.buffer,
size: size,
};
const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!);
sessionCommandList!.push(commandInfo);
} else {
// GPU copy
const commandEncoder = this.backend.getCommandEncoder();
this.backend.endComputePass();
commandEncoder.copyBufferToBuffer(
sourceGpuDataCache.gpuData.buffer,
0,
destinationGpuDataCache.gpuData.buffer,
0,
size,
);
}

this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1);
this.backend.pendingDispatchNumber++;

if (
this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber ||
this.backend.queryType === 'at-passes'
) {
this.backend.endComputePass();
}
if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber) {
this.backend.flush();
}
}

registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previous?: [GpuDataId, GPUBuffer]): number {
Expand Down
10 changes: 5 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ export class ProgramManager {
): void {
TRACE_FUNC_BEGIN(buildArtifact.programInfo.name);
const device = this.backend.device;
const computePassEncoder = this.backend.getComputePassEncoder();
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2);
const entries = [];
for (const input of inputs) {
Expand All @@ -68,11 +67,12 @@ export class ProgramManager {
};
const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!);
sessionCommandList!.push(commandInfo);
} else {
const computePassEncoder = this.backend.getComputePassEncoder();
computePassEncoder.setPipeline(buildArtifact.computePipeline);
computePassEncoder.setBindGroup(0, bindGroup);
computePassEncoder.dispatchWorkgroups(...dispatchGroup);
}

computePassEncoder.setPipeline(buildArtifact.computePipeline);
computePassEncoder.setBindGroup(0, bindGroup);
computePassEncoder.dispatchWorkgroups(...dispatchGroup);
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1);
this.backend.pendingDispatchNumber++;

Expand Down

0 comments on commit e8b80b9

Please sign in to comment.