Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce storageStruct #29908

Open
wants to merge 25 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/nodes/TSL.js
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export * from './core/IndexNode.js';
export * from './core/ParameterNode.js';
export * from './core/PropertyNode.js';
export * from './core/StackNode.js';
export * from './core/StructTypeNode.js';
export * from './core/UniformGroupNode.js';
export * from './core/UniformNode.js';
export * from './core/VaryingNode.js';
Expand Down
10 changes: 10 additions & 0 deletions src/nodes/accessors/StorageBufferNode.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class StorageBufferNode extends BufferNode {
this.isAtomic = false;

this.bufferObject = false;
this.bufferStruct = false;
this.bufferCount = bufferCount;

this._attribute = null;
Expand Down Expand Up @@ -84,6 +85,14 @@ class StorageBufferNode extends BufferNode {

}

setBufferStruct( value ) {

this.bufferStruct = value;

return this;

}

setAccess( value ) {

this.access = value;
Expand Down Expand Up @@ -166,4 +175,5 @@ export default StorageBufferNode;

// Read-Write Storage
export const storage = ( value, type, count ) => nodeObject( new StorageBufferNode( value, type, count ) );
export const storageStruct = ( value, type, count ) => nodeObject( new StorageBufferNode( value, type, count ).setBufferStruct( true ) );
export const storageObject = ( value, type, count ) => nodeObject( new StorageBufferNode( value, type, count ).setBufferObject( true ) );
16 changes: 16 additions & 0 deletions src/nodes/core/StructTypeNode.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,19 @@ class StructTypeNode extends Node {
}

export default StructTypeNode;

export const struct = ( members ) => {

return Object.entries( members ).map( ( [ name, value ] ) => {

if ( typeof value === 'string' ) {

return { name, type: value, isAtomic: false };

}

return { name, type: value.type, isAtomic: value.atomic || false };

} );

};
167 changes: 162 additions & 5 deletions src/renderers/webgpu/nodes/WGSLNodeBuilder.js
Original file line number Diff line number Diff line change
Expand Up @@ -964,6 +964,88 @@ ${ flowData.code }

}

getTypeFromCustomStruct( shaderStage ) {

const uniforms = this.uniforms[ shaderStage ];

const bufferStructMap = new Map();

uniforms.forEach( ( uniform ) => {

const { name, node } = uniform;
const hasBufferStruct = node.bufferStruct === true;
bufferStructMap.set( name, hasBufferStruct );

} );

return bufferStructMap;

}

getMembersFromCustomStruct( members ) {

const structMembers = members.map( ( { name, type, isAtomic } ) => {

let finalType = wgslTypeLib[ type ];

if ( ! finalType ) {

console.warn( `Unrecognized type: ${type}` );
finalType = 'vec4<f32>';

}

return `${name}: ${isAtomic ? `atomic<${finalType}>` : finalType}`;

} );

return `\t${structMembers.join( ',\n\t' )}`;

}

getCustomStructNameFromShader( source ) {

const functionRegex = /fn\s+\w+\s*\(([\s\S]*?)\)/g; // filter shader header
const parameterRegex = /(\w+)\s*:\s*(ptr<\s*([\w]+),\s*(?:array<([\w<>]+)>|(\w+))[^>]*>|[\w<>,]+)/g; // filter parameters
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's strange to me to find regular expression here. This shouldn't happen when code is generated from nodes, and not the other way around or reprocessed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally had the parser in mind that you also use for the pointers because it is part of the WGSLNodeBuilder. However, the parser does not recognize stageData.codes as WGSL code, so it always thrown an error. At first I thought that this would only be the case with compute shaders because the -> void expression in stageData.codes is missing. _getWGSLComputeCode doesn't have that. But even when trying to use vertex and fragment shaders, the parser slipped into the else expression because it didn't recognize the shaders from stageData.codes as WGSL shaders.

So far stageData.codes is not used anywhere. It was only created. In your opinion, does the code from stageData.codes have to be compatible with the parser?
I then concentrated on reading out only the header, i.e. the part in the ( ), since the rest is not important. That seemed to me to be the safest and most stable way. If you see a better way to extract the datas from the shaders I welcome your idea. I originally wanted to forego the function entirely and do it with the parser.
Is the process clear from the review text?


const results = [];

let match;

while ( ( match = functionRegex.exec( source ) ) !== null ) {

const parameterString = match[ 1 ];

let paramMatch;

while ( ( paramMatch = parameterRegex.exec( parameterString ) ) !== null ) {

const [ , name, fullType, ptrType, arrayType, directStructName ] = paramMatch;

const structName = arrayType || directStructName || null;

const type = ptrType || fullType;

if ( Object.values( wgslTypeLib ).includes( structName ) || structName === null ) {

continue;

}

results.push( {
name,
type,
Fixed Show fixed Hide fixed
structName
} );

}

}

return results;

}

getStructMembers( struct ) {

const snippets = [];
Expand Down Expand Up @@ -1171,12 +1253,14 @@ ${ flowData.code }
const bufferType = this.getType( bufferNode.bufferType );
const bufferCount = bufferNode.bufferCount;


const isArray = bufferNode.value.array.length !== bufferNode.value.itemSize;
const bufferCountSnippet = bufferCount > 0 && uniform.type === 'buffer' ? ', ' + bufferCount : '';
const bufferTypeSnippet = bufferNode.isAtomic ? `atomic<${bufferType}>` : `${bufferType}`;
const bufferSnippet = `\t${ uniform.name } : array< ${ bufferTypeSnippet }${ bufferCountSnippet } >\n`;
const bufferSnippet = bufferNode.bufferStruct ? this.getMembersFromCustomStruct( bufferType ) : `\t${ uniform.name } : array< ${ bufferTypeSnippet }${ bufferCountSnippet } >\n`;
const bufferAccessMode = bufferNode.isStorageBufferNode ? `storage, ${ this.getStorageAccess( bufferNode ) }` : 'uniform';

bufferSnippets.push( this._getWGSLStructBinding( 'NodeBuffer_' + bufferNode.id, bufferSnippet, bufferAccessMode, uniformIndexes.binding ++, uniformIndexes.group ) );
bufferSnippets.push( this._getWGSLStructBinding( bufferNode.bufferStruct, isArray, 'NodeBuffer_' + bufferNode.id, bufferSnippet, bufferAccessMode, uniformIndexes.binding ++, uniformIndexes.group ) );

} else {

Expand All @@ -1199,7 +1283,7 @@ ${ flowData.code }

const group = uniformGroups[ name ];

structSnippets.push( this._getWGSLStructBinding( name, group.snippets.join( ',\n' ), 'uniform', group.index, group.id ) );
structSnippets.push( this._getWGSLStructBinding( false, false, name, group.snippets.join( ',\n' ), 'uniform', group.index, group.id ) );

}

Expand Down Expand Up @@ -1228,9 +1312,69 @@ ${ flowData.code }
stageData.codes = this.getCodes( shaderStage );
stageData.directives = this.getDirectives( shaderStage );
stageData.scopedArrays = this.getScopedArrays( shaderStage );
stageData.isBufferStruct = this.getTypeFromCustomStruct( shaderStage );
stageData.customStructNames = this.getCustomStructNameFromShader( stageData.codes );

//

const reduceFlow = ( flow ) => {

return flow.replace( /&(\w+)\.(\w+)/g, ( match, bufferName, uniformName ) =>

stageData.isBufferStruct.get( uniformName ) === true ? `&${bufferName}` : match

);

};

const extractPointerNames = ( source ) => {

const match = source.match( /\(([^)]+)\)/ );
if ( ! match ) return [];

const content = match[ 1 ];

return content
.split( /\s*,\s*/ )
.map( part => part.trim() )
.filter( part => part.includes( '&' ) )
.map( part => part.replace( /&/g, '' ) )
.filter( part => ! part.includes( '.' ) );

};

const createStructNameMapping = ( nodeBuffers, structs ) => {

const resultMap = new Map();

for ( let i = 0; i < nodeBuffers.length; i ++ ) {

const bufferName = nodeBuffers[ i ];
const struct = structs[ i ];

resultMap.set( bufferName, struct.structName );

}

return resultMap;

};

const replaceStructNamesInUniforms = ( shaderCode, map ) => {

for ( const [ key, value ] of map.entries() ) {

const regex = new RegExp( `\\b${key}Struct\\b`, 'g' );
shaderCode = shaderCode.replace( regex, value );

}

return shaderCode;

};


let pointerNames, structnameMapping;
let flow = '// code\n\n';
flow += this.flowCode[ shaderStage ];

Expand All @@ -1255,6 +1399,13 @@ ${ flowData.code }

flow += `${ flowSlotData.code }\n\t`;



flow = reduceFlow( flow );
pointerNames = extractPointerNames( flow );
structnameMapping = createStructNameMapping( pointerNames, stageData.customStructNames );
stageData.uniforms = replaceStructNamesInUniforms( stageData.uniforms, structnameMapping );

if ( node === mainNode && shaderStage !== 'compute' ) {

flow += '// result\n\n\t';
Expand Down Expand Up @@ -1289,6 +1440,11 @@ ${ flowData.code }

}

flow = reduceFlow( flow );
pointerNames = extractPointerNames( flow );
structnameMapping = createStructNameMapping( pointerNames, stageData.customStructNames );
stageData.uniforms = replaceStructNamesInUniforms( stageData.uniforms, structnameMapping );

}

}
Expand Down Expand Up @@ -1493,14 +1649,15 @@ ${vars}

}

_getWGSLStructBinding( name, vars, access, binding = 0, group = 0 ) {
_getWGSLStructBinding( isBufferStruct, isArray, name, vars, access, binding = 0, group = 0 ) {

const structName = name + 'Struct';
const structSnippet = this._getWGSLStruct( structName, vars );
const structName_ = isBufferStruct ? ( isArray ? `array<${structName}>` : structName ) : structName;

return `${structSnippet}
@binding( ${binding} ) @group( ${group} )
var<${access}> ${name} : ${structName};`;
var<${access}> ${name} : ${structName_};`;

}

Expand Down