Skip to content

Commit

Permalink
GH-3035: ParquetRewriter: Add a column renaming feature (#3036)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxNevermind authored Nov 13, 2024
1 parent 34359c9 commit 686f071
Show file tree
Hide file tree
Showing 3 changed files with 363 additions and 106 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
* Please note the schema of all <code>inputFiles</code> must be the same, otherwise the rewrite will fail.
* <p>
* <h2>Applying column transformations</h2>
* Some supported column transformations: pruning, masking, encrypting, changing a codec.
* Some supported column transformations: pruning, masking, renaming, encrypting, changing a codec.
* See {@link RewriteOptions} and {@link RewriteOptions.Builder} for the full list with description.
* <p>
* <h2><i>Joining</i> with extra files with a different schema</h2>
Expand Down Expand Up @@ -149,28 +149,30 @@ public class ParquetRewriter implements Closeable {
private final IndexCache.CacheStrategy indexCacheStrategy;
private final boolean overwriteInputWithJoinColumns;
private final InternalFileEncryptor nullColumnEncryptor;
private final Map<String, String> renamedColumns;

public ParquetRewriter(RewriteOptions options) throws IOException {
this.newCodecName = options.getNewCodecName();
this.indexCacheStrategy = options.getIndexCacheStrategy();
this.overwriteInputWithJoinColumns = options.getOverwriteInputWithJoinColumns();
this.renamedColumns = options.getRenameColumns();
ParquetConfiguration conf = options.getParquetConfiguration();
OutputFile out = options.getParquetOutputFile();
inputFiles.addAll(getFileReaders(options.getParquetInputFiles(), conf));
inputFilesToJoin.addAll(getFileReaders(options.getParquetInputFilesToJoin(), conf));
this.inputFiles.addAll(getFileReaders(options.getParquetInputFiles(), conf));
this.inputFilesToJoin.addAll(getFileReaders(options.getParquetInputFilesToJoin(), conf));
this.outSchema = pruneColumnsInSchema(getSchema(), options.getPruneColumns());
this.extraMetaData = getExtraMetadata(options);
ensureSameSchema(inputFiles);
ensureSameSchema(inputFilesToJoin);
ensureRowCount();
ensureRenamingCorrectness(outSchema, renamedColumns);
OutputFile out = options.getParquetOutputFile();
LOG.info(
"Start rewriting {} input file(s) {} to {}",
inputFiles.size() + inputFilesToJoin.size(),
Stream.concat(options.getParquetInputFiles().stream(), options.getParquetInputFilesToJoin().stream())
.collect(Collectors.toList()),
out);

this.outSchema = pruneColumnsInSchema(getSchema(), options.getPruneColumns());
this.extraMetaData = getExtraMetadata(options);

if (options.getMaskColumns() != null) {
this.maskColumns = new HashMap<>();
for (Map.Entry<String, MaskMode> col : options.getMaskColumns().entrySet()) {
Expand All @@ -184,9 +186,9 @@ public ParquetRewriter(RewriteOptions options) throws IOException {
}

ParquetFileWriter.Mode writerMode = ParquetFileWriter.Mode.CREATE;
writer = new ParquetFileWriter(
this.writer = new ParquetFileWriter(
out,
outSchema,
renamedColumns.isEmpty() ? outSchema : getSchemaWithRenamedColumns(this.outSchema),
writerMode,
DEFAULT_BLOCK_SIZE,
MAX_PADDING_SIZE_DEFAULT,
Expand All @@ -200,7 +202,8 @@ public ParquetRewriter(RewriteOptions options) throws IOException {
this.nullColumnEncryptor = null;
} else {
this.nullColumnEncryptor = new InternalFileEncryptor(options.getFileEncryptionProperties());
List<ColumnDescriptor> columns = outSchema.getColumns();
List<ColumnDescriptor> columns =
getSchemaWithRenamedColumns(this.outSchema).getColumns();
for (int i = 0; i < columns.size(); i++) {
writer.getEncryptor()
.getColumnSetup(ColumnPath.get(columns.get(i).getPath()), true, i);
Expand All @@ -223,8 +226,8 @@ public ParquetRewriter(
this.writer = writer;
this.outSchema = outSchema;
this.newCodecName = codecName;
extraMetaData = new HashMap<>(meta.getFileMetaData().getKeyValueMetaData());
extraMetaData.put(
this.extraMetaData = new HashMap<>(meta.getFileMetaData().getKeyValueMetaData());
this.extraMetaData.put(
ORIGINAL_CREATED_BY_KEY,
originalCreatedBy != null
? originalCreatedBy
Expand All @@ -239,6 +242,7 @@ public ParquetRewriter(
this.indexCacheStrategy = IndexCache.CacheStrategy.NONE;
this.overwriteInputWithJoinColumns = false;
this.nullColumnEncryptor = null;
this.renamedColumns = new HashMap<>();
}

private MessageType getSchema() {
Expand Down Expand Up @@ -266,6 +270,27 @@ private MessageType getSchema() {
}
}

private MessageType getSchemaWithRenamedColumns(MessageType schema) {
List<Type> fields = schema.getFields().stream()
.map(type -> {
if (!renamedColumns.containsKey(type.getName())) {
return type;
} else if (type.isPrimitive()) {
return new PrimitiveType(
type.getRepetition(),
type.asPrimitiveType().getPrimitiveTypeName(),
renamedColumns.get(type.getName()));
} else {
return new GroupType(
type.getRepetition(),
renamedColumns.get(type.getName()),
type.asGroupType().getFields());
}
})
.collect(Collectors.toList());
return new MessageType(schema.getName(), fields);
}

private Map<String, String> getExtraMetadata(RewriteOptions options) {
List<TransParquetFileReader> allFiles;
if (options.getIgnoreJoinFilesMetadata()) {
Expand Down Expand Up @@ -338,6 +363,21 @@ private void ensureSameSchema(Queue<TransParquetFileReader> inputFileReaders) {
}
}

private void ensureRenamingCorrectness(MessageType schema, Map<String, String> renameMap) {
Set<String> columns = schema.getFields().stream().map(Type::getName).collect(Collectors.toSet());
renameMap.forEach((src, dst) -> {
if (!columns.contains(src)) {
String msg = String.format("Column to rename '%s' is not found in input files schema", src);
LOG.error(msg);
throw new IllegalArgumentException(msg);
} else if (columns.contains(dst)) {
String msg = String.format("Renamed column target name '%s' is already present in a schema", dst);
LOG.error(msg);
throw new IllegalArgumentException(msg);
}
});
}

@Override
public void close() throws IOException {
writer.end(extraMetaData);
Expand Down Expand Up @@ -421,6 +461,27 @@ public void processBlocks() throws IOException {
if (readerToJoin != null) readerToJoin.close();
}

private ColumnPath normalizeFieldsInPath(ColumnPath path) {
if (renamedColumns.isEmpty()) {
return path;
} else {
String[] pathArray = path.toArray();
pathArray[0] = renamedColumns.getOrDefault(pathArray[0], pathArray[0]);
return ColumnPath.get(pathArray);
}
}

private PrimitiveType normalizeNameInType(PrimitiveType type) {
if (renamedColumns.isEmpty()) {
return type;
} else {
return new PrimitiveType(
type.getRepetition(),
type.asPrimitiveType().getPrimitiveTypeName(),
renamedColumns.getOrDefault(type.getName(), type.getName()));
}
}

private void processBlock(
TransParquetFileReader reader,
int blockIdx,
Expand All @@ -431,7 +492,28 @@ private void processBlock(
if (chunk.isEncrypted()) {
throw new IOException("Column " + chunk.getPath().toDotString() + " is already encrypted");
}
ColumnDescriptor descriptor = outSchema.getColumns().get(outColumnIdx);

ColumnChunkMetaData chunkNormalized = chunk;
if (!renamedColumns.isEmpty()) {
// Keep an eye if this get stale because of ColumnChunkMetaData change
chunkNormalized = ColumnChunkMetaData.get(
normalizeFieldsInPath(chunk.getPath()),
normalizeNameInType(chunk.getPrimitiveType()),
chunk.getCodec(),
chunk.getEncodingStats(),
chunk.getEncodings(),
chunk.getStatistics(),
chunk.getFirstDataPageOffset(),
chunk.getDictionaryPageOffset(),
chunk.getValueCount(),
chunk.getTotalSize(),
chunk.getTotalUncompressedSize(),
chunk.getSizeStatistics());
}

ColumnDescriptor descriptorOriginal = outSchema.getColumns().get(outColumnIdx);
ColumnDescriptor descriptorRenamed =
getSchemaWithRenamedColumns(outSchema).getColumns().get(outColumnIdx);
BlockMetaData blockMetaData = reader.getFooter().getBlocks().get(blockIdx);
String originalCreatedBy = reader.getFileMetaData().getCreatedBy();

Expand All @@ -443,13 +525,21 @@ private void processBlock(
// Mask column and compress it again.
MaskMode maskMode = maskColumns.get(chunk.getPath());
if (maskMode.equals(MaskMode.NULLIFY)) {
Type.Repetition repetition = descriptor.getPrimitiveType().getRepetition();
Type.Repetition repetition =
descriptorOriginal.getPrimitiveType().getRepetition();
if (repetition.equals(Type.Repetition.REQUIRED)) {
throw new IOException(
"Required column [" + descriptor.getPrimitiveType().getName() + "] cannot be nullified");
throw new IOException("Required column ["
+ descriptorOriginal.getPrimitiveType().getName() + "] cannot be nullified");
}
nullifyColumn(
reader, blockIdx, descriptor, chunk, writer, newCodecName, encryptColumn, originalCreatedBy);
reader,
blockIdx,
descriptorOriginal,
chunk,
writer,
newCodecName,
encryptColumn,
originalCreatedBy);
} else {
throw new UnsupportedOperationException("Only nullify is supported for now");
}
Expand All @@ -462,7 +552,7 @@ private void processBlock(
}

// Translate compression and/or encryption
writer.startColumn(descriptor, chunk.getValueCount(), newCodecName);
writer.startColumn(descriptorRenamed, chunk.getValueCount(), newCodecName);
processChunk(
reader,
blockMetaData.getRowCount(),
Expand All @@ -480,7 +570,8 @@ private void processBlock(
BloomFilter bloomFilter = indexCache.getBloomFilter(chunk);
ColumnIndex columnIndex = indexCache.getColumnIndex(chunk);
OffsetIndex offsetIndex = indexCache.getOffsetIndex(chunk);
writer.appendColumnChunk(descriptor, reader.getStream(), chunk, bloomFilter, columnIndex, offsetIndex);
writer.appendColumnChunk(
descriptorRenamed, reader.getStream(), chunkNormalized, bloomFilter, columnIndex, offsetIndex);
}
}

Expand Down Expand Up @@ -522,7 +613,7 @@ private void processChunk(
}

if (bloomFilter != null) {
writer.addBloomFilter(chunk.getPath().toDotString(), bloomFilter);
writer.addBloomFilter(normalizeFieldsInPath(chunk.getPath()).toDotString(), bloomFilter);
}

reader.setStreamPosition(chunk.getStartingPos());
Expand Down Expand Up @@ -580,7 +671,7 @@ private void processChunk(
dataPageAAD);
statistics = convertStatistics(
originalCreatedBy,
chunk.getPrimitiveType(),
normalizeNameInType(chunk.getPrimitiveType()),
headerV1.getStatistics(),
columnIndex,
pageOrdinal,
Expand Down Expand Up @@ -648,7 +739,7 @@ private void processChunk(
dataPageAAD);
statistics = convertStatistics(
originalCreatedBy,
chunk.getPrimitiveType(),
normalizeNameInType(chunk.getPrimitiveType()),
headerV2.getStatistics(),
columnIndex,
pageOrdinal,
Expand Down Expand Up @@ -887,7 +978,7 @@ private void nullifyColumn(
CompressionCodecFactory.BytesInputCompressor compressor = codecFactory.getCompressor(newCodecName);

// Create new schema that only has the current column
MessageType newSchema = newSchema(outSchema, descriptor);
MessageType newSchema = getSchemaWithRenamedColumns(newSchema(outSchema, descriptor));
ColumnChunkPageWriteStore cPageStore = new ColumnChunkPageWriteStore(
compressor,
newSchema,
Expand Down
Loading

0 comments on commit 686f071

Please sign in to comment.