diff --git a/src/Polly.Core/PredicateBuilder.TResult.cs b/src/Polly.Core/PredicateBuilder.TResult.cs index 521db8fa5f3..fbbb5536291 100644 --- a/src/Polly.Core/PredicateBuilder.TResult.cs +++ b/src/Polly.Core/PredicateBuilder.TResult.cs @@ -36,6 +36,10 @@ public PredicateBuilder Handle(Func predi /// /// The type of the inner exception to handle. /// The same instance of the for chaining. + /// + /// This method will also handle any exception found for of + /// an , or at any level of nesting within an . + /// public PredicateBuilder HandleInner() where TException : Exception => HandleInner(static _ => true); @@ -46,12 +50,46 @@ public PredicateBuilder HandleInner() /// The predicate function to use for handling the inner exception. /// The same instance of the for chaining. /// Thrown when the is . + /// + /// This method will also handle any exception found for of + /// an , or at any level of nesting within an . + /// public PredicateBuilder HandleInner(Func predicate) where TException : Exception { Guard.NotNull(predicate); - return Add(outcome => outcome.Exception?.InnerException is TException innerException && predicate(innerException)); + return Add(outcome => HandleInner(outcome.Exception, predicate)); + + static bool HandleInner(Exception? exception, Func predicate) + { + if (exception is AggregateException aggregate) + { + foreach (var innerException in aggregate.Flatten().InnerExceptions) + { + if (HandleNested(predicate, innerException)) + { + return true; + } + } + } + + return HandleNested(predicate, exception); + + static bool HandleNested(Func predicate, Exception? current) + { + if (current is null) + { + return false; + } + else if (current is TException exceptionOfT) + { + return predicate(exceptionOfT); + } + + return HandleNested(predicate, current.InnerException); + } + } } /// @@ -89,7 +127,7 @@ public PredicateBuilder HandleResult(TResult result, IEqualityComparer< { 0 => throw new InvalidOperationException("No predicates were configured. There must be at least one predicate added."), 1 => _predicates[0], - _ => CreatePredicate(_predicates.ToArray()), + _ => CreatePredicate([.. _predicates]), }; internal Func> Build() @@ -100,19 +138,19 @@ internal Func> Build() return args => new ValueTask(predicate(args.Outcome)); } - private static Predicate> CreatePredicate(Predicate>[] predicates) - => outcome => - { - foreach (var predicate in predicates) - { - if (predicate(outcome)) - { - return true; - } - } - - return false; - }; + private static Predicate> CreatePredicate(Predicate>[] predicates) => + outcome => + { + foreach (var predicate in predicates) + { + if (predicate(outcome)) + { + return true; + } + } + + return false; + }; private PredicateBuilder Add(Predicate> predicate) { diff --git a/test/Polly.Core.Tests/PredicateBuilderTests.cs b/test/Polly.Core.Tests/PredicateBuilderTests.cs index 3d3913b7c3e..1a4876cf7f9 100644 --- a/test/Polly.Core.Tests/PredicateBuilderTests.cs +++ b/test/Polly.Core.Tests/PredicateBuilderTests.cs @@ -9,24 +9,47 @@ public class PredicateBuilderTests { public static TheoryData>, Outcome, bool> HandleResultData = new() { - { builder => builder.HandleResult("val"), Outcome.FromResult("val"), true }, - { builder => builder.HandleResult("val"), Outcome.FromResult("val2"), false }, - { builder => builder.HandleResult("val"), Outcome.FromException(new InvalidOperationException()), false }, - { builder => builder.HandleResult("val", StringComparer.OrdinalIgnoreCase) ,Outcome.FromResult("VAL"), true }, - { builder => builder.HandleResult(r => r == "val"), Outcome.FromResult("val"), true }, - { builder => builder.HandleResult(r => r == "val2"), Outcome.FromResult("val"), false }, - { builder => builder.Handle(), Outcome.FromException(new InvalidOperationException()), true }, - { builder => builder.Handle(), Outcome.FromException(new FormatException()), false }, - { builder => builder.Handle(e => false), Outcome.FromException(new InvalidOperationException()), false }, - { builder => builder.HandleInner(e => false), Outcome.FromException(new InvalidOperationException()), false }, - { builder => builder.HandleInner(), Outcome.FromResult("value"), false }, - { builder => builder.Handle(), Outcome.FromResult("value"), false }, - { builder => builder.Handle().HandleResult("value"), Outcome.FromResult("value"), true }, - { builder => builder.Handle().HandleResult("value"), Outcome.FromResult("value2"), false }, - { builder => builder.HandleInner(), Outcome.FromException(new InvalidOperationException("dummy", new FormatException() )), true }, - { builder => builder.HandleInner(e => false), Outcome.FromException(new InvalidOperationException("dummy", new FormatException() )), false }, - { builder => builder.HandleInner(e => e.Message == "m"), Outcome.FromException(new InvalidOperationException("dummy", new FormatException("m") )), true }, - { builder => builder.HandleInner(e => e.Message == "x"), Outcome.FromException(new InvalidOperationException("dummy", new FormatException("m") )), false }, + { builder => builder.HandleResult("val"), CreateOutcome("val"), true }, + { builder => builder.HandleResult("val"), CreateOutcome("val2"), false }, + { builder => builder.HandleResult("val"), CreateOutcome(new InvalidOperationException()), false }, + { builder => builder.HandleResult("val", StringComparer.OrdinalIgnoreCase), CreateOutcome("VAL"), true }, + { builder => builder.HandleResult(r => r == "val"), CreateOutcome("val"), true }, + { builder => builder.HandleResult(r => r == "val2"), CreateOutcome("val"), false }, + { builder => builder.Handle(), CreateOutcome(new InvalidOperationException()), true }, + { builder => builder.Handle(), CreateOutcome(new FormatException()), false }, + { builder => builder.Handle(e => false), CreateOutcome(new InvalidOperationException()), false }, + { builder => builder.HandleInner(e => false), CreateOutcome(new InvalidOperationException()), false }, + { builder => builder.HandleInner(), CreateOutcome("value"), false }, + { builder => builder.Handle(), CreateOutcome("value"), false }, + { builder => builder.Handle().HandleResult("value"), CreateOutcome("value"), true }, + { builder => builder.Handle().HandleResult("value"), CreateOutcome("value2"), false }, + { builder => builder.HandleInner(), CreateOutcome(new InvalidOperationException("dummy", new FormatException() )), true }, + { builder => builder.HandleInner(e => false), CreateOutcome(new InvalidOperationException("dummy", new FormatException() )), false }, + { builder => builder.HandleInner(e => e.Message == "m"), CreateOutcome(new InvalidOperationException("dummy", new FormatException("m") )), true }, + { builder => builder.HandleInner(e => e.Message == "x"), CreateOutcome(new InvalidOperationException("dummy", new FormatException("m") )), false }, +#pragma warning disable CA2201 + //// See https://github.com/App-vNext/Polly/issues/2161 + { builder => builder.HandleInner(), CreateOutcome(new InvalidOperationException("1")), true }, + { builder => builder.HandleInner(), CreateOutcome(new Exception("1", new InvalidOperationException("2"))), true }, + { builder => builder.HandleInner(), CreateOutcome(new FormatException("1", new InvalidOperationException("2"))), true }, + { builder => builder.HandleInner(), CreateOutcome(new Exception("1", new Exception("2", new InvalidOperationException("3")))), true }, + { builder => builder.HandleInner(), CreateOutcome(new AggregateException("1", new Exception("2a"), new InvalidOperationException("2b"))), true }, + { builder => builder.HandleInner(), CreateOutcome(new AggregateException("1", new Exception("2", new InvalidOperationException("3")))), true }, + { builder => builder.HandleInner(), CreateOutcome(new AggregateException("1", new FormatException("2", new NotSupportedException("3")))), false }, + { builder => builder.HandleInner(), CreateOutcome(new AggregateException("1")), false }, + { builder => builder.HandleInner(ex => ex.Message is "3"), CreateOutcome(new AggregateException("1", new FormatException("2", new NotSupportedException("3")))), false }, + { builder => builder.HandleInner(ex => ex.Message is "unreachable"), CreateOutcome(new AggregateException("1", new FormatException("2", new NotSupportedException("3")))), false }, + { builder => builder.HandleInner(ex => ex.Message is "1"), CreateOutcome(new InvalidOperationException("1")), true }, + { builder => builder.HandleInner(ex => ex.Message is "2"), CreateOutcome(new Exception("1", new InvalidOperationException("2"))), true }, + { builder => builder.HandleInner(ex => ex.Message is "3"), CreateOutcome(new Exception("1", new Exception("2", new InvalidOperationException("3")))), true }, + { builder => builder.HandleInner(ex => ex.Message is "2b"), CreateOutcome(new AggregateException("1", new Exception("2a"), new InvalidOperationException("2b"))), true }, + { builder => builder.HandleInner(ex => ex.Message is "3"), CreateOutcome(new AggregateException("1", new Exception("2", new InvalidOperationException("3")))), true }, + { builder => builder.HandleInner(ex => ex.Message is "unreachable"), CreateOutcome(new InvalidOperationException("1")), false }, + { builder => builder.HandleInner(ex => ex.Message is "unreachable"), CreateOutcome(new Exception("1", new InvalidOperationException("2"))), false }, + { builder => builder.HandleInner(ex => ex.Message is "unreachable"), CreateOutcome(new Exception("1", new Exception("2", new InvalidOperationException("3")))), false }, + { builder => builder.HandleInner(ex => ex.Message is "unreachable"), CreateOutcome(new AggregateException("1", new Exception("2a"), new InvalidOperationException("2b"))), false }, + { builder => builder.HandleInner(ex => ex.Message is "unreachable"), CreateOutcome(new AggregateException("1", new Exception("2", new InvalidOperationException("3")))), false }, +#pragma warning restore CA2201 }; [Fact] @@ -66,7 +89,7 @@ public async Task Operator_RetryStrategyOptions_Ok() ShouldHandle = new PredicateBuilder().HandleResult("error") }; - var handled = await options.ShouldHandle(new RetryPredicateArguments(ResilienceContextPool.Shared.Get(), Outcome.FromResult("error"), 0)); + var handled = await options.ShouldHandle(new RetryPredicateArguments(ResilienceContextPool.Shared.Get(), CreateOutcome("error"), 0)); handled.Should().BeTrue(); } @@ -79,7 +102,7 @@ public async Task Operator_FallbackStrategyOptions_Ok() ShouldHandle = new PredicateBuilder().HandleResult("error") }; - var handled = await options.ShouldHandle(new(ResilienceContextPool.Shared.Get(), Outcome.FromResult("error"))); + var handled = await options.ShouldHandle(new(ResilienceContextPool.Shared.Get(), CreateOutcome("error"))); handled.Should().BeTrue(); } @@ -92,7 +115,7 @@ public async Task Operator_HedgingStrategyOptions_Ok() ShouldHandle = new PredicateBuilder().HandleResult("error") }; - var handled = await options.ShouldHandle(new(ResilienceContextPool.Shared.Get(), Outcome.FromResult("error"))); + var handled = await options.ShouldHandle(new(ResilienceContextPool.Shared.Get(), CreateOutcome("error"))); handled.Should().BeTrue(); } @@ -105,8 +128,14 @@ public async Task Operator_AdvancedCircuitBreakerStrategyOptions_Ok() ShouldHandle = new PredicateBuilder().HandleResult("error") }; - var handled = await options.ShouldHandle(new(ResilienceContextPool.Shared.Get(), Outcome.FromResult("error"))); + var handled = await options.ShouldHandle(new(ResilienceContextPool.Shared.Get(), CreateOutcome("error"))); handled.Should().BeTrue(); } + + private static Outcome CreateOutcome(Exception exception) + => Outcome.FromException(exception); + + private static Outcome CreateOutcome(string result) + => Outcome.FromResult(result); }