Skip to content

Commit

Permalink
Support Spark 3.
Browse files Browse the repository at this point in the history
Kudos to @paulfryze's for paving the path for this change.

Fixes #97
  • Loading branch information
thesamet committed Oct 1, 2020
1 parent fe2a232 commit a468d34
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 84 deletions.
6 changes: 3 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ lazy val sparkSqlScalaPB = project
name := "sparksql-scalapb",
crossScalaVersions := Seq(Scala212),
libraryDependencies ++= Seq(
"org.typelevel" %% "frameless-dataset" % "0.8.0",
"org.typelevel" %% "frameless-dataset" % "0.9.0",
"com.thesamet.scalapb" %% "scalapb-runtime" % scalapbVersion,
"com.thesamet.scalapb" %% "scalapb-runtime" % scalapbVersion % "protobuf",
"org.apache.spark" %% "spark-sql" % "2.4.7" % "provided",
"org.apache.spark" %% "spark-sql" % "2.4.7" % "test",
"org.apache.spark" %% "spark-sql" % "3.0.1" % "provided",
"org.apache.spark" %% "spark-sql" % "3.0.1" % "test",
"org.scalatest" %% "scalatest" % "3.2.2" % "test",
"org.scalatestplus" %% "scalacheck-1-14" % "3.2.2.0" % "test",
"com.github.alexarchambault" %% "scalacheck-shapeless_1.14" % "1.2.5" % "test"
Expand Down
2 changes: 1 addition & 1 deletion project/plugins.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ addSbtPlugin("com.jsuereth" % "sbt-pgp" % "2.0.1")

addSbtPlugin("com.github.gseitz" % "sbt-release" % "1.0.13")

addSbtPlugin("com.thesamet" % "sbt-protoc" % "0.99.34")
addSbtPlugin("com.thesamet" % "sbt-protoc" % "1.0.0-RC2")

addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "3.9.4")

Expand Down
17 changes: 17 additions & 0 deletions shell.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{pkgs ? import <nixpkgs> {
config = {
packageOverrides = pkgs: {
sbt = pkgs.sbt.override { jre = pkgs.openjdk11; };
};
};
}} :
pkgs.mkShell {
buildInputs = [
pkgs.sbt
pkgs.openjdk11
pkgs.nodejs

# keep this line if you use bash
pkgs.bashInteractive
];
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ import scalapb.descriptors._
import org.apache.spark.sql.catalyst.expressions.objects.CatalystToExternalMap
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
import org.apache.spark.sql.catalyst.expressions.UnaryExpression
import org.apache.spark.sql.catalyst.expressions.Unevaluable
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.catalyst.expressions.objects.UnresolvedCatalystToExternalMap

trait FromCatalystHelpers {
def protoSql: ProtoSQL
Expand All @@ -38,16 +42,24 @@ trait FromCatalystHelpers {
cmp.scalaDescriptor.fields.map { fd =>
fieldFromCatalyst(cmp, fd, input)
}
else
else {
cmp.scalaDescriptor.fields.map { fd =>
val newPath = addToPath(input, schemaOptions.columnNaming.fieldName(fd))
fieldFromCatalyst(cmp, fd, newPath)
}
StaticInvoke(
JavaHelpers.getClass,
ObjectType(classOf[PValue]),
"mkPMessage",
Literal.fromObject(cmp) :: CreateArray(args) :: Nil
}
val outputType = ObjectType(classOf[PValue])
val mapArgs =
StaticInvoke(
JavaHelpers.getClass,
ObjectType(classOf[Map[FieldDescriptor, PValue]]),
"mkMap",
Literal.fromObject(cmp) :: CreateArray(args) :: Nil
)
If(
IsNull(input),
Literal.create(null, outputType),
NewInstance(classOf[PMessage], mapArgs :: Nil, outputType)
)
}

Expand All @@ -72,13 +84,14 @@ trait FromCatalystHelpers {
val mapEntryCmp = cmp.messageCompanionForFieldNumber(fd.number)
val keyDesc = mapEntryCmp.scalaDescriptor.findFieldByNumber(1).get
val valDesc = mapEntryCmp.scalaDescriptor.findFieldByNumber(2).get
val objs = MyCatalystToExternalMap(
val urobjs = MyUnresolvedCatalystToExternalMap(
input,
(in: Expression) => singleFieldValueFromCatalyst(mapEntryCmp, keyDesc, in),
(in: Expression) => singleFieldValueFromCatalyst(mapEntryCmp, valDesc, in),
input,
ProtoSQL.dataTypeFor(fd).asInstanceOf[MapType],
classOf[Vector[(Any, Any)]]
)
val objs = MyCatalystToExternalMap(urobjs)
StaticInvoke(
JavaHelpers.getClass,
ObjectType(classOf[PValue]),
Expand Down Expand Up @@ -151,53 +164,31 @@ trait FromCatalystHelpers {
}

def addToPath(path: Expression, name: String): Expression = {
val res = path match {
case _: BoundReference =>
UnresolvedAttribute.quoted(name)
case _ =>
UnresolvedExtractValue(path, expressions.Literal(name))
}
res
UnresolvedExtractValue(path, expressions.Literal(name))
}
}

object MyCatalystToExternalMap {
private val curId = new java.util.concurrent.atomic.AtomicInteger()
case class MyUnresolvedCatalystToExternalMap(
child: Expression,
@transient keyFunction: Expression => Expression,
@transient valueFunction: Expression => Expression,
mapType: MapType,
collClass: Class[_]
)

/**
* Construct an instance of CatalystToExternalMap case class.
*
* @param keyFunction The function applied on the key collection elements.
* @param valueFunction The function applied on the value collection elements.
* @param inputData An expression that when evaluated returns a map object.
* @param mapType,
* @param collClass The type of the resulting collection.
*/
def apply(
keyFunction: Expression => Expression,
valueFunction: Expression => Expression,
inputData: Expression,
mapType: MapType,
collClass: Class[_]
): CatalystToExternalMap = {
val id = curId.getAndIncrement()
val keyLoopValue = s"CatalystToExternalMap_keyLoopValue$id"
val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false)
val valueLoopValue = s"CatalystToExternalMap_valueLoopValue$id"
val valueLoopIsNull = if (mapType.valueContainsNull) {
s"CatalystToExternalMap_valueLoopIsNull$id"
} else {
"false"
}
val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType)
object MyCatalystToExternalMap {
def apply(u: MyUnresolvedCatalystToExternalMap): CatalystToExternalMap = {
val mapType = u.mapType
val keyLoopVar = LambdaVariable("CatalystToExternalMap_key", mapType.keyType, nullable = false)
val valueLoopVar =
LambdaVariable("CatalystToExternalMap_value", mapType.valueType, mapType.valueContainsNull)
CatalystToExternalMap(
keyLoopValue,
keyFunction(keyLoopVar),
valueLoopValue,
valueLoopIsNull,
valueFunction(valueLoopVar),
inputData,
collClass
keyLoopVar,
u.keyFunction(keyLoopVar),
valueLoopVar,
u.valueFunction(valueLoopVar),
u.child,
u.collClass
)
}
}
17 changes: 6 additions & 11 deletions sparksql-scalapb/src/main/scala/scalapb/spark/JavaHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,12 @@ object JavaHelpers {
}
}

def asPValue(h: Object): PValue = h.asInstanceOf[PValue]

def mkPMessage(cmp: GeneratedMessageCompanion[_], args: ArrayData): PValue = {
// returning Any to ensure the any-value doesn't get unwrapped in runtime.
PMessage(
cmp.scalaDescriptor.fields
.zip(args.array)
.filterNot(_._2 == PEmpty)
.toMap
.asInstanceOf[Map[FieldDescriptor, PValue]]
)
def mkMap(cmp: GeneratedMessageCompanion[_], args: ArrayData): Map[FieldDescriptor, PValue] = {
cmp.scalaDescriptor.fields
.zip(args.array)
.filterNot(_._2 == PEmpty)
.toMap
.asInstanceOf[Map[FieldDescriptor, PValue]]
}

def mkPRepeated(args: ArrayData): PValue = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import scalapb.descriptors.Descriptor
import com.google.protobuf.wrappers.Int32Value
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.Metadata

class ProtoSQL(val schemaOptions: SchemaOptions) extends Udfs {
self =>
Expand Down
19 changes: 8 additions & 11 deletions sparksql-scalapb/src/main/scala/scalapb/spark/TypedEncoders.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ import scalapb._
import scalapb.descriptors.{PValue, Reads}

import scala.reflect.ClassTag
import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct
import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql.catalyst.WalkedTypePath
import org.apache.spark.sql.catalyst.expressions.objects.NewInstance
import scalapb.descriptors.PMessage
import org.apache.spark.sql.catalyst.expressions.CreateArray

trait TypedEncoders extends FromCatalystHelpers with ToCatalystHelpers with Serializable {
class MessageTypedEncoder[T <: GeneratedMessage](implicit
Expand All @@ -25,14 +32,6 @@ trait TypedEncoders extends FromCatalystHelpers with ToCatalystHelpers with Seri
def fromCatalyst(path: Expression): Expression = {
val expr = pmessageFromCatalyst(cmp, path)

val pmsg =
StaticInvoke(
JavaHelpers.getClass,
ObjectType(classOf[PValue]),
"asPValue",
expr :: Nil
)

val reads = Invoke(
Literal.fromObject(cmp),
"messageReads",
Expand All @@ -42,9 +41,7 @@ trait TypedEncoders extends FromCatalystHelpers with ToCatalystHelpers with Seri

val read = Invoke(reads, "read", ObjectType(classOf[Function[_, _]]))

val out = Invoke(read, "apply", ObjectType(ct.runtimeClass), pmsg :: Nil)

If(IsNull(path), Literal.create(null, out.dataType), out)
Invoke(read, "apply", ObjectType(ct.runtimeClass), expr :: Nil)
}

override def toCatalyst(path: Expression): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class AllTypesSpec
def verifyTypes[
T <: GeneratedMessage: Arbitrary: GeneratedMessageCompanion: ClassTag
](
protoSQL: ProtoSQL,
skipCoalesce: Boolean = false
protoSQL: ProtoSQL
): Unit =
forAll { (n: Seq[T]) =>
import protoSQL.implicits._
Expand All @@ -44,15 +43,12 @@ class AllTypesSpec
// Creates dataframe using encoder serialization:
val ds2 = spark.createDataset(n)
ds2.collect() must contain theSameElementsAs (n)
if (!skipCoalesce) {
ds2.toDF.coalesce(1).except(df1.coalesce(1)).count() must be(0)
}
}

def verifyTypes[
T <: GeneratedMessage: Arbitrary: GeneratedMessageCompanion: ClassTag
]: Unit =
verifyTypes[T](ProtoSQL, false)
verifyTypes[T](ProtoSQL)

"AllTypes" should "work for int32" in {
verifyTypes[AT2.Int32Test]
Expand Down Expand Up @@ -126,7 +122,7 @@ class AllTypesSpec
}

it should "work for maps" in {
verifyTypes[AT2.MapTypes](ProtoSQL, true)
verifyTypes[AT3.MapTypes](ProtoSQL, true)
verifyTypes[AT2.MapTypes](ProtoSQL)
verifyTypes[AT3.MapTypes](ProtoSQL)
}
}

0 comments on commit a468d34

Please sign in to comment.