Skip to content

Commit

Permalink
Ensure ZPool finalizers are run when resource acquisition fails (zio#…
Browse files Browse the repository at this point in the history
  • Loading branch information
kyri-petrou authored Sep 3, 2024
1 parent 36587a6 commit 3a3151f
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 35 deletions.
51 changes: 36 additions & 15 deletions core-tests/shared/src/test/scala/zio/ZPoolSpec.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package zio

import zio.test.TestAspect.{exceptJS, nonFlaky}
import zio.test.TestAspect._
import zio.test._

object ZPoolSpec extends ZIOBaseSpec {
Expand All @@ -14,7 +14,7 @@ object ZPoolSpec extends ZIOBaseSpec {
_ <- count.get.repeatUntil(_ == 10)
value <- count.get
} yield assertTrue(value == 10)
} +
} @@ exceptJS(nonFlaky) +
test("cleans up items when shut down") {
for {
count <- Ref.make(0)
Expand All @@ -25,7 +25,7 @@ object ZPoolSpec extends ZIOBaseSpec {
_ <- scope.close(Exit.succeed(()))
value <- count.get
} yield assertTrue(value == 0)
} +
} @@ exceptJS(nonFlaky) +
test("acquire one item") {
for {
count <- Ref.make(0)
Expand All @@ -34,7 +34,7 @@ object ZPoolSpec extends ZIOBaseSpec {
_ <- count.get.repeatUntil(_ == 10)
item <- ZIO.scoped(pool.get)
} yield assertTrue(item == 1)
} +
} @@ exceptJS(nonFlaky) +
test("reports failures via get") {
for {
count <- Ref.make(0)
Expand All @@ -53,7 +53,7 @@ object ZPoolSpec extends ZIOBaseSpec {
_ <- ZIO.collectAll(List.fill(10)(pool.get))
result <- Live.live(ZIO.scoped(pool.get).disconnect.timeout(1.millis))
} yield assertTrue(result == None)
} +
} @@ exceptJS(nonFlaky) +
test("reuse released items") {
for {
count <- Ref.make(0)
Expand All @@ -62,7 +62,7 @@ object ZPoolSpec extends ZIOBaseSpec {
_ <- ZIO.scoped(pool.get).repeatN(99)
result <- count.get
} yield assertTrue(result == 10)
} +
} @@ exceptJS(nonFlaky) +
test("invalidate item") {
for {
count <- Ref.make(0)
Expand All @@ -73,7 +73,7 @@ object ZPoolSpec extends ZIOBaseSpec {
result <- ZIO.scoped(pool.get)
value <- count.get
} yield assertTrue(result == 2 && value == 10)
} +
} @@ exceptJS(nonFlaky) +
test("reallocate invalidated item on concurrent demand and maxed out pool size") {
ZIO.scoped {
for {
Expand All @@ -87,7 +87,7 @@ object ZPoolSpec extends ZIOBaseSpec {
finalizedCount <- finalized.get
} yield assertTrue(allocatedCount == 10 && finalizedCount == 10)
}
} +
} @@ exceptJS(nonFlaky) +
test("invalidate all items in pool and check that pool.get doesn't hang forever") {
for {
allocated <- Ref.make(0)
Expand All @@ -101,13 +101,13 @@ object ZPoolSpec extends ZIOBaseSpec {
allocatedCount <- allocated.get
finalizedCount <- finalized.get
} yield assertTrue(result == 3 && allocatedCount == 4 && finalizedCount == 2)
} +
} @@ exceptJS(nonFlaky) +
test("retry on failed acquire should not exhaust pool") {
for {
pool <- ZPool.make(ZIO.fail(new Exception("err")).as(1), 0 to 1, Duration.Infinity)
error <- Live.live(ZIO.scoped(pool.get.retryN(5)).timeoutFail(new Exception("timeout"))(1.second).flip)
} yield assertTrue(error.getMessage == "err")
} +
} @@ exceptJS(nonFlaky) +
test("compositional retry") {
def cond(i: Int) = if (i <= 10) ZIO.fail(i) else ZIO.succeed(i)
for {
Expand All @@ -117,7 +117,7 @@ object ZPoolSpec extends ZIOBaseSpec {
_ <- count.get.repeatUntil(_ == 10)
result <- ZIO.scoped(pool.get).eventually
} yield assertTrue(result == 11)
} +
} @@ exceptJS(nonFlaky) +
test("max pool size") {
for {
promise <- Promise.make[Nothing, Unit]
Expand All @@ -131,7 +131,7 @@ object ZPoolSpec extends ZIOBaseSpec {
_ <- TestClock.adjust(60.seconds)
min <- count.get
} yield assertTrue(min == 10 && max == 15)
} +
} @@ exceptJS(nonFlaky(10)) +
test("shutdown robustness") {
for {
count <- Ref.make(0)
Expand All @@ -157,14 +157,35 @@ object ZPoolSpec extends ZIOBaseSpec {
for {
_ <- ZIO.scoped(ZPool.make(ZIO.unit, 10 to 15, 60.seconds).uninterruptible)
} yield assertCompletes
} +
} @@ exceptJS(nonFlaky) +
test("make preserves interruptibility") {
val get = ZIO.checkInterruptible(status => ZIO.succeed(status.isInterruptible))
for {
pool <- ZPool.make(get, 10 to 15, 60.seconds)
_ <- pool.get.repeatN(9)
interruptible <- pool.get
} yield assertTrue(interruptible)
}
}
} @@ exceptJS(nonFlaky) +
test("doesn't leak resources when acquisition is interrupted") {
for {
latch <- Promise.make[Nothing, Unit]
latch2 <- Promise.make[Nothing, Unit]
incCounter = ZIO.acquireRelease(ZIO.unit)(_ => latch2.succeed(()))
pool <- ZPool.make(incCounter <* latch.succeed(()) <* ZIO.never, 0 to 100, Duration.Infinity)
f <- ZIO.scoped(pool.get).fork
_ <- latch.await
_ <- f.interrupt
_ <- latch2.await
} yield assertCompletes
} @@ withLiveClock @@ exceptJS(nonFlaky(1000)) +
test("doesn't leak resources when acquisition failed") {
for {
latch <- Promise.make[Nothing, Unit]
incCounter = ZIO.acquireRelease(ZIO.unit)(_ => latch.succeed(()))
pool <- ZPool.make(incCounter <* ZIO.fail("oh no"), 0 to 100, Duration.Infinity)
_ <- ZIO.scoped(pool.get).ignore
_ <- latch.await
} yield assertCompletes
} @@ withLiveClock @@ exceptJS(nonFlaky(1000))
}.provideLayer(Scope.default) @@ timeout(30.seconds)
}
34 changes: 14 additions & 20 deletions core/shared/src/main/scala/zio/ZPool.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,15 @@ object ZPool {
case Exit.Success(item) =>
invalidated.get.flatMap { set =>
if (set.contains(item)) finalizeInvalid(attempted) *> acquire
else ZIO.succeed(attempted)
else Exit.succeed(attempted)
}
case _ =>
ZIO.succeed(attempted)
state.modify { case State(size, free) =>
if (size <= range.start)
attempted.finalizer *> allocate -> State(size, free + 1)
else
attempted.finalizer -> State(size - 1, free)
}.flatten *> Exit.succeed(attempted)
}
},
State(size, free - 1)
Expand All @@ -201,14 +206,8 @@ object ZPool {
track(attempted.result) *>
getAndShutdown.whenZIO(isShuttingDown.get)
}

case Exit.Failure(_) =>
state.modify { case State(size, free) =>
if (size <= range.start)
allocate -> State(size, free + 1)
else
ZIO.unit -> State(size - 1, free)
}.flatten
case _ =>
Exit.unit // Handled during acquire
}

def finalizeInvalid(attempted: Attempted[E, A]): UIO[Any] =
Expand All @@ -224,19 +223,14 @@ object ZPool {
def allocate: UIO[Any] =
for {
scope <- Scope.make
exit <- restore(scope.extend(creator)).exit
attempted <- ZIO.succeed(Attempted(exit, scope.close(Exit.succeed(()))))
exit <- scope.extend(restore(creator)).exit
attempted <- ZIO.succeed(Attempted(exit, scope.close(exit)))
_ <- items.offer(attempted)
_ <- track(attempted.result)
_ <- getAndShutdown.whenZIO(isShuttingDown.get)
} yield attempted

for {
releaseAndAttempted <- ZIO.acquireRelease(acquire)(release(_)).withEarlyRelease.disconnect
(release, attempted) = releaseAndAttempted
_ <- release.when(attempted.isFailure)
item <- attempted.toZIO
} yield item
ZIO.acquireRelease(acquire)(release).flatMap(_.result).disconnect
}

/**
Expand All @@ -250,8 +244,8 @@ object ZPool {
(
for {
scope <- Scope.make
exit <- restore(scope.extend(creator)).exit
attempted <- ZIO.succeed(Attempted(exit, scope.close(Exit.succeed(()))))
exit <- scope.extend(restore(creator)).exit
attempted <- ZIO.succeed(Attempted(exit, scope.close(exit)))
_ <- items.offer(attempted)
_ <- track(attempted.result)
_ <- getAndShutdown.whenZIO(isShuttingDown.get)
Expand Down

0 comments on commit 3a3151f

Please sign in to comment.