Skip to content

Commit

Permalink
[js/webgpu] Donot record with computePassEncoder when capturing
Browse files Browse the repository at this point in the history
  • Loading branch information
axinging committed Nov 21, 2024
1 parent 497b06f commit 048a78a
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 21 deletions.
26 changes: 20 additions & 6 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,21 @@ import {
TimestampQuery,
} from './webgpu/types';

interface CommandInfo {
interface ComputeCommand {
readonly kernelId: number;
readonly computePipeline: GPUComputePipeline;
readonly bindGroup: GPUBindGroup;
readonly dispatchGroup: [number, number, number];
}

interface MemcpyCommand {
readonly source: GPUBuffer;
readonly dest: GPUBuffer;
readonly size: number;
}

type Command = ComputeCommand | MemcpyCommand;

interface KernelInfo {
readonly kernelType: string;
readonly kernelName: string;
Expand Down Expand Up @@ -234,9 +242,9 @@ export class WebGpuBackend {
env: Env;
sessionStatus: SessionState = 'default';
/**
* a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session.
* a SessionID -> Command[] mapping. It's used to record all GPU commands for corresponding session.
*/
capturedCommandList: Map<number, CommandInfo[]> = new Map();
capturedCommandList: Map<number, Command[]> = new Map();

/**
* a SessionID -> PendingKernelInfo[] mapping for profiling.
Expand Down Expand Up @@ -910,9 +918,15 @@ export class WebGpuBackend {
const computePassEncoder = this.getComputePassEncoder();
const command = sessionCommandList![i];
this.writeTimestamp(this.pendingDispatchNumber * 2);
computePassEncoder.setPipeline(command.computePipeline);
computePassEncoder.setBindGroup(0, command.bindGroup);
computePassEncoder.dispatchWorkgroups(...command.dispatchGroup);
if ('bindGroup' in command) {
computePassEncoder.setPipeline(command.computePipeline);
computePassEncoder.setBindGroup(0, command.bindGroup);
computePassEncoder.dispatchWorkgroups(...command.dispatchGroup);
} else {
const commandEncoder = this.getCommandEncoder();
this.endComputePass();
commandEncoder.copyBufferToBuffer(command.source, 0, command.dest, 0, command.size);
}
this.writeTimestamp(this.pendingDispatchNumber * 2 + 1);
this.pendingDispatchNumber++;
if (this.queryType !== 'none') {
Expand Down
43 changes: 33 additions & 10 deletions js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -274,16 +274,39 @@ class GpuDataManagerImpl implements GpuDataManager {

const size = calcNormalizedBufferSize(sourceGpuDataCache.originalSize);

// GPU copy
const commandEncoder = this.backend.getCommandEncoder();
this.backend.endComputePass();
commandEncoder.copyBufferToBuffer(
sourceGpuDataCache.gpuData.buffer,
0,
destinationGpuDataCache.gpuData.buffer,
0,
size,
);
if (this.backend.sessionStatus === 'capturing') {
const command = {
source: sourceGpuDataCache.gpuData.buffer,
dest: destinationGpuDataCache.gpuData.buffer,
size: size,
};
const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!);
sessionCommandList!.push(command);
} else {
// GPU copy
const commandEncoder = this.backend.getCommandEncoder();
this.backend.endComputePass();
commandEncoder.copyBufferToBuffer(
sourceGpuDataCache.gpuData.buffer,
0,
destinationGpuDataCache.gpuData.buffer,
0,
size,
);
}

this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1);
this.backend.pendingDispatchNumber++;

if (
this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber ||
this.backend.queryType === 'at-passes'
) {
this.backend.endComputePass();
}
if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber) {
this.backend.flush();
}
}

registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previous?: [GpuDataId, GPUBuffer]): number {
Expand Down
10 changes: 5 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/program-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ export class ProgramManager {
): void {
TRACE_FUNC_BEGIN(buildArtifact.programInfo.name);
const device = this.backend.device;
const computePassEncoder = this.backend.getComputePassEncoder();
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2);
const entries = [];
for (const input of inputs) {
Expand All @@ -68,11 +67,12 @@ export class ProgramManager {
};
const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!);
sessionCommandList!.push(commandInfo);
} else {
const computePassEncoder = this.backend.getComputePassEncoder();
computePassEncoder.setPipeline(buildArtifact.computePipeline);
computePassEncoder.setBindGroup(0, bindGroup);
computePassEncoder.dispatchWorkgroups(...dispatchGroup);
}

computePassEncoder.setPipeline(buildArtifact.computePipeline);
computePassEncoder.setBindGroup(0, bindGroup);
computePassEncoder.dispatchWorkgroups(...dispatchGroup);
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1);
this.backend.pendingDispatchNumber++;

Expand Down

0 comments on commit 048a78a

Please sign in to comment.