diff --git a/packages/typegpu/src/common/writeSoA.ts b/packages/typegpu/src/common/writeSoA.ts index a867373b92..79c53f6d67 100644 --- a/packages/typegpu/src/common/writeSoA.ts +++ b/packages/typegpu/src/common/writeSoA.ts @@ -1,15 +1,19 @@ import { invariant } from '../errors.ts'; import { roundUp } from '../mathUtils.ts'; +import type { Undecorate } from '../data/dataTypes.ts'; import { alignmentOf } from '../data/alignmentOf.ts'; +import { undecorate } from '../data/dataTypes.ts'; import { offsetsForProps } from '../data/offsets.ts'; import { sizeOf } from '../data/sizeOf.ts'; import type { BaseData, TypedArrayFor, WgslArray, WgslStruct } from '../data/wgslTypes.ts'; -import { isMat, isMat2x2f, isMat3x3f, isWgslArray } from '../data/wgslTypes.ts'; +import { isAtomic, isMat, isMat2x2f, isMat3x3f, isWgslArray } from '../data/wgslTypes.ts'; import type { BufferWriteOptions, TgpuBuffer } from '../core/buffer/buffer.ts'; import type { Prettify } from '../shared/utilityTypes.ts'; -type UnwrapWgslArray = T extends WgslArray ? UnwrapWgslArray : T; -type PackedSoAInputFor = TypedArrayFor>; +type PackedScalarFor = + Undecorate extends WgslArray ? PackedScalarFor : Undecorate; + +type PackedSoAInputFor = TypedArrayFor>; type SoAFieldsFor> = { [K in keyof T as [PackedSoAInputFor] extends [never] ? never : K]: PackedSoAInputFor; @@ -19,74 +23,53 @@ type SoAInputFor> = [keyof T] extends [keyof ? Prettify> : never; -function getPackedMatrixLayout(schema: BaseData) { - if (!isMat(schema)) { - return undefined; - } - - const dim = isMat3x3f(schema) ? 3 : isMat2x2f(schema) ? 2 : 4; - const packedColumnSize = dim * 4; +function packedSchemaOf(schema: BaseData): BaseData { + const unpackedSchema = undecorate(schema); + return isAtomic(unpackedSchema) ? unpackedSchema.inner : unpackedSchema; +} - return { - dim, - packedColumnSize, - packedSize: dim * packedColumnSize, - } as const; +function packedMatrixDimOf(schema: BaseData): 2 | 3 | 4 | undefined { + return isMat3x3f(schema) ? 3 : isMat2x2f(schema) ? 2 : isMat(schema) ? 4 : undefined; } function packedSizeOf(schema: BaseData): number { - const matrixLayout = getPackedMatrixLayout(schema); - if (matrixLayout) { - return matrixLayout.packedSize; + const packedSchema = packedSchemaOf(schema); + const matrixDim = packedMatrixDimOf(packedSchema); + if (matrixDim) { + return matrixDim * matrixDim * 4; } - - if (isWgslArray(schema)) { - return schema.elementCount * packedSizeOf(schema.elementType); + if (isWgslArray(packedSchema)) { + return packedSchema.elementCount * packedSizeOf(packedSchema.elementType); } - - return sizeOf(schema); + return sizeOf(packedSchema); } -function inferSoAElementCount( +function computeSoAByteLength( arraySchema: WgslArray, soaData: Record, ): number | undefined { const structSchema = arraySchema.elementType as WgslStruct; let inferredCount: number | undefined; - for (const key in soaData) { + for (const key in structSchema.propTypes) { const srcArray = soaData[key]; const fieldSchema = structSchema.propTypes[key]; if (srcArray === undefined || fieldSchema === undefined) { continue; } - - const fieldPackedSize = packedSizeOf(fieldSchema); - if (fieldPackedSize === 0) { + const packedFieldSize = packedSizeOf(fieldSchema); + if (packedFieldSize === 0) { continue; } - - const fieldElementCount = Math.floor(srcArray.byteLength / fieldPackedSize); + const fieldElementCount = Math.floor(srcArray.byteLength / packedFieldSize); inferredCount = inferredCount === undefined ? fieldElementCount : Math.min(inferredCount, fieldElementCount); } - - return inferredCount; -} - -function computeSoAByteLength( - arraySchema: WgslArray, - soaData: Record, -): number | undefined { - const elementCount = inferSoAElementCount(arraySchema, soaData); - if (elementCount === undefined) { + if (inferredCount === undefined) { return undefined; } - const elementStride = roundUp( - sizeOf(arraySchema.elementType), - alignmentOf(arraySchema.elementType), - ); - return elementCount * elementStride; + const elementStride = roundUp(sizeOf(structSchema), alignmentOf(structSchema)); + return inferredCount * elementStride; } function writePackedValue( @@ -96,41 +79,42 @@ function writePackedValue( dstOffset: number, srcOffset: number, ): void { - const matrixLayout = getPackedMatrixLayout(schema); - if (matrixLayout) { - const gpuColumnStride = roundUp(matrixLayout.packedColumnSize, alignmentOf(schema)); - - for (let col = 0; col < matrixLayout.dim; col++) { + const unpackedSchema = undecorate(schema); + const packedSchema = isAtomic(unpackedSchema) ? unpackedSchema.inner : unpackedSchema; + const matrixDim = packedMatrixDimOf(packedSchema); + if (matrixDim) { + const packedColumnSize = matrixDim * 4; + const gpuColumnStride = roundUp(packedColumnSize, alignmentOf(schema)); + for (let col = 0; col < matrixDim; col++) { target.set( srcBytes.subarray( - srcOffset + col * matrixLayout.packedColumnSize, - srcOffset + col * matrixLayout.packedColumnSize + matrixLayout.packedColumnSize, + srcOffset + col * packedColumnSize, + srcOffset + col * packedColumnSize + packedColumnSize, ), dstOffset + col * gpuColumnStride, ); } - return; } - - if (isWgslArray(schema)) { - const packedElementSize = packedSizeOf(schema.elementType); - const gpuElementStride = roundUp(sizeOf(schema.elementType), alignmentOf(schema.elementType)); - - for (let i = 0; i < schema.elementCount; i++) { + if (isWgslArray(unpackedSchema)) { + const packedElementSize = packedSizeOf(unpackedSchema.elementType); + const gpuElementStride = roundUp( + sizeOf(unpackedSchema.elementType), + alignmentOf(unpackedSchema.elementType), + ); + + for (let i = 0; i < unpackedSchema.elementCount; i++) { writePackedValue( target, - schema.elementType, + unpackedSchema.elementType, srcBytes, dstOffset + i * gpuElementStride, srcOffset + i * packedElementSize, ); } - return; } - - target.set(srcBytes.subarray(srcOffset, srcOffset + sizeOf(schema)), dstOffset); + target.set(srcBytes.subarray(srcOffset, srcOffset + sizeOf(packedSchema)), dstOffset); } function scatterSoA( @@ -141,7 +125,6 @@ function scatterSoA( endOffset: number, ): void { const structSchema = arraySchema.elementType as WgslStruct; - const offsets = offsetsForProps(structSchema); const elementStride = roundUp(sizeOf(structSchema), alignmentOf(structSchema)); invariant( startOffset % elementStride === 0, @@ -150,6 +133,7 @@ function scatterSoA( const startElement = Math.floor(startOffset / elementStride); const endElement = Math.min(arraySchema.elementCount, Math.ceil(endOffset / elementStride)); const elementCount = Math.max(0, endElement - startElement); + const offsets = offsetsForProps(structSchema); for (const key in structSchema.propTypes) { const fieldSchema = structSchema.propTypes[key]; @@ -158,12 +142,10 @@ function scatterSoA( } const srcArray = soaData[key]; invariant(srcArray !== undefined, `Missing SoA data for field '${key}'`); - const fieldOffset = offsets[key]?.offset; invariant(fieldOffset !== undefined, `Field ${key} not found in struct schema`); - const srcBytes = new Uint8Array(srcArray.buffer, srcArray.byteOffset, srcArray.byteLength); - const packedFieldSize = packedSizeOf(fieldSchema); + const srcBytes = new Uint8Array(srcArray.buffer, srcArray.byteOffset, srcArray.byteLength); for (let i = 0; i < elementCount; i++) { writePackedValue( target, diff --git a/packages/typegpu/src/data/wgslTypes.ts b/packages/typegpu/src/data/wgslTypes.ts index facc581f95..820f1911c7 100644 --- a/packages/typegpu/src/data/wgslTypes.ts +++ b/packages/typegpu/src/data/wgslTypes.ts @@ -62,13 +62,15 @@ export type TypedArrayFor = T extends F32 | Vec2f | Vec3f | Vec4f | Mat2x2f | ? Float32Array : T extends F16 | Vec2h | Vec3h | Vec4h ? Float16Array - : T extends I32 | Vec2i | Vec3i | Vec4i + : T extends I32 | Vec2i | Vec3i | Vec4i | Atomic ? Int32Array - : T extends U32 | Vec2u | Vec3u | Vec4u + : T extends U32 | Vec2u | Vec3u | Vec4u | Atomic ? Uint32Array : T extends U16 ? Uint16Array - : never; + : T extends Decorated + ? TypedArrayFor + : never; /** * Vector infix notation. diff --git a/packages/typegpu/tests/buffer.test.ts b/packages/typegpu/tests/buffer.test.ts index 2354321713..11b1342c34 100644 --- a/packages/typegpu/tests/buffer.test.ts +++ b/packages/typegpu/tests/buffer.test.ts @@ -692,6 +692,31 @@ describe('TgpuBuffer', () => { ]); }); + it('should fast-path aligned raw input for a buffer with decorated data', ({ root, device }) => { + const DecoratedSchema = d.struct({ + a: d.size(12, d.f32), + b: d.align(16, d.u32), + c: d.arrayOf(d.u32, 3), + }); + + const decoratedBuffer = root.createBuffer(DecoratedSchema); + const rawDecoratedBuffer = root.unwrap(decoratedBuffer); + + const aligned = new ArrayBuffer(32); + new Float32Array(aligned, 0, 1)[0] = 1.0; + new Uint32Array(aligned, 16, 1)[0] = 2; + new Uint32Array(aligned, 20, 3).set([3, 4, 5]); + + decoratedBuffer.write(aligned); + + expect(device.mock.queue.writeBuffer.mock.calls).toStrictEqual([ + [rawDecoratedBuffer, 0, expect.any(ArrayBuffer), 0, 32], + ]); + + const uploaded = device.mock.queue.writeBuffer.mock.calls[0]?.[2] as ArrayBuffer; + expect([...new Uint8Array(uploaded)]).toStrictEqual([...new Uint8Array(aligned)]); + }); + it('should throw an error on the type level when using a schema containing boolean', ({ root, }) => { @@ -775,6 +800,40 @@ describe('TgpuBuffer', () => { expect([...new Uint16Array(written, 0, 4)]).toStrictEqual([10, 20, 30, 40]); }); + it('should accept typed views when writing to arrays of atomics at the type level', ({ + root, + }) => { + const u32Buffer = root.createBuffer(d.arrayOf(d.atomic(d.u32), 4)); + const i32Buffer = root.createBuffer(d.arrayOf(d.atomic(d.i32), 4)); + + expectTypeOf(u32Buffer.write) + .parameter(0) + .toEqualTypeOf(); + + expectTypeOf(i32Buffer.write).parameter(0).toEqualTypeOf(); + }); + + it('should fast-path typed views when writing to arrays of atomics', ({ root, device }) => { + const u32Buffer = root.createBuffer(d.arrayOf(d.atomic(d.u32), 4)); + const i32Buffer = root.createBuffer(d.arrayOf(d.atomic(d.i32), 4)); + const rawU32Buffer = root.unwrap(u32Buffer); + const rawI32Buffer = root.unwrap(i32Buffer); + + u32Buffer.write(new Uint32Array([10, 20, 30, 40])); + i32Buffer.write(new Int32Array([-1, -2, -3, -4])); + + expect(device.mock.queue.writeBuffer.mock.calls).toStrictEqual([ + [rawU32Buffer, 0, expect.any(ArrayBuffer), 0, 16], + [rawI32Buffer, 0, expect.any(ArrayBuffer), 0, 16], + ]); + + const writtenU32 = device.mock.queue.writeBuffer.mock.calls[0]?.[2] as ArrayBuffer; + const writtenI32 = device.mock.queue.writeBuffer.mock.calls[1]?.[2] as ArrayBuffer; + + expect([...new Uint32Array(writtenU32, 0, 4)]).toStrictEqual([10, 20, 30, 40]); + expect([...new Int32Array(writtenI32, 0, 4)]).toStrictEqual([-1, -2, -3, -4]); + }); + it('should allow an array of u32 to be used as an index buffer as well as any other usage', ({ root, }) => { @@ -1225,6 +1284,139 @@ describe('ValidateBufferSchema', () => { expect([...headings[1]!]).toStrictEqual([4, 5, 6]); }); + it('should treat atomics like normal scalars when writing SoA', ({ root, device }) => { + const Entry = d.struct({ + id: d.atomic(d.u32), + states: d.arrayOf(d.atomic(d.i32), 4), + }); + + const schema = d.arrayOf(Entry, 2); + const buffer = root.createBuffer(schema); + root.unwrap(buffer); + + common.writeSoA(buffer, { + id: new Uint32Array([1000, 2000]), + states: new Int32Array([1, 2, 3, 4, 5, 6, 7, 8]), + }); + + const uploadedBuffer = device.mock.queue.writeBuffer.mock.calls[0]?.[2] as ArrayBuffer; + + const ids = [ + new DataView(uploadedBuffer).getUint32(0, true), + new DataView(uploadedBuffer).getUint32(20, true), + ]; + const states = [new Int32Array(uploadedBuffer, 4, 4), new Int32Array(uploadedBuffer, 24, 4)]; + + expect(ids).toStrictEqual([1000, 2000]); + expect([...states[0]!]).toStrictEqual([1, 2, 3, 4]); + expect([...states[1]!]).toStrictEqual([5, 6, 7, 8]); + }); + + it('should treat decorated types like normal types when writing SoA', ({ root, device }) => { + const Entry = d.struct({ + magic: d.u32, + id: d.align(16, d.u32), + pos: d.size(64, d.vec3f), + someData: d.arrayOf(d.f32, 4), + }); + + const schema = d.arrayOf(Entry, 2); + const buffer = root.createBuffer(schema); + root.unwrap(buffer); + + common.writeSoA(buffer, { + magic: new Uint32Array([10, 20]), + id: new Uint32Array([100, 200]), + pos: new Float32Array([1, 2, 3, 4, 5, 6]), + someData: new Float32Array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]), + }); + + const uploadedBuffer = device.mock.queue.writeBuffer.mock.calls[0]?.[2] as ArrayBuffer; + + const magics = [ + new DataView(uploadedBuffer).getUint32(0, true), + new DataView(uploadedBuffer).getUint32(112, true), + ]; + const ids = [ + new DataView(uploadedBuffer).getUint32(16, true), + new DataView(uploadedBuffer).getUint32(128, true), + ]; + const positions = [ + new Float32Array(uploadedBuffer, 32, 3), + new Float32Array(uploadedBuffer, 144, 3), + ]; + const someData = [ + new Float32Array(uploadedBuffer, 96, 4), + new Float32Array(uploadedBuffer, 208, 4), + ]; + + expect(magics).toStrictEqual([10, 20]); + expect(ids).toStrictEqual([100, 200]); + expect([...positions[0]!]).toStrictEqual([1, 2, 3]); + expect([...positions[1]!]).toStrictEqual([4, 5, 6]); + expect([...someData[0]!].map((value) => Number(value.toFixed(6)))).toStrictEqual([ + 0.1, 0.2, 0.3, 0.4, + ]); + expect([...someData[1]!].map((value) => Number(value.toFixed(6)))).toStrictEqual([ + 0.5, 0.6, 0.7, 0.8, + ]); + }); + + it('should treat decorated array fields like normal types when writing SoA', ({ + root, + device, + }) => { + const Entry = d.struct({ + id: d.u32, + values: d.align(16, d.arrayOf(d.f32, 4)), + }); + + const schema = d.arrayOf(Entry, 2); + const buffer = root.createBuffer(schema); + root.unwrap(buffer); + + common.writeSoA(buffer, { + id: new Uint32Array([7, 8]), + values: new Float32Array([1, 2, 3, 4, 5, 6, 7, 8]), + }); + + const uploadedBuffer = device.mock.queue.writeBuffer.mock.calls[0]?.[2] as ArrayBuffer; + const ids = [ + new DataView(uploadedBuffer).getUint32(0, true), + new DataView(uploadedBuffer).getUint32(32, true), + ]; + const values = [ + new Float32Array(uploadedBuffer, 16, 4), + new Float32Array(uploadedBuffer, 48, 4), + ]; + + expect(ids).toStrictEqual([7, 8]); + expect([...values[0]!]).toStrictEqual([1, 2, 3, 4]); + expect([...values[1]!]).toStrictEqual([5, 6, 7, 8]); + }); + + it('should write SoA data for decorated array fields with padded elements', ({ + root, + device, + }) => { + const Entry = d.struct({ + values: d.align(16, d.arrayOf(d.vec3f, 2)), + }); + + const schema = d.arrayOf(Entry, 2); + const buffer = root.createBuffer(schema); + root.unwrap(buffer); + + common.writeSoA(buffer, { + values: new Float32Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), + }); + + const uploadedBuffer = device.mock.queue.writeBuffer.mock.calls[0]?.[2] as ArrayBuffer; + const result = new Float32Array(uploadedBuffer); + + expect([...result]).toStrictEqual([1, 2, 3, 0, 4, 5, 6, 0, 7, 8, 9, 0, 10, 11, 12, 0]); + }); + it('should accept SoA input for struct fields that are fixed-size arrays of primitives', () => { type Test = { a: d.F32; @@ -1245,6 +1437,18 @@ describe('ValidateBufferSchema', () => { }>(); }); + it('should accept SoA input for decorated array fields', () => { + type Test = { + id: d.U32; + values: d.Decorated, [d.Align<16>]>; + }; + + expectTypeOf>().toEqualTypeOf<{ + id: Uint32Array; + values: Float32Array; + }>(); + }); + it('should reject SoA input for struct fields that contain nested structs', () => { const Nested = d.struct({ x: d.f32,