Skip to content

Commit 0de3a59

Browse files
fix(spark): incorrect conversion of expand relation
In the expand relation, the projection expressions are stored in a two dimensional array. The spark matrix needs to be transposed in order to map the expressions into substrait, and vice-versa. I hadn’t noticed this earlier. Also, the remap field should not be used because the outputs are defined directly in the projections array. Signed-off-by: Andrew Coleman <[email protected]>
1 parent e3139c6 commit 0de3a59

File tree

5 files changed

+15
-10
lines changed

5 files changed

+15
-10
lines changed

core/src/main/java/io/substrait/relation/Expand.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import io.substrait.type.Type;
55
import io.substrait.type.TypeCreator;
66
import java.util.List;
7-
import java.util.stream.Stream;
87
import org.immutables.value.Value;
98

109
@Value.Enclosing
@@ -18,7 +17,7 @@ public abstract class Expand extends SingleInputRel {
1817
public Type.Struct deriveRecordType() {
1918
Type.Struct initial = getInput().getRecordType();
2019
return TypeCreator.of(initial.nullable())
21-
.struct(Stream.concat(initial.fields().stream(), Stream.of(TypeCreator.REQUIRED.I64)));
20+
.struct(getFields().stream().map(ExpandField::getType));
2221
}
2322

2423
@Override
@@ -52,7 +51,9 @@ public abstract static class SwitchingField implements ExpandField {
5251
public abstract List<Expression> getDuplicates();
5352

5453
public Type getType() {
55-
return getDuplicates().get(0).getType();
54+
var nullable = getDuplicates().stream().anyMatch(d -> d.getType().nullable());
55+
var type = getDuplicates().get(0).getType();
56+
return nullable ? TypeCreator.asNullable(type) : TypeCreator.asNotNullable(type);
5657
}
5758

5859
public static ImmutableExpand.SwitchingField.Builder builder() {

spark/src/main/scala/io/substrait/spark/SparkExtension.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ object SparkExtension {
3434
private val EXTENSION_COLLECTION: SimpleExtension.ExtensionCollection =
3535
SimpleExtension.loadDefaults()
3636

37+
val COLLECTION: SimpleExtension.ExtensionCollection = EXTENSION_COLLECTION.merge(SparkImpls)
38+
3739
lazy val SparkScalarFunctions: Seq[SimpleExtension.ScalarFunctionVariant] = {
3840
val ret = new collection.mutable.ArrayBuffer[SimpleExtension.ScalarFunctionVariant]()
3941
ret.appendAll(EXTENSION_COLLECTION.scalarFunctions().asScala)

spark/src/main/scala/io/substrait/spark/logical/ToLogicalPlan.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,14 +277,13 @@ class ToLogicalPlan(spark: SparkSession) extends DefaultRelVisitor[LogicalPlan]
277277
}
278278

279279
// An output column is nullable if any of the projections can assign null to it
280-
val types = projections.transpose.map(p => (p.head.dataType, p.exists(_.nullable)))
281-
282-
val output = types
280+
val output = projections
281+
.map(p => (p.head.dataType, p.exists(_.nullable)))
283282
.zip(names)
284283
.map { case (t, name) => StructField(name, t._1, t._2) }
285284
.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
286285

287-
Expand(projections, output, child)
286+
Expand(projections.transpose, output, child)
288287
}
289288
}
290289

spark/src/main/scala/io/substrait/spark/logical/ToSubstraitRel.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
290290
}
291291

292292
override def visitExpand(p: Expand): relation.Rel = {
293-
val fields = p.projections.map(
293+
val fields = p.projections.transpose.map(
294294
proj => {
295295
relation.Expand.SwitchingField.builder
296296
.duplicates(
@@ -302,7 +302,6 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
302302
val names = p.output.map(_.name)
303303

304304
relation.Expand.builder
305-
.remap(relation.Rel.Remap.offset(p.child.output.size, names.size))
306305
.fields(fields.asJava)
307306
.hint(Hint.builder.addAllOutputNames(names.asJava).build())
308307
.input(visit(p.child))

spark/src/test/scala/io/substrait/spark/SubstraitPlanTestBase.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import io.substrait.debug.TreePrinter
2626
import io.substrait.extension.ExtensionCollector
2727
import io.substrait.plan.{Plan, PlanProtoConverter, ProtoPlanConverter}
2828
import io.substrait.proto
29-
import io.substrait.relation.RelProtoConverter
29+
import io.substrait.relation.{ProtoRelConverter, RelProtoConverter}
3030
import org.scalactic.Equality
3131
import org.scalactic.source.Position
3232
import org.scalatest.Succeeded
@@ -93,6 +93,10 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>
9393
require(logicalPlan2.resolved);
9494
val pojoRel2 = new ToSubstraitRel().visit(logicalPlan2)
9595

96+
val extensionCollector = new ExtensionCollector;
97+
val proto = new RelProtoConverter(extensionCollector).toProto(pojoRel)
98+
new ProtoRelConverter(extensionCollector, SparkExtension.COLLECTION).from(proto)
99+
96100
pojoRel2.shouldEqualPlainly(pojoRel)
97101
logicalPlan2
98102
}

0 commit comments

Comments
 (0)