From 67dbd96d2ae5a94d7ef82b56a9bcb750966bca58 Mon Sep 17 00:00:00 2001 From: Vasily Bondarenko Date: Wed, 31 Jul 2024 14:18:05 +0100 Subject: [PATCH] SPARK-364 overwrite to keep collection options --- .../spark/sql/connector/RoundTripTest.java | 10 +- .../connector/write/TruncateModesTest.java | 124 ++++++++++++++++++ .../sql/connector/write/MongoBatchWrite.java | 12 +- .../sql/connector/write/TruncateMode.java | 62 +++++++++ 4 files changed, 200 insertions(+), 8 deletions(-) create mode 100644 src/integrationTest/java/com/mongodb/spark/sql/connector/write/TruncateModesTest.java create mode 100644 src/main/java/com/mongodb/spark/sql/connector/write/TruncateMode.java diff --git a/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java b/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java index 284ad28a..611a0137 100644 --- a/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java +++ b/src/integrationTest/java/com/mongodb/spark/sql/connector/RoundTripTest.java @@ -26,6 +26,7 @@ import com.mongodb.spark.sql.connector.beans.DateTimeBean; import com.mongodb.spark.sql.connector.beans.PrimitiveBean; import com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorTestCase; +import com.mongodb.spark.sql.connector.write.TruncateMode; import java.sql.Date; import java.sql.Timestamp; import java.time.Instant; @@ -41,6 +42,8 @@ import org.apache.spark.sql.Encoders; import org.apache.spark.sql.SparkSession; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; public class RoundTripTest extends MongoSparkConnectorTestCase { @@ -68,8 +71,9 @@ void testPrimitiveBean() { assertIterableEquals(dataSetOriginal, dataSetMongo); } - @Test - void testBoxedBean() { + @ParameterizedTest + @EnumSource(TruncateMode.class) + void testBoxedBean(TruncateMode mode) { // Given List dataSetOriginal = singletonList(new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, true)); @@ -79,7 +83,7 @@ void testBoxedBean() { Encoder encoder = Encoders.bean(BoxedBean.class); Dataset dataset = spark.createDataset(dataSetOriginal, encoder); - dataset.write().format("mongodb").mode("Overwrite").save(); + dataset.write().format("mongodb").mode("Overwrite").option("truncate_mode", mode.name()).save(); // Then List dataSetMongo = spark diff --git a/src/integrationTest/java/com/mongodb/spark/sql/connector/write/TruncateModesTest.java b/src/integrationTest/java/com/mongodb/spark/sql/connector/write/TruncateModesTest.java new file mode 100644 index 00000000..ad564ff5 --- /dev/null +++ b/src/integrationTest/java/com/mongodb/spark/sql/connector/write/TruncateModesTest.java @@ -0,0 +1,124 @@ +package com.mongodb.spark.sql.connector.write; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertIterableEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoDatabase; +import com.mongodb.client.model.CreateCollectionOptions; +import com.mongodb.client.model.IndexOptions; +import com.mongodb.spark.sql.connector.beans.BoxedBean; +import com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorTestCase; + +import java.util.ArrayList; +import java.util.List; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.SparkSession; +import org.bson.Document; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +public class TruncateModesTest extends MongoSparkConnectorTestCase { + + public static final String INT_FIELD_INDEX = "intFieldIndex"; + public static final String ID_INDEX = "_id_"; + + // private static final Encoder encoder = Encoders.bean(BoxedBean.class); + + @BeforeEach + void setup() { + MongoDatabase database = getDatabase(); + getCollection().drop(); + CreateCollectionOptions createCollectionOptions = + new CreateCollectionOptions().capped(true).maxDocuments(1024).sizeInBytes(1024); + database.createCollection(getCollectionName(), createCollectionOptions); + MongoCollection collection = database.getCollection(getCollectionName()); + collection.insertOne(new Document().append("intField", null)); + collection.createIndex( + new Document().append("intField", 1), new IndexOptions().name(INT_FIELD_INDEX)); + } + + @Test + void testCollectionDroppedOnOverwrite() { + // Given + List dataSetOriginal = singletonList(getBoxedBean()); + + // when + SparkSession spark = getOrCreateSparkSession(); + Encoder encoder = Encoders.bean(BoxedBean.class); + Dataset dataset = spark.createDataset(dataSetOriginal, encoder); + dataset + .write() + .format("mongodb") + .mode("Overwrite") + .option("truncate_mode", TruncateMode.DROP.name()) + .save(); + + // Then + List dataSetMongo = + spark.read().format("mongodb").schema(encoder.schema()).load().as(encoder).collectAsList(); + assertIterableEquals(dataSetOriginal, dataSetMongo); + + List indexes = + getCollection().listIndexes().map(it -> it.getString("name")).into(new ArrayList<>()); + assertEquals(indexes, singletonList(ID_INDEX)); + Document options = getCollectionOptions(); + assertTrue(options.isEmpty()); + } + + @ParameterizedTest + @EnumSource( + value = TruncateMode.class, + names = {"DROP"}, + mode = EnumSource.Mode.EXCLUDE) + void testOptionKeepingOverwrites(TruncateMode mode) { + // Given + List dataSetOriginal = singletonList(getBoxedBean()); + + // when + SparkSession spark = getOrCreateSparkSession(); + Encoder encoder = Encoders.bean(BoxedBean.class); + Dataset dataset = spark.createDataset(dataSetOriginal, encoder); + dataset.write().format("mongodb").mode("Overwrite").option("truncate_mode", mode.name()).save(); + + // Then + List dataSetMongo = + spark.read().format("mongodb").schema(encoder.schema()).load().as(encoder).collectAsList(); + assertIterableEquals(dataSetOriginal, dataSetMongo); + + List indexes = + getCollection().listIndexes().map(it -> it.getString("name")).into(new ArrayList<>()); + assertEquals(indexes, asList(ID_INDEX, INT_FIELD_INDEX)); + + Document options = getCollectionOptions(); + assertTrue(options.getBoolean("capped")); + } + + private @NotNull BoxedBean getBoxedBean() { + return new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, true); + } + + private Document getCollectionOptions() { + Document getCollectionMeta = + new Document() + .append("listCollections", 1) + .append("filter", new Document().append("name", getCollectionName())); + + Document foundMeta = getDatabase().runCommand(getCollectionMeta); + Document cursor = foundMeta.get("cursor", Document.class); + List firstBatch = cursor.getList("firstBatch", Document.class); + if (firstBatch.isEmpty()) { + return getCollectionMeta; + } + + return firstBatch.get(0).get("options", Document.class); + } +} diff --git a/src/main/java/com/mongodb/spark/sql/connector/write/MongoBatchWrite.java b/src/main/java/com/mongodb/spark/sql/connector/write/MongoBatchWrite.java index 09063f0d..52490083 100644 --- a/src/main/java/com/mongodb/spark/sql/connector/write/MongoBatchWrite.java +++ b/src/main/java/com/mongodb/spark/sql/connector/write/MongoBatchWrite.java @@ -19,7 +19,6 @@ import static java.lang.String.format; -import com.mongodb.client.MongoCollection; import com.mongodb.spark.sql.connector.config.WriteConfig; import com.mongodb.spark.sql.connector.exceptions.DataException; import java.util.Arrays; @@ -62,7 +61,8 @@ final class MongoBatchWrite implements BatchWrite { @Override public DataWriterFactory createBatchWriterFactory(final PhysicalWriteInfo physicalWriteInfo) { if (truncate) { - writeConfig.doWithCollection(MongoCollection::drop); + TruncateMode mode = TruncateMode.valueOf(writeConfig.getOrDefault("truncate_mode", "DROP")); + mode.truncate(writeConfig); } return new MongoDataWriterFactory(info.schema(), writeConfig); } @@ -88,8 +88,10 @@ public void commit(final WriterCommitMessage[] messages) { @Override public void abort(final WriterCommitMessage[] messages) { long tasksCompleted = Arrays.stream(messages).filter(Objects::nonNull).count(); - throw new DataException(format( - "Write aborted for: %s. %s/%s tasks completed.", - info.queryId(), tasksCompleted, messages.length)); + throw new DataException( + format( + "Write aborted for: %s. %s/%s tasks completed.", + info.queryId(), tasksCompleted, messages.length)); } + } diff --git a/src/main/java/com/mongodb/spark/sql/connector/write/TruncateMode.java b/src/main/java/com/mongodb/spark/sql/connector/write/TruncateMode.java new file mode 100644 index 00000000..12364b77 --- /dev/null +++ b/src/main/java/com/mongodb/spark/sql/connector/write/TruncateMode.java @@ -0,0 +1,62 @@ +package com.mongodb.spark.sql.connector.write; + +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoCollection; +import com.mongodb.client.MongoDatabase; +import com.mongodb.spark.sql.connector.config.WriteConfig; +import java.util.ArrayList; +import java.util.List; +import org.bson.Document; + +public enum TruncateMode { + DROP { + @Override + public void truncate(final WriteConfig writeConfig) { + writeConfig.doWithCollection(MongoCollection::drop); + } + }, + DELETE_ALL { + @Override + public void truncate(final WriteConfig writeConfig) { + writeConfig.doWithCollection(collection -> collection.deleteMany(new Document())); + } + }, + RECREATE { + @Override + public void truncate(final WriteConfig writeConfig) { + MongoClient mongoClient = writeConfig.getMongoClient(); + MongoDatabase database = mongoClient.getDatabase(writeConfig.getDatabaseName()); + + String collectionName = writeConfig.getCollectionName(); + Document getCollectionMeta = + new Document() + .append("listCollections", 1) + .append("filter", new Document().append("name", collectionName)); + + Document foundMeta = database.runCommand(getCollectionMeta); + Document cursor = foundMeta.get("cursor", Document.class); + List firstBatch = cursor.getList("firstBatch", Document.class); + if (firstBatch.isEmpty()) { + return; + } + + Document collectionObj = firstBatch.get(0); + Document options = collectionObj.get("options", Document.class); + Document createCollectionWithOptions = new Document().append("create", collectionName); + createCollectionWithOptions.putAll(options); + + MongoCollection originalCollection = database.getCollection(collectionName); + List originalIndexes = originalCollection.listIndexes().into(new ArrayList<>()); + originalCollection.drop(); + + database.runCommand(createCollectionWithOptions); + Document createIndexes = new Document() + .append("createIndexes", collectionName) + .append("indexes", originalIndexes); + // note: potentially we're missing now only params: writeConcern, commitQuorum, comment + database.runCommand(createIndexes); + } + }; + + public abstract void truncate(WriteConfig writeConfig); +}