Skip to content

Commit

Permalink
[WebGPU] Make dataToGPU upload to GPU if data is on CPU (#8483)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattsoulanille authored Dec 19, 2024
1 parent cb6206c commit 2644bd0
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -594,16 +594,19 @@ export class WebGPUBackend extends KernelBackend {
* @param dataId The source tensor.
*/
override readToGPU(dataId: DataId): GPUData {
const srcTensorData = this.tensorMap.get(dataId);
const {values, dtype, shape, resource} = srcTensorData;
let srcTensorData = this.tensorMap.get(dataId);
const {values, dtype, shape} = srcTensorData;
let resource = srcTensorData.resource;

if (dtype === 'complex64') {
throw new Error('Does not support reading buffer for complex64 dtype.');
}

if (resource == null) {
if (values != null) {
throw new Error('Data is not on GPU but on CPU.');
this.uploadToGPU(dataId);
srcTensorData = this.tensorMap.get(dataId);
resource = srcTensorData.resource;
} else {
throw new Error('There is no data on GPU or CPU.');
}
Expand Down
12 changes: 12 additions & 0 deletions tfjs-backend-webgpu/src/backend_webgpu_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,18 @@ describeWebGPU('backend webgpu', () => {
await c3.data();
tf.env().set('WEBGPU_DEFERRED_SUBMIT_BATCH_SIZE', savedFlag);
});

it('dataToGPU uploads to GPU if the tensor is on CPU', async () => {
const webGPUBackend = (tf.backend() as WebGPUBackend);
const data = [1,2,3,4,5];
const tensor = tf.tensor1d(data);
const res = tensor.dataToGPU();
expect(res.buffer).toBeDefined();
const resData = await webGPUBackend.getBufferData(res.buffer);
const values = tf.util.convertBackendValuesAndArrayBuffer(
resData, res.tensorRef.dtype);
expectArraysEqual(values, data);
});
});

describeWebGPU('backendWebGPU', () => {
Expand Down

0 comments on commit 2644bd0

Please sign in to comment.