Skip to content

Commit

Permalink
typelevel#804 - proof of Vector issue outside of typelevel#803
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-twiner committed Mar 20, 2024
1 parent 526f896 commit 955ba82
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 3 deletions.
16 changes: 16 additions & 0 deletions dataset/src/test/scala/frameless/EncoderTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ object EncoderTests {
case class InstantRow(i: java.time.Instant)
case class DurationRow(d: java.time.Duration)
case class PeriodRow(p: java.time.Period)

case class VectorOfObject(a: Vector[X1[Int]])
}

class EncoderTests extends TypedDatasetSuite with Matchers {
Expand All @@ -32,4 +34,18 @@ class EncoderTests extends TypedDatasetSuite with Matchers {
test("It should encode java.time.Period") {
implicitly[TypedEncoder[PeriodRow]]
}

test("It should encode a Vector of Objects") {
forceInterpreted {
implicit val e = implicitly[TypedEncoder[VectorOfObject]]
implicit val te = TypedExpressionEncoder[VectorOfObject]
implicit val xe = implicitly[TypedEncoder[X1[VectorOfObject]]]
implicit val xte = TypedExpressionEncoder[X1[VectorOfObject]]
val v = (1 to 20).map(X1(_)).toVector
val ds = {
sqlContext.createDataset(Seq(X1[VectorOfObject](VectorOfObject(v))))
}
ds.head.a.a shouldBe v
}
}
}
71 changes: 68 additions & 3 deletions dataset/src/test/scala/frameless/package.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import java.time.format.DateTimeFormatter
import java.time.{ LocalDateTime => JavaLocalDateTime }
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.internal.SQLConf

import org.scalacheck.{ Arbitrary, Gen }
import java.time.format.DateTimeFormatter
import java.time.{LocalDateTime => JavaLocalDateTime}
import org.scalacheck.{Arbitrary, Gen}

package object frameless {

Expand Down Expand Up @@ -39,6 +41,14 @@ package object frameless {

def vectorGen[A: Arbitrary]: Gen[Vector[A]] = arbVector[A].arbitrary

implicit def arbSeq[A](
implicit
A: Arbitrary[A]
): Arbitrary[scala.collection.Seq[A]] =
Arbitrary(Gen.listOf(A.arbitrary).map(_.toVector.toSeq))

def seqGen[A: Arbitrary]: Gen[scala.collection.Seq[A]] = arbSeq[A].arbitrary

implicit val arbUdtEncodedClass: Arbitrary[UdtEncodedClass] = Arbitrary {
for {
int <- Arbitrary.arbitrary[Int]
Expand Down Expand Up @@ -161,4 +171,59 @@ package object frameless {
}
res
}

// from Quality, which is from Spark test versions

// if this blows then debug on CodeGenerator 1294, 1299 and grab code.body
def forceCodeGen[T](f: => T): T = {
val codegenMode = CodegenObjectFactoryMode.CODEGEN_ONLY.toString

withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) {
f
}
}

def forceInterpreted[T](f: => T): T = {
val codegenMode = CodegenObjectFactoryMode.NO_CODEGEN.toString

withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) {
f
}
}

/**
* runs the same test with both eval and codegen, then does the same again using resolveWith
*
* @param f
* @tparam T
* @return
*/
def evalCodeGens[T](f: => T): (T, T) =
(forceCodeGen(f), forceInterpreted(f))

/**
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restores all SQL
* configurations.
*/
protected def withSQLConf[T](pairs: (String, String)*)(f: => T): T = {
val conf = SQLConf.get
val (keys, values) = pairs.unzip
val currentValues = keys.map { key =>
if (conf.contains(key)) {
Some(conf.getConfString(key))
} else {
None
}
}
(keys, values).zipped.foreach { (k, v) =>
conf.setConfString(k, v)
}
try f finally {
keys.zip(currentValues).foreach {
case (key, Some(value)) => conf.setConfString(key, value)
case (key, None) => conf.unsetConf(key)
}
}
}

}

0 comments on commit 955ba82

Please sign in to comment.