diff --git a/paimon-common/src/main/java/org/apache/paimon/predicate/FieldRef.java b/paimon-common/src/main/java/org/apache/paimon/predicate/FieldRef.java index 92a489d00eb8..a648c11ffaae 100644 --- a/paimon-common/src/main/java/org/apache/paimon/predicate/FieldRef.java +++ b/paimon-common/src/main/java/org/apache/paimon/predicate/FieldRef.java @@ -23,7 +23,10 @@ import org.apache.paimon.shade.jackson2.com.fasterxml.jackson.annotation.JsonCreator; import org.apache.paimon.shade.jackson2.com.fasterxml.jackson.annotation.JsonProperty; +import javax.annotation.Nullable; + import java.io.Serializable; +import java.util.Arrays; import java.util.Objects; /** A reference to a field in an input. */ @@ -34,19 +37,31 @@ public class FieldRef implements Serializable { private static final String FIELD_INDEX = "index"; private static final String FIELD_NAME = "name"; private static final String FIELD_TYPE = "type"; + private static final String FIELD_NESTED_INDEXES = "nestedIndexes"; + private static final String FIELD_NESTED_ARITIES = "nestedArities"; private final int index; private final String name; private final DataType type; + @Nullable private final int[] nestedIndexes; + @Nullable private final int[] nestedArities; + + public FieldRef(int index, String name, DataType type) { + this(index, name, type, null, null); + } @JsonCreator public FieldRef( @JsonProperty(FIELD_INDEX) int index, @JsonProperty(FIELD_NAME) String name, - @JsonProperty(FIELD_TYPE) DataType type) { + @JsonProperty(FIELD_TYPE) DataType type, + @JsonProperty(FIELD_NESTED_INDEXES) @Nullable int[] nestedIndexes, + @JsonProperty(FIELD_NESTED_ARITIES) @Nullable int[] nestedArities) { this.index = index; this.name = name; this.type = type; + this.nestedIndexes = nestedIndexes; + this.nestedArities = nestedArities; } @JsonProperty(FIELD_INDEX) @@ -64,6 +79,18 @@ public DataType type() { return type; } + @JsonProperty(FIELD_NESTED_INDEXES) + @Nullable + public int[] nestedIndexes() { + return nestedIndexes; + } + + @JsonProperty(FIELD_NESTED_ARITIES) + @Nullable + public int[] nestedArities() { + return nestedArities; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -75,12 +102,17 @@ public boolean equals(Object o) { FieldRef fieldRef = (FieldRef) o; return index == fieldRef.index && Objects.equals(name, fieldRef.name) - && Objects.equals(type, fieldRef.type); + && Objects.equals(type, fieldRef.type) + && Arrays.equals(nestedIndexes, fieldRef.nestedIndexes) + && Arrays.equals(nestedArities, fieldRef.nestedArities); } @Override public int hashCode() { - return Objects.hash(index, name, type); + int result = Objects.hash(index, name, type); + result = 31 * result + Arrays.hashCode(nestedIndexes); + result = 31 * result + Arrays.hashCode(nestedArities); + return result; } @Override diff --git a/paimon-common/src/main/java/org/apache/paimon/predicate/FieldTransform.java b/paimon-common/src/main/java/org/apache/paimon/predicate/FieldTransform.java index 5136dedec530..27947188ef1e 100644 --- a/paimon-common/src/main/java/org/apache/paimon/predicate/FieldTransform.java +++ b/paimon-common/src/main/java/org/apache/paimon/predicate/FieldTransform.java @@ -71,7 +71,34 @@ public DataType outputType() { @Override public Object transform(InternalRow row) { - return get(row, fieldRef.index(), fieldRef.type()); + if (fieldRef.nestedIndexes() == null || fieldRef.nestedIndexes().length == 0) { + return get(row, fieldRef.index(), fieldRef.type()); + } + + InternalRow currentRow = row; + if (currentRow == null) { + return null; + } + int[] indexes = fieldRef.nestedIndexes(); + int[] arities = fieldRef.nestedArities(); + + if (currentRow.isNullAt(fieldRef.index())) { + return null; + } + currentRow = currentRow.getRow(fieldRef.index(), arities[0]); + + for (int i = 0; i < indexes.length - 1; i++) { + if (currentRow == null || currentRow.isNullAt(indexes[i])) { + return null; + } + currentRow = currentRow.getRow(indexes[i], arities[i + 1]); + } + + if (currentRow == null) { + return null; + } + int leafIndex = indexes[indexes.length - 1]; + return get(currentRow, leafIndex, fieldRef.type()); } @Override diff --git a/paimon-common/src/main/java/org/apache/paimon/predicate/LeafPredicate.java b/paimon-common/src/main/java/org/apache/paimon/predicate/LeafPredicate.java index f68771bb7ef5..9c2c2eef2709 100644 --- a/paimon-common/src/main/java/org/apache/paimon/predicate/LeafPredicate.java +++ b/paimon-common/src/main/java/org/apache/paimon/predicate/LeafPredicate.java @@ -146,6 +146,9 @@ public boolean test( return !(function instanceof AlwaysFalse); } FieldRef fieldRef = fieldRefOptional.get(); + if (fieldRef.nestedIndexes() != null && fieldRef.nestedIndexes().length > 0) { + return true; + } int index = fieldRef.index(); DataType type = fieldRef.type(); diff --git a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java index 31e4c8aec92a..72e800b8ce4d 100644 --- a/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java +++ b/paimon-spark/paimon-spark-common/src/main/java/org/apache/paimon/spark/SparkFilterConverter.java @@ -18,8 +18,11 @@ package org.apache.paimon.spark; +import org.apache.paimon.predicate.FieldRef; +import org.apache.paimon.predicate.FieldTransform; import org.apache.paimon.predicate.Predicate; import org.apache.paimon.predicate.PredicateBuilder; +import org.apache.paimon.predicate.Transform; import org.apache.paimon.types.DataType; import org.apache.paimon.types.RowType; @@ -107,53 +110,55 @@ public Predicate convert(Filter filter) { return PredicateBuilder.alwaysFalse(); } else if (filter instanceof EqualTo) { EqualTo eq = (EqualTo) filter; - int index = fieldIndex(eq.attribute()); + FieldInfo fieldInfo = resolveField(eq.attribute()); if (isNaN(eq.value())) { - return builder.isNaN(index); + return builder.isNaN(fieldInfo.transform()); } - Object literal = convertLiteral(index, eq.value()); - return builder.equal(index, literal); + Object literal = convertLiteral(fieldInfo.type(), eq.value()); + return builder.equal(fieldInfo.transform(), literal); } else if (filter instanceof EqualNullSafe) { EqualNullSafe eq = (EqualNullSafe) filter; - int index = fieldIndex(eq.attribute()); + FieldInfo fieldInfo = resolveField(eq.attribute()); if (eq.value() == null) { - return builder.isNull(index); + return builder.isNull(fieldInfo.transform()); } else { - Object literal = convertLiteral(index, eq.value()); - return builder.equal(index, literal); + Object literal = convertLiteral(fieldInfo.type(), eq.value()); + return builder.equal(fieldInfo.transform(), literal); } } else if (filter instanceof GreaterThan) { GreaterThan gt = (GreaterThan) filter; - int index = fieldIndex(gt.attribute()); - Object literal = convertLiteral(index, gt.value()); - return builder.greaterThan(index, literal); + FieldInfo fieldInfo = resolveField(gt.attribute()); + Object literal = convertLiteral(fieldInfo.type(), gt.value()); + return builder.greaterThan(fieldInfo.transform(), literal); } else if (filter instanceof GreaterThanOrEqual) { GreaterThanOrEqual gt = (GreaterThanOrEqual) filter; - int index = fieldIndex(gt.attribute()); - Object literal = convertLiteral(index, gt.value()); - return builder.greaterOrEqual(index, literal); + FieldInfo fieldInfo = resolveField(gt.attribute()); + Object literal = convertLiteral(fieldInfo.type(), gt.value()); + return builder.greaterOrEqual(fieldInfo.transform(), literal); } else if (filter instanceof LessThan) { LessThan lt = (LessThan) filter; - int index = fieldIndex(lt.attribute()); - Object literal = convertLiteral(index, lt.value()); - return builder.lessThan(index, literal); + FieldInfo fieldInfo = resolveField(lt.attribute()); + Object literal = convertLiteral(fieldInfo.type(), lt.value()); + return builder.lessThan(fieldInfo.transform(), literal); } else if (filter instanceof LessThanOrEqual) { LessThanOrEqual lt = (LessThanOrEqual) filter; - int index = fieldIndex(lt.attribute()); - Object literal = convertLiteral(index, lt.value()); - return builder.lessOrEqual(index, literal); + FieldInfo fieldInfo = resolveField(lt.attribute()); + Object literal = convertLiteral(fieldInfo.type(), lt.value()); + return builder.lessOrEqual(fieldInfo.transform(), literal); } else if (filter instanceof In) { In in = (In) filter; - int index = fieldIndex(in.attribute()); + FieldInfo fieldInfo = resolveField(in.attribute()); return builder.in( - index, + fieldInfo.transform(), Arrays.stream(in.values()) - .map(v -> convertLiteral(index, v)) + .map(v -> convertLiteral(fieldInfo.type(), v)) .collect(Collectors.toList())); } else if (filter instanceof IsNull) { - return builder.isNull(fieldIndex(((IsNull) filter).attribute())); + FieldInfo fieldInfo = resolveField(((IsNull) filter).attribute()); + return builder.isNull(fieldInfo.transform()); } else if (filter instanceof IsNotNull) { - return builder.isNotNull(fieldIndex(((IsNotNull) filter).attribute())); + FieldInfo fieldInfo = resolveField(((IsNotNull) filter).attribute()); + return builder.isNotNull(fieldInfo.transform()); } else if (filter instanceof And) { And and = (And) filter; return PredicateBuilder.and(convert(and.left()), convert(and.right())); @@ -168,19 +173,19 @@ public Predicate convert(Filter filter) { } } else if (filter instanceof StringStartsWith) { StringStartsWith startsWith = (StringStartsWith) filter; - int index = fieldIndex(startsWith.attribute()); - Object literal = convertLiteral(index, startsWith.value()); - return builder.startsWith(index, literal); + FieldInfo fieldInfo = resolveField(startsWith.attribute()); + Object literal = convertLiteral(fieldInfo.type(), startsWith.value()); + return builder.startsWith(fieldInfo.transform(), literal); } else if (filter instanceof StringEndsWith) { StringEndsWith endsWith = (StringEndsWith) filter; - int index = fieldIndex(endsWith.attribute()); - Object literal = convertLiteral(index, endsWith.value()); - return builder.endsWith(index, literal); + FieldInfo fieldInfo = resolveField(endsWith.attribute()); + Object literal = convertLiteral(fieldInfo.type(), endsWith.value()); + return builder.endsWith(fieldInfo.transform(), literal); } else if (filter instanceof StringContains) { StringContains contains = (StringContains) filter; - int index = fieldIndex(contains.attribute()); - Object literal = convertLiteral(index, contains.value()); - return builder.contains(index, literal); + FieldInfo fieldInfo = resolveField(contains.attribute()); + Object literal = convertLiteral(fieldInfo.type(), contains.value()); + return builder.contains(fieldInfo.transform(), literal); } throw new UnsupportedOperationException( @@ -198,7 +203,8 @@ private static boolean isNaN(Object value) { } public Object convertLiteral(String field, Object value) { - return convertLiteral(fieldIndex(field), value); + FieldInfo fieldInfo = resolveField(field); + return convertLiteral(fieldInfo.type(), value); } public String convertString(String field, Object value) { @@ -206,18 +212,81 @@ public String convertString(String field, Object value) { return literal == null ? null : literal.toString(); } - private int fieldIndex(String field) { - int index = rowType.getFieldIndex(field); - // TODO: support nested field - if (index == -1) { + private static class FieldInfo { + private final Transform transform; + private final DataType type; + + public FieldInfo(Transform transform, DataType type) { + this.transform = transform; + this.type = type; + } + + public Transform transform() { + return transform; + } + + public DataType type() { + return type; + } + } + + private FieldInfo resolveField(String field) { + String[] parts = field.split("\\."); + int topLevelIndex = rowType.getFieldIndex(parts[0]); + if (topLevelIndex == -1) { throw new UnsupportedOperationException( - String.format("Nested field '%s' is unsupported.", field)); + String.format("Field '%s' is not found in table schema.", parts[0])); + } + + DataType fieldType; + int[] nestedIndexes = null; + int[] nestedArities = null; + if (parts.length == 1) { + fieldType = rowType.getTypeAt(topLevelIndex); + } else { + fieldType = getNestedFieldType(rowType, parts); + if (fieldType == null) { + throw new UnsupportedOperationException( + String.format("Nested field '%s' is unsupported.", field)); + } + nestedIndexes = new int[parts.length - 1]; + nestedArities = new int[parts.length - 1]; + DataType currentType = rowType.getTypeAt(topLevelIndex); + for (int i = 0; i < parts.length - 1; i++) { + RowType currentSelection = (RowType) currentType; + nestedArities[i] = currentSelection.getFieldCount(); + String nextPart = parts[i + 1]; + int nextIndex = currentSelection.getFieldIndex(nextPart); + nestedIndexes[i] = nextIndex; + currentType = currentSelection.getTypeAt(nextIndex); + } + } + + Transform transform = + new FieldTransform( + new FieldRef( + topLevelIndex, field, fieldType, nestedIndexes, nestedArities)); + return new FieldInfo(transform, fieldType); + } + + private DataType getNestedFieldType(RowType rowType, String[] path) { + DataType currentType = rowType; + for (String part : path) { + if (currentType instanceof RowType) { + RowType currentSelection = (RowType) currentType; + int idx = currentSelection.getFieldIndex(part); + if (idx == -1) { + return null; + } + currentType = currentSelection.getTypeAt(idx); + } else { + return null; + } } - return index; + return currentType; } - private Object convertLiteral(int index, Object value) { - DataType type = rowType.getTypeAt(index); + private Object convertLiteral(DataType type, Object value) { return convertJavaObject(type, value); } } diff --git a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java index 8b5457c9dff6..b3ed4ed5cda7 100644 --- a/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java +++ b/paimon-spark/paimon-spark-ut/src/test/java/org/apache/paimon/spark/SparkFilterConverterTest.java @@ -277,4 +277,44 @@ public void testIgnoreFailure() { assertThat(converter.convert(not, true)).isNull(); assertThat(converter.convertIgnoreFailure(not)).isNull(); } + + @Test + public void testNestedField() { + RowType nestedType = + new RowType( + Arrays.asList( + new DataField(0, "b", new IntType()), + new DataField(1, "c", new VarCharType(10)))); + RowType rowType = new RowType(Arrays.asList(new DataField(0, "a", nestedType))); + SparkFilterConverter converter = new SparkFilterConverter(rowType); + + EqualTo eq = EqualTo.apply("a.b", 1); + Predicate actual = converter.convert(eq); + assertThat(actual.toString()).isEqualTo("Equal(a.b, 1)"); + + IsNull isNull = IsNull.apply("a.c"); + Predicate actualIsNull = converter.convert(isNull); + assertThat(actualIsNull.toString()).isEqualTo("IsNull(a.c)"); + + GenericRow nestedRow = new GenericRow(2); + nestedRow.setField(0, 1); + nestedRow.setField(1, fromString("paimon")); + GenericRow row = new GenericRow(1); + row.setField(0, nestedRow); + + assertThat(actual.test(row)).isTrue(); + assertThat(actualIsNull.test(row)).isFalse(); + + nestedRow.setField(0, 2); + assertThat(actual.test(row)).isFalse(); + + nestedRow.setField(1, null); + assertThat(actualIsNull.test(row)).isTrue(); + + GenericRow min = new GenericRow(1); + min.setField(0, nestedRow); + GenericRow max = new GenericRow(1); + max.setField(0, nestedRow); + assertThat(actual.test(1L, min, max, new GenericArray(new long[] {0L}))).isTrue(); + } }