Skip to content

Commit 35fde68

Browse files
feat: add ExpandRel support to core and spark
Signed-off-by: Andrew Coleman <[email protected]>
1 parent 79f3779 commit 35fde68

File tree

19 files changed

+359
-35
lines changed

19 files changed

+359
-35
lines changed

core/src/main/java/io/substrait/dsl/SubstraitBuilder.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import io.substrait.plan.Plan;
2222
import io.substrait.relation.Aggregate;
2323
import io.substrait.relation.Cross;
24+
import io.substrait.relation.Expand;
2425
import io.substrait.relation.Fetch;
2526
import io.substrait.relation.Filter;
2627
import io.substrait.relation.Join;
@@ -313,6 +314,23 @@ private Project project(
313314
return Project.builder().input(input).expressions(expressions).remap(remap).build();
314315
}
315316

317+
public Expand expand(Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn, Rel input) {
318+
return expand(fieldsFn, Optional.empty(), input);
319+
}
320+
321+
public Expand expand(
322+
Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn, Rel.Remap remap, Rel input) {
323+
return expand(fieldsFn, Optional.of(remap), input);
324+
}
325+
326+
private Expand expand(
327+
Function<Rel, Iterable<? extends Expand.ExpandField>> fieldsFn,
328+
Optional<Rel.Remap> remap,
329+
Rel input) {
330+
var fields = fieldsFn.apply(input);
331+
return Expand.builder().input(input).fields(fields).remap(remap).build();
332+
}
333+
316334
public Set set(Set.SetOp op, Rel... inputs) {
317335
return set(op, Optional.empty(), inputs);
318336
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package io.substrait.hint;
2+
3+
import io.substrait.proto.RelCommon;
4+
import java.util.List;
5+
import java.util.Optional;
6+
import org.immutables.value.Value;
7+
8+
@Value.Immutable
9+
public abstract class Hint {
10+
public abstract Optional<String> getAlias();
11+
12+
public abstract List<String> getOutputNames();
13+
14+
public RelCommon.Hint toProto() {
15+
var builder = RelCommon.Hint.newBuilder().addAllOutputNames(getOutputNames());
16+
getAlias().ifPresent(builder::setAlias);
17+
return builder.build();
18+
}
19+
20+
public static ImmutableHint.Builder builder() {
21+
return ImmutableHint.builder();
22+
}
23+
}

core/src/main/java/io/substrait/relation/AbstractRelVisitor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ public OUTPUT visit(Project project) throws EXCEPTION {
5353
return visitFallback(project);
5454
}
5555

56+
@Override
57+
public OUTPUT visit(Expand expand) throws EXCEPTION {
58+
return visitFallback(expand);
59+
}
60+
5661
@Override
5762
public OUTPUT visit(Sort sort) throws EXCEPTION {
5863
return visitFallback(sort);
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package io.substrait.relation;
2+
3+
import io.substrait.expression.Expression;
4+
import io.substrait.type.Type;
5+
import io.substrait.type.TypeCreator;
6+
import java.util.List;
7+
import java.util.stream.Stream;
8+
import org.immutables.value.Value;
9+
10+
@Value.Enclosing
11+
@Value.Immutable
12+
public abstract class Expand extends SingleInputRel {
13+
static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(Expand.class);
14+
15+
public abstract List<ExpandField> getFields();
16+
17+
@Override
18+
public Type.Struct deriveRecordType() {
19+
Type.Struct initial = getInput().getRecordType();
20+
return TypeCreator.of(initial.nullable())
21+
.struct(Stream.concat(initial.fields().stream(), Stream.of(TypeCreator.REQUIRED.I64)));
22+
}
23+
24+
@Override
25+
public <O, E extends Exception> O accept(RelVisitor<O, E> visitor) throws E {
26+
return visitor.visit(this);
27+
}
28+
29+
public static ImmutableExpand.Builder builder() {
30+
return ImmutableExpand.builder();
31+
}
32+
33+
public interface ExpandField {
34+
Type getType();
35+
}
36+
37+
@Value.Immutable
38+
public abstract static class ConsistentField implements ExpandField {
39+
public abstract Expression getExpression();
40+
41+
public Type getType() {
42+
return getExpression().getType();
43+
}
44+
45+
public static ImmutableExpand.ConsistentField.Builder builder() {
46+
return ImmutableExpand.ConsistentField.builder();
47+
}
48+
}
49+
50+
@Value.Immutable
51+
public abstract static class SwitchingField implements ExpandField {
52+
public abstract List<Expression> getDuplicates();
53+
54+
public Type getType() {
55+
return getDuplicates().get(0).getType();
56+
}
57+
58+
public static ImmutableExpand.SwitchingField.Builder builder() {
59+
return ImmutableExpand.SwitchingField.builder();
60+
}
61+
}
62+
}

core/src/main/java/io/substrait/relation/ProtoRelConverter.java

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
import io.substrait.extension.AdvancedExtension;
88
import io.substrait.extension.ExtensionLookup;
99
import io.substrait.extension.SimpleExtension;
10+
import io.substrait.hint.Hint;
1011
import io.substrait.proto.AggregateRel;
1112
import io.substrait.proto.ConsistentPartitionWindowRel;
1213
import io.substrait.proto.CrossRel;
14+
import io.substrait.proto.ExpandRel;
1315
import io.substrait.proto.ExtensionLeafRel;
1416
import io.substrait.proto.ExtensionMultiRel;
1517
import io.substrait.proto.ExtensionSingleRel;
@@ -87,6 +89,9 @@ public Rel from(io.substrait.proto.Rel rel) {
8789
case PROJECT -> {
8890
return newProject(rel.getProject());
8991
}
92+
case EXPAND -> {
93+
return newExpand(rel.getExpand());
94+
}
9095
case CROSS -> {
9196
return newCross(rel.getCross());
9297
}
@@ -155,7 +160,10 @@ protected Filter newFilter(FilterRel rel) {
155160
}
156161

157162
protected NamedStruct newNamedStruct(ReadRel rel) {
158-
var namedStruct = rel.getBaseSchema();
163+
return newNamedStruct(rel.getBaseSchema());
164+
}
165+
166+
protected NamedStruct newNamedStruct(io.substrait.proto.NamedStruct namedStruct) {
159167
var struct = namedStruct.getStruct();
160168
return ImmutableNamedStruct.builder()
161169
.names(namedStruct.getNamesList())
@@ -389,6 +397,38 @@ protected Project newProject(ProjectRel rel) {
389397
return builder.build();
390398
}
391399

400+
protected Expand newExpand(ExpandRel rel) {
401+
var input = from(rel.getInput());
402+
var converter = new ProtoExpressionConverter(lookup, extensions, input.getRecordType(), this);
403+
var builder =
404+
Expand.builder()
405+
.input(input)
406+
.fields(
407+
rel.getFieldsList().stream()
408+
.map(
409+
expandField ->
410+
switch (expandField.getFieldTypeCase()) {
411+
case CONSISTENT_FIELD -> Expand.ConsistentField.builder()
412+
.expression(converter.from(expandField.getConsistentField()))
413+
.build();
414+
case SWITCHING_FIELD -> Expand.SwitchingField.builder()
415+
.duplicates(
416+
expandField.getSwitchingField().getDuplicatesList().stream()
417+
.map(converter::from)
418+
.collect(java.util.stream.Collectors.toList()))
419+
.build();
420+
case FIELDTYPE_NOT_SET -> throw new UnsupportedOperationException(
421+
"Expand fields not set");
422+
})
423+
.collect(java.util.stream.Collectors.toList()));
424+
425+
builder
426+
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
427+
.remap(optionalRelmap(rel.getCommon()))
428+
.hint(optionalHint(rel.getCommon()));
429+
return builder.build();
430+
}
431+
392432
protected Aggregate newAggregate(AggregateRel rel) {
393433
var input = from(rel.getInput());
394434
var protoExprConverter =
@@ -647,6 +687,16 @@ protected static Optional<Rel.Remap> optionalRelmap(io.substrait.proto.RelCommon
647687
relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null);
648688
}
649689

690+
protected static Optional<Hint> optionalHint(io.substrait.proto.RelCommon relCommon) {
691+
if (!relCommon.hasHint()) return Optional.empty();
692+
var hint = relCommon.getHint();
693+
var builder = Hint.builder().addAllOutputNames(hint.getOutputNamesList());
694+
if (!hint.getAlias().isEmpty()) {
695+
builder.alias(hint.getAlias());
696+
}
697+
return Optional.of(builder.build());
698+
}
699+
650700
protected Optional<AdvancedExtension> optionalAdvancedExtension(
651701
io.substrait.proto.RelCommon relCommon) {
652702
return Optional.ofNullable(

core/src/main/java/io/substrait/relation/Rel.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.substrait.relation;
22

33
import io.substrait.extension.AdvancedExtension;
4+
import io.substrait.hint.Hint;
45
import io.substrait.type.Type;
56
import io.substrait.type.TypeCreator;
67
import java.util.List;
@@ -21,6 +22,8 @@ public interface Rel {
2122

2223
List<Rel> getInputs();
2324

25+
Optional<Hint> getHint();
26+
2427
@Value.Immutable
2528
public abstract static class Remap {
2629
public abstract List<Integer> indices();

core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,11 @@ public Optional<Rel> visit(Project project) throws EXCEPTION {
201201
.build());
202202
}
203203

204+
@Override
205+
public Optional<Rel> visit(Expand expand) throws EXCEPTION {
206+
throw new UnsupportedOperationException();
207+
}
208+
204209
@Override
205210
public Optional<Rel> visit(Sort sort) throws EXCEPTION {
206211
var input = sort.getInput().accept(this);

0 commit comments

Comments
 (0)