diff --git a/packages/compiler/src/core/checker.ts b/packages/compiler/src/core/checker.ts index 25dbc4efcd..c592dbe87f 100644 --- a/packages/compiler/src/core/checker.ts +++ b/packages/compiler/src/core/checker.ts @@ -222,7 +222,7 @@ export interface Checker { projection: ProjectionNode, args?: (Type | string | number | boolean)[], ): Type; - resolveIdentifier(node: IdentifierNode): Sym | undefined; + resolveIdentifier(node: IdentifierNode): Sym[] | undefined; resolveCompletions(node: IdentifierNode): Map; createType( typeDef: T, @@ -2541,7 +2541,7 @@ export function createChecker(program: Program, resolver: NameResolver): Checker return resolver.getSymbolLinks(s); } - function resolveIdentifier(id: IdentifierNode, mapper?: TypeMapper): Sym | undefined { + function resolveIdentifier(id: IdentifierNode, mapper?: TypeMapper): Sym[] | undefined { let sym: Sym | undefined; const { node, kind } = getIdentifierContext(id); @@ -2549,8 +2549,8 @@ export function createChecker(program: Program, resolver: NameResolver): Checker case IdentifierKind.ModelExpressionProperty: case IdentifierKind.ObjectLiteralProperty: const model = getReferencedModel(node as ModelPropertyNode | ObjectLiteralPropertyNode); - if (model) { - sym = getMemberSymbol(model.node!.symbol, id.sv); + if (model.length === 1) { + sym = getMemberSymbol(model[0].node!.symbol, id.sv); } else { return undefined; } @@ -2558,7 +2558,7 @@ export function createChecker(program: Program, resolver: NameResolver): Checker case IdentifierKind.ModelStatementProperty: case IdentifierKind.Declaration: const links = resolver.getNodeLinks(id); - return links.resolvedSymbol; + return links.resolvedSymbol === undefined ? undefined : [links.resolvedSymbol]; case IdentifierKind.Other: return undefined; @@ -2597,7 +2597,14 @@ export function createChecker(program: Program, resolver: NameResolver): Checker compilerAssert(false, "Unreachable"); } - return sym?.symbolSource ?? sym; + if (sym) { + if (sym.symbolSource) { + return [sym.symbolSource]; + } else { + return [sym]; + } + } + return undefined; //sym?.symbolSource ?? sym; } function getTemplateDeclarationsForArgument( @@ -2617,7 +2624,7 @@ export function createChecker(program: Program, resolver: NameResolver): Checker function getReferencedModel( propertyNode: ObjectLiteralPropertyNode | ModelPropertyNode, - ): Model | undefined { + ): Model[] { type ModelOrArrayValueNode = ArrayLiteralNode | ObjectLiteralNode; type ModelOrArrayTypeNode = ModelExpressionNode | TupleExpressionNode; type ModelOrArrayNode = ModelOrArrayValueNode | ModelOrArrayTypeNode; @@ -2660,9 +2667,8 @@ export function createChecker(program: Program, resolver: NameResolver): Checker refType = getReferencedTypeFromConstAssignment(foundNode as ModelOrArrayValueNode); break; } - return refType?.kind === "Model" || refType?.kind === "Tuple" - ? getNestedModel(refType, path) - : undefined; + + return getNestedModel(refType, path); function pushToModelPath(node: Node, preNode: Node | undefined, path: PathSeg[]) { if (node.kind === SyntaxKind.ArrayLiteral || node.kind === SyntaxKind.TupleExpression) { @@ -2681,38 +2687,75 @@ export function createChecker(program: Program, resolver: NameResolver): Checker } } - function getNestedModel( - modelOrTuple: Model | Tuple | undefined, - path: PathSeg[], - ): Model | undefined { - let cur: Type | undefined = modelOrTuple; - for (const seg of path) { + function getNestedModel(modelOrTupleOrUnion: Type | undefined, path: PathSeg[]): Model[] { + let cur = modelOrTupleOrUnion; + + if (cur && cur.kind !== "Model" && cur.kind !== "Tuple" && cur.kind !== "Union") { + return []; + } + + if (path.length === 0) { + // Handle union and model type nesting when path is empty switch (cur?.kind) { - case "Tuple": - if ( - seg.tupleIndex !== undefined && - seg.tupleIndex >= 0 && - seg.tupleIndex < cur.values.length - ) { - cur = cur.values[seg.tupleIndex]; - } else { - return undefined; - } - break; case "Model": - if (cur.name === "Array" && seg.tupleIndex !== undefined) { - cur = cur.templateMapper?.args[0] as Model; - } else if (cur.name !== "Array" && seg.propertyName) { - cur = cur.properties.get(seg.propertyName)?.type; - } else { - return undefined; + return [cur]; + case "Union": + const models: Model[] = []; + for (const variant of cur.variants.values()) { + if ( + variant.type.kind === "Model" || + variant.type.kind === "Tuple" || + variant.type.kind === "Union" + ) { + models.push(...(getNestedModel(variant.type, path) ?? [])); + } } - break; + return models; default: - return undefined; + return []; } } - return cur?.kind === "Model" ? cur : undefined; + + const seg = path[0]; + switch (cur?.kind) { + case "Tuple": + if ( + seg.tupleIndex !== undefined && + seg.tupleIndex >= 0 && + seg.tupleIndex < cur.values.length + ) { + return getNestedModel(cur.values[seg.tupleIndex], path.slice(1)); + } else { + return []; + } + + case "Model": + if (cur.name === "Array" && seg.tupleIndex !== undefined) { + cur = cur.templateMapper?.args[0] as Model; + } else if (cur.name !== "Array" && seg.propertyName) { + cur = cur.properties.get(seg.propertyName)?.type; + } else { + return []; + } + return getNestedModel(cur, path.slice(1)); + + case "Union": + // When seg.property name exists, it means that it is in the union model or tuple, + // and the corresponding model or tuple needs to be found recursively. + const models: Model[] = []; + for (const variant of cur.variants.values()) { + if ( + variant.type.kind === "Model" || + variant.type.kind === "Tuple" || + variant.type.kind === "Union" + ) { + models.push(...(getNestedModel(variant.type, path) ?? [])); + } + } + return models; + default: + return []; + } } function getReferencedTypeFromTemplateDeclaration(node: ModelOrArrayNode): Type | undefined { @@ -2816,7 +2859,7 @@ export function createChecker(program: Program, resolver: NameResolver): Checker return undefined; } - const decDecl: DecoratorDeclarationStatementNode | undefined = decSym.declarations.find( + const decDecl: DecoratorDeclarationStatementNode | undefined = decSym[0].declarations.find( (x): x is DecoratorDeclarationStatementNode => x.kind === SyntaxKind.DecoratorDeclarationStatement, ); @@ -2905,26 +2948,12 @@ export function createChecker(program: Program, resolver: NameResolver): Checker kind === IdentifierKind.ObjectLiteralProperty ) { const model = getReferencedModel(ancestor as ModelPropertyNode | ObjectLiteralPropertyNode); - if (!model) { + if (model.length <= 0) { return completions; } const curModelNode = ancestor.parent as ModelExpressionNode | ObjectLiteralNode; - - for (const prop of walkPropertiesInherited(model)) { - if ( - identifier.sv === prop.name || - !curModelNode.properties.find( - (p) => - (p.kind === SyntaxKind.ModelProperty || - p.kind === SyntaxKind.ObjectLiteralProperty) && - p.id.sv === prop.name, - ) - ) { - const sym = getMemberSymbol(model.node!.symbol, prop.name); - if (sym) { - addCompletion(prop.name, sym); - } - } + for (const curModel of model) { + addInheritedPropertyCompletions(curModel, curModelNode); } } else if (identifier.parent && identifier.parent.kind === SyntaxKind.MemberExpression) { let base = resolver.getNodeLinks(identifier.parent.base).resolvedSymbol; @@ -2986,6 +3015,28 @@ export function createChecker(program: Program, resolver: NameResolver): Checker return completions; + function addInheritedPropertyCompletions( + model: Model, + curModelNode: ModelExpressionNode | ObjectLiteralNode, + ) { + for (const prop of walkPropertiesInherited(model)) { + if ( + identifier.sv === prop.name || + !curModelNode.properties.find( + (p) => + (p.kind === SyntaxKind.ModelProperty || + p.kind === SyntaxKind.ObjectLiteralProperty) && + p.id.sv === prop.name, + ) + ) { + const sym = getMemberSymbol(model.node!.symbol, prop.name); + if (sym) { + addCompletion(prop.name, sym); + } + } + } + } + function addCompletions(table: SymbolTable | undefined) { if (!table) { return; diff --git a/packages/compiler/src/server/serverlib.ts b/packages/compiler/src/server/serverlib.ts index eb484b0da8..2a0108a3c7 100644 --- a/packages/compiler/src/server/serverlib.ts +++ b/packages/compiler/src/server/serverlib.ts @@ -522,7 +522,7 @@ export function createServer(host: ServerHost): Server { const markdown: MarkupContent = { kind: MarkupKind.Markdown, - value: sym ? getSymbolDetails(program, sym) : "", + value: sym ? getSymbolDetails(program, sym[0]) : "", }; return { contents: markdown, @@ -562,7 +562,10 @@ export function createServer(host: ServerHost): Server { const sym = program.checker.resolveIdentifier( node.target.kind === SyntaxKind.MemberExpression ? node.target.id : node.target, ); - const templateDeclNode = sym?.declarations[0]; + if (!sym) { + return undefined; + } + const templateDeclNode = sym[0].declarations[0]; if ( !templateDeclNode || !("templateParameters" in templateDeclNode) || @@ -584,7 +587,7 @@ export function createServer(host: ServerHost): Server { const help: SignatureHelp = { signatures: [ { - label: `${sym.name}<${parameters.map((x) => x.label).join(", ")}>`, + label: `${sym[0].name}<${parameters.map((x) => x.label).join(", ")}>`, parameters, activeParameter: Math.min(parameters.length - 1, argumentIndex), }, @@ -593,7 +596,7 @@ export function createServer(host: ServerHost): Server { activeParameter: 0, }; - const doc = getSymbolDetails(program, sym, { + const doc = getSymbolDetails(program, sym[0], { includeSignature: false, includeParameterTags: false, }); @@ -616,7 +619,7 @@ export function createServer(host: ServerHost): Server { return undefined; } - const decoratorDeclNode: DecoratorDeclarationStatementNode | undefined = sym.declarations.find( + const decoratorDeclNode: DecoratorDeclarationStatementNode | undefined = sym[0].declarations.find( (x): x is DecoratorDeclarationStatementNode => x.kind === SyntaxKind.DecoratorDeclarationStatement, ); @@ -670,7 +673,7 @@ export function createServer(host: ServerHost): Server { activeParameter: 0, }; - const doc = getSymbolDetails(program, sym, { + const doc = getSymbolDetails(program, sym[0], { includeSignature: false, includeParameterTags: false, }); @@ -746,7 +749,7 @@ export function createServer(host: ServerHost): Server { switch (node?.kind) { case SyntaxKind.Identifier: const sym = result.program.checker.resolveIdentifier(node); - return getLocations(sym?.declarations); + return getLocations(sym?sym[0].declarations:undefined); case SyntaxKind.StringLiteral: if (node.parent?.kind === SyntaxKind.ImportStatement) { return [await getImportLocation(node.value, result.script)]; @@ -894,7 +897,7 @@ export function createServer(host: ServerHost): Server { visitChildren(searchFile, function visit(node) { if (node.kind === SyntaxKind.Identifier) { const s = program.checker.resolveIdentifier(node); - if (s === sym || (sym.type && s?.type === sym.type)) { + if ( s === sym || (sym[0].type && s && s[0].type === sym[0].type)) { references.push(node); } } diff --git a/packages/compiler/test/server/completion.test.ts b/packages/compiler/test/server/completion.test.ts index a14847876a..07f0a8e861 100644 --- a/packages/compiler/test/server/completion.test.ts +++ b/packages/compiler/test/server/completion.test.ts @@ -664,6 +664,118 @@ describe("identifiers", () => { ); }); + it("completes union variants(models) of template parameters", async () => { + const completions = await complete( + ` + model Options { + a: string; + b: Nested; + } + model Nested { + c:Foo2; + } + model Foo1 { + foo1: string; + } + model Foo2 { + foo2: string; + } + + model Test {} + + alias A = Test<#{┆}>; + `, + ); + + check( + completions, + [ + { + label: "foo1", + insertText: "foo1", + kind: CompletionItemKind.Field, + documentation: { + kind: MarkupKind.Markdown, + value: "(model property)\n```typespec\nFoo1.foo1: string\n```", + }, + }, + { + label: "a", + insertText: "a", + kind: CompletionItemKind.Field, + documentation: { + kind: MarkupKind.Markdown, + value: "(model property)\n```typespec\nOptions.a: string\n```", + }, + }, + { + label: "b", + insertText: "b", + kind: CompletionItemKind.Field, + documentation: { + kind: MarkupKind.Markdown, + value: "(model property)\n```typespec\nOptions.b: Nested\n```", + }, + }, + ], + { + allowAdditionalCompletions: false, + }, + ); + }); + + it("completes specific type in union variants(models) of template parameters", async () => { + const completions = await complete( + ` + model Options { + a: string; + b: Nested; + } + model Nested { + c:Foo2; + d:string; + } + model Foo1 { + foo1: string; + } + model Foo2 { + foo2: string; + } + + model Test {} + + alias A = Test<#{a:"",b:#{┆}}>; + `, + ); + + check( + completions, + [ + { + label: "c", + insertText: "c", + kind: CompletionItemKind.Field, + documentation: { + kind: MarkupKind.Markdown, + value: "(model property)\n```typespec\nNested.c: Foo2\n```", + }, + }, + { + label: "d", + insertText: "d", + kind: CompletionItemKind.Field, + documentation: { + kind: MarkupKind.Markdown, + value: "(model property)\n```typespec\nNested.d: string\n```", + }, + }, + ], + { + allowAdditionalCompletions: false, + }, + ); + }); + it("completes namespace operations", async () => { const completions = await complete( `