Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-414 support for delete #124

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

package com.mongodb.spark.sql.connector;

import static com.mongodb.spark.sql.connector.mongodb.MongoSparkConnectorHelper.CATALOG;
import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertIterableEquals;

import com.mongodb.spark.sql.connector.beans.BoxedBean;
Expand All @@ -39,6 +41,7 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.jupiter.api.Test;

Expand Down Expand Up @@ -162,4 +165,33 @@ void testComplexBean() {
.collectAsList();
assertIterableEquals(dataSetOriginal, dataSetMongo);
}

@Test
void testCatalogAccessAndDelete() {
List<BoxedBean> dataSetOriginal =
asList(
new BoxedBean((byte) 1, (short) 2, 0, 4L, 5.0f, 6.0, true),
new BoxedBean((byte) 1, (short) 2, 1, 4L, 5.0f, 6.0, true),
new BoxedBean((byte) 1, (short) 2, 2, 4L, 5.0f, 6.0, true),
new BoxedBean((byte) 1, (short) 2, 3, 4L, 5.0f, 6.0, false),
new BoxedBean((byte) 1, (short) 2, 4, 4L, 5.0f, 6.0, false),
new BoxedBean((byte) 1, (short) 2, 5, 4L, 5.0f, 6.0, false));

SparkSession spark = getOrCreateSparkSession();
Encoder<BoxedBean> encoder = Encoders.bean(BoxedBean.class);
spark
.createDataset(dataSetOriginal, encoder)
.write()
.format("mongodb")
.mode("Overwrite")
.save();

String tableName = CATALOG + "." + HELPER.getDatabaseName() + "." + HELPER.getCollectionName();
List<Row> rows = spark.sql("select * from " + tableName).collectAsList();
assertEquals(6, rows.size());

spark.sql("delete from " + tableName + " where not booleanField and intField > 3");
rows = spark.sql("select * from " + tableName).collectAsList();
assertEquals(4, rows.size());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.mongodb.client.model.UpdateOptions;
import com.mongodb.client.model.Updates;
import com.mongodb.connection.ClusterType;
import com.mongodb.spark.sql.connector.MongoCatalog;
import com.mongodb.spark.sql.connector.config.MongoConfig;
import java.io.File;
import java.io.IOException;
Expand Down Expand Up @@ -62,6 +63,7 @@ public class MongoSparkConnectorHelper
"{_id: '%s', pk: '%s', dups: '%s', i: %d, s: '%s'}";
private static final String COMPLEX_SAMPLE_DATA_TEMPLATE =
"{_id: '%s', nested: {pk: '%s', dups: '%s', i: %d}, s: '%s'}";
public static final String CATALOG = "mongo_catalog";

private static final Logger LOGGER = LoggerFactory.getLogger(MongoSparkConnectorHelper.class);

Expand Down Expand Up @@ -146,6 +148,7 @@ public SparkConf getSparkConf() {
.set("spark.sql.streaming.checkpointLocation", getTempDirectory())
.set("spark.sql.streaming.forceDeleteTempCheckpointLocation", "true")
.set("spark.app.id", "MongoSparkConnector")
.set("spark.sql.catalog." + CATALOG, MongoCatalog.class.getCanonicalName())
.set(
MongoConfig.PREFIX + MongoConfig.CONNECTION_STRING_CONFIG,
getConnectionString().getConnectionString())
Expand Down
214 changes: 214 additions & 0 deletions src/main/java/com/mongodb/spark/sql/connector/ExpressionConverter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
package com.mongodb.spark.sql.connector;

import static com.mongodb.spark.sql.connector.schema.RowToBsonDocumentConverter.createObjectToBsonValue;
import static java.lang.String.format;

import com.mongodb.client.model.Filters;
import com.mongodb.spark.sql.connector.assertions.Assertions;
import com.mongodb.spark.sql.connector.config.WriteConfig;
import com.mongodb.spark.sql.connector.schema.RowToBsonDocumentConverter;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.sources.And;
import org.apache.spark.sql.sources.EqualNullSafe;
import org.apache.spark.sql.sources.EqualTo;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.sources.GreaterThan;
import org.apache.spark.sql.sources.GreaterThanOrEqual;
import org.apache.spark.sql.sources.In;
import org.apache.spark.sql.sources.IsNotNull;
import org.apache.spark.sql.sources.IsNull;
import org.apache.spark.sql.sources.LessThan;
import org.apache.spark.sql.sources.LessThanOrEqual;
import org.apache.spark.sql.sources.Not;
import org.apache.spark.sql.sources.Or;
import org.apache.spark.sql.sources.StringContains;
import org.apache.spark.sql.sources.StringEndsWith;
import org.apache.spark.sql.sources.StringStartsWith;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.bson.BsonValue;
import org.bson.conversions.Bson;
import org.jetbrains.annotations.Nullable;
import org.jetbrains.annotations.VisibleForTesting;

public class ExpressionConverter {
private final StructType schema;

public ExpressionConverter(final StructType schema) {
this.schema = schema;
}

public FilterAndPipelineStage processFilter(final Filter filter) {
Assertions.ensureArgument(() -> filter != null, () -> "Invalid argument filter cannot be null");
if (filter instanceof And) {
And andFilter = (And) filter;
FilterAndPipelineStage eitherLeft = processFilter(andFilter.left());
FilterAndPipelineStage eitherRight = processFilter(andFilter.right());
if (eitherLeft.hasPipelineStage() && eitherRight.hasPipelineStage()) {
return new FilterAndPipelineStage(
filter, Filters.and(eitherLeft.getPipelineStage(), eitherRight.getPipelineStage()));
}
} else if (filter instanceof EqualNullSafe) {
EqualNullSafe equalNullSafe = (EqualNullSafe) filter;
String fieldName = unquoteFieldName(equalNullSafe.attribute());
return new FilterAndPipelineStage(
filter,
getBsonValue(fieldName, equalNullSafe.value())
.map(bsonValue -> Filters.eq(fieldName, bsonValue))
.orElse(null));
} else if (filter instanceof EqualTo) {
EqualTo equalTo = (EqualTo) filter;
String fieldName = unquoteFieldName(equalTo.attribute());
return new FilterAndPipelineStage(
filter,
getBsonValue(fieldName, equalTo.value())
.map(bsonValue -> Filters.eq(fieldName, bsonValue))
.orElse(null));
} else if (filter instanceof GreaterThan) {
GreaterThan greaterThan = (GreaterThan) filter;
String fieldName = unquoteFieldName(greaterThan.attribute());
return new FilterAndPipelineStage(
filter,
getBsonValue(fieldName, greaterThan.value())
.map(bsonValue -> Filters.gt(fieldName, bsonValue))
.orElse(null));
} else if (filter instanceof GreaterThanOrEqual) {
GreaterThanOrEqual greaterThanOrEqual = (GreaterThanOrEqual) filter;
String fieldName = unquoteFieldName(greaterThanOrEqual.attribute());
return new FilterAndPipelineStage(
filter,
getBsonValue(fieldName, greaterThanOrEqual.value())
.map(bsonValue -> Filters.gte(fieldName, bsonValue))
.orElse(null));
} else if (filter instanceof In) {
In inFilter = (In) filter;
String fieldName = unquoteFieldName(inFilter.attribute());
List<BsonValue> values =
Arrays.stream(inFilter.values())
.map(v -> getBsonValue(fieldName, v))
.filter(Optional::isPresent)
.map(Optional::get)
.collect(Collectors.toList());

// Ensure all values were matched otherwise leave to Spark to filter.
Bson pipelineStage = null;
if (values.size() == inFilter.values().length) {
pipelineStage = Filters.in(fieldName, values);
}
return new FilterAndPipelineStage(filter, pipelineStage);
} else if (filter instanceof IsNull) {
IsNull isNullFilter = (IsNull) filter;
String fieldName = unquoteFieldName(isNullFilter.attribute());
return new FilterAndPipelineStage(filter, Filters.eq(fieldName, null));
} else if (filter instanceof IsNotNull) {
IsNotNull isNotNullFilter = (IsNotNull) filter;
String fieldName = unquoteFieldName(isNotNullFilter.attribute());
return new FilterAndPipelineStage(filter, Filters.ne(fieldName, null));
} else if (filter instanceof LessThan) {
LessThan lessThan = (LessThan) filter;
String fieldName = unquoteFieldName(lessThan.attribute());
return new FilterAndPipelineStage(
filter,
getBsonValue(fieldName, lessThan.value())
.map(bsonValue -> Filters.lt(fieldName, bsonValue))
.orElse(null));
} else if (filter instanceof LessThanOrEqual) {
LessThanOrEqual lessThanOrEqual = (LessThanOrEqual) filter;
String fieldName = unquoteFieldName(lessThanOrEqual.attribute());
return new FilterAndPipelineStage(
filter,
getBsonValue(fieldName, lessThanOrEqual.value())
.map(bsonValue -> Filters.lte(fieldName, bsonValue))
.orElse(null));
} else if (filter instanceof Not) {
Not notFilter = (Not) filter;
FilterAndPipelineStage notChild = processFilter(notFilter.child());
if (notChild.hasPipelineStage()) {
return new FilterAndPipelineStage(filter, Filters.not(notChild.pipelineStage));
}
} else if (filter instanceof Or) {
Or or = (Or) filter;
FilterAndPipelineStage eitherLeft = processFilter(or.left());
FilterAndPipelineStage eitherRight = processFilter(or.right());
if (eitherLeft.hasPipelineStage() && eitherRight.hasPipelineStage()) {
return new FilterAndPipelineStage(
filter, Filters.or(eitherLeft.getPipelineStage(), eitherRight.getPipelineStage()));
}
} else if (filter instanceof StringContains) {
StringContains stringContains = (StringContains) filter;
String fieldName = unquoteFieldName(stringContains.attribute());
return new FilterAndPipelineStage(
filter, Filters.regex(fieldName, format(".*%s.*", stringContains.value())));
} else if (filter instanceof StringEndsWith) {
StringEndsWith stringEndsWith = (StringEndsWith) filter;
String fieldName = unquoteFieldName(stringEndsWith.attribute());
return new FilterAndPipelineStage(
filter, Filters.regex(fieldName, format(".*%s$", stringEndsWith.value())));
} else if (filter instanceof StringStartsWith) {
StringStartsWith stringStartsWith = (StringStartsWith) filter;
String fieldName = unquoteFieldName(stringStartsWith.attribute());
return new FilterAndPipelineStage(
filter, Filters.regex(fieldName, format("^%s.*", stringStartsWith.value())));
}
return new FilterAndPipelineStage(filter, null);
}

@VisibleForTesting
static String unquoteFieldName(final String fieldName) {
// Spark automatically escapes hyphenated names using backticks
if (fieldName.contains("`")) {
return new Column(fieldName).toString();
}
return fieldName;
}

private Optional<BsonValue> getBsonValue(final String fieldName, final Object value) {
try {
StructType localSchema = schema;
DataType localDataType = localSchema;

for (String localFieldName : fieldName.split("\\.")) {
StructField localField = localSchema.apply(localFieldName);
localDataType = localField.dataType();
if (localField.dataType() instanceof StructType) {
localSchema = (StructType) localField.dataType();
}
}
RowToBsonDocumentConverter.ObjectToBsonValue objectToBsonValue =
createObjectToBsonValue(localDataType, WriteConfig.ConvertJson.FALSE, false);
return Optional.of(objectToBsonValue.apply(value));
} catch (Exception e) {
// ignore
return Optional.empty();
}
}

/** FilterAndPipelineStage - contains an optional pipeline stage for the filter. */
public static final class FilterAndPipelineStage {

private final Filter filter;
private final Bson pipelineStage;

private FilterAndPipelineStage(final Filter filter, @Nullable final Bson pipelineStage) {
this.filter = filter;
this.pipelineStage = pipelineStage;
}

public Filter getFilter() {
return filter;
}

public Bson getPipelineStage() {
return pipelineStage;
}

public boolean hasPipelineStage() {
return pipelineStage != null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.mongodb.spark.sql.connector.config.ReadConfig;
import com.mongodb.spark.sql.connector.config.WriteConfig;
import com.mongodb.spark.sql.connector.exceptions.MongoSparkException;
import com.mongodb.spark.sql.connector.schema.InferSchema;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand Down Expand Up @@ -239,7 +240,7 @@ public Table loadTable(final Identifier identifier) throws NoSuchTableException
properties.put(
MongoConfig.READ_PREFIX + MongoConfig.DATABASE_NAME_CONFIG, identifier.namespace()[0]);
properties.put(MongoConfig.READ_PREFIX + MongoConfig.COLLECTION_NAME_CONFIG, identifier.name());
return new MongoTable(MongoConfig.readConfig(properties));
return new MongoTable(InferSchema.inferSchema(new CaseInsensitiveStringMap(properties)), MongoConfig.readConfig(properties));
}

/**
Expand Down
31 changes: 30 additions & 1 deletion src/main/java/com/mongodb/spark/sql/connector/MongoTable.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,25 @@

import static java.util.Arrays.asList;

import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
import com.mongodb.client.model.Filters;
import com.mongodb.spark.connector.Versions;
import com.mongodb.spark.sql.connector.ExpressionConverter.FilterAndPipelineStage;
import com.mongodb.spark.sql.connector.config.MongoConfig;
import com.mongodb.spark.sql.connector.config.ReadConfig;
import com.mongodb.spark.sql.connector.config.WriteConfig;
import com.mongodb.spark.sql.connector.read.MongoScanBuilder;
import com.mongodb.spark.sql.connector.write.MongoWriteBuilder;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.spark.sql.connector.catalog.SupportsDelete;
import org.apache.spark.sql.connector.catalog.SupportsRead;
import org.apache.spark.sql.connector.catalog.SupportsWrite;
import org.apache.spark.sql.connector.catalog.Table;
Expand All @@ -38,13 +46,16 @@
import org.apache.spark.sql.connector.read.ScanBuilder;
import org.apache.spark.sql.connector.write.LogicalWriteInfo;
import org.apache.spark.sql.connector.write.WriteBuilder;
import org.apache.spark.sql.sources.Filter;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.bson.Document;
import org.bson.conversions.Bson;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Represents a MongoDB Collection. */
final class MongoTable implements Table, SupportsWrite, SupportsRead {
final class MongoTable implements Table, SupportsWrite, SupportsRead, SupportsDelete {
private static final Logger LOGGER = LoggerFactory.getLogger(MongoTable.class);
private static final Set<TableCapability> TABLE_CAPABILITY_SET = new HashSet<>(asList(
TableCapability.BATCH_WRITE,
Expand Down Expand Up @@ -179,4 +190,22 @@ public int hashCode() {
result = 31 * result + Arrays.hashCode(partitioning);
return result;
}

@Override
public void deleteWhere(final Filter[] filters) {
ExpressionConverter converter = new ExpressionConverter(schema);

List<Bson> stages = Arrays.stream(filters)
.map(converter::processFilter)
.filter(FilterAndPipelineStage::hasPipelineStage)
.map(FilterAndPipelineStage::getPipelineStage)
.collect(Collectors.toList());
Bson query = Filters.and(stages);
WriteConfig writeConfig = mongoConfig.toWriteConfig();

MongoClient mongoClient = writeConfig.getMongoClient();
MongoDatabase database = mongoClient.getDatabase(writeConfig.getDatabaseName());
MongoCollection<Document> collection = database.getCollection(writeConfig.getCollectionName());
collection.deleteMany(query);
}
}
Loading