Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 122 additions & 1 deletion packages/unplugin-typegpu/src/babel.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import * as Babel from '@babel/standalone';
import type TemplateGenerator from '@babel/template';
import type { TraverseOptions } from '@babel/traverse';
import type { NodePath, TraverseOptions } from '@babel/traverse';
import type * as babel from '@babel/types';
import defu from 'defu';
import { FORMAT_VERSION } from 'tinyest';
Expand Down Expand Up @@ -108,6 +108,124 @@ function wrapInAutoName(node: babel.Expression, name: string) {
);
}

type UseGpuFunctionPath = NodePath<
babel.FunctionDeclaration | babel.FunctionExpression | babel.ArrowFunctionExpression
>;

function objectDestructuringError(message: string): Error {
return new Error(`Unsupported object destructuring in "use gpu" functions: ${message}`);
}

function hasObjectPatternDeclaration(node: babel.VariableDeclaration): boolean {
return node.declarations.some((decl) => decl.id.type === 'ObjectPattern');
}

function expandObjectPatternDeclaration(
node: babel.VariableDeclaration,
path: NodePath<babel.VariableDeclaration>,
): babel.VariableDeclaration[] | null {
if (!hasObjectPatternDeclaration(node)) {
return null;
}

const expanded: babel.VariableDeclaration[] = [];

for (const declarator of node.declarations) {
if (declarator.id.type === 'Identifier') {
expanded.push(
types.variableDeclaration(node.kind, [types.cloneNode(declarator, true)]),
);
continue;
}

if (declarator.id.type !== 'ObjectPattern') {
throw objectDestructuringError('only flat object patterns are supported');
}

if (!declarator.init) {
throw objectDestructuringError('an initializer is required');
}

let objectSource = declarator.init;

if (objectSource.type !== 'Identifier') {
const tmpId = path.scope.generateUidIdentifier('tmp');

expanded.push(
types.variableDeclaration(node.kind, [
types.variableDeclarator(tmpId, types.cloneNode(objectSource, true)),
]),
);
objectSource = tmpId;
}

for (const property of declarator.id.properties) {
if (property.type === 'RestElement') {
throw objectDestructuringError('rest properties are not supported');
}

if (property.type !== 'ObjectProperty') {
throw objectDestructuringError('only plain object properties are supported');
}

if (property.computed || property.key.type !== 'Identifier') {
throw objectDestructuringError('only identifier property names are supported');
}

if (property.value.type !== 'Identifier') {
if (property.value.type === 'AssignmentPattern') {
throw objectDestructuringError('default values are not supported');
}

throw objectDestructuringError('nested destructuring is not supported');
}

expanded.push(
types.variableDeclaration(node.kind, [
types.variableDeclarator(
types.cloneNode(property.value, true),
types.memberExpression(
types.cloneNode(objectSource, true),
types.identifier(property.key.name),
),
),
]),
);
}
}

return expanded;
}

function normalizeObjectDestructuring(path: UseGpuFunctionPath) {
path.traverse({
Function(innerPath) {
if (innerPath.node !== path.node) {
innerPath.skip();
}
},

VariableDeclaration(innerPath) {
if (hasObjectPatternDeclaration(innerPath.node)) {
const parentPath = innerPath.parentPath;
if (!parentPath.isBlockStatement() && !parentPath.isProgram()) {
throw objectDestructuringError(
'unsupported object destructuring in non-block variable declaration (e.g. for-loop initializer or for-of/in)',
);
}
}

const expanded = expandObjectPatternDeclaration(innerPath.node, innerPath);
if (!expanded) {
return;
}

innerPath.replaceWithMultiple(expanded);
innerPath.skip();
},
});
}

function functionVisitor(ctx: Context): TraverseOptions {
let inUseGpuScope = false;

Expand Down Expand Up @@ -179,6 +297,7 @@ function functionVisitor(ctx: Context): TraverseOptions {
ArrowFunctionExpression: {
enter(path) {
if (containsUseGpuDirective(path.node)) {
normalizeObjectDestructuring(path);
fnNodeToOriginalMap.set(path.node, types.cloneNode(path.node, true));
if (inUseGpuScope) {
throw new Error(`Nesting 'use gpu' functions is not allowed`);
Expand All @@ -200,6 +319,7 @@ function functionVisitor(ctx: Context): TraverseOptions {
FunctionExpression: {
enter(path) {
if (containsUseGpuDirective(path.node)) {
normalizeObjectDestructuring(path);
fnNodeToOriginalMap.set(path.node, types.cloneNode(path.node, true));
if (inUseGpuScope) {
throw new Error(`Nesting 'use gpu' functions is not allowed`);
Expand All @@ -221,6 +341,7 @@ function functionVisitor(ctx: Context): TraverseOptions {
FunctionDeclaration: {
enter(path) {
if (containsUseGpuDirective(path.node)) {
normalizeObjectDestructuring(path);
fnNodeToOriginalMap.set(path.node, types.cloneNode(path.node, true));
if (inUseGpuScope) {
throw new Error(`Nesting 'use gpu' functions is not allowed`);
Expand Down
219 changes: 215 additions & 4 deletions packages/unplugin-typegpu/src/rollup-impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,207 @@ export function containsUseGpuDirective(node: FunctionNode): boolean {
return false;
}

function objectDestructuringError(message: string): Error {
return new Error(`Unsupported object destructuring in "use gpu" functions: ${message}`);
}

function hasObjectPatternDeclaration(node: acorn.VariableDeclaration): boolean {
return node.declarations.some((decl) => decl.id.type === 'ObjectPattern');
}

function cloneIdentifierNode(node: acorn.Identifier): acorn.AnyNode {
return structuredClone(node);
}

function createMemberExpression(
object: acorn.Expression,
propertyName: string,
): acorn.AnyNode {
return {
type: 'MemberExpression',
object: structuredClone(object),
property: { type: 'Identifier', name: propertyName },
computed: false,
} as acorn.AnyNode;
}

function expandObjectPatternDeclaration(
node: acorn.VariableDeclaration,
sliceNode: (node: acorn.Node) => string,
getTmpId: () => string,
): { declarations: acorn.AnyNode[]; replacement: string } | null {
if (!hasObjectPatternDeclaration(node)) {
return null;
}

const expanded: acorn.AnyNode[] = [];
const declarations: string[] = [];

for (const declarator of node.declarations) {
if (declarator.id.type === 'Identifier') {
expanded.push({
type: 'VariableDeclaration',
kind: node.kind,
declarations: [structuredClone(declarator)],
} as acorn.AnyNode);

declarations.push(
`${node.kind} ${declarator.id.name}${
declarator.init ? ` = ${sliceNode(declarator.init)}` : ''
};`,
);
continue;
}

if (declarator.id.type !== 'ObjectPattern') {
throw objectDestructuringError('only flat object patterns are supported');
}

if (!declarator.init) {
throw objectDestructuringError('an initializer is required');
}

let objectSourceStr = sliceNode(declarator.init);
let objectSourceAst = declarator.init as acorn.Expression;

if (objectSourceAst.type !== 'Identifier') {
const tmpName = getTmpId();
objectSourceStr = tmpName;
objectSourceAst = {
type: 'Identifier',
name: tmpName,
} as acorn.AnyNode as acorn.Identifier;

expanded.push({
type: 'VariableDeclaration',
kind: node.kind,
declarations: [
{
type: 'VariableDeclarator',
id: structuredClone(objectSourceAst),
init: declarator.init,
},
],
} as acorn.AnyNode);

declarations.push(`${node.kind} ${tmpName} = ${sliceNode(declarator.init)};`);
}

for (const property of declarator.id.properties) {
if (property.type === 'RestElement') {
throw objectDestructuringError('rest properties are not supported');
}

if (property.type !== 'Property') {
throw objectDestructuringError('only plain object properties are supported');
}

if (property.computed || property.key.type !== 'Identifier') {
throw objectDestructuringError('only identifier property names are supported');
}

if (property.value.type !== 'Identifier') {
if (property.value.type === 'AssignmentPattern') {
throw objectDestructuringError('default values are not supported');
}

throw objectDestructuringError('nested destructuring is not supported');
}

expanded.push({
type: 'VariableDeclaration',
kind: node.kind,
declarations: [
{
type: 'VariableDeclarator',
id: cloneIdentifierNode(property.value),
init: createMemberExpression(objectSourceAst, property.key.name),
},
],
} as acorn.AnyNode);

declarations.push(`${node.kind} ${property.value.name} = ${objectSourceStr}.${property.key.name};`);
}
}

return { declarations: expanded, replacement: declarations.join(' ') };
}

function normalizeObjectDestructuring(
node: acorn.AnyNode,
replaceNode: (node: acorn.Node, content: string) => void,
sliceNode: (node: acorn.Node) => string,
) {
let tmpCounter = 0;
const getTmpId = () => {
const id = tmpCounter === 0 ? '_tmp' : `_tmp${tmpCounter}`;
tmpCounter++;
return id;
};

walk(node as Node, {
enter(current, parent) {
const currentNode = current as acorn.AnyNode;
const parentNode = parent as acorn.AnyNode | undefined;

if (
currentNode.type === 'VariableDeclaration' &&
hasObjectPatternDeclaration(currentNode) &&
parentNode?.type !== 'BlockStatement' &&
parentNode?.type !== 'Program'
) {
throw objectDestructuringError(
'unsupported object destructuring in non-block variable declaration (e.g. for-loop initializer or for-of/in)',
);
}
},
});

const rewriteBody = (body: acorn.AnyNode[]) => {
const nextBody: acorn.AnyNode[] = [];

for (const statement of body) {
if (statement.type === 'VariableDeclaration') {
const expanded = expandObjectPatternDeclaration(statement, sliceNode, getTmpId);
if (expanded) {
replaceNode(statement, expanded.replacement);
nextBody.push(...expanded.declarations);
continue;
}
}

if (statement.type === 'BlockStatement') {
rewriteBody(statement.body);
} else if (statement.type === 'IfStatement') {
if (statement.consequent.type === 'BlockStatement') {
rewriteBody(statement.consequent.body);
}
if (statement.alternate?.type === 'BlockStatement') {
rewriteBody(statement.alternate.body);
}
} else if (
statement.type === 'ForStatement' &&
statement.body.type === 'BlockStatement'
) {
rewriteBody(statement.body.body);
} else if (
statement.type === 'WhileStatement' &&
statement.body.type === 'BlockStatement'
) {
rewriteBody(statement.body.body);
}

nextBody.push(statement);
}

body.splice(0, body.length, ...nextBody);
};

if (node.body.type === 'BlockStatement') {
rewriteBody(node.body.body);
}
}

export function removeUseGpuDirective(node: FunctionNode) {
const cloned = structuredClone(node);

Expand Down Expand Up @@ -127,9 +328,13 @@ export const rollUpImpl = (rawOptions: Options) => {
(implementation.type === 'FunctionExpression' ||
implementation.type === 'ArrowFunctionExpression')
) {
tgslFunctionDefs.push({
def: removeUseGpuDirective(implementation),
});
const def = removeUseGpuDirective(implementation);
normalizeObjectDestructuring(
def,
(targetNode, content) => magicString.overwriteNode(targetNode as Node, content),
(targetNode) => magicString.sliceNode(targetNode as Node),
);
tgslFunctionDefs.push({ def });
this.skip();
}
}
Expand All @@ -141,8 +346,14 @@ export const rollUpImpl = (rawOptions: Options) => {
node.type === 'FunctionDeclaration'
) {
if (containsUseGpuDirective(node)) {
const def = removeUseGpuDirective(node);
normalizeObjectDestructuring(
def,
(targetNode, content) => magicString.overwriteNode(targetNode as Node, content),
(targetNode) => magicString.sliceNode(targetNode as Node),
);
tgslFunctionDefs.push({
def: removeUseGpuDirective(node),
def,
name: getFunctionName(node, parent),
});
this.skip();
Expand Down
Loading