Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-364 overwrite to keep collection options #123

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.mongodb.spark.sql.connector.beans.DateTimeBean;
import com.mongodb.spark.sql.connector.beans.PrimitiveBean;
import com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorTestCase;
import com.mongodb.spark.sql.connector.write.TruncateMode;
import java.sql.Date;
import java.sql.Timestamp;
import java.time.Instant;
Expand All @@ -41,6 +42,8 @@
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;

public class RoundTripTest extends MongoSparkConnectorTestCase {

Expand Down Expand Up @@ -68,8 +71,9 @@ void testPrimitiveBean() {
assertIterableEquals(dataSetOriginal, dataSetMongo);
}

@Test
void testBoxedBean() {
@ParameterizedTest
@EnumSource(TruncateMode.class)
void testBoxedBean(TruncateMode mode) {
// Given
List<BoxedBean> dataSetOriginal =
singletonList(new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, true));
Expand All @@ -79,7 +83,7 @@ void testBoxedBean() {
Encoder<BoxedBean> encoder = Encoders.bean(BoxedBean.class);

Dataset<BoxedBean> dataset = spark.createDataset(dataSetOriginal, encoder);
dataset.write().format("mongodb").mode("Overwrite").save();
dataset.write().format("mongodb").mode("Overwrite").option("truncate_mode", mode.name()).save();

// Then
List<BoxedBean> dataSetMongo = spark
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
package com.mongodb.spark.sql.connector.write;

import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertIterableEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
import com.mongodb.client.model.CreateCollectionOptions;
import com.mongodb.client.model.IndexOptions;
import com.mongodb.spark.sql.connector.beans.BoxedBean;
import com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorTestCase;

import java.util.ArrayList;
import java.util.List;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.SparkSession;
import org.bson.Document;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;

public class TruncateModesTest extends MongoSparkConnectorTestCase {

public static final String INT_FIELD_INDEX = "intFieldIndex";
public static final String ID_INDEX = "_id_";

// private static final Encoder<BoxedBean> encoder = Encoders.bean(BoxedBean.class);

@BeforeEach
void setup() {
MongoDatabase database = getDatabase();
getCollection().drop();
CreateCollectionOptions createCollectionOptions =
new CreateCollectionOptions().capped(true).maxDocuments(1024).sizeInBytes(1024);
database.createCollection(getCollectionName(), createCollectionOptions);
MongoCollection<Document> collection = database.getCollection(getCollectionName());
collection.insertOne(new Document().append("intField", null));
collection.createIndex(
new Document().append("intField", 1), new IndexOptions().name(INT_FIELD_INDEX));
}

@Test
void testCollectionDroppedOnOverwrite() {
// Given
List<BoxedBean> dataSetOriginal = singletonList(getBoxedBean());

// when
SparkSession spark = getOrCreateSparkSession();
Encoder<BoxedBean> encoder = Encoders.bean(BoxedBean.class);
Dataset<BoxedBean> dataset = spark.createDataset(dataSetOriginal, encoder);
dataset
.write()
.format("mongodb")
.mode("Overwrite")
.option("truncate_mode", TruncateMode.DROP.name())
.save();

// Then
List<BoxedBean> dataSetMongo =
spark.read().format("mongodb").schema(encoder.schema()).load().as(encoder).collectAsList();
assertIterableEquals(dataSetOriginal, dataSetMongo);

List<String> indexes =
getCollection().listIndexes().map(it -> it.getString("name")).into(new ArrayList<>());
assertEquals(indexes, singletonList(ID_INDEX));
Document options = getCollectionOptions();
assertTrue(options.isEmpty());
}

@ParameterizedTest
@EnumSource(
value = TruncateMode.class,
names = {"DROP"},
mode = EnumSource.Mode.EXCLUDE)
void testOptionKeepingOverwrites(TruncateMode mode) {
// Given
List<BoxedBean> dataSetOriginal = singletonList(getBoxedBean());

// when
SparkSession spark = getOrCreateSparkSession();
Encoder<BoxedBean> encoder = Encoders.bean(BoxedBean.class);
Dataset<BoxedBean> dataset = spark.createDataset(dataSetOriginal, encoder);
dataset.write().format("mongodb").mode("Overwrite").option("truncate_mode", mode.name()).save();

// Then
List<BoxedBean> dataSetMongo =
spark.read().format("mongodb").schema(encoder.schema()).load().as(encoder).collectAsList();
assertIterableEquals(dataSetOriginal, dataSetMongo);

List<String> indexes =
getCollection().listIndexes().map(it -> it.getString("name")).into(new ArrayList<>());
assertEquals(indexes, asList(ID_INDEX, INT_FIELD_INDEX));

Document options = getCollectionOptions();
assertTrue(options.getBoolean("capped"));
}

private @NotNull BoxedBean getBoxedBean() {
return new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, true);
}

private Document getCollectionOptions() {
Document getCollectionMeta =
new Document()
.append("listCollections", 1)
.append("filter", new Document().append("name", getCollectionName()));

Document foundMeta = getDatabase().runCommand(getCollectionMeta);
Document cursor = foundMeta.get("cursor", Document.class);
List<Document> firstBatch = cursor.getList("firstBatch", Document.class);
if (firstBatch.isEmpty()) {
return getCollectionMeta;
}

return firstBatch.get(0).get("options", Document.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import static java.lang.String.format;

import com.mongodb.client.MongoCollection;
import com.mongodb.spark.sql.connector.config.WriteConfig;
import com.mongodb.spark.sql.connector.exceptions.DataException;
import java.util.Arrays;
Expand Down Expand Up @@ -62,7 +61,8 @@ final class MongoBatchWrite implements BatchWrite {
@Override
public DataWriterFactory createBatchWriterFactory(final PhysicalWriteInfo physicalWriteInfo) {
if (truncate) {
writeConfig.doWithCollection(MongoCollection::drop);
TruncateMode mode = TruncateMode.valueOf(writeConfig.getOrDefault("truncate_mode", "DROP"));
mode.truncate(writeConfig);
}
return new MongoDataWriterFactory(info.schema(), writeConfig);
}
Expand All @@ -88,8 +88,10 @@ public void commit(final WriterCommitMessage[] messages) {
@Override
public void abort(final WriterCommitMessage[] messages) {
long tasksCompleted = Arrays.stream(messages).filter(Objects::nonNull).count();
throw new DataException(format(
"Write aborted for: %s. %s/%s tasks completed.",
info.queryId(), tasksCompleted, messages.length));
throw new DataException(
format(
"Write aborted for: %s. %s/%s tasks completed.",
info.queryId(), tasksCompleted, messages.length));
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package com.mongodb.spark.sql.connector.write;

import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
import com.mongodb.spark.sql.connector.config.WriteConfig;
import java.util.ArrayList;
import java.util.List;
import org.bson.Document;

public enum TruncateMode {
DROP {
@Override
public void truncate(final WriteConfig writeConfig) {
writeConfig.doWithCollection(MongoCollection::drop);
}
},
DELETE_ALL {
@Override
public void truncate(final WriteConfig writeConfig) {
writeConfig.doWithCollection(collection -> collection.deleteMany(new Document()));
}
},
RECREATE {
@Override
public void truncate(final WriteConfig writeConfig) {
MongoClient mongoClient = writeConfig.getMongoClient();
MongoDatabase database = mongoClient.getDatabase(writeConfig.getDatabaseName());

String collectionName = writeConfig.getCollectionName();
Document getCollectionMeta =
new Document()
.append("listCollections", 1)
.append("filter", new Document().append("name", collectionName));

Document foundMeta = database.runCommand(getCollectionMeta);
Document cursor = foundMeta.get("cursor", Document.class);
List<Document> firstBatch = cursor.getList("firstBatch", Document.class);
if (firstBatch.isEmpty()) {
return;
}

Document collectionObj = firstBatch.get(0);
Document options = collectionObj.get("options", Document.class);
Document createCollectionWithOptions = new Document().append("create", collectionName);
createCollectionWithOptions.putAll(options);

MongoCollection<Document> originalCollection = database.getCollection(collectionName);
List<Document> originalIndexes = originalCollection.listIndexes().into(new ArrayList<>());
originalCollection.drop();

database.runCommand(createCollectionWithOptions);
Document createIndexes = new Document()
.append("createIndexes", collectionName)
.append("indexes", originalIndexes);
// note: potentially we're missing now only params: writeConcern, commitQuorum, comment
database.runCommand(createIndexes);
}
};

public abstract void truncate(WriteConfig writeConfig);
}