Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 49 additions & 67 deletions packages/typegpu/src/common/writeSoA.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import { invariant } from '../errors.ts';
import { roundUp } from '../mathUtils.ts';
import type { Undecorate } from '../data/dataTypes.ts';
import { alignmentOf } from '../data/alignmentOf.ts';
import { undecorate } from '../data/dataTypes.ts';
import { offsetsForProps } from '../data/offsets.ts';
import { sizeOf } from '../data/sizeOf.ts';
import type { BaseData, TypedArrayFor, WgslArray, WgslStruct } from '../data/wgslTypes.ts';
import { isMat, isMat2x2f, isMat3x3f, isWgslArray } from '../data/wgslTypes.ts';
import { isAtomic, isMat, isMat2x2f, isMat3x3f, isWgslArray } from '../data/wgslTypes.ts';
import type { BufferWriteOptions, TgpuBuffer } from '../core/buffer/buffer.ts';
import type { Prettify } from '../shared/utilityTypes.ts';

type UnwrapWgslArray<T> = T extends WgslArray<infer U> ? UnwrapWgslArray<U> : T;
type PackedSoAInputFor<T> = TypedArrayFor<UnwrapWgslArray<T>>;
type PackedScalarFor<T> =
Undecorate<T> extends WgslArray<infer TElement> ? PackedScalarFor<TElement> : Undecorate<T>;

type PackedSoAInputFor<T> = TypedArrayFor<PackedScalarFor<T>>;

type SoAFieldsFor<T extends Record<string, BaseData>> = {
[K in keyof T as [PackedSoAInputFor<T[K]>] extends [never] ? never : K]: PackedSoAInputFor<T[K]>;
Expand All @@ -19,74 +23,53 @@ type SoAInputFor<T extends Record<string, BaseData>> = [keyof T] extends [keyof
? Prettify<SoAFieldsFor<T>>
: never;

function getPackedMatrixLayout(schema: BaseData) {
if (!isMat(schema)) {
return undefined;
}

const dim = isMat3x3f(schema) ? 3 : isMat2x2f(schema) ? 2 : 4;
const packedColumnSize = dim * 4;
function packedSchemaOf(schema: BaseData): BaseData {
const unpackedSchema = undecorate(schema);
return isAtomic(unpackedSchema) ? unpackedSchema.inner : unpackedSchema;
}

return {
dim,
packedColumnSize,
packedSize: dim * packedColumnSize,
} as const;
function packedMatrixDimOf(schema: BaseData): 2 | 3 | 4 | undefined {
return isMat3x3f(schema) ? 3 : isMat2x2f(schema) ? 2 : isMat(schema) ? 4 : undefined;
}

function packedSizeOf(schema: BaseData): number {
const matrixLayout = getPackedMatrixLayout(schema);
if (matrixLayout) {
return matrixLayout.packedSize;
const packedSchema = packedSchemaOf(schema);
const matrixDim = packedMatrixDimOf(packedSchema);
if (matrixDim) {
return matrixDim * matrixDim * 4;
}

if (isWgslArray(schema)) {
return schema.elementCount * packedSizeOf(schema.elementType);
if (isWgslArray(packedSchema)) {
return packedSchema.elementCount * packedSizeOf(packedSchema.elementType);
}

return sizeOf(schema);
return sizeOf(packedSchema);
}

function inferSoAElementCount(
function computeSoAByteLength(
arraySchema: WgslArray,
soaData: Record<string, ArrayBufferView>,
): number | undefined {
const structSchema = arraySchema.elementType as WgslStruct;
let inferredCount: number | undefined;

for (const key in soaData) {
for (const key in structSchema.propTypes) {
const srcArray = soaData[key];
const fieldSchema = structSchema.propTypes[key];
if (srcArray === undefined || fieldSchema === undefined) {
continue;
}

const fieldPackedSize = packedSizeOf(fieldSchema);
if (fieldPackedSize === 0) {
const packedFieldSize = packedSizeOf(fieldSchema);
if (packedFieldSize === 0) {
continue;
}

const fieldElementCount = Math.floor(srcArray.byteLength / fieldPackedSize);
const fieldElementCount = Math.floor(srcArray.byteLength / packedFieldSize);
inferredCount =
inferredCount === undefined ? fieldElementCount : Math.min(inferredCount, fieldElementCount);
}

return inferredCount;
}

function computeSoAByteLength(
arraySchema: WgslArray,
soaData: Record<string, ArrayBufferView>,
): number | undefined {
const elementCount = inferSoAElementCount(arraySchema, soaData);
if (elementCount === undefined) {
if (inferredCount === undefined) {
return undefined;
}
const elementStride = roundUp(
sizeOf(arraySchema.elementType),
alignmentOf(arraySchema.elementType),
);
return elementCount * elementStride;
const elementStride = roundUp(sizeOf(structSchema), alignmentOf(structSchema));
return inferredCount * elementStride;
}

function writePackedValue(
Expand All @@ -96,41 +79,42 @@ function writePackedValue(
dstOffset: number,
srcOffset: number,
): void {
const matrixLayout = getPackedMatrixLayout(schema);
if (matrixLayout) {
const gpuColumnStride = roundUp(matrixLayout.packedColumnSize, alignmentOf(schema));

for (let col = 0; col < matrixLayout.dim; col++) {
const unpackedSchema = undecorate(schema);
const packedSchema = isAtomic(unpackedSchema) ? unpackedSchema.inner : unpackedSchema;
const matrixDim = packedMatrixDimOf(packedSchema);
if (matrixDim) {
const packedColumnSize = matrixDim * 4;
const gpuColumnStride = roundUp(packedColumnSize, alignmentOf(schema));
for (let col = 0; col < matrixDim; col++) {
target.set(
srcBytes.subarray(
srcOffset + col * matrixLayout.packedColumnSize,
srcOffset + col * matrixLayout.packedColumnSize + matrixLayout.packedColumnSize,
srcOffset + col * packedColumnSize,
srcOffset + col * packedColumnSize + packedColumnSize,
),
dstOffset + col * gpuColumnStride,
);
}

return;
}

if (isWgslArray(schema)) {
const packedElementSize = packedSizeOf(schema.elementType);
const gpuElementStride = roundUp(sizeOf(schema.elementType), alignmentOf(schema.elementType));

for (let i = 0; i < schema.elementCount; i++) {
if (isWgslArray(unpackedSchema)) {
const packedElementSize = packedSizeOf(unpackedSchema.elementType);
const gpuElementStride = roundUp(
sizeOf(unpackedSchema.elementType),
alignmentOf(unpackedSchema.elementType),
);

for (let i = 0; i < unpackedSchema.elementCount; i++) {
writePackedValue(
target,
schema.elementType,
unpackedSchema.elementType,
srcBytes,
dstOffset + i * gpuElementStride,
srcOffset + i * packedElementSize,
);
}

return;
}

target.set(srcBytes.subarray(srcOffset, srcOffset + sizeOf(schema)), dstOffset);
target.set(srcBytes.subarray(srcOffset, srcOffset + sizeOf(packedSchema)), dstOffset);
}

function scatterSoA(
Expand All @@ -141,7 +125,6 @@ function scatterSoA(
endOffset: number,
): void {
const structSchema = arraySchema.elementType as WgslStruct;
const offsets = offsetsForProps(structSchema);
const elementStride = roundUp(sizeOf(structSchema), alignmentOf(structSchema));
invariant(
startOffset % elementStride === 0,
Expand All @@ -150,6 +133,7 @@ function scatterSoA(
const startElement = Math.floor(startOffset / elementStride);
const endElement = Math.min(arraySchema.elementCount, Math.ceil(endOffset / elementStride));
const elementCount = Math.max(0, endElement - startElement);
const offsets = offsetsForProps(structSchema);

for (const key in structSchema.propTypes) {
const fieldSchema = structSchema.propTypes[key];
Expand All @@ -158,12 +142,10 @@ function scatterSoA(
}
const srcArray = soaData[key];
invariant(srcArray !== undefined, `Missing SoA data for field '${key}'`);

const fieldOffset = offsets[key]?.offset;
invariant(fieldOffset !== undefined, `Field ${key} not found in struct schema`);
const srcBytes = new Uint8Array(srcArray.buffer, srcArray.byteOffset, srcArray.byteLength);

const packedFieldSize = packedSizeOf(fieldSchema);
const srcBytes = new Uint8Array(srcArray.buffer, srcArray.byteOffset, srcArray.byteLength);
for (let i = 0; i < elementCount; i++) {
writePackedValue(
target,
Expand Down
8 changes: 5 additions & 3 deletions packages/typegpu/src/data/wgslTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,15 @@ export type TypedArrayFor<T> = T extends F32 | Vec2f | Vec3f | Vec4f | Mat2x2f |
? Float32Array
: T extends F16 | Vec2h | Vec3h | Vec4h
? Float16Array
: T extends I32 | Vec2i | Vec3i | Vec4i
: T extends I32 | Vec2i | Vec3i | Vec4i | Atomic<I32>
? Int32Array
: T extends U32 | Vec2u | Vec3u | Vec4u
: T extends U32 | Vec2u | Vec3u | Vec4u | Atomic<U32>
? Uint32Array
: T extends U16
? Uint16Array
: never;
: T extends Decorated<infer TBase>
? TypedArrayFor<TBase>
: never;

/**
* Vector infix notation.
Expand Down
Loading
Loading