Skip to content

Commit 9eacb4a

Browse files
authored
Fix propagation of FiberRef modified within a ZQuery (#515)
1 parent 4e9262e commit 9eacb4a

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

zio-query/shared/src/main/scala/zio/query/ZQuery.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -639,17 +639,16 @@ final class ZQuery[-R, +E, +A] private (private val step: ZIO[R, Nothing, Result
639639
}
640640
state.setFiberRefs(newRefs)
641641
restore(runToZIO).exitWith { exit =>
642-
val curRefs = state.getFiberRefs(false)
642+
var curRefs = state.getFiberRefs(false)
643643
if (curRefs eq newRefs) {
644644
// Cheap and common: FiberRefs were not modified during the execution so we just replace them with the old ones
645645
state.setFiberRefs(oldRefs)
646646
} else {
647-
// FiberRefs were mdified so we need to manually revert each one
648-
var revertedRefs = oldRefs
649-
revertedRefs = resetRef(fid, oldRefs, revertedRefs)(currentCache)
650-
revertedRefs = resetRef(fid, oldRefs, revertedRefs)(currentScope)
651-
revertedRefs = resetRef(fid, oldRefs, revertedRefs)(disabledCache)
652-
state.setFiberRefs(revertedRefs)
647+
// FiberRefs were modified so we need to manually revert each one
648+
curRefs = resetRef(fid, oldRefs, curRefs)(currentCache)
649+
curRefs = resetRef(fid, oldRefs, curRefs)(currentScope)
650+
curRefs = resetRef(fid, oldRefs, curRefs)(disabledCache)
651+
state.setFiberRefs(curRefs)
653652
}
654653
scope.closeAndExitWith(exit)
655654
}

zio-query/shared/src/test/scala/zio/query/ZQuerySpec.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,18 @@ object ZQuerySpec extends ZIOBaseSpec {
402402
} yield (c1, c2)
403403

404404
q.run.map { case (c1, c2) => assertTrue(c1 != QueryScope.NoOp, c1 == c2) }
405+
},
406+
test("propagates FiberRef changes") {
407+
val ref = FiberRef.unsafe.make("a")(Unsafe)
408+
for {
409+
_ <- ZQuery.fromZIO(ref.update(_ + "b")).run
410+
res1 <- ref.get
411+
_ <- ZQuery.fromZIO(ref.update(_ + "c")).run
412+
res2 <- ref.get
413+
_ <- ref.set("d")
414+
_ <- ZQuery.fromZIO(ref.update(_ + "e")).run
415+
res3 <- ref.get
416+
} yield assertTrue(res1 == "ab", res2 == "abc", res3 == "de")
405417
}
406418
),
407419
suite("catchAllZIO")(

0 commit comments

Comments
 (0)