diff --git a/src/CodeAnalysis.Tests/ConventionAnalyzerTests.cs b/src/CodeAnalysis.Tests/ConventionAnalyzerTests.cs index 9ee32d9..edfbe9d 100644 --- a/src/CodeAnalysis.Tests/ConventionAnalyzerTests.cs +++ b/src/CodeAnalysis.Tests/ConventionAnalyzerTests.cs @@ -3,10 +3,13 @@ using System.Collections.Immutable; using System.IO; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Devlooped.Extensions.DependencyInjection; +using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Testing; +using Microsoft.CodeAnalysis.Diagnostics; using Microsoft.CodeAnalysis.Testing; using Xunit; using Xunit.Abstractions; @@ -54,7 +57,7 @@ public static void Main() .AddPackages(ImmutableArray.Create( new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0"))) }, - }.WithPreprocessorSymbols(); + }; var expected = Verifier.Diagnostic(ConventionsAnalyzer.AssignableTypeOfRequired).WithLocation(0); test.ExpectedDiagnostics.Add(expected); @@ -98,7 +101,7 @@ public static void Main() .AddPackages(ImmutableArray.Create( new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0"))) }, - }.WithPreprocessorSymbols(); + }; //var expected = Verifier.Diagnostic(ConventionsAnalyzer.AssignableTypeOfRequired).WithLocation(0); //test.ExpectedDiagnostics.Add(expected); @@ -145,7 +148,7 @@ public static void Main() .AddPackages(ImmutableArray.Create( new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0"))) }, - }.WithPreprocessorSymbols(); + }; var expected = Verifier.Diagnostic(ConventionsAnalyzer.OpenGenericType).WithLocation(0); test.ExpectedDiagnostics.Add(expected); @@ -153,4 +156,62 @@ public static void Main() await test.RunAsync(); } -} + [Fact] + public async Task WarnIfAmbiguousLifetime() + { + var test = new CSharpSourceGeneratorTest + { + TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck, + TestCode = + """ + using System; + using Microsoft.Extensions.DependencyInjection; + + public interface IRepository { } + public class MyRepository : IRepository { } + + public static class Program + { + public static void Main() + { + var services = new ServiceCollection(); + {|#0:services.AddServices(typeof(IRepository), ServiceLifetime.Scoped)|}; + {|#1:services.AddServices("Repository", ServiceLifetime.Singleton)|}; + } + } + """, + TestState = + { + AnalyzerConfigFiles = + { + ("/.editorconfig", + """ + is_global = true + build_property.AddServicesExtension = true + """) + }, + Sources = + { + StaticGenerator.AddServicesExtension, + StaticGenerator.ServiceAttribute, + StaticGenerator.ServiceAttributeT, + }, + ReferenceAssemblies = new ReferenceAssemblies( + "net8.0", + new PackageIdentity( + "Microsoft.NETCore.App.Ref", "8.0.0"), + Path.Combine("ref", "net8.0")) + .AddPackages(ImmutableArray.Create( + new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0"))) + }, + }; + + var expected = Verifier.Diagnostic(IncrementalGenerator.AmbiguousLifetime) + .WithArguments("MyRepository", "Scoped and Singleton") + .WithLocation(0).WithLocation(1); + + test.ExpectedDiagnostics.Add(expected); + + await test.RunAsync(); + } +} \ No newline at end of file diff --git a/src/DependencyInjection.Tests/DependencyInjection.Tests.csproj b/src/DependencyInjection.Tests/DependencyInjection.Tests.csproj index bbeb3ef..390b88f 100644 --- a/src/DependencyInjection.Tests/DependencyInjection.Tests.csproj +++ b/src/DependencyInjection.Tests/DependencyInjection.Tests.csproj @@ -17,6 +17,8 @@ + + @@ -32,6 +34,11 @@ + + + + + true diff --git a/src/DependencyInjection/IncrementalGenerator.cs b/src/DependencyInjection/IncrementalGenerator.cs index dd9ee2d..069594d 100644 --- a/src/DependencyInjection/IncrementalGenerator.cs +++ b/src/DependencyInjection/IncrementalGenerator.cs @@ -9,7 +9,6 @@ using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Diagnostics; using Microsoft.Extensions.DependencyInjection; using KeyedService = (Microsoft.CodeAnalysis.INamedTypeSymbol Type, Microsoft.CodeAnalysis.TypedConstant? Key); @@ -22,11 +21,21 @@ namespace Devlooped.Extensions.DependencyInjection; [Generator(LanguageNames.CSharp)] public class IncrementalGenerator : IIncrementalGenerator { - class ServiceSymbol(INamedTypeSymbol type, int lifetime, TypedConstant? key) + public static DiagnosticDescriptor AmbiguousLifetime { get; } = + new DiagnosticDescriptor( + "DDI004", + "Ambiguous lifetime registration.", + "More than one registration matches {0} with lifetimes {1}.", + "Build", + DiagnosticSeverity.Warning, + isEnabledByDefault: true); + + class ServiceSymbol(INamedTypeSymbol type, int lifetime, TypedConstant? key, Location? location) { public INamedTypeSymbol Type => type; public int Lifetime => lifetime; public TypedConstant? Key => key; + public Location? Location => location; public override bool Equals(object? obj) { @@ -42,7 +51,7 @@ public override int GetHashCode() => HashCode.Combine(SymbolEqualityComparer.Default.GetHashCode(type), lifetime, key); } - record ServiceRegistration(int Lifetime, TypeSyntax? AssignableTo, string? FullNameExpression) + record ServiceRegistration(int Lifetime, TypeSyntax? AssignableTo, string? FullNameExpression, Location? Location) { Regex? regex; @@ -175,7 +184,7 @@ bool IsExport(AttributeData attr) } } - services.Add(new(x, lifetime, key)); + services.Add(new(x, lifetime, key, attr.ApplicationSyntaxReference?.GetSyntax().GetLocation())); } return services.ToImmutableArray(); @@ -220,7 +229,7 @@ bool IsExport(AttributeData attr) if (registration!.FullNameExpression != null && !registration.Regex.IsMatch(typeSymbol.ToFullName(compilation))) continue; - results.Add(new ServiceSymbol(typeSymbol, registration.Lifetime, null)); + results.Add(new ServiceSymbol(typeSymbol, registration.Lifetime, null, registration.Location)); } return results.ToImmutable(); @@ -259,6 +268,30 @@ void RegisterServicesOutput(IncrementalGeneratorInitializationContext context, I context.RegisterImplementationSourceOutput( services.Where(x => x!.Lifetime == 2 && x.Key is not null).Select((x, _) => new KeyedService(x!.Type, x.Key!)).Collect().Combine(compilation), (ctx, data) => AddPartial("AddKeyedTransient", ctx, data)); + + context.RegisterImplementationSourceOutput(services.Collect(), ReportInconsistencies); + } + + void ReportInconsistencies(SourceProductionContext context, ImmutableArray array) + { + var grouped = array.GroupBy(x => x.Type, SymbolEqualityComparer.Default).Where(g => g.Count() > 1).ToImmutableArray(); + if (grouped.Length == 0) + return; + + foreach (var group in grouped) + { + // report if within the group, there are different lifetimes with the same key (or no key) + foreach (var keyed in group.GroupBy(x => x.Key?.Value).Where(g => g.Count() > 1)) + { + var lifetimes = string.Join(", ", keyed.Select(x => x.Lifetime).Distinct().Select(x => x switch { 0 => "Singleton", 1 => "Scoped", 2 => "Transient", _ => "Unknown" }); + + var location = keyed.Where(x => x.Location != null).FirstOrDefault()?.Location; + var otherLocations = keyed.Where(x => x.Location != null).Skip(1).Select(x => x.Location!); + + context.ReportDiagnostic(Diagnostic.Create(AmbiguousLifetime, + location, otherLocations, keyed.First().Type.ToDisplayString(), lifetimes)); + } + } } static string? GetInvokedMethodName(InvocationExpressionSyntax invocation) => invocation.Expression switch @@ -330,7 +363,7 @@ void RegisterServicesOutput(IncrementalGeneratorInitializationContext context, I if (assignableTo != null || fullNameExpression != null) { - return new ServiceRegistration(lifetime, assignableTo, fullNameExpression); + return new ServiceRegistration(lifetime, assignableTo, fullNameExpression, invocation.GetLocation()); } } return null;