From 3a3151f828fd873c4c06a19adbabcf874a7ddbd9 Mon Sep 17 00:00:00 2001 From: kyri-petrou <67301607+kyri-petrou@users.noreply.github.com> Date: Tue, 3 Sep 2024 20:51:23 +0300 Subject: [PATCH] Ensure ZPool finalizers are run when resource acquisition fails (#9163) --- .../shared/src/test/scala/zio/ZPoolSpec.scala | 51 +++++++++++++------ core/shared/src/main/scala/zio/ZPool.scala | 34 +++++-------- 2 files changed, 50 insertions(+), 35 deletions(-) diff --git a/core-tests/shared/src/test/scala/zio/ZPoolSpec.scala b/core-tests/shared/src/test/scala/zio/ZPoolSpec.scala index c149bebe3f6d..2ac13e5a5815 100644 --- a/core-tests/shared/src/test/scala/zio/ZPoolSpec.scala +++ b/core-tests/shared/src/test/scala/zio/ZPoolSpec.scala @@ -1,6 +1,6 @@ package zio -import zio.test.TestAspect.{exceptJS, nonFlaky} +import zio.test.TestAspect._ import zio.test._ object ZPoolSpec extends ZIOBaseSpec { @@ -14,7 +14,7 @@ object ZPoolSpec extends ZIOBaseSpec { _ <- count.get.repeatUntil(_ == 10) value <- count.get } yield assertTrue(value == 10) - } + + } @@ exceptJS(nonFlaky) + test("cleans up items when shut down") { for { count <- Ref.make(0) @@ -25,7 +25,7 @@ object ZPoolSpec extends ZIOBaseSpec { _ <- scope.close(Exit.succeed(())) value <- count.get } yield assertTrue(value == 0) - } + + } @@ exceptJS(nonFlaky) + test("acquire one item") { for { count <- Ref.make(0) @@ -34,7 +34,7 @@ object ZPoolSpec extends ZIOBaseSpec { _ <- count.get.repeatUntil(_ == 10) item <- ZIO.scoped(pool.get) } yield assertTrue(item == 1) - } + + } @@ exceptJS(nonFlaky) + test("reports failures via get") { for { count <- Ref.make(0) @@ -53,7 +53,7 @@ object ZPoolSpec extends ZIOBaseSpec { _ <- ZIO.collectAll(List.fill(10)(pool.get)) result <- Live.live(ZIO.scoped(pool.get).disconnect.timeout(1.millis)) } yield assertTrue(result == None) - } + + } @@ exceptJS(nonFlaky) + test("reuse released items") { for { count <- Ref.make(0) @@ -62,7 +62,7 @@ object ZPoolSpec extends ZIOBaseSpec { _ <- ZIO.scoped(pool.get).repeatN(99) result <- count.get } yield assertTrue(result == 10) - } + + } @@ exceptJS(nonFlaky) + test("invalidate item") { for { count <- Ref.make(0) @@ -73,7 +73,7 @@ object ZPoolSpec extends ZIOBaseSpec { result <- ZIO.scoped(pool.get) value <- count.get } yield assertTrue(result == 2 && value == 10) - } + + } @@ exceptJS(nonFlaky) + test("reallocate invalidated item on concurrent demand and maxed out pool size") { ZIO.scoped { for { @@ -87,7 +87,7 @@ object ZPoolSpec extends ZIOBaseSpec { finalizedCount <- finalized.get } yield assertTrue(allocatedCount == 10 && finalizedCount == 10) } - } + + } @@ exceptJS(nonFlaky) + test("invalidate all items in pool and check that pool.get doesn't hang forever") { for { allocated <- Ref.make(0) @@ -101,13 +101,13 @@ object ZPoolSpec extends ZIOBaseSpec { allocatedCount <- allocated.get finalizedCount <- finalized.get } yield assertTrue(result == 3 && allocatedCount == 4 && finalizedCount == 2) - } + + } @@ exceptJS(nonFlaky) + test("retry on failed acquire should not exhaust pool") { for { pool <- ZPool.make(ZIO.fail(new Exception("err")).as(1), 0 to 1, Duration.Infinity) error <- Live.live(ZIO.scoped(pool.get.retryN(5)).timeoutFail(new Exception("timeout"))(1.second).flip) } yield assertTrue(error.getMessage == "err") - } + + } @@ exceptJS(nonFlaky) + test("compositional retry") { def cond(i: Int) = if (i <= 10) ZIO.fail(i) else ZIO.succeed(i) for { @@ -117,7 +117,7 @@ object ZPoolSpec extends ZIOBaseSpec { _ <- count.get.repeatUntil(_ == 10) result <- ZIO.scoped(pool.get).eventually } yield assertTrue(result == 11) - } + + } @@ exceptJS(nonFlaky) + test("max pool size") { for { promise <- Promise.make[Nothing, Unit] @@ -131,7 +131,7 @@ object ZPoolSpec extends ZIOBaseSpec { _ <- TestClock.adjust(60.seconds) min <- count.get } yield assertTrue(min == 10 && max == 15) - } + + } @@ exceptJS(nonFlaky(10)) + test("shutdown robustness") { for { count <- Ref.make(0) @@ -157,7 +157,7 @@ object ZPoolSpec extends ZIOBaseSpec { for { _ <- ZIO.scoped(ZPool.make(ZIO.unit, 10 to 15, 60.seconds).uninterruptible) } yield assertCompletes - } + + } @@ exceptJS(nonFlaky) + test("make preserves interruptibility") { val get = ZIO.checkInterruptible(status => ZIO.succeed(status.isInterruptible)) for { @@ -165,6 +165,27 @@ object ZPoolSpec extends ZIOBaseSpec { _ <- pool.get.repeatN(9) interruptible <- pool.get } yield assertTrue(interruptible) - } - } + } @@ exceptJS(nonFlaky) + + test("doesn't leak resources when acquisition is interrupted") { + for { + latch <- Promise.make[Nothing, Unit] + latch2 <- Promise.make[Nothing, Unit] + incCounter = ZIO.acquireRelease(ZIO.unit)(_ => latch2.succeed(())) + pool <- ZPool.make(incCounter <* latch.succeed(()) <* ZIO.never, 0 to 100, Duration.Infinity) + f <- ZIO.scoped(pool.get).fork + _ <- latch.await + _ <- f.interrupt + _ <- latch2.await + } yield assertCompletes + } @@ withLiveClock @@ exceptJS(nonFlaky(1000)) + + test("doesn't leak resources when acquisition failed") { + for { + latch <- Promise.make[Nothing, Unit] + incCounter = ZIO.acquireRelease(ZIO.unit)(_ => latch.succeed(())) + pool <- ZPool.make(incCounter <* ZIO.fail("oh no"), 0 to 100, Duration.Infinity) + _ <- ZIO.scoped(pool.get).ignore + _ <- latch.await + } yield assertCompletes + } @@ withLiveClock @@ exceptJS(nonFlaky(1000)) + }.provideLayer(Scope.default) @@ timeout(30.seconds) } diff --git a/core/shared/src/main/scala/zio/ZPool.scala b/core/shared/src/main/scala/zio/ZPool.scala index 6c4dc1032cc6..d81bc8417448 100644 --- a/core/shared/src/main/scala/zio/ZPool.scala +++ b/core/shared/src/main/scala/zio/ZPool.scala @@ -175,10 +175,15 @@ object ZPool { case Exit.Success(item) => invalidated.get.flatMap { set => if (set.contains(item)) finalizeInvalid(attempted) *> acquire - else ZIO.succeed(attempted) + else Exit.succeed(attempted) } case _ => - ZIO.succeed(attempted) + state.modify { case State(size, free) => + if (size <= range.start) + attempted.finalizer *> allocate -> State(size, free + 1) + else + attempted.finalizer -> State(size - 1, free) + }.flatten *> Exit.succeed(attempted) } }, State(size, free - 1) @@ -201,14 +206,8 @@ object ZPool { track(attempted.result) *> getAndShutdown.whenZIO(isShuttingDown.get) } - - case Exit.Failure(_) => - state.modify { case State(size, free) => - if (size <= range.start) - allocate -> State(size, free + 1) - else - ZIO.unit -> State(size - 1, free) - }.flatten + case _ => + Exit.unit // Handled during acquire } def finalizeInvalid(attempted: Attempted[E, A]): UIO[Any] = @@ -224,19 +223,14 @@ object ZPool { def allocate: UIO[Any] = for { scope <- Scope.make - exit <- restore(scope.extend(creator)).exit - attempted <- ZIO.succeed(Attempted(exit, scope.close(Exit.succeed(())))) + exit <- scope.extend(restore(creator)).exit + attempted <- ZIO.succeed(Attempted(exit, scope.close(exit))) _ <- items.offer(attempted) _ <- track(attempted.result) _ <- getAndShutdown.whenZIO(isShuttingDown.get) } yield attempted - for { - releaseAndAttempted <- ZIO.acquireRelease(acquire)(release(_)).withEarlyRelease.disconnect - (release, attempted) = releaseAndAttempted - _ <- release.when(attempted.isFailure) - item <- attempted.toZIO - } yield item + ZIO.acquireRelease(acquire)(release).flatMap(_.result).disconnect } /** @@ -250,8 +244,8 @@ object ZPool { ( for { scope <- Scope.make - exit <- restore(scope.extend(creator)).exit - attempted <- ZIO.succeed(Attempted(exit, scope.close(Exit.succeed(())))) + exit <- scope.extend(restore(creator)).exit + attempted <- ZIO.succeed(Attempted(exit, scope.close(exit))) _ <- items.offer(attempted) _ <- track(attempted.result) _ <- getAndShutdown.whenZIO(isShuttingDown.get)