diff --git a/src/ModelContextProtocol.Analyzers/EquatableArray.cs b/src/ModelContextProtocol.Analyzers/EquatableArray.cs new file mode 100644 index 000000000..458e557cf --- /dev/null +++ b/src/ModelContextProtocol.Analyzers/EquatableArray.cs @@ -0,0 +1,55 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; + +namespace ModelContextProtocol.Analyzers; + +/// An immutable, equatable array. +/// The type of values in the array. +internal readonly struct EquatableArray : IEnumerable, IEquatable> +{ + /// The underlying array. + private readonly T[]? _array; + + /// The source to enumerate and wrap. + public EquatableArray(IEnumerable source) => _array = source.ToArray(); + + /// The source to wrap. + public EquatableArray(T[] array) => _array = array; + + /// Gets a reference to an item at a specified position within the array. + /// The index of the item to retrieve a reference to. + /// A reference to an item at a specified position within the array. + public ref readonly T this[int index] => ref NonNullArray[index]; + + /// Gets the backing array. + private T[] NonNullArray => _array ?? []; + + /// Gets the length of the current array. + public int Length => NonNullArray.Length; + + /// + public bool Equals(EquatableArray other) => NonNullArray.SequenceEqual(other.NonNullArray); + + /// + public override bool Equals(object? obj) => obj is EquatableArray array && Equals(array); + + /// + public override int GetHashCode() + { + int hash = 17; + foreach (T item in NonNullArray) + { + hash = hash * 31 + (item?.GetHashCode() ?? 0); + } + + return hash; + } + + /// + IEnumerator IEnumerable.GetEnumerator() => ((IEnumerable)NonNullArray).GetEnumerator(); + + /// + IEnumerator IEnumerable.GetEnumerator() => NonNullArray.GetEnumerator(); +} diff --git a/src/ModelContextProtocol.Analyzers/XmlToDescriptionGenerator.cs b/src/ModelContextProtocol.Analyzers/XmlToDescriptionGenerator.cs index 40109158d..992a8394b 100644 --- a/src/ModelContextProtocol.Analyzers/XmlToDescriptionGenerator.cs +++ b/src/ModelContextProtocol.Analyzers/XmlToDescriptionGenerator.cs @@ -5,6 +5,7 @@ using System.CodeDom.Compiler; using System.Collections.Immutable; using System.Text; +using System.Xml; using System.Xml.Linq; namespace ModelContextProtocol.Analyzers; @@ -24,92 +25,218 @@ public sealed class XmlToDescriptionGenerator : IIncrementalGenerator public void Initialize(IncrementalGeneratorInitializationContext context) { - // Use ForAttributeWithMetadataName for each MCP attribute type - var toolMethods = CreateProviderForAttribute(context, McpServerToolAttributeName); - var promptMethods = CreateProviderForAttribute(context, McpServerPromptAttributeName); - var resourceMethods = CreateProviderForAttribute(context, McpServerResourceAttributeName); - - // Combine all three providers - var allMethods = toolMethods - .Collect() - .Combine(promptMethods.Collect()) - .Combine(resourceMethods.Collect()) + // Extract method information for all MCP tools, prompts, and resources. + // The transform extracts all necessary data upfront so the output doesn't depend on the compilation. + var allMethods = CreateProviderForAttribute(context, McpServerToolAttributeName).Collect() + .Combine(CreateProviderForAttribute(context, McpServerPromptAttributeName).Collect()) + .Combine(CreateProviderForAttribute(context, McpServerResourceAttributeName).Collect()) .Select(static (tuple, _) => { - var ((tool, prompt), resource) = tuple; - return tool.AddRange(prompt).AddRange(resource); + var ((tools, prompts), resources) = tuple; + return new EquatableArray(tools.Concat(prompts).Concat(resources)); }); - // Combine with compilation to get well-known type symbols. - var compilationAndMethods = context.CompilationProvider.Combine(allMethods); + // Report diagnostics for all methods. + context.RegisterSourceOutput( + allMethods, + static (spc, methods) => + { + foreach (var method in methods) + { + foreach (var diagnostic in method.Diagnostics) + { + spc.ReportDiagnostic(CreateDiagnostic(diagnostic)); + } + } + }); - // Write out the source for all methods. - context.RegisterSourceOutput(compilationAndMethods, static (spc, source) => Execute(source.Left, source.Right, spc)); + // Generate source code only for methods that need generation. + context.RegisterSourceOutput( + allMethods.Select(static (methods, _) => new EquatableArray(methods.Where(m => m.NeedsGeneration))), + static (spc, methods) => + { + if (methods.Length > 0) + { + spc.AddSource(GeneratedFileName, SourceText.From(GenerateSourceFile(methods), Encoding.UTF8)); + } + }); } + private static Diagnostic CreateDiagnostic(DiagnosticInfo info) => + Diagnostic.Create(info.Id switch + { + "MCP001" => Diagnostics.InvalidXmlDocumentation, + "MCP002" => Diagnostics.McpMethodMustBePartial, + _ => throw new InvalidOperationException($"Unknown diagnostic ID: {info.Id}") + }, info.Location, info.MessageArgs); + private static IncrementalValuesProvider CreateProviderForAttribute( IncrementalGeneratorInitializationContext context, string attributeMetadataName) => context.SyntaxProvider.ForAttributeWithMetadataName( attributeMetadataName, static (node, _) => node is MethodDeclarationSyntax, - static (ctx, ct) => + static (ctx, _) => ExtractMethodInfo((MethodDeclarationSyntax)ctx.TargetNode, (IMethodSymbol)ctx.TargetSymbol, ctx.SemanticModel.Compilation)); + + private static MethodToGenerate ExtractMethodInfo( + MethodDeclarationSyntax methodDeclaration, + IMethodSymbol methodSymbol, + Compilation compilation) + { + bool isPartial = methodDeclaration.Modifiers.Any(SyntaxKind.PartialKeyword); + var descriptionAttribute = compilation.GetTypeByMetadataName(DescriptionAttributeName); + + // Try to extract XML documentation + var (xmlDocs, hasInvalidXml) = TryExtractXmlDocumentation(methodSymbol); + + // For non-partial methods, check if we should report a diagnostic + if (!isPartial) + { + // Report invalid XML diagnostic only if the method would have generated content + if (hasInvalidXml) { - var methodDeclaration = (MethodDeclarationSyntax)ctx.TargetNode; - var methodSymbol = (IMethodSymbol)ctx.TargetSymbol; - return new MethodToGenerate(methodDeclaration, methodSymbol); - }); + // We can't know if it would have been generatable, so skip for non-partial + return MethodToGenerate.Empty; + } + + // Check if this non-partial method has generatable content - if so, report diagnostic + if (xmlDocs is not null && descriptionAttribute is not null && + HasGeneratableContent(xmlDocs, methodSymbol, descriptionAttribute)) + { + return MethodToGenerate.CreateDiagnosticOnly( + new DiagnosticInfo("MCP002", methodDeclaration.Identifier.GetLocation(), methodSymbol.Name)); + } + + return MethodToGenerate.Empty; + } + + // For partial methods with invalid XML, report diagnostic but still generate partial declaration. + EquatableArray diagnostics = hasInvalidXml ? + new EquatableArray(ImmutableArray.Create(new DiagnosticInfo("MCP001", methodSymbol.Locations.FirstOrDefault(), methodSymbol.Name))) : + default; + + bool needsMethodDescription = xmlDocs is not null && + !string.IsNullOrWhiteSpace(xmlDocs.MethodDescription) && + (descriptionAttribute is null || !HasAttribute(methodSymbol, descriptionAttribute)); + + bool needsReturnDescription = xmlDocs is not null && + !string.IsNullOrWhiteSpace(xmlDocs.Returns) && + (descriptionAttribute is null || + methodSymbol.GetReturnTypeAttributes().All(attr => !SymbolEqualityComparer.Default.Equals(attr.AttributeClass, descriptionAttribute))); + + // Extract method info for partial methods + var modifiers = methodDeclaration.Modifiers + .Where(m => !m.IsKind(SyntaxKind.AsyncKeyword)) + .Select(m => m.Text); + string modifiersStr = string.Join(" ", modifiers); + string returnType = methodSymbol.ReturnType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat); + string methodName = methodSymbol.Name; + + // Extract parameters + var parameterSyntaxList = methodDeclaration.ParameterList.Parameters; + ParameterInfo[] parameters = new ParameterInfo[methodSymbol.Parameters.Length]; + for (int i = 0; i < methodSymbol.Parameters.Length; i++) + { + var param = methodSymbol.Parameters[i]; + var paramSyntax = i < parameterSyntaxList.Count ? parameterSyntaxList[i] : null; + + parameters[i] = new ParameterInfo( + ParameterType: param.Type.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat), + Name: param.Name, + HasDescriptionAttribute: descriptionAttribute is not null && HasAttribute(param, descriptionAttribute), + XmlDescription: xmlDocs?.Parameters.TryGetValue(param.Name, out var pd) == true && !string.IsNullOrWhiteSpace(pd) ? pd : null, + DefaultValue: paramSyntax?.Default?.ToFullString().Trim()); + } - private static void Execute(Compilation compilation, ImmutableArray methods, SourceProductionContext context) + return new MethodToGenerate( + NeedsGeneration: true, + TypeInfo: ExtractTypeInfo(methodSymbol.ContainingType), + Modifiers: modifiersStr, + ReturnType: returnType, + MethodName: methodName, + Parameters: new EquatableArray(parameters), + MethodDescription: needsMethodDescription ? xmlDocs?.MethodDescription : null, + ReturnDescription: needsReturnDescription ? xmlDocs?.Returns : null, + Diagnostics: diagnostics); + } + + /// Checks if XML documentation would generate any Description attributes for a method. + private static bool HasGeneratableContent(XmlDocumentation xmlDocs, IMethodSymbol methodSymbol, INamedTypeSymbol descriptionAttribute) { - if (methods.IsDefaultOrEmpty || - compilation.GetTypeByMetadataName(DescriptionAttributeName) is not { } descriptionAttribute) + if (!string.IsNullOrWhiteSpace(xmlDocs.MethodDescription) && !HasAttribute(methodSymbol, descriptionAttribute)) { - return; + return true; } - // Gather a list of all methods needing generation. - List<(IMethodSymbol MethodSymbol, MethodDeclarationSyntax MethodDeclaration, XmlDocumentation? XmlDocs)> methodsToGenerate = new(methods.Length); - foreach (var methodModel in methods) + if (!string.IsNullOrWhiteSpace(xmlDocs.Returns) && + methodSymbol.GetReturnTypeAttributes().All(attr => !SymbolEqualityComparer.Default.Equals(attr.AttributeClass, descriptionAttribute))) { - var xmlDocs = ExtractXmlDocumentation(methodModel.MethodSymbol, methodModel.MethodDeclaration, context); + return true; + } - // Generate implementation for partial methods. - if (methodModel.MethodDeclaration.Modifiers.Any(SyntaxKind.PartialKeyword)) - { - methodsToGenerate.Add((methodModel.MethodSymbol, methodModel.MethodDeclaration, xmlDocs)); - } - else if (xmlDocs is not null && HasGeneratableContent(xmlDocs, methodModel.MethodSymbol, descriptionAttribute)) + foreach (var param in methodSymbol.Parameters) + { + if (!HasAttribute(param, descriptionAttribute) && + xmlDocs.Parameters.TryGetValue(param.Name, out var paramDoc) && + !string.IsNullOrWhiteSpace(paramDoc)) { - // The method is not partial but has XML docs that would generate attributes; issue a diagnostic. - context.ReportDiagnostic(Diagnostic.Create( - Diagnostics.McpMethodMustBePartial, - methodModel.MethodDeclaration.Identifier.GetLocation(), - methodModel.MethodSymbol.Name)); + return true; } } - // Generate a single file with all partial declarations. - if (methodsToGenerate.Count > 0) + return false; + } + + private static TypeInfo ExtractTypeInfo(INamedTypeSymbol? typeSymbol) + { + if (typeSymbol is null) { - string source = GenerateSourceFile(compilation, methodsToGenerate, descriptionAttribute); - context.AddSource(GeneratedFileName, SourceText.From(source, Encoding.UTF8)); + return new TypeInfo(string.Empty, default); } + + // Build list of nested types from innermost to outermost + var typesBuilder = ImmutableArray.CreateBuilder(); + for (var current = typeSymbol; current is not null; current = current.ContainingType) + { + var typeDecl = current.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() as TypeDeclarationSyntax; + string typeKeyword; + if (typeDecl is RecordDeclarationSyntax rds) + { + string classOrStruct = rds.ClassOrStructKeyword.ValueText; + if (string.IsNullOrEmpty(classOrStruct)) + { + classOrStruct = "class"; + } + typeKeyword = $"{typeDecl.Keyword.ValueText} {classOrStruct}"; + } + else + { + typeKeyword = typeDecl?.Keyword.ValueText ?? "class"; + } + + typesBuilder.Add(new TypeDeclarationInfo(current.Name, typeKeyword)); + } + + // Reverse to get outermost first + typesBuilder.Reverse(); + + string ns = typeSymbol.ContainingNamespace.IsGlobalNamespace ? "" : typeSymbol.ContainingNamespace.ToDisplayString(); + return new TypeInfo(ns, new EquatableArray(typesBuilder.ToImmutable())); } - private static XmlDocumentation? ExtractXmlDocumentation(IMethodSymbol methodSymbol, MethodDeclarationSyntax methodDeclaration, SourceProductionContext context) + private static (XmlDocumentation? Docs, bool HasInvalidXml) TryExtractXmlDocumentation(IMethodSymbol methodSymbol) { string? xmlDoc = methodSymbol.GetDocumentationCommentXml(); if (string.IsNullOrWhiteSpace(xmlDoc)) { - return null; + return (null, false); } try { if (XDocument.Parse(xmlDoc).Element("member") is not { } memberElement) { - return null; + return (null, false); } var summary = CleanXmlDocText(memberElement.Element("summary")?.Value); @@ -134,19 +261,11 @@ private static void Execute(Compilation compilation, ImmutableArray methods, - INamedTypeSymbol descriptionAttribute) + private static string GenerateSourceFile(EquatableArray methods) { StringWriter sw = new(); IndentedTextWriter writer = new(sw); @@ -183,10 +299,7 @@ private static string GenerateSourceFile( writer.WriteLine(); // Group methods by namespace and containing type - var groupedMethods = methods.GroupBy(m => - m.MethodSymbol.ContainingNamespace.Name == compilation.GlobalNamespace.Name ? "" : - m.MethodSymbol.ContainingNamespace?.ToDisplayString() ?? - ""); + var groupedMethods = methods.GroupBy(m => m.TypeInfo.Namespace); bool firstNamespace = true; foreach (var namespaceGroup in groupedMethods) @@ -197,7 +310,7 @@ private static string GenerateSourceFile( } firstNamespace = false; - // Check if this is the global namespace (methods with null ContainingNamespace) + // Check if this is the global namespace bool isGlobalNamespace = string.IsNullOrEmpty(namespaceGroup.Key); if (!isGlobalNamespace) { @@ -206,15 +319,10 @@ private static string GenerateSourceFile( writer.Indent++; } - // Group by containing type within namespace + // Group by containing type within namespace (using structural equality for TypeInfo) bool isFirstTypeInNamespace = true; - foreach (var typeGroup in namespaceGroup.GroupBy(m => m.MethodSymbol.ContainingType, SymbolEqualityComparer.Default)) + foreach (var typeGroup in namespaceGroup.GroupBy(m => m.TypeInfo)) { - if (typeGroup.Key is not INamedTypeSymbol containingType) - { - continue; - } - if (!isFirstTypeInNamespace) { writer.WriteLine(); @@ -222,7 +330,7 @@ private static string GenerateSourceFile( isFirstTypeInNamespace = false; // Write out the type, which could include parent types. - AppendNestedTypeDeclarations(writer, containingType, typeGroup, descriptionAttribute); + AppendNestedTypeDeclarations(writer, typeGroup.Key, typeGroup); } if (!isGlobalNamespace) @@ -237,50 +345,23 @@ private static string GenerateSourceFile( private static void AppendNestedTypeDeclarations( IndentedTextWriter writer, - INamedTypeSymbol typeSymbol, - IGrouping typeGroup, - INamedTypeSymbol descriptionAttribute) + TypeInfo typeInfo, + IEnumerable methods) { - // Build stack of nested types from innermost to outermost - Stack types = []; - for (var current = typeSymbol; current is not null; current = current.ContainingType) - { - types.Push(current); - } - // Generate type declarations from outermost to innermost - int nestingCount = types.Count; - while (types.Count > 0) + int nestingCount = typeInfo.Types.Length; + foreach (var type in typeInfo.Types) { - // Get the type keyword and handle records - var type = types.Pop(); - var typeDecl = type.DeclaringSyntaxReferences.FirstOrDefault()?.GetSyntax() as TypeDeclarationSyntax; - string typeKeyword; - if (typeDecl is RecordDeclarationSyntax rds) - { - string classOrStruct = rds.ClassOrStructKeyword.ValueText; - if (string.IsNullOrEmpty(classOrStruct)) - { - classOrStruct = "class"; - } - - typeKeyword = $"{typeDecl.Keyword.ValueText} {classOrStruct}"; - } - else - { - typeKeyword = typeDecl?.Keyword.ValueText ?? "class"; - } - - writer.WriteLine($"partial {typeKeyword} {type.Name}"); + writer.WriteLine($"partial {type.TypeKeyword} {type.Name}"); writer.WriteLine("{"); writer.Indent++; } // Generate methods for this type. bool firstMethodInType = true; - foreach (var (methodSymbol, methodDeclaration, xmlDocs) in typeGroup) + foreach (var method in methods) { - AppendMethodDeclaration(writer, methodSymbol, methodDeclaration, xmlDocs, descriptionAttribute, firstMethodInType); + AppendMethodDeclaration(writer, method, firstMethodInType); firstMethodInType = false; } @@ -294,10 +375,7 @@ private static void AppendNestedTypeDeclarations( private static void AppendMethodDeclaration( IndentedTextWriter writer, - IMethodSymbol methodSymbol, - MethodDeclarationSyntax methodDeclaration, - XmlDocumentation? xmlDocs, - INamedTypeSymbol descriptionAttribute, + MethodToGenerate method, bool firstMethodInType) { if (!firstMethodInType) @@ -305,64 +383,50 @@ private static void AppendMethodDeclaration( writer.WriteLine(); } - // Add the Description attribute for method if needed and documentation exists - if (xmlDocs is not null && - !string.IsNullOrWhiteSpace(xmlDocs.MethodDescription) && - !HasAttribute(methodSymbol, descriptionAttribute)) + // Add the Description attribute for method if needed + if (!string.IsNullOrWhiteSpace(method.MethodDescription)) { - writer.WriteLine($"[Description(\"{EscapeString(xmlDocs.MethodDescription)}\")]"); + writer.WriteLine($"[Description(\"{EscapeString(method.MethodDescription!)}\")]"); } - // Add return: Description attribute if needed and documentation exists - if (xmlDocs is not null && - !string.IsNullOrWhiteSpace(xmlDocs.Returns) && - methodSymbol.GetReturnTypeAttributes().All(attr => !SymbolEqualityComparer.Default.Equals(attr.AttributeClass, descriptionAttribute))) + // Add return: Description attribute if needed + if (!string.IsNullOrWhiteSpace(method.ReturnDescription)) { - writer.WriteLine($"[return: Description(\"{EscapeString(xmlDocs.Returns)}\")]"); + writer.WriteLine($"[return: Description(\"{EscapeString(method.ReturnDescription!)}\")]"); } - // Copy modifiers from original method syntax, excluding 'async' which is invalid on partial declarations (CS1994). - // Add return type (without nullable annotations). - // Add method name. - var modifiers = methodDeclaration.Modifiers - .Where(m => !m.IsKind(SyntaxKind.AsyncKeyword)) - .Select(m => m.Text); - writer.Write(string.Join(" ", modifiers)); + // Write method signature + writer.Write(method.Modifiers); writer.Write(' '); - writer.Write(methodSymbol.ReturnType.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat)); + writer.Write(method.ReturnType); writer.Write(' '); - writer.Write(methodSymbol.Name); + writer.Write(method.MethodName); // Add parameters with their Description attributes. writer.Write("("); - var parameterSyntaxList = methodDeclaration.ParameterList.Parameters; - for (int i = 0; i < methodSymbol.Parameters.Length; i++) + for (int i = 0; i < method.Parameters.Length; i++) { - IParameterSymbol param = methodSymbol.Parameters[i]; - ParameterSyntax? paramSyntax = i < parameterSyntaxList.Count ? parameterSyntaxList[i] : null; + var param = method.Parameters[i]; if (i > 0) { writer.Write(", "); } - if (xmlDocs is not null && - !HasAttribute(param, descriptionAttribute) && - xmlDocs.Parameters.TryGetValue(param.Name, out var paramDoc) && - !string.IsNullOrWhiteSpace(paramDoc)) + if (!param.HasDescriptionAttribute && !string.IsNullOrWhiteSpace(param.XmlDescription)) { - writer.Write($"[Description(\"{EscapeString(paramDoc)}\")] "); + writer.Write($"[Description(\"{EscapeString(param.XmlDescription!)}\")] "); } - writer.Write(param.Type.ToDisplayString(SymbolDisplayFormat.MinimallyQualifiedFormat)); + writer.Write(param.ParameterType); writer.Write(' '); writer.Write(param.Name); - // Preserve default parameter values from the original syntax. - if (paramSyntax?.Default is { } defaultValue) + // Preserve default parameter values + if (!string.IsNullOrEmpty(param.DefaultValue)) { writer.Write(' '); - writer.Write(defaultValue.ToFullString().Trim()); + writer.Write(param.DefaultValue); } } writer.WriteLine(");"); @@ -391,39 +455,86 @@ private static string EscapeString(string text) => .Replace("\n", "\\n") .Replace("\t", "\\t"); - /// Checks if XML documentation would generate any Description attributes for a method. - private static bool HasGeneratableContent(XmlDocumentation xmlDocs, IMethodSymbol methodSymbol, INamedTypeSymbol descriptionAttribute) + // Cache-friendly data structures - these hold only primitive data, no symbols or syntax + + /// Represents a method that may need Description attributes generated. + private readonly record struct MethodToGenerate( + bool NeedsGeneration, + TypeInfo TypeInfo, + string Modifiers, + string ReturnType, + string MethodName, + EquatableArray Parameters, + string? MethodDescription, + string? ReturnDescription, + EquatableArray Diagnostics) : IEquatable { - // Check if method description would be generated - if (!string.IsNullOrWhiteSpace(xmlDocs.MethodDescription) && !HasAttribute(methodSymbol, descriptionAttribute)) - { - return true; - } + public static MethodToGenerate Empty => new( + NeedsGeneration: false, + TypeInfo: default, + Modifiers: string.Empty, + ReturnType: string.Empty, + MethodName: string.Empty, + Parameters: default, + MethodDescription: null, + ReturnDescription: null, + Diagnostics: default); + + public static MethodToGenerate CreateDiagnosticOnly(DiagnosticInfo diagnostic) => new( + NeedsGeneration: false, + TypeInfo: default, + Modifiers: string.Empty, + ReturnType: string.Empty, + MethodName: string.Empty, + Parameters: default, + MethodDescription: null, + ReturnDescription: null, + Diagnostics: new([diagnostic])); + } - // Check if return description would be generated - if (!string.IsNullOrWhiteSpace(xmlDocs.Returns) && - methodSymbol.GetReturnTypeAttributes().All(attr => !SymbolEqualityComparer.Default.Equals(attr.AttributeClass, descriptionAttribute))) - { - return true; - } + /// Holds information about a method parameter. + private readonly record struct ParameterInfo( + string ParameterType, + string Name, + bool HasDescriptionAttribute, + string? XmlDescription, + string? DefaultValue); + + /// Holds information about a type containing MCP methods. + private readonly record struct TypeInfo( + string Namespace, + EquatableArray Types); + + /// Holds information about a type declaration. + private readonly record struct TypeDeclarationInfo( + string Name, + string TypeKeyword); + + /// Holds diagnostic information to be reported. + private readonly struct DiagnosticInfo + { + public string Id { get; } + public Location? Location { get; } + public string MethodName { get; } - // Check if any parameter descriptions would be generated - foreach (var param in methodSymbol.Parameters) + public DiagnosticInfo(string id, Location? location, string methodName) { - if (!HasAttribute(param, descriptionAttribute) && - xmlDocs.Parameters.TryGetValue(param.Name, out var paramDoc) && - !string.IsNullOrWhiteSpace(paramDoc)) - { - return true; - } + Id = id; + Location = location; + MethodName = methodName; } - return false; - } + public object?[] MessageArgs => [MethodName]; - /// Represents a method that may need Description attributes generated. - private readonly record struct MethodToGenerate(MethodDeclarationSyntax MethodDeclaration, IMethodSymbol MethodSymbol); + // For incremental generator caching, we compare only the logical content, not the Location object + public bool Equals(DiagnosticInfo other) => + Id == other.Id && MethodName == other.MethodName; + + public override bool Equals(object? obj) => obj is DiagnosticInfo other && Equals(other); + + public override int GetHashCode() => (Id, MethodName).GetHashCode(); + } - /// Holds extracted XML documentation for a method. + /// Holds extracted XML documentation for a method (used only during extraction, not cached). private sealed record XmlDocumentation(string MethodDescription, string Returns, Dictionary Parameters); } diff --git a/tests/ModelContextProtocol.Analyzers.Tests/XmlToDescriptionGeneratorTests.cs b/tests/ModelContextProtocol.Analyzers.Tests/XmlToDescriptionGeneratorTests.cs index bffb24647..0feacd0b2 100644 --- a/tests/ModelContextProtocol.Analyzers.Tests/XmlToDescriptionGeneratorTests.cs +++ b/tests/ModelContextProtocol.Analyzers.Tests/XmlToDescriptionGeneratorTests.cs @@ -1796,4 +1796,383 @@ private class GeneratorRunResult public List Diagnostics { get; set; } = []; public Compilation? Compilation { get; set; } } + + [Fact] + public void Caching_WithIdenticalCompilation_AllOutputsCached() + { + // This tests that running the same compilation twice uses cached results + const string Source = """ + using ModelContextProtocol.Server; + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// Test tool + [McpServerTool] + public static partial string TestMethod(string input) => input; + } + """; + + var compilation = CreateCompilation(Source); + var driver = CreateTrackedDriver(); + + // Run #1 + driver = driver.RunGenerators(compilation, TestContext.Current.CancellationToken); + var result1 = driver.GetRunResult(); + Assert.Single(result1.Results); + Assert.Single(result1.Results[0].GeneratedSources); + + // Run #2 with same compilation - should be fully cached + driver = driver.RunGenerators(compilation, TestContext.Current.CancellationToken); + var result2 = driver.GetRunResult(); + Assert.Single(result2.Results); + + var allOutputs = result2.Results[0].TrackedSteps.Values + .SelectMany(steps => steps.SelectMany(step => step.Outputs)) + .ToList(); + Assert.NotEmpty(allOutputs); + Assert.All(allOutputs, output => + Assert.True(output.Reason is IncrementalStepRunReason.Cached or IncrementalStepRunReason.Unchanged, + $"Expected Cached or Unchanged but got {output.Reason}")); + } + + [Fact] + public void Caching_WithNewCompilationSameSource_OutputsCached() + { + const string Source = """ + using ModelContextProtocol.Server; + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// Test tool + [McpServerTool] + public static partial string TestMethod(string input) => input; + } + """; + + var driver = CreateTrackedDriver(); + + // Run #1 with first compilation + var compilation1 = CreateCompilation(Source); + driver = driver.RunGenerators(compilation1, TestContext.Current.CancellationToken); + var result1 = driver.GetRunResult(); + Assert.Single(result1.Results); + + // Run #2 with NEW compilation from same source + // This creates new syntax trees and new symbol instances + var compilation2 = CreateCompilation(Source); + Assert.NotSame(compilation1, compilation2); // Verify these are different instances + + driver = driver.RunGenerators(compilation2, TestContext.Current.CancellationToken); + var result2 = driver.GetRunResult(); + Assert.Single(result2.Results); + + // The source generation output should be cached because the extracted data is semantically identical + var sourceOutputSteps = result2.Results[0].TrackedSteps + .Where(kvp => kvp.Key.Contains("SourceOutput") || kvp.Key.Contains("RegisterSourceOutput")) + .SelectMany(kvp => kvp.Value.SelectMany(step => step.Outputs)) + .ToList(); + + // At minimum, check that we're not regenerating everything from scratch + var allOutputs = result2.Results[0].TrackedSteps.Values + .SelectMany(steps => steps.SelectMany(step => step.Outputs)) + .ToList(); + Assert.NotEmpty(allOutputs); + + // With proper value equality, the final output should be unchanged + // (the source text should be identical even if intermediate steps ran) + Assert.Equal( + result1.Results[0].GeneratedSources[0].SourceText.ToString(), + result2.Results[0].GeneratedSources[0].SourceText.ToString()); + } + + [Fact] + public void Caching_WithUnrelatedFileChange_McpMethodCached() + { + // Adding an unrelated file should not cause MCP method extraction to re-run + const string McpSource = """ + using ModelContextProtocol.Server; + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// Test tool + [McpServerTool] + public static partial string TestMethod(string input) => input; + } + """; + + const string UnrelatedSource1 = """ + namespace Other; + public class Unrelated { public int Value { get; set; } } + """; + + const string UnrelatedSource2 = """ + namespace Other; + public class Unrelated { public int Value { get; set; } public string Name { get; set; } } + """; + + var driver = CreateTrackedDriver(); + + // Run #1 with MCP file + unrelated file + driver = driver.RunGenerators(CreateCompilation(McpSource, UnrelatedSource1), TestContext.Current.CancellationToken); + var result1 = driver.GetRunResult(); + Assert.Single(result1.Results); + var output1 = result1.Results[0].GeneratedSources[0].SourceText.ToString(); + + // Run #2 with MCP file + MODIFIED unrelated file + driver = driver.RunGenerators(CreateCompilation(McpSource, UnrelatedSource2), TestContext.Current.CancellationToken); + var result2 = driver.GetRunResult(); + Assert.Single(result2.Results); + var output2 = result2.Results[0].GeneratedSources[0].SourceText.ToString(); + + // Output should be identical + Assert.Equal(output1, output2); + + // Check that ForAttributeWithMetadataName steps for the MCP method are cached + var forAttributeSteps = result2.Results[0].TrackedSteps + .Where(kvp => kvp.Key.Contains("ForAttributeWithMetadataName")) + .SelectMany(kvp => kvp.Value.SelectMany(step => step.Outputs)) + .ToList(); + + // The MCP method should be cached since it didn't change + Assert.Contains(forAttributeSteps, output => + output.Reason is IncrementalStepRunReason.Cached or IncrementalStepRunReason.Unchanged); + } + + [Fact] + public void Caching_WithXmlDocChange_OutputRegenerated() + { + // Changing XML docs should cause regeneration + const string Source1 = """ + using ModelContextProtocol.Server; + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// Original description + [McpServerTool] + public static partial string TestMethod(string input) => input; + } + """; + + const string Source2 = """ + using ModelContextProtocol.Server; + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// Modified description + [McpServerTool] + public static partial string TestMethod(string input) => input; + } + """; + + var driver = CreateTrackedDriver(); + + // Run #1 + driver = driver.RunGenerators(CreateCompilation(Source1), TestContext.Current.CancellationToken); + var result1 = driver.GetRunResult(); + var output1 = result1.Results[0].GeneratedSources[0].SourceText.ToString(); + Assert.Contains("Original description", output1); + + // Run #2 with modified XML docs + driver = driver.RunGenerators(CreateCompilation(Source2), TestContext.Current.CancellationToken); + var result2 = driver.GetRunResult(); + var output2 = result2.Results[0].GeneratedSources[0].SourceText.ToString(); + Assert.Contains("Modified description", output2); + Assert.DoesNotContain("Original description", output2); + + // Verify that there was actual regeneration (not just cached) + var allOutputs = result2.Results[0].TrackedSteps.Values + .SelectMany(steps => steps.SelectMany(step => step.Outputs)) + .ToList(); + Assert.Contains(allOutputs, output => + output.Reason is IncrementalStepRunReason.Modified or IncrementalStepRunReason.New); + } + + [Fact] + public void Caching_WithAddedMethod_ExistingMethodCached() + { + const string Source1 = """ + using ModelContextProtocol.Server; + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// First tool + [McpServerTool] + public static partial string FirstMethod(string input) => input; + } + """; + + const string Source2 = """ + using ModelContextProtocol.Server; + namespace Test; + + [McpServerToolType] + public partial class TestTools + { + /// First tool + [McpServerTool] + public static partial string FirstMethod(string input) => input; + + /// Second tool + [McpServerTool] + public static partial string SecondMethod(string input) => input; + } + """; + + var driver = CreateTrackedDriver(); + + // Run #1 + driver = driver.RunGenerators(CreateCompilation(Source1), TestContext.Current.CancellationToken); + var result1 = driver.GetRunResult(); + Assert.Single(result1.Results[0].GeneratedSources); + + // Run #2 with added method + driver = driver.RunGenerators(CreateCompilation(Source2), TestContext.Current.CancellationToken); + var result2 = driver.GetRunResult(); + Assert.Single(result2.Results[0].GeneratedSources); + + var output2 = result2.Results[0].GeneratedSources[0].SourceText.ToString(); + Assert.Contains("First tool", output2); + Assert.Contains("Second tool", output2); + + // The ForAttributeWithMetadataName step should have some cached outputs (the first method) + var forAttributeSteps = result2.Results[0].TrackedSteps + .Where(kvp => kvp.Key.Contains("ForAttributeWithMetadataName")) + .SelectMany(kvp => kvp.Value.SelectMany(step => step.Outputs)) + .ToList(); + + // Should have both cached (first method) and new (second method) outputs + Assert.Contains(forAttributeSteps, output => + output.Reason is IncrementalStepRunReason.Cached or IncrementalStepRunReason.Unchanged); + Assert.Contains(forAttributeSteps, output => + output.Reason is IncrementalStepRunReason.New or IncrementalStepRunReason.Modified); + } + + [Fact] + public void Caching_MultipleMethodsAcrossFiles_IndependentCaching() + { + const string File1 = """ + using ModelContextProtocol.Server; + namespace Test; + + [McpServerToolType] + public partial class Tools1 + { + /// Tool in file 1 + [McpServerTool] + public static partial string Method1(string input) => input; + } + """; + + const string File2Original = """ + using ModelContextProtocol.Server; + namespace Test; + + [McpServerToolType] + public partial class Tools2 + { + /// Tool in file 2 + [McpServerTool] + public static partial string Method2(string input) => input; + } + """; + + const string File2Modified = """ + using ModelContextProtocol.Server; + namespace Test; + + [McpServerToolType] + public partial class Tools2 + { + /// Modified tool in file 2 + [McpServerTool] + public static partial string Method2(string input) => input; + } + """; + + var driver = CreateTrackedDriver(); + + // Run #1 + driver = driver.RunGenerators(CreateCompilation(File1, File2Original), TestContext.Current.CancellationToken); + var result1 = driver.GetRunResult(); + Assert.Single(result1.Results[0].GeneratedSources); + + // Run #2 - only File2 changed + driver = driver.RunGenerators(CreateCompilation(File1, File2Modified), TestContext.Current.CancellationToken); + var result2 = driver.GetRunResult(); + Assert.Single(result2.Results[0].GeneratedSources); + + var output2 = result2.Results[0].GeneratedSources[0].SourceText.ToString(); + Assert.Contains("Tool in file 1", output2); // Unchanged + Assert.Contains("Modified tool in file 2", output2); // Changed + + // Method1 extraction should be cached, Method2 should be modified + var forAttributeSteps = result2.Results[0].TrackedSteps + .Where(kvp => kvp.Key.Contains("ForAttributeWithMetadataName")) + .SelectMany(kvp => kvp.Value.SelectMany(step => step.Outputs)) + .ToList(); + + Assert.Contains(forAttributeSteps, output => + output.Reason is IncrementalStepRunReason.Cached or IncrementalStepRunReason.Unchanged); + Assert.Contains(forAttributeSteps, output => + output.Reason is IncrementalStepRunReason.Modified or IncrementalStepRunReason.New); + } + + /// + /// Creates a compilation with the specified source and standard references. + /// Each call creates a NEW compilation instance to ensure we're testing value equality, not reference equality. + /// + private static CSharpCompilation CreateCompilation(params string[] sources) + { + var syntaxTrees = sources.Select(s => CSharpSyntaxTree.ParseText(s)).ToArray(); + + var runtimePath = Path.GetDirectoryName(typeof(object).Assembly.Location)!; + List referenceList = + [ + MetadataReference.CreateFromFile(typeof(object).Assembly.Location), + MetadataReference.CreateFromFile(typeof(System.ComponentModel.DescriptionAttribute).Assembly.Location), + MetadataReference.CreateFromFile(Path.Combine(runtimePath, "System.Runtime.dll")), + MetadataReference.CreateFromFile(Path.Combine(runtimePath, "netstandard.dll")), + ]; + + try + { + var coreAssemblyPath = Path.Combine(AppContext.BaseDirectory, "ModelContextProtocol.Core.dll"); + if (File.Exists(coreAssemblyPath)) + { + referenceList.Add(MetadataReference.CreateFromFile(coreAssemblyPath)); + } + } + catch + { + // Ignore + } + + return CSharpCompilation.Create( + "TestAssembly", + syntaxTrees, + referenceList, + new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)); + } + + /// + /// Creates a generator driver with step tracking enabled. + /// + private static GeneratorDriver CreateTrackedDriver() => + CSharpGeneratorDriver.Create( + generators: [new XmlToDescriptionGenerator().AsSourceGenerator()], + driverOptions: new GeneratorDriverOptions( + disabledOutputs: default, + trackIncrementalGeneratorSteps: true)); }