Skip to content

Commit

Permalink
Manual tuple destructuring in some lemmas for QOI
Browse files Browse the repository at this point in the history
  • Loading branch information
mario-bucev authored and vkuncak committed May 2, 2022
1 parent 37ac32c commit 54a1301
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 17 deletions.
2 changes: 1 addition & 1 deletion qoi/stainless.conf
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
vc-cache = true
timeout = 180
timeout = 300
strict-arithmetic = false
batched = true
solvers = "smt-cvc4,smt-z3,no-inc:smt-z3:z3 tactic.default_tactic=smt sat.euf=true"
Expand Down
53 changes: 43 additions & 10 deletions qoi/verified/decoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -705,10 +705,19 @@ object decoder {
require(pxPos0 + chan <= pixels.length)
require(inPos0 < inPos1 && inPos1 <= chunksLen)

val (index1, pixels1, decIter1) = decodeLoopPure(index, pixels, pxPrev, inPos0, inPos1, pxPos0)
val (indexNext, pixelNext, _, decIterNext) = decodeNextPure(index, pixels, pxPrev, inPos0, pxPos0)
val res1 = decodeLoopPure(index, pixels, pxPrev, inPos0, inPos1, pxPos0)
val index1 = res1._1
val pixels1 = res1._2
val decIter1 = res1._3
val resNext = decodeNextPure(index, pixels, pxPrev, inPos0, pxPos0)
val indexNext = resNext._1
val pixelNext = resNext._2
val decIterNext = resNext._4
require(decIterNext.inPos < inPos1 && decIterNext.pxPos < pixels.length && decIterNext.pxPos + chan <= pixels.length)
val (index2, pixels2, decIter2) = decodeLoopPure(indexNext, pixelNext, decIterNext.px, decIterNext.inPos, inPos1, decIterNext.pxPos)
val res2 = decodeLoopPure(indexNext, pixelNext, decIterNext.px, decIterNext.inPos, inPos1, decIterNext.pxPos)
val index2 = res2._1
val pixels2 = res2._2
val decIter2 = res2._3

{
()
Expand Down Expand Up @@ -741,10 +750,16 @@ object decoder {
require(arraysEq(bytes, bytes2, inPos0, untilInPos))

val ctx2 = Ctx(freshCopy(bytes2), w, h, chan)
val (ix1, pix1, decIter1) = decodeLoopPure(index, pixels, pxPrev, inPos0, untilInPos, pxPos0)(using ctx1)
val res1 = decodeLoopPure(index, pixels, pxPrev, inPos0, untilInPos, pxPos0)(using ctx1)
val ix1 = res1._1
val pix1 = res1._2
val decIter1 = res1._3
require(decIter1.pxPos <= pixels.length)
require(decIter1.inPos == untilInPos)
val (ix2, pix2, decIter2) = decodeLoopPure(index, pixels, pxPrev, inPos0, untilInPos, pxPos0)(using ctx2)
val res2 = decodeLoopPure(index, pixels, pxPrev, inPos0, untilInPos, pxPos0)(using ctx2)
val ix2 = res2._1
val pix2 = res2._2
val decIter2 = res2._3

{
assert(ctx2.bytes == bytes2)
Expand Down Expand Up @@ -817,11 +832,20 @@ object decoder {
require(pxPosInv(pxPos0))
require(pxPos0 + chan <= pixels.length)
require(inPos0 < inPos1 && inPos1 < inPos2 && inPos2 <= chunksLen)
val (index1, pixels1, decIter1) = decodeLoopPure(index, pixels, pxPrev, inPos0, inPos1, pxPos0)
val res1 = decodeLoopPure(index, pixels, pxPrev, inPos0, inPos1, pxPos0)
val index1 = res1._1
val pixels1 = res1._2
val decIter1 = res1._3
require(decIter1.pxPos < pixels.length && decIter1.pxPos + chan <= pixels.length)
require(decIter1.inPos == inPos1)
val (index2, pixels2, decIter2) = decodeLoopPure(index1, pixels1, decIter1.px, inPos1, inPos2, decIter1.pxPos)
val (index3, pixels3, decIter3) = decodeLoopPure(index, pixels, pxPrev, inPos0, inPos2, pxPos0)
val res2 = decodeLoopPure(index1, pixels1, decIter1.px, inPos1, inPos2, decIter1.pxPos)
val index2 = res2._1
val pixels2 = res2._2
val decIter2 = res2._3
val res3 = decodeLoopPure(index, pixels, pxPrev, inPos0, inPos2, pxPos0)
val index3 = res3._1
val pixels3 = res3._2
val decIter3 = res3._3

{
val (indexNext, pixelNext, _, decIterNext) = decodeNextPure(index, pixels, pxPrev, inPos0, pxPos0)
Expand Down Expand Up @@ -882,9 +906,18 @@ object decoder {
require(inPos0 < chunksLen)
require(bytes.length == bytes2.length)
val ctx2 = Ctx(freshCopy(bytes2), w, h, chan)
val (ix1, pix1, res1, decIter1) = decodeNextPure(index, pixels, pxPrev, inPos0, pxPos0)(using ctx1)
val resNext1 = decodeNextPure(index, pixels, pxPrev, inPos0, pxPos0)(using ctx1)
val ix1 = resNext1._1
val pix1 = resNext1._2
val res1 = resNext1._3
val decIter1 = resNext1._4

require(arraysEq(bytes, bytes2, inPos0, decIter1.inPos))
val (ix2, pix2, res2, decIter2) = decodeNextPure(index, pixels, pxPrev, inPos0, pxPos0)(using ctx2)
val resNext2 = decodeNextPure(index, pixels, pxPrev, inPos0, pxPos0)(using ctx2)
val ix2 = resNext2._1
val pix2 = resNext2._2
val res2 = resNext2._3
val decIter2 = resNext2._4

{
doDecodeNextBytesEqLemma(index, pxPrev, inPos0, bytes2)
Expand Down
15 changes: 9 additions & 6 deletions qoi/verified/encoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -302,9 +302,10 @@ object encoder {
assert(decodedPreRec.pixels.length == pixels.length)
assert(decodedPreRec.pixels.length == w * h * chan)
assert((w * h * chan) % chan == 0)
assert(decodedPreRec.pixels.length % chan == 0)
assert(0 <= decodedPreRec.pxPos)
assert(decodedPreRec.pxPos <= w * h * chan)
assert(decoder.pxPosInv(decodedPreRec.pxPos))
assert(decodedPreRec.pixels.length % chan == 0)
assert(decoder.pxPosInv(decodedPreRec.pxPos)) // Very slow (~110s)
val (ix2, pix2, decIter2) = decoder.decodeLoopPure(decodedPreRec.index, decodedPreRec.pixels, px, outPos2, outPosRes, decodedPreRec.pxPos)
assert(decIter2.pxPos == decoded.pxPos)
assert(decIter2.inPos == decoded.inPos)
Expand All @@ -318,8 +319,8 @@ object encoder {
assert(oldDecoded.pixels.length == w * h * chan)
assert((w * h * chan) % chan == 0)
assert(oldDecoded.pixels.length % chan == 0)
assert(0 <= oldDecoded.pxPos && oldDecoded.pxPos <= oldDecoded.pixels.length)
val (ix3, pix3, decIter3) = decoder.decodeLoopPure(oldDecoded.index, oldDecoded.pixels, pxPrev, outPos0, outPosRes, oldDecoded.pxPos)
assert(0 <= oldDecoded.pxPos && oldDecoded.pxPos <= oldDecoded.pixels.length) // Very slow (~120s)
val (ix3, pix3, decIter3) = decoder.decodeLoopPure(oldDecoded.index, oldDecoded.pixels, pxPrev, outPos0, outPosRes, oldDecoded.pxPos) // Precond 4 slow (~85s)

assert(outPosRes <= bytes.length - Padding)
assert(outPos0 < outPos2)
Expand All @@ -343,7 +344,7 @@ object encoder {
assert(ix3 == decoded.index)
assert(pix3 == decoded.pixels)

check(decodeLoopEncodeProp(bytes, pxPrev, outPos0, outPosRes, oldDecoded, pxRes, decoded))
check(decodeLoopEncodeProp(bytes, pxPrev, outPos0, outPosRes, oldDecoded, pxRes, decoded)) // Slow (~70s)
}
check(oldBytes.length == bytes.length)
check(decodeLoopEncodeProp(bytes, pxPrev, outPos0, outPosRes, oldDecoded, pxRes, decoded))
Expand Down Expand Up @@ -795,10 +796,12 @@ object encoder {
decoded.pixels = freshCopy(newDecoded.pixels)
decoded.inPos = newDecoded.inPos
decoded.pxPos = newDecoded.pxPos

check(index == decoded.index)
}

EncodeSingleStepResult(px, outPos2, run1)
}.ensuring { case EncodeSingleStepResult(px, outPos2, run1) =>
}.ensuring { case EncodeSingleStepResult(px, outPos2, run1) => // Wins the "slowest to verify" award (~210s)
// Bytes and index length are unchanged
bytes.length == maxSize &&&
index.length == 64 &&&
Expand Down

0 comments on commit 54a1301

Please sign in to comment.