Skip to content

Commit

Permalink
FLINK-33759: add support for nested array with row/map type
Browse files Browse the repository at this point in the history
  • Loading branch information
ukby1234 authored and JingGe committed May 27, 2024
1 parent 87b7193 commit 57b2005
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -381,9 +381,16 @@ private MapWriter(LogicalType keyType, LogicalType valueType, GroupType groupTyp

@Override
public void write(RowData row, int ordinal) {
recordConsumer.startGroup();
writeMapData(row.getMap(ordinal));
}

MapData mapData = row.getMap(ordinal);
@Override
public void write(ArrayData arrayData, int ordinal) {
writeMapData(arrayData.getMap(ordinal));
}

private void writeMapData(MapData mapData) {
recordConsumer.startGroup();

if (mapData != null && mapData.size() > 0) {
recordConsumer.startField(repeatedGroupName, 0);
Expand Down Expand Up @@ -412,9 +419,6 @@ public void write(RowData row, int ordinal) {
}
recordConsumer.endGroup();
}

@Override
public void write(ArrayData arrayData, int ordinal) {}
}

/** It writes an array type field to parquet. */
Expand All @@ -438,8 +442,16 @@ private ArrayWriter(LogicalType t, GroupType groupType) {

@Override
public void write(RowData row, int ordinal) {
writeArrayData(row.getArray(ordinal));
}

@Override
public void write(ArrayData arrayData, int ordinal) {
writeArrayData(arrayData.getArray(ordinal));
}

private void writeArrayData(ArrayData arrayData) {
recordConsumer.startGroup();
ArrayData arrayData = row.getArray(ordinal);
int listLength = arrayData.size();

if (listLength > 0) {
Expand All @@ -458,9 +470,6 @@ public void write(RowData row, int ordinal) {
}
recordConsumer.endGroup();
}

@Override
public void write(ArrayData arrayData, int ordinal) {}
}

/** It writes a row type field to parquet. */
Expand Down Expand Up @@ -500,7 +509,12 @@ public void write(RowData row, int ordinal) {
}

@Override
public void write(ArrayData arrayData, int ordinal) {}
public void write(ArrayData arrayData, int ordinal) {
recordConsumer.startGroup();
RowData rowData = arrayData.getRow(ordinal, fieldWriters.length);
write(rowData);
recordConsumer.endGroup();
}
}

private void writeTimestamp(RecordConsumer recordConsumer, TimestampData timestampData) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,20 @@ class ParquetRowDataWriterTest {
new VarCharType(VarCharType.MAX_LENGTH)),
RowType.of(new VarCharType(VarCharType.MAX_LENGTH), new IntType()));

private static final RowType NESTED_ARRAY_MAP_TYPE =
RowType.of(
new IntType(),
new ArrayType(true, new ArrayType(true, new IntType())),
new ArrayType(
true,
new MapType(
true,
new VarCharType(VarCharType.MAX_LENGTH),
new VarCharType(VarCharType.MAX_LENGTH))));

private static final RowType NESTED_ARRAY_ROW_TYPE =
RowType.of(new IntType(), new ArrayType(true, RowType.of(new IntType())));

@SuppressWarnings("unchecked")
private static final DataFormatConverters.DataFormatConverter<RowData, Row> CONVERTER_COMPLEX =
DataFormatConverters.getConverterForDataType(
Expand All @@ -111,13 +125,29 @@ class ParquetRowDataWriterTest {
DataFormatConverters.getConverterForDataType(
TypeConversions.fromLogicalToDataType(ROW_TYPE));

@SuppressWarnings("unchecked")
private static final DataFormatConverters.DataFormatConverter<RowData, Row>
NESTED_ARRAY_MAP_CONVERTER =
DataFormatConverters.getConverterForDataType(
TypeConversions.fromLogicalToDataType(NESTED_ARRAY_MAP_TYPE));

@SuppressWarnings("unchecked")
private static final DataFormatConverters.DataFormatConverter<RowData, Row>
NESTED_ARRAY_ROW_CONVERTER =
DataFormatConverters.getConverterForDataType(
TypeConversions.fromLogicalToDataType(NESTED_ARRAY_ROW_TYPE));

@Test
void testTypes(@TempDir java.nio.file.Path folder) throws Exception {
Configuration conf = new Configuration();
innerTest(folder, conf, true);
innerTest(folder, conf, false);
complexTypeTest(folder, conf, true);
complexTypeTest(folder, conf, false);
nestedArrayAndMapTest(folder, conf, true);
nestedArrayAndMapTest(folder, conf, false);
nestedArrayAndRowTest(folder, conf, true);
nestedArrayAndRowTest(folder, conf, false);
}

@Test
Expand All @@ -128,6 +158,10 @@ void testCompression(@TempDir java.nio.file.Path folder) throws Exception {
innerTest(folder, conf, false);
complexTypeTest(folder, conf, true);
complexTypeTest(folder, conf, false);
nestedArrayAndMapTest(folder, conf, true);
nestedArrayAndMapTest(folder, conf, false);
nestedArrayAndRowTest(folder, conf, true);
nestedArrayAndRowTest(folder, conf, false);
}

@Test
Expand Down Expand Up @@ -230,6 +264,71 @@ public void complexTypeTest(java.nio.file.Path folder, Configuration conf, boole
assertThat(fileContent).isEqualTo(rows);
}

public void nestedArrayAndMapTest(
java.nio.file.Path folder, Configuration conf, boolean utcTimestamp) throws Exception {
Path path = new Path(folder.toString(), UUID.randomUUID().toString());
int number = 1000;
List<Row> rows = new ArrayList<>(number);

for (int i = 0; i < number; i++) {
Integer v = i;
Map<String, String> mp1 = new HashMap<>();
mp1.put(null, "val_" + i);
Map<String, String> mp2 = new HashMap<>();
mp2.put("key_" + i, null);
mp2.put("key@" + i, "val@" + i);

rows.add(
Row.of(
v,
new Integer[][] {{i, i + 1, null}, {i, i + 2, null}, null},
new Map[] {null, mp1, mp2}));
}

ParquetWriterFactory<RowData> factory =
ParquetRowDataBuilder.createWriterFactory(
NESTED_ARRAY_MAP_TYPE, conf, utcTimestamp);
BulkWriter<RowData> writer =
factory.create(path.getFileSystem().create(path, FileSystem.WriteMode.OVERWRITE));
for (int i = 0; i < number; i++) {
writer.addElement(NESTED_ARRAY_MAP_CONVERTER.toInternal(rows.get(i)));
}
writer.flush();
writer.finish();

File file = new File(path.getPath());
final List<Row> fileContent = readNestedArrayAndMap(file);
assertThat(fileContent).isEqualTo(rows);
}

public void nestedArrayAndRowTest(
java.nio.file.Path folder, Configuration conf, boolean utcTimestamp) throws Exception {
Path path = new Path(folder.toString(), UUID.randomUUID().toString());
int number = 1000;
List<Row> rows = new ArrayList<>(number);

for (int i = 0; i < number; i++) {
Integer v = i;
Integer v1 = i + number + 1;
rows.add(Row.of(v, new Row[] {Row.of(v1)}));
}

ParquetWriterFactory<RowData> factory =
ParquetRowDataBuilder.createWriterFactory(
NESTED_ARRAY_ROW_TYPE, conf, utcTimestamp);
BulkWriter<RowData> writer =
factory.create(path.getFileSystem().create(path, FileSystem.WriteMode.OVERWRITE));
for (int i = 0; i < number; i++) {
writer.addElement(NESTED_ARRAY_ROW_CONVERTER.toInternal(rows.get(i)));
}
writer.flush();
writer.finish();

File file = new File(path.getPath());
final List<Row> fileContent = readNestedArrayAndRowParquetFile(file);
assertThat(fileContent).isEqualTo(rows);
}

private static List<Row> readParquetFile(File file) throws IOException {
InputFile inFile =
HadoopInputFile.fromPath(
Expand Down Expand Up @@ -260,6 +359,95 @@ private static List<Row> readParquetFile(File file) throws IOException {
return results;
}

// TODO: If parquet vectorized reader support nested array or map, remove this function
private static List<Row> readNestedArrayAndMap(File file) throws IOException {
InputFile inFile =
HadoopInputFile.fromPath(
new org.apache.hadoop.fs.Path(file.toURI()), new Configuration());

ArrayList<Row> results = new ArrayList<>();
try (ParquetReader<GenericRecord> reader =
AvroParquetReader.<GenericRecord>builder(inFile).build()) {
GenericRecord next;
while ((next = reader.read()) != null) {
Integer c0 = (Integer) next.get(0);

// read array<array<int>>
List<Integer[]> nestedArray = new ArrayList<>();
ArrayList<GenericData.Record> recordList =
(ArrayList<GenericData.Record>) next.get(1);
recordList.forEach(
record -> {
ArrayList<GenericData.Record> origVals =
(ArrayList<GenericData.Record>) record.get(0);
List<Integer> intArrays = (origVals == null) ? null : new ArrayList<>();
if (origVals != null) {
origVals.forEach(
r -> {
intArrays.add((Integer) r.get(0));
});
}
nestedArray.add(
origVals == null ? null : intArrays.toArray(new Integer[0]));
});

// read array<map<String, String>>
List<Map<String, String>> nestedMap = new ArrayList<>();
recordList = (ArrayList<GenericData.Record>) next.get(2);
recordList.forEach(
record -> {
Map<Utf8, Utf8> origMp = (Map<Utf8, Utf8>) record.get(0);
Map<String, String> mp = (origMp == null) ? null : new HashMap<>();
if (origMp != null) {
for (Utf8 key : origMp.keySet()) {
String k = key == null ? null : key.toString();
String v =
origMp.get(key) == null
? null
: origMp.get(key).toString();
mp.put(k, v);
}
}
nestedMap.add(mp);
});

Row row =
Row.of(
c0,
nestedArray.toArray(new Integer[0][0]),
nestedMap.toArray(new Map[0]));
results.add(row);
}
}

return results;
}

private static List<Row> readNestedArrayAndRowParquetFile(File file) throws IOException {
InputFile inFile =
HadoopInputFile.fromPath(
new org.apache.hadoop.fs.Path(file.toURI()), new Configuration());

ArrayList<Row> results = new ArrayList<>();
try (ParquetReader<GenericRecord> reader =
AvroParquetReader.<GenericRecord>builder(inFile).build()) {
GenericRecord next;
while ((next = reader.read()) != null) {
Integer c0 = (Integer) next.get(0);
List<Row> nestedArray = new ArrayList<>();
ArrayList<GenericData.Record> recordList =
(ArrayList<GenericData.Record>) next.get(1);
for (GenericData.Record record : recordList) {
nestedArray.add(Row.of(((GenericData.Record) record.get(0)).get(0)));
}
Row row = Row.of(c0, nestedArray.toArray(new Row[0]));
results.add(row);
}
}

return results;
}

private LocalDateTime toDateTime(Integer v) {
v = (v > 0 ? v : -v) % 1000;
return LocalDateTime.now().plusNanos(v).plusSeconds(v);
Expand Down

0 comments on commit 57b2005

Please sign in to comment.