Skip to content

Commit

Permalink
Add floating point arithmetic and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
agilot committed Feb 10, 2025
1 parent 7a35134 commit 611ee7d
Show file tree
Hide file tree
Showing 8 changed files with 445 additions and 6 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ lazy val nTestParallelism = {
def ghProject(repo: String, version: String) = RootProject(uri(s"${repo}#${version}"))

// lazy val smtlib = RootProject(file("../scala-smtlib")) // If you have a local copy of Scala-SMTLIB and would like to do some changes
lazy val smtlib = ghProject("https://github.com/epfl-lara/scala-smtlib.git", "51a44878858b427f1a4e5a5eb01d8f796898d812")
lazy val smtlib = ghProject("https://github.com/epfl-lara/scala-smtlib.git", "39745509132b01dc3291112c5259f5e77492d42c")

lazy val scriptName = settingKey[String]("Name of the generated 'inox' script")

Expand Down
15 changes: 15 additions & 0 deletions src/main/scala/inox/ast/Deconstructors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ trait TreeDeconstructor {
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
(_, _, _, _, _) => t.BVLiteral(signed, bits, size))
},
classOf[s.FPLiteral] -> { expr =>
val s.FPLiteral(exponent, significand, bits) = expr: @unchecked
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
(_, _, _, _, _) => t.FPLiteral(exponent, significand, bits))
},
classOf[s.IntegerLiteral] -> { expr =>
val s.IntegerLiteral(i) = expr: @unchecked
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
Expand Down Expand Up @@ -297,6 +302,11 @@ trait TreeDeconstructor {
(NoIdentifiers, NoVariables, Seq(e), NoTypes, NoFlags,
(_, _, es, _, _) => t.BVSignedToUnsigned(es(0)))
},
classOf[s.FPEquals] -> { expr =>
val s.FPEquals(t1, t2) = expr: @unchecked
(NoIdentifiers, NoVariables, Seq(t1, t2), NoTypes, NoFlags,
(_, _, es, _, _) => t.FPEquals(es(0), es(1)))
},
classOf[s.Tuple] -> { expr =>
val s.Tuple(args) = expr: @unchecked
(NoIdentifiers, NoVariables, args, NoTypes, NoFlags,
Expand Down Expand Up @@ -458,6 +468,11 @@ trait TreeDeconstructor {
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
(_, _, _, _, _) => t.BVType(signed, size))
},
classOf[s.FPType] -> { tpe =>
val s.FPType(exponent, significand) = tpe: @unchecked
(NoIdentifiers, NoVariables, NoExpressions, NoTypes, NoFlags,
(_, _, _, _, _) => t.FPType(exponent, significand))
},

// @nv: can't use `s.Untyped.getClass` as it is not yet created at this point
scala.reflect.classTag[s.Untyped.type].runtimeClass -> { _ =>
Expand Down
37 changes: 34 additions & 3 deletions src/main/scala/inox/ast/Expressions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -233,18 +233,42 @@ trait Expressions { self: Trees =>
/** $encodingof a floating point literal */
sealed case class FPLiteral(exponent: Int, significand: Int, value: BitSet) extends Literal[BitSet] {
override def getType(using Symbols) = FPType(exponent, significand)
def isNegative: Boolean = !isNaN && value(exponent + significand)
def isPositive: Boolean = !isNaN && !isNegative
def isZero: Boolean = !Range(1, significand + exponent).exists(value)
def isNumber: Boolean = !Range(significand, significand + exponent).forall(value)
def isNaN: Boolean = !isNumber && Range(1, significand).exists(value)
def isInfinite: Boolean = !isNumber && !isNaN
def toBV: BVLiteral = BVLiteral(true, value, exponent + significand)

def strictEquals(obj: Any): Boolean = obj match {
case lit @ FPLiteral(e2, s2, v2) => exponent == e2 && significand == s2 && value == v2
case _ => false
}

/** Semantic equality for FP */
def semEquals(obj: Any): Boolean = obj match
case lit @ FPLiteral(e2, s2, v2) =>
!isNaN && !lit.isNaN && ((isZero && lit.isZero) || strictEquals(obj))
case _ => strictEquals(obj)

override def equals(obj: Any): Boolean = strictEquals(obj)
}

object FPLiteral {
def fromBV(exponent: Int, significand: Int, bv: BVLiteral) = FPLiteral(exponent, significand, bv.value)
def fromBV(exponent: Int, significand: Int, bv: BVLiteral): FPLiteral = FPLiteral(exponent, significand, bv.value)
def plusZero(exponent: Int, significand: Int) = FPLiteral(exponent, significand, BitSet.empty)
def minusZero(exponent: Int, significand: Int) = FPLiteral(exponent, significand, BitSet(exponent + significand))
def NaN(exponent: Int, significand: Int) = FPLiteral(exponent, significand, BitSet(Range(significand - 1, exponent + significand)*))
def minusInfinity(exponent: Int, significand: Int) = FPLiteral(exponent, significand, BitSet(Range(significand, exponent + significand + 1)*))
def plusInfinity(exponent: Int, significand: Int) = FPLiteral(exponent, significand, BitSet(Range(significand, exponent + significand)*))
}

object Float32Literal {
def apply(value: Float): FPLiteral = FPLiteral.fromBV(8, 24, Int32Literal(java.lang.Float.floatToIntBits(value)))

def unapply(e: Expr): Option[Float] = e match {
case f @ FPLiteral(8, 24, b) if b.maxOption.getOrElse(-1) < 32 =>
case f @ FPLiteral(8, 24, b) if b.maxOption.getOrElse(-1) <= 32 =>
f.toBV match {
case Int32Literal(i) => Some(java.lang.Float.intBitsToFloat(i))
case _ => None
Expand All @@ -257,7 +281,7 @@ trait Expressions { self: Trees =>
def apply(value: Double): FPLiteral = FPLiteral.fromBV(11, 53, Int64Literal(java.lang.Double.doubleToLongBits(value)))

def unapply(e: Expr): Option[Double] = e match {
case f @ FPLiteral(11, 53, b) if b.maxOption.getOrElse(-1) < 64 =>
case f @ FPLiteral(11, 53, b) if b.maxOption.getOrElse(-1) <= 64 =>
f.toBV match {
case Int64Literal(i) => Some(java.lang.Double.longBitsToDouble(i))
case _ => None
Expand Down Expand Up @@ -617,6 +641,13 @@ trait Expressions { self: Trees =>
}
}

/* FP operaions */

sealed case class FPEquals(lhs: Expr, rhs: Expr) extends Expr with CachingTyped {
override protected def computeType(using Symbols): Type =
if getFPType(lhs, rhs).isTyped then BooleanType() else Untyped
}


/* Tuple operations */

Expand Down
6 changes: 6 additions & 0 deletions src/main/scala/inox/ast/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ trait Printer {
case Int32Literal(v) => p"$v"
case Int64Literal(v) => p"$v"
case BVLiteral(_, bits, size) => p"x${(size to 1 by -1).map(i => if (bits(i)) "1" else "0").mkString("")}"
case Float32Literal(f) => p"$f"
case Float64Literal(f) => p"$f"
case FPLiteral(exponent, significand, bits) =>
p"x${(exponent + significand to 1 by -1).map(i => if (bits(i)) "1" else "0").mkString("")}"
case IntegerLiteral(v) => p"$v"
case FractionLiteral(n, d) =>
if (d == 1) p"$n"
Expand Down Expand Up @@ -275,6 +279,8 @@ trait Printer {
case BVUnsignedToSigned(e) => p"$e.toSigned"
case BVSignedToUnsigned(e) => p"$e.toUnsigned"

case FPEquals(l, r) => p"$l === $r"

case fs @ FiniteSet(rs, _) => p"Set(${rs})"
case fs @ FiniteBag(rs, _) => p"Bag(${rs.toSeq})"
case fm @ FiniteMap(rs, dflt, _, _) =>
Expand Down
21 changes: 21 additions & 0 deletions src/main/scala/inox/solvers/smtlib/SMTLIBParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,24 @@ trait SMTLIBParser {
case _ => BVLiteral(true, n, size.intValue)
}

case FloatingPoint.FPLit(sign, exponent, significand) => otpe match {
case Some(FPType(eb, sb)) =>
(fromSMT(sign, Some(BVType(true, 1))),
fromSMT(exponent, Some(BVType(true, eb))),
fromSMT(significand, Some(BVType(true, sb - 1)))) match {
case (BVLiteral(true, bitset1, 1), BVLiteral(true, bitset2, `eb`), BVLiteral(true, bitset3, rem)) if rem == sb - 1 =>
FPLiteral(eb, sb, bitset1.map(_ + eb + sb - 1) ++ bitset2.map(_ + sb - 1) ++ bitset3)
case _ => throw new MissformedSMTException(term, "FP Lit has inconsistent components")
}
case _ => throw new MissformedSMTException(term, "FP Lit is not of type Float")
}
case FloatingPoint.PlusZero(exponent, significand) => FPLiteral.plusZero(exponent.toInt, significand.toInt)
case FloatingPoint.MinusZero(exponent, significand) => FPLiteral.minusZero(exponent.toInt, significand.toInt)
case FloatingPoint.NaN(exponent, significand) => FPLiteral.NaN(exponent.toInt, significand.toInt)
case FloatingPoint.PlusInfinity(exponent, significand) => FPLiteral.plusInfinity(exponent.toInt, significand.toInt)
case FloatingPoint.MinusInfinity(exponent, significand) => FPLiteral.minusInfinity(exponent.toInt, significand.toInt)


case SDecimal(value) =>
exprOps.normalizeFraction(FractionLiteral(
value.bigDecimal.movePointRight(value.scale).toBigInteger,
Expand Down Expand Up @@ -196,6 +214,9 @@ trait SMTLIBParser {
case Some(BVType(signed, _)) => signed
case _ => true
}, (i + 1).bigInteger.intValueExact))

case FloatingPoint.Eq(e1, e2) => fromSMTUnifyType(e1, e2, None)(FPEquals.apply)


case ArraysEx.Select(e1, e2) => otpe match {
case Some(tpe) =>
Expand Down
25 changes: 23 additions & 2 deletions src/main/scala/inox/solvers/smtlib/SMTLIBTarget.scala
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {
case IntegerType() => Ints.IntSort()
case RealType() => Reals.RealSort()
case BVType(_,l) => FixedSizeBitVectors.BitVectorSort(l)
case FPType(e, s) => FloatingPoint.FloatingPointSort(e, s)
case CharType() => FixedSizeBitVectors.BitVectorSort(16)
case StringType() => Strings.StringSort()

Expand Down Expand Up @@ -277,6 +278,14 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {

case IntegerLiteral(i) => intToTerm(i)
case BVLiteral(_, bits, size) => FixedSizeBitVectors.BitVectorLit(List.range(1, size + 1).map(i => bits(size + 1 - i)))
case FPLiteral(e, s, bits) =>
val size = e + s
FloatingPoint.FPLit(
FixedSizeBitVectors.BitVectorLit(List.range(1, 2).map(i => bits(size + 1 - i))),
FixedSizeBitVectors.BitVectorLit(List.range(2, e + 2).map(i => bits(size + 1 - i))),
FixedSizeBitVectors.BitVectorLit(List.range(e + 2, size + 1).map(i => bits(size + 1 - i)))
)

case FractionLiteral(n, d) => Reals.Div(realToTerm(n), realToTerm(d))
case CharLiteral(c) => FixedSizeBitVectors.BitVectorLit(Hexadecimal.fromShort(c.toShort))
case BooleanLiteral(v) => Core.BoolConst(v)
Expand Down Expand Up @@ -372,35 +381,40 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {

case UMinus(u) => u.getType match {
case BVType(_,_) => FixedSizeBitVectors.Neg(toSMT(u))
case FPType(_, _) => FloatingPoint.Neg(toSMT(u))
case IntegerType() => Ints.Neg(toSMT(u))
case RealType() => Reals.Neg(toSMT(u))
}

case Equals(a, b) => Core.Equals(toSMT(a), toSMT(b))
case Implies(a, b) => Core.Implies(toSMT(a), toSMT(b))
case pl @ Plus(a, _) =>
case pl @ Plus(a, b) =>
val rec = flattenPlus(pl).map(toSMT)
a.getType match {
case BVType(_,_) => FixedSizeBitVectors.Add(rec)
case FPType(_, _) => FloatingPoint.Add(FloatingPoint.RNE(), toSMT(a), toSMT(b))
case IntegerType() => Ints.Add(rec)
case RealType() => Reals.Add(rec)
}
case Minus(a, b) => a.getType match {
case BVType(_,_) => FixedSizeBitVectors.Sub(toSMT(a), toSMT(b))
case FPType(_,_) => FloatingPoint.Sub(FloatingPoint.RNE(), toSMT(a), toSMT(b))
case IntegerType() => Ints.Sub(toSMT(a), toSMT(b))
case RealType() => Reals.Sub(toSMT(a), toSMT(b))
}
case tms @ Times(a, _) =>
case tms @ Times(a, b) =>
val rec = flattenTimes(tms).map(toSMT)
a.getType match {
case BVType(_,_) => FixedSizeBitVectors.Mul(rec)
case FPType(_,_) => FloatingPoint.Mul(FloatingPoint.RNE(), toSMT(a), toSMT(b))
case IntegerType() => Ints.Mul(rec)
case RealType() => Reals.Mul(rec)
}

case Division(a, b) => a.getType match {
case BVType(true, _) => FixedSizeBitVectors.SDiv(toSMT(a), toSMT(b))
case BVType(false, _) => FixedSizeBitVectors.UDiv(toSMT(a), toSMT(b))
case FPType(_,_) => FloatingPoint.Div(FloatingPoint.RNE(), toSMT(a), toSMT(b))
case IntegerType() =>
val ar = toSMT(a)
val br = toSMT(b)
Expand All @@ -415,6 +429,7 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {
case Remainder(a, b) => a.getType match {
case BVType(true, _) => FixedSizeBitVectors.SRem(toSMT(a), toSMT(b))
case BVType(false, _) => FixedSizeBitVectors.URem(toSMT(a), toSMT(b))
case FPType(_, _) => FloatingPoint.Rem(toSMT(a), toSMT(b))
case IntegerType() =>
val q = toSMT(Division(a, b))
Ints.Sub(toSMT(a), Ints.Mul(toSMT(b), q))
Expand All @@ -440,27 +455,31 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {
case LessThan(a, b) => a.getType match {
case BVType(true, _) => FixedSizeBitVectors.SLessThan(toSMT(a), toSMT(b))
case BVType(false, _) => FixedSizeBitVectors.ULessThan(toSMT(a), toSMT(b))
case FPType(_,_) => FloatingPoint.LessThan(toSMT(a), toSMT(b))
case IntegerType() => Ints.LessThan(toSMT(a), toSMT(b))
case RealType() => Reals.LessThan(toSMT(a), toSMT(b))
case CharType() => FixedSizeBitVectors.ULessThan(toSMT(a), toSMT(b))
}
case LessEquals(a, b) => a.getType match {
case BVType(true, _) => FixedSizeBitVectors.SLessEquals(toSMT(a), toSMT(b))
case BVType(false, _) => FixedSizeBitVectors.ULessEquals(toSMT(a), toSMT(b))
case FPType(_,_) => FloatingPoint.LessEquals(toSMT(a), toSMT(b))
case IntegerType() => Ints.LessEquals(toSMT(a), toSMT(b))
case RealType() => Reals.LessEquals(toSMT(a), toSMT(b))
case CharType() => FixedSizeBitVectors.ULessEquals(toSMT(a), toSMT(b))
}
case GreaterThan(a, b) => a.getType match {
case BVType(true, _) => FixedSizeBitVectors.SGreaterThan(toSMT(a), toSMT(b))
case BVType(false, _) => FixedSizeBitVectors.UGreaterThan(toSMT(a), toSMT(b))
case FPType(_,_) => FloatingPoint.GreaterThan(toSMT(a), toSMT(b))
case IntegerType() => Ints.GreaterThan(toSMT(a), toSMT(b))
case RealType() => Reals.GreaterThan(toSMT(a), toSMT(b))
case CharType() => FixedSizeBitVectors.UGreaterThan(toSMT(a), toSMT(b))
}
case GreaterEquals(a, b) => a.getType match {
case BVType(true, _) => FixedSizeBitVectors.SGreaterEquals(toSMT(a), toSMT(b))
case BVType(false, _) => FixedSizeBitVectors.UGreaterEquals(toSMT(a), toSMT(b))
case FPType(_,_) => FloatingPoint.GreaterEquals(toSMT(a), toSMT(b))
case IntegerType() => Ints.GreaterEquals(toSMT(a), toSMT(b))
case RealType() => Reals.GreaterEquals(toSMT(a), toSMT(b))
case CharType() => FixedSizeBitVectors.UGreaterEquals(toSMT(a), toSMT(b))
Expand All @@ -487,6 +506,8 @@ trait SMTLIBTarget extends SMTLIBParser with Interruptible with ADTManagers {
case BVUnsignedToSigned(e) => toSMT(e)
case BVSignedToUnsigned(e) => toSMT(e)

case FPEquals(a, b) => FloatingPoint.Eq(toSMT(a), toSMT(b))

case And(sub) => SmtLibConstructors.and(sub.map(toSMT))
case Or(sub) => SmtLibConstructors.or(sub.map(toSMT))
case IfExpr(cond, thenn, elze) => Core.ITE(toSMT(cond), toSMT(thenn), toSMT(elze))
Expand Down
Loading

0 comments on commit 611ee7d

Please sign in to comment.