Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3531,6 +3531,124 @@ void testBulkUpdateAllOperatorTypes() throws Exception {
}
}

@Test
@DisplayName(
"Should efficiently batch updates across multiple key groups with complex operations")
void testBulkUpdateMultipleGroupsComplexOperations() throws Exception {
Map<Key, java.util.Collection<SubDocumentUpdate>> updates = new LinkedHashMap<>();

// ===== Group 1: Top-level primitive + top-level array (3 keys: 1, 5, 8) =====
// All have item="Soap" - these should be batched together
// This tests: SET on primitive field, APPEND_TO_LIST on array field
List<SubDocumentUpdate> group1Updates =
List.of(
SubDocumentUpdate.of("price", 99), // SET operator (top-level primitive)
SubDocumentUpdate.builder()
.subDocument("tags")
.operator(UpdateOperator.APPEND_TO_LIST)
.subDocumentValue(SubDocumentValue.of(new String[] {"updated-tag", "batch-test"}))
.build()); // APPEND_TO_LIST on top-level array

updates.put(rawKey("1"), group1Updates);
updates.put(rawKey("5"), group1Updates);
updates.put(rawKey("8"), group1Updates);

// ===== Group 2: Nested JSONB updates (2 keys: 3, 7) =====
// Both have props - these should be batched together
// This tests: SET on nested JSONB fields
List<SubDocumentUpdate> group2Updates =
List.of(
SubDocumentUpdate.builder()
.subDocument("props.brand")
.operator(UpdateOperator.SET)
.subDocumentValue(SubDocumentValue.of("PremiumBrand"))
.build(), // SET on nested JSONB primitive
SubDocumentUpdate.builder()
.subDocument("props.size")
.operator(UpdateOperator.SET)
.subDocumentValue(SubDocumentValue.of("XL"))
.build()); // SET on another nested field

updates.put(rawKey("3"), group2Updates);
updates.put(rawKey("7"), group2Updates);

// ===== Group 3: ADD operator + REMOVE_ALL_FROM_LIST (2 keys: 2, 6) =====
// Both have quantity and tags - these should be batched together
// This tests: ADD on numeric field, REMOVE_ALL_FROM_LIST on array
List<SubDocumentUpdate> group3Updates =
List.of(
SubDocumentUpdate.builder()
.subDocument("quantity")
.operator(UpdateOperator.ADD)
.subDocumentValue(SubDocumentValue.of(100))
.build(), // ADD to numeric field
SubDocumentUpdate.builder()
.subDocument("tags")
.operator(UpdateOperator.REMOVE_ALL_FROM_LIST)
.subDocumentValue(SubDocumentValue.of(new String[] {"glass", "plastic"}))
.build()); // REMOVE_ALL_FROM_LIST

updates.put(rawKey("2"), group3Updates);
updates.put(rawKey("6"), group3Updates);

// Execute bulk update - should have 3 groups with 2-3 keys each
BulkUpdateResult result = flatCollection.bulkUpdate(updates, UpdateOptions.builder().build());

// Total unique keys: 1, 2, 3, 5, 6, 7, 8 = 7 keys
assertEquals(7, result.getUpdatedCount(), "Should update 7 rows");

// Verify keys 1, 5, 8 have Group 1 updates (top-level primitive + array)
for (String id : List.of("1", "5", "8")) {
try (CloseableIterator<Document> iter = flatCollection.find(queryById(id))) {
assertTrue(iter.hasNext());
JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson());
assertEquals(99, json.get("price").asInt(), "Key " + id + " price should be 99");
JsonNode tags = json.get("tags");
List<String> tagList = new ArrayList<>();
tags.forEach(t -> tagList.add(t.asText()));
assertTrue(
tagList.contains("updated-tag"), "Key " + id + " should contain 'updated-tag'");
assertTrue(tagList.contains("batch-test"), "Key " + id + " should contain 'batch-test'");
}
}

// Verify keys 3, 7 have Group 2 updates (nested JSONB)
for (String id : List.of("3", "7")) {
try (CloseableIterator<Document> iter = flatCollection.find(queryById(id))) {
assertTrue(iter.hasNext());
JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson());
JsonNode props = json.get("props");
assertNotNull(props, "Key " + id + " should have props");
assertEquals(
"PremiumBrand",
props.get("brand").asText(),
"Key " + id + " brand should be updated");
assertEquals("XL", props.get("size").asText(), "Key " + id + " size should be XL");
}
}

// Verify keys 2, 6 have Group 3 updates (ADD + REMOVE_ALL_FROM_LIST)
try (CloseableIterator<Document> iter = flatCollection.find(queryById("2"))) {
assertTrue(iter.hasNext());
JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson());
assertEquals(101, json.get("quantity").asInt()); // 1 + 100
JsonNode tags = json.get("tags");
List<String> tagList = new ArrayList<>();
tags.forEach(t -> tagList.add(t.asText()));
assertFalse(tagList.contains("glass"), "Key 2 should not have 'glass' tag");
}

try (CloseableIterator<Document> iter = flatCollection.find(queryById("6"))) {
assertTrue(iter.hasNext());
JsonNode json = OBJECT_MAPPER.readTree(iter.next().toJson());
assertEquals(105, json.get("quantity").asInt()); // 5 + 100
JsonNode tags = json.get("tags");
List<String> tagList = new ArrayList<>();
tags.forEach(t -> tagList.add(t.asText()));
assertFalse(tagList.contains("plastic"), "Key 6 should not have 'plastic' tag");
}
}

@Test
@DisplayName("Should handle edge cases: empty map, null map, non-existent keys")
void testBulkUpdateEdgeCases() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
Expand Down Expand Up @@ -874,59 +875,34 @@ public BulkUpdateResult bulkUpdate(

String tableName = tableIdentifier.getTableName();
String quotedPkColumn = PostgresUtils.wrapFieldNamesWithDoubleQuotes(getPKForTable(tableName));

Set<Key> updatedKeys = new HashSet<>();

long batchUpdateTimestamp = System.currentTimeMillis();

try (Connection connection = client.getPooledConnection()) {
for (Map.Entry<Key, Collection<SubDocumentUpdate>> entry : updates.entrySet()) {
Key key = entry.getKey();
Collection<SubDocumentUpdate> keyUpdates = entry.getValue();
// Group keys by their "SQL shape" (same update operations)
Map<String, KeyUpdateGroup> keyGroups = groupKeysByUpdateShape(updates, tableName);

if (keyUpdates == null || keyUpdates.isEmpty()) {
continue;
}
int totalUpdated = 0;

try (Connection connection = client.getPooledConnection()) {
// Execute one multi-row UPDATE per group (or fallback to single-key if group size = 1)
for (Map.Entry<String, KeyUpdateGroup> entry : keyGroups.entrySet()) {
try {
boolean updated =
updateSingleKey(
connection, key, keyUpdates, tableName, quotedPkColumn, batchUpdateTimestamp);
if (updated) {
updatedKeys.add(key);
}
int updated =
executeBatchUpdate(
connection, entry.getValue(), tableName, quotedPkColumn, batchUpdateTimestamp);
totalUpdated += updated;
} catch (Exception e) {
LOGGER.warn("Failed to update key {}: {}", key, e.getMessage());
// Continue with other keys - no cross-key atomicity
LOGGER.warn(
"Failed to update key group (size: {}): {}",
entry.getValue().getKeys().size(),
e.getMessage());
// Continue with other groups - no cross-group atomicity
}
}
} catch (SQLException e) {
throw new IOException("Failed to get connection for bulk update", e);
}

return new BulkUpdateResult(updatedKeys.size());
}

private boolean updateSingleKey(
Connection connection,
Key key,
Collection<SubDocumentUpdate> keyUpdates,
String tableName,
String quotedPkColumn,
long keyUpdateTimestamp)
throws IOException, SQLException {

updateValidator.validate(keyUpdates);
Map<String, String> resolvedColumns = resolvePathsToColumns(keyUpdates, tableName);

return executeKeyUpdate(
connection,
key,
keyUpdates,
tableName,
quotedPkColumn,
resolvedColumns,
keyUpdateTimestamp);
return new BulkUpdateResult(totalUpdated);
}

private boolean executeKeyUpdate(
Expand Down Expand Up @@ -972,6 +948,178 @@ private boolean executeKeyUpdate(
}
}

/**
* Groups keys that have identical update operations together. Keys with the same "shape" can be
* updated in a single multi-row statement.
*/
private Map<String, KeyUpdateGroup> groupKeysByUpdateShape(
Map<Key, Collection<SubDocumentUpdate>> updates, String tableName) {

Map<String, KeyUpdateGroup> groups = new LinkedHashMap<>();

for (Map.Entry<Key, Collection<SubDocumentUpdate>> entry : updates.entrySet()) {
Key key = entry.getKey();
Collection<SubDocumentUpdate> keyUpdates = entry.getValue();

if (keyUpdates == null || keyUpdates.isEmpty()) {
continue;
}

try {
updateValidator.validate(keyUpdates);
Map<String, String> resolvedColumns = resolvePathsToColumns(keyUpdates, tableName);

String shapeKey = computeUpdateShapeKey(keyUpdates, resolvedColumns);

groups
.computeIfAbsent(shapeKey, k -> new KeyUpdateGroup(resolvedColumns))
.addKeyWithUpdates(key, keyUpdates);

} catch (Exception e) {
LOGGER.warn("Failed to group key {}: {}", key, e.getMessage());
}
}

return groups;
}

private String computeUpdateShapeKey(
Collection<SubDocumentUpdate> updates, Map<String, String> resolvedColumns) {

List<SubDocumentUpdate> sorted = new ArrayList<>(updates);
sorted.sort(Comparator.comparing(u -> u.getSubDocument().getPath()));

StringBuilder sb = new StringBuilder();
for (SubDocumentUpdate update : sorted) {
String path = update.getSubDocument().getPath();
String column = resolvedColumns.get(path);
sb.append(column)
.append(":")
.append(update.getOperator())
.append(":")
.append(path)
.append(";");
}

return sb.toString();
}

/**
* Executes a batch UPDATE for all keys in the group using JDBC batching. All keys in the group
* share the same SQL structure, so we can use a single PreparedStatement.
*/
private int executeBatchUpdate(
Connection connection,
KeyUpdateGroup keyGroup,
String tableName,
String quotedPkColumn,
long epochMillis)
throws SQLException {

List<Key> keys = keyGroup.getKeys();
List<Collection<SubDocumentUpdate>> allKeyUpdates = keyGroup.getKeyUpdates();
Map<String, String> resolvedColumns = keyGroup.getResolvedColumns();

// Use the first key's updates to build the SQL template
Collection<SubDocumentUpdate> templateUpdates = allKeyUpdates.get(0);
List<String> setFragments = new ArrayList<>();
List<Object> templateParams = new ArrayList<>();

boolean hasUpdates =
buildSetClauseFragments(
connection, templateUpdates, tableName, resolvedColumns, setFragments, templateParams);

if (!hasUpdates) {
return 0;
}

appendLastUpdatedTimestamp(setFragments, templateParams, tableName, epochMillis);

// Build UPDATE SQL (same for all keys in this group)
String sql =
String.format(
"UPDATE %s SET %s WHERE %s = ?",
tableIdentifier, String.join(", ", setFragments), quotedPkColumn);

LOGGER.debug("Executing batch update SQL: {} for {} keys", sql, keys.size());

// Use JDBC batching to execute all updates in one round-trip
try (PreparedStatement ps = connection.prepareStatement(sql)) {
for (int i = 0; i < keys.size(); i++) {
Key key = keys.get(i);
Collection<SubDocumentUpdate> keyUpdates = allKeyUpdates.get(i);

// Build parameters for this specific key
List<String> keySetFragments = new ArrayList<>();
List<Object> keyParams = new ArrayList<>();
buildSetClauseFragments(
connection, keyUpdates, tableName, resolvedColumns, keySetFragments, keyParams);

// Add timestamp parameter
if (lastUpdatedTsColumn != null) {
Optional<PostgresColumnMetadata> colMeta =
schemaRegistry.getColumnOrRefresh(tableName, lastUpdatedTsColumn);
if (colMeta.isPresent()) {
Object timestampValue =
convertTimestampForType(epochMillis, colMeta.get().getPostgresType());
keyParams.add(timestampValue);
}
}

// Bind parameters for this key
int idx = 1;
for (Object param : keyParams) {
ps.setObject(idx++, param);
}
ps.setObject(idx, key.toString()); // WHERE clause parameter

ps.addBatch();
}

int[] results = ps.executeBatch();
int totalUpdated = 0;
for (int result : results) {
if (result > 0) {
totalUpdated++;
}
}

LOGGER.debug("Batch update affected {} rows out of {} keys", totalUpdated, keys.size());
return totalUpdated;
} catch (SQLException e) {
LOGGER.warn("Failed to execute batch update. SQL: {}, Error: {}", sql, e.getMessage());
throw e;
}
}

/** Holds a group of keys that share the same update shape. */
private static class KeyUpdateGroup {
private final Map<String, String> resolvedColumns;
private final List<Key> keys = new ArrayList<>();
private final List<Collection<SubDocumentUpdate>> keyUpdates = new ArrayList<>();

KeyUpdateGroup(Map<String, String> resolvedColumns) {
this.resolvedColumns = resolvedColumns;
}

void addKeyWithUpdates(Key key, Collection<SubDocumentUpdate> updates) {
keys.add(key);
keyUpdates.add(updates);
}

Map<String, String> getResolvedColumns() {
return resolvedColumns;
}

List<Key> getKeys() {
return keys;
}

List<Collection<SubDocumentUpdate>> getKeyUpdates() {
return keyUpdates;
}
}

/**
* Validates all updates and resolves column names.
*
Expand Down
Loading