Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
import org.apache.paimon.shade.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import org.apache.paimon.shade.jackson2.com.fasterxml.jackson.annotation.JsonProperty;

import javax.annotation.Nullable;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Objects;

/** A reference to a field in an input. */
Expand All @@ -34,19 +37,31 @@ public class FieldRef implements Serializable {
private static final String FIELD_INDEX = "index";
private static final String FIELD_NAME = "name";
private static final String FIELD_TYPE = "type";
private static final String FIELD_NESTED_INDEXES = "nestedIndexes";
private static final String FIELD_NESTED_ARITIES = "nestedArities";

private final int index;
private final String name;
private final DataType type;
@Nullable private final int[] nestedIndexes;
@Nullable private final int[] nestedArities;

public FieldRef(int index, String name, DataType type) {
this(index, name, type, null, null);
}

@JsonCreator
public FieldRef(
@JsonProperty(FIELD_INDEX) int index,
@JsonProperty(FIELD_NAME) String name,
@JsonProperty(FIELD_TYPE) DataType type) {
@JsonProperty(FIELD_TYPE) DataType type,
@JsonProperty(FIELD_NESTED_INDEXES) @Nullable int[] nestedIndexes,
@JsonProperty(FIELD_NESTED_ARITIES) @Nullable int[] nestedArities) {
this.index = index;
this.name = name;
this.type = type;
this.nestedIndexes = nestedIndexes;
this.nestedArities = nestedArities;
}

@JsonProperty(FIELD_INDEX)
Expand All @@ -64,6 +79,18 @@ public DataType type() {
return type;
}

@JsonProperty(FIELD_NESTED_INDEXES)
@Nullable
public int[] nestedIndexes() {
return nestedIndexes;
}

@JsonProperty(FIELD_NESTED_ARITIES)
@Nullable
public int[] nestedArities() {
return nestedArities;
}

@Override
public boolean equals(Object o) {
if (this == o) {
Expand All @@ -75,12 +102,17 @@ public boolean equals(Object o) {
FieldRef fieldRef = (FieldRef) o;
return index == fieldRef.index
&& Objects.equals(name, fieldRef.name)
&& Objects.equals(type, fieldRef.type);
&& Objects.equals(type, fieldRef.type)
&& Arrays.equals(nestedIndexes, fieldRef.nestedIndexes)
&& Arrays.equals(nestedArities, fieldRef.nestedArities);
}

@Override
public int hashCode() {
return Objects.hash(index, name, type);
int result = Objects.hash(index, name, type);
result = 31 * result + Arrays.hashCode(nestedIndexes);
result = 31 * result + Arrays.hashCode(nestedArities);
return result;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,34 @@ public DataType outputType() {

@Override
public Object transform(InternalRow row) {
return get(row, fieldRef.index(), fieldRef.type());
if (fieldRef.nestedIndexes() == null || fieldRef.nestedIndexes().length == 0) {
return get(row, fieldRef.index(), fieldRef.type());
}

InternalRow currentRow = row;
if (currentRow == null) {
return null;
}
int[] indexes = fieldRef.nestedIndexes();
int[] arities = fieldRef.nestedArities();

if (currentRow.isNullAt(fieldRef.index())) {
return null;
}
currentRow = currentRow.getRow(fieldRef.index(), arities[0]);

for (int i = 0; i < indexes.length - 1; i++) {
if (currentRow == null || currentRow.isNullAt(indexes[i])) {
return null;
}
currentRow = currentRow.getRow(indexes[i], arities[i + 1]);
}

if (currentRow == null) {
return null;
}
int leafIndex = indexes[indexes.length - 1];
return get(currentRow, leafIndex, fieldRef.type());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ public boolean test(
return !(function instanceof AlwaysFalse);
}
FieldRef fieldRef = fieldRefOptional.get();
if (fieldRef.nestedIndexes() != null && fieldRef.nestedIndexes().length > 0) {
return true;
}
int index = fieldRef.index();
DataType type = fieldRef.type();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@

package org.apache.paimon.spark;

import org.apache.paimon.predicate.FieldRef;
import org.apache.paimon.predicate.FieldTransform;
import org.apache.paimon.predicate.Predicate;
import org.apache.paimon.predicate.PredicateBuilder;
import org.apache.paimon.predicate.Transform;
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.RowType;

Expand Down Expand Up @@ -107,53 +110,55 @@ public Predicate convert(Filter filter) {
return PredicateBuilder.alwaysFalse();
} else if (filter instanceof EqualTo) {
EqualTo eq = (EqualTo) filter;
int index = fieldIndex(eq.attribute());
FieldInfo fieldInfo = resolveField(eq.attribute());
if (isNaN(eq.value())) {
return builder.isNaN(index);
return builder.isNaN(fieldInfo.transform());
}
Object literal = convertLiteral(index, eq.value());
return builder.equal(index, literal);
Object literal = convertLiteral(fieldInfo.type(), eq.value());
return builder.equal(fieldInfo.transform(), literal);
} else if (filter instanceof EqualNullSafe) {
EqualNullSafe eq = (EqualNullSafe) filter;
int index = fieldIndex(eq.attribute());
FieldInfo fieldInfo = resolveField(eq.attribute());
if (eq.value() == null) {
return builder.isNull(index);
return builder.isNull(fieldInfo.transform());
} else {
Object literal = convertLiteral(index, eq.value());
return builder.equal(index, literal);
Object literal = convertLiteral(fieldInfo.type(), eq.value());
return builder.equal(fieldInfo.transform(), literal);
}
} else if (filter instanceof GreaterThan) {
GreaterThan gt = (GreaterThan) filter;
int index = fieldIndex(gt.attribute());
Object literal = convertLiteral(index, gt.value());
return builder.greaterThan(index, literal);
FieldInfo fieldInfo = resolveField(gt.attribute());
Object literal = convertLiteral(fieldInfo.type(), gt.value());
return builder.greaterThan(fieldInfo.transform(), literal);
} else if (filter instanceof GreaterThanOrEqual) {
GreaterThanOrEqual gt = (GreaterThanOrEqual) filter;
int index = fieldIndex(gt.attribute());
Object literal = convertLiteral(index, gt.value());
return builder.greaterOrEqual(index, literal);
FieldInfo fieldInfo = resolveField(gt.attribute());
Object literal = convertLiteral(fieldInfo.type(), gt.value());
return builder.greaterOrEqual(fieldInfo.transform(), literal);
} else if (filter instanceof LessThan) {
LessThan lt = (LessThan) filter;
int index = fieldIndex(lt.attribute());
Object literal = convertLiteral(index, lt.value());
return builder.lessThan(index, literal);
FieldInfo fieldInfo = resolveField(lt.attribute());
Object literal = convertLiteral(fieldInfo.type(), lt.value());
return builder.lessThan(fieldInfo.transform(), literal);
} else if (filter instanceof LessThanOrEqual) {
LessThanOrEqual lt = (LessThanOrEqual) filter;
int index = fieldIndex(lt.attribute());
Object literal = convertLiteral(index, lt.value());
return builder.lessOrEqual(index, literal);
FieldInfo fieldInfo = resolveField(lt.attribute());
Object literal = convertLiteral(fieldInfo.type(), lt.value());
return builder.lessOrEqual(fieldInfo.transform(), literal);
} else if (filter instanceof In) {
In in = (In) filter;
int index = fieldIndex(in.attribute());
FieldInfo fieldInfo = resolveField(in.attribute());
return builder.in(
index,
fieldInfo.transform(),
Arrays.stream(in.values())
.map(v -> convertLiteral(index, v))
.map(v -> convertLiteral(fieldInfo.type(), v))
.collect(Collectors.toList()));
} else if (filter instanceof IsNull) {
return builder.isNull(fieldIndex(((IsNull) filter).attribute()));
FieldInfo fieldInfo = resolveField(((IsNull) filter).attribute());
return builder.isNull(fieldInfo.transform());
} else if (filter instanceof IsNotNull) {
return builder.isNotNull(fieldIndex(((IsNotNull) filter).attribute()));
FieldInfo fieldInfo = resolveField(((IsNotNull) filter).attribute());
return builder.isNotNull(fieldInfo.transform());
} else if (filter instanceof And) {
And and = (And) filter;
return PredicateBuilder.and(convert(and.left()), convert(and.right()));
Expand All @@ -168,19 +173,19 @@ public Predicate convert(Filter filter) {
}
} else if (filter instanceof StringStartsWith) {
StringStartsWith startsWith = (StringStartsWith) filter;
int index = fieldIndex(startsWith.attribute());
Object literal = convertLiteral(index, startsWith.value());
return builder.startsWith(index, literal);
FieldInfo fieldInfo = resolveField(startsWith.attribute());
Object literal = convertLiteral(fieldInfo.type(), startsWith.value());
return builder.startsWith(fieldInfo.transform(), literal);
} else if (filter instanceof StringEndsWith) {
StringEndsWith endsWith = (StringEndsWith) filter;
int index = fieldIndex(endsWith.attribute());
Object literal = convertLiteral(index, endsWith.value());
return builder.endsWith(index, literal);
FieldInfo fieldInfo = resolveField(endsWith.attribute());
Object literal = convertLiteral(fieldInfo.type(), endsWith.value());
return builder.endsWith(fieldInfo.transform(), literal);
} else if (filter instanceof StringContains) {
StringContains contains = (StringContains) filter;
int index = fieldIndex(contains.attribute());
Object literal = convertLiteral(index, contains.value());
return builder.contains(index, literal);
FieldInfo fieldInfo = resolveField(contains.attribute());
Object literal = convertLiteral(fieldInfo.type(), contains.value());
return builder.contains(fieldInfo.transform(), literal);
}

throw new UnsupportedOperationException(
Expand All @@ -198,26 +203,90 @@ private static boolean isNaN(Object value) {
}

public Object convertLiteral(String field, Object value) {
return convertLiteral(fieldIndex(field), value);
FieldInfo fieldInfo = resolveField(field);
return convertLiteral(fieldInfo.type(), value);
}

public String convertString(String field, Object value) {
Object literal = convertLiteral(field, value);
return literal == null ? null : literal.toString();
}

private int fieldIndex(String field) {
int index = rowType.getFieldIndex(field);
// TODO: support nested field
if (index == -1) {
private static class FieldInfo {
private final Transform transform;
private final DataType type;

public FieldInfo(Transform transform, DataType type) {
this.transform = transform;
this.type = type;
}

public Transform transform() {
return transform;
}

public DataType type() {
return type;
}
}

private FieldInfo resolveField(String field) {
String[] parts = field.split("\\.");
int topLevelIndex = rowType.getFieldIndex(parts[0]);
if (topLevelIndex == -1) {
throw new UnsupportedOperationException(
String.format("Nested field '%s' is unsupported.", field));
String.format("Field '%s' is not found in table schema.", parts[0]));
}

DataType fieldType;
int[] nestedIndexes = null;
int[] nestedArities = null;
if (parts.length == 1) {
fieldType = rowType.getTypeAt(topLevelIndex);
} else {
fieldType = getNestedFieldType(rowType, parts);
if (fieldType == null) {
throw new UnsupportedOperationException(
String.format("Nested field '%s' is unsupported.", field));
}
nestedIndexes = new int[parts.length - 1];
nestedArities = new int[parts.length - 1];
DataType currentType = rowType.getTypeAt(topLevelIndex);
for (int i = 0; i < parts.length - 1; i++) {
RowType currentSelection = (RowType) currentType;
nestedArities[i] = currentSelection.getFieldCount();
String nextPart = parts[i + 1];
int nextIndex = currentSelection.getFieldIndex(nextPart);
nestedIndexes[i] = nextIndex;
currentType = currentSelection.getTypeAt(nextIndex);
}
}

Transform transform =
new FieldTransform(
new FieldRef(
topLevelIndex, field, fieldType, nestedIndexes, nestedArities));
return new FieldInfo(transform, fieldType);
}

private DataType getNestedFieldType(RowType rowType, String[] path) {
DataType currentType = rowType;
for (String part : path) {
if (currentType instanceof RowType) {
RowType currentSelection = (RowType) currentType;
int idx = currentSelection.getFieldIndex(part);
if (idx == -1) {
return null;
}
currentType = currentSelection.getTypeAt(idx);
} else {
return null;
}
}
return index;
return currentType;
}

private Object convertLiteral(int index, Object value) {
DataType type = rowType.getTypeAt(index);
private Object convertLiteral(DataType type, Object value) {
return convertJavaObject(type, value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -277,4 +277,44 @@ public void testIgnoreFailure() {
assertThat(converter.convert(not, true)).isNull();
assertThat(converter.convertIgnoreFailure(not)).isNull();
}

@Test
public void testNestedField() {
RowType nestedType =
new RowType(
Arrays.asList(
new DataField(0, "b", new IntType()),
new DataField(1, "c", new VarCharType(10))));
RowType rowType = new RowType(Arrays.asList(new DataField(0, "a", nestedType)));
SparkFilterConverter converter = new SparkFilterConverter(rowType);

EqualTo eq = EqualTo.apply("a.b", 1);
Predicate actual = converter.convert(eq);
assertThat(actual.toString()).isEqualTo("Equal(a.b, 1)");

IsNull isNull = IsNull.apply("a.c");
Predicate actualIsNull = converter.convert(isNull);
assertThat(actualIsNull.toString()).isEqualTo("IsNull(a.c)");

GenericRow nestedRow = new GenericRow(2);
nestedRow.setField(0, 1);
nestedRow.setField(1, fromString("paimon"));
GenericRow row = new GenericRow(1);
row.setField(0, nestedRow);

assertThat(actual.test(row)).isTrue();
assertThat(actualIsNull.test(row)).isFalse();

nestedRow.setField(0, 2);
assertThat(actual.test(row)).isFalse();

nestedRow.setField(1, null);
assertThat(actualIsNull.test(row)).isTrue();

GenericRow min = new GenericRow(1);
min.setField(0, nestedRow);
GenericRow max = new GenericRow(1);
max.setField(0, nestedRow);
assertThat(actual.test(1L, min, max, new GenericArray(new long[] {0L}))).isTrue();
}
}
Loading