Skip to content

Commit

Permalink
Detect and report inconsistencies in registration lifetimes
Browse files Browse the repository at this point in the history
A given type may match more than once when using convention registrations (i.e. typeof(IDisposable) and separately, by regex).

Ideally there should be no overlap for the same concrete type, as that may cause weird lifetime bugs due to the first one to register the implementation type to "win" (since we use TryAddXXX). So it should be a warning to have a case where this happens.

Fixes #115
  • Loading branch information
kzu committed Dec 6, 2024
1 parent 4db4b1b commit 529d225
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 10 deletions.
69 changes: 65 additions & 4 deletions src/CodeAnalysis.Tests/ConventionAnalyzerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
using System.Collections.Immutable;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Devlooped.Extensions.DependencyInjection;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Testing;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Testing;
using Xunit;
using Xunit.Abstractions;
Expand Down Expand Up @@ -54,7 +57,7 @@ public static void Main()
.AddPackages(ImmutableArray.Create(
new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0")))
},
}.WithPreprocessorSymbols();
};

var expected = Verifier.Diagnostic(ConventionsAnalyzer.AssignableTypeOfRequired).WithLocation(0);
test.ExpectedDiagnostics.Add(expected);
Expand Down Expand Up @@ -98,7 +101,7 @@ public static void Main()
.AddPackages(ImmutableArray.Create(
new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0")))
},
}.WithPreprocessorSymbols();
};

//var expected = Verifier.Diagnostic(ConventionsAnalyzer.AssignableTypeOfRequired).WithLocation(0);
//test.ExpectedDiagnostics.Add(expected);
Expand Down Expand Up @@ -145,12 +148,70 @@ public static void Main()
.AddPackages(ImmutableArray.Create(
new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0")))
},
}.WithPreprocessorSymbols();
};

var expected = Verifier.Diagnostic(ConventionsAnalyzer.OpenGenericType).WithLocation(0);
test.ExpectedDiagnostics.Add(expected);

await test.RunAsync();
}

}
[Fact]
public async Task WarnIfAmbiguousLifetime()
{
var test = new CSharpSourceGeneratorTest<IncrementalGenerator, DefaultVerifier>
{
TestBehaviors = TestBehaviors.SkipGeneratedSourcesCheck,
TestCode =
"""
using System;
using Microsoft.Extensions.DependencyInjection;
public interface IRepository { }
public class MyRepository : IRepository { }
public static class Program
{
public static void Main()
{
var services = new ServiceCollection();
{|#0:services.AddServices(typeof(IRepository), ServiceLifetime.Scoped)|};
{|#1:services.AddServices("Repository", ServiceLifetime.Singleton)|};
}
}
""",
TestState =
{
AnalyzerConfigFiles =
{
("/.editorconfig",
"""
is_global = true
build_property.AddServicesExtension = true
""")
},
Sources =
{
StaticGenerator.AddServicesExtension,
StaticGenerator.ServiceAttribute,
StaticGenerator.ServiceAttributeT,
},
ReferenceAssemblies = new ReferenceAssemblies(
"net8.0",
new PackageIdentity(
"Microsoft.NETCore.App.Ref", "8.0.0"),
Path.Combine("ref", "net8.0"))
.AddPackages(ImmutableArray.Create(
new PackageIdentity("Microsoft.Extensions.DependencyInjection", "8.0.0")))
},
};

var expected = Verifier.Diagnostic(IncrementalGenerator.AmbiguousLifetime)
.WithArguments("MyRepository", "Scoped and Singleton")
.WithLocation(0).WithLocation(1);

test.ExpectedDiagnostics.Add(expected);

await test.RunAsync();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
<PackageReference Include="System.Composition.AttributedModel" Version="8.0.0" />
<PackageReference Include="System.Composition.Hosting" Version="8.0.0" />
<PackageReference Include="System.Composition.TypedParts" Version="8.0.0" />
<PackageReference Include="Microsoft.Bcl.HashCode" Version="6.0.0" GeneratePathProperty="true" />
<PackageReference Include="Microsoft.Bcl.TimeProvider" Version="8.0.1" GeneratePathProperty="true" />
</ItemGroup>

<ItemGroup>
Expand All @@ -32,6 +34,11 @@
<Import Project="..\DependencyInjection\Devlooped.Extensions.DependencyInjection.targets" />
<Import Project="..\SponsorLink\SponsorLink.Analyzer.Tests.targets" />

<ItemGroup>
<Analyzer Include="$(PkgMicrosoft_Bcl_HashCode)\lib\netstandard2.0\Microsoft.Bcl.HashCode.dll" />
<Analyzer Include="$(PkgMicrosoft_Bcl_TimeProvider)\lib\netstandard2.0\Microsoft.Bcl.TimeProvider.dll" />
</ItemGroup>

<!-- Force immediate reporting of status, no install-time grace period -->
<PropertyGroup>
<SponsorLinkNoInstallGrace>true</SponsorLinkNoInstallGrace>
Expand Down
45 changes: 39 additions & 6 deletions src/DependencyInjection/IncrementalGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.Extensions.DependencyInjection;
using KeyedService = (Microsoft.CodeAnalysis.INamedTypeSymbol Type, Microsoft.CodeAnalysis.TypedConstant? Key);

Expand All @@ -22,11 +21,21 @@ namespace Devlooped.Extensions.DependencyInjection;
[Generator(LanguageNames.CSharp)]
public class IncrementalGenerator : IIncrementalGenerator
{
class ServiceSymbol(INamedTypeSymbol type, int lifetime, TypedConstant? key)
public static DiagnosticDescriptor AmbiguousLifetime { get; } =
new DiagnosticDescriptor(
"DDI004",
"Ambiguous lifetime registration.",
"More than one registration matches {0} with lifetimes {1}.",
"Build",
DiagnosticSeverity.Warning,
isEnabledByDefault: true);

class ServiceSymbol(INamedTypeSymbol type, int lifetime, TypedConstant? key, Location? location)
{
public INamedTypeSymbol Type => type;
public int Lifetime => lifetime;
public TypedConstant? Key => key;
public Location? Location => location;

public override bool Equals(object? obj)
{
Expand All @@ -42,7 +51,7 @@ public override int GetHashCode()
=> HashCode.Combine(SymbolEqualityComparer.Default.GetHashCode(type), lifetime, key);
}

record ServiceRegistration(int Lifetime, TypeSyntax? AssignableTo, string? FullNameExpression)
record ServiceRegistration(int Lifetime, TypeSyntax? AssignableTo, string? FullNameExpression, Location? Location)
{
Regex? regex;

Expand Down Expand Up @@ -175,7 +184,7 @@ bool IsExport(AttributeData attr)
}
}

services.Add(new(x, lifetime, key));
services.Add(new(x, lifetime, key, attr.ApplicationSyntaxReference?.GetSyntax().GetLocation()));
}

return services.ToImmutableArray();
Expand Down Expand Up @@ -220,7 +229,7 @@ bool IsExport(AttributeData attr)
if (registration!.FullNameExpression != null && !registration.Regex.IsMatch(typeSymbol.ToFullName(compilation)))
continue;

results.Add(new ServiceSymbol(typeSymbol, registration.Lifetime, null));
results.Add(new ServiceSymbol(typeSymbol, registration.Lifetime, null, registration.Location));
}

return results.ToImmutable();
Expand Down Expand Up @@ -259,6 +268,30 @@ void RegisterServicesOutput(IncrementalGeneratorInitializationContext context, I
context.RegisterImplementationSourceOutput(
services.Where(x => x!.Lifetime == 2 && x.Key is not null).Select((x, _) => new KeyedService(x!.Type, x.Key!)).Collect().Combine(compilation),
(ctx, data) => AddPartial("AddKeyedTransient", ctx, data));

context.RegisterImplementationSourceOutput(services.Collect(), ReportInconsistencies);
}

void ReportInconsistencies(SourceProductionContext context, ImmutableArray<ServiceSymbol> array)
{
var grouped = array.GroupBy(x => x.Type, SymbolEqualityComparer.Default).Where(g => g.Count() > 1).ToImmutableArray();
if (grouped.Length == 0)
return;

foreach (var group in grouped)
{
// report if within the group, there are different lifetimes with the same key (or no key)
foreach (var keyed in group.GroupBy(x => x.Key?.Value).Where(g => g.Count() > 1))
{
var lifetimes = string.Join(", ", keyed.Select(x => x.Lifetime).Distinct().Select(x => x switch { 0 => "Singleton", 1 => "Scoped", 2 => "Transient", _ => "Unknown" });

var location = keyed.Where(x => x.Location != null).FirstOrDefault()?.Location;
var otherLocations = keyed.Where(x => x.Location != null).Skip(1).Select(x => x.Location!);

context.ReportDiagnostic(Diagnostic.Create(AmbiguousLifetime,
location, otherLocations, keyed.First().Type.ToDisplayString(), lifetimes));
}
}
}

static string? GetInvokedMethodName(InvocationExpressionSyntax invocation) => invocation.Expression switch
Expand Down Expand Up @@ -330,7 +363,7 @@ void RegisterServicesOutput(IncrementalGeneratorInitializationContext context, I

if (assignableTo != null || fullNameExpression != null)
{
return new ServiceRegistration(lifetime, assignableTo, fullNameExpression);
return new ServiceRegistration(lifetime, assignableTo, fullNameExpression, invocation.GetLocation());
}
}
return null;
Expand Down

0 comments on commit 529d225

Please sign in to comment.