Skip to content

Commit b46b17e

Browse files
committed
chore: implment write insert logic
refs: #85
1 parent 0a470b7 commit b46b17e

File tree

11 files changed

+169
-32
lines changed

11 files changed

+169
-32
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ build
3131

3232
# Intellij
3333
.idea
34+
.cq

lib/src/main/java/io/cloudquery/helper/ArrowHelper.java

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import static java.util.Arrays.asList;
44

55
import com.google.protobuf.ByteString;
6+
import io.cloudquery.scalar.ValidationException;
67
import io.cloudquery.schema.Column;
78
import io.cloudquery.schema.Resource;
89
import io.cloudquery.schema.Table;
@@ -221,14 +222,12 @@ public static Table fromArrowSchema(Schema schema) {
221222
String constraintName = metaData.get(CQ_EXTENSION_CONSTRAINT_NAME);
222223

223224
TableBuilder tableBuilder =
224-
Table.builder().name(name).constraintName(constraintName).columns(columns);
225-
226-
if (title != null) {
227-
tableBuilder.title(title);
228-
}
229-
if (description != null) {
230-
tableBuilder.description(description);
231-
}
225+
Table.builder()
226+
.name(name)
227+
.constraintName(constraintName)
228+
.columns(columns)
229+
.title(title)
230+
.description(description);
232231
if (parent != null) {
233232
tableBuilder.parent(Table.builder().name(parent).build());
234233
}
@@ -260,4 +259,25 @@ public static ByteString encode(Resource resource) throws IOException {
260259
}
261260
}
262261
}
262+
263+
public static Resource decodeResource(ByteString byteString)
264+
throws IOException, ValidationException {
265+
try (BufferAllocator bufferAllocator = new RootAllocator()) {
266+
try (ArrowStreamReader reader =
267+
new ArrowStreamReader(byteString.newInput(), bufferAllocator)) {
268+
VectorSchemaRoot vectorSchemaRoot = reader.getVectorSchemaRoot();
269+
reader.loadNextBatch();
270+
Schema schema = vectorSchemaRoot.getSchema();
271+
Table table = fromArrowSchema(schema);
272+
Resource resource = Resource.builder().table(table).build();
273+
for (int i = 0; i < vectorSchemaRoot.getSchema().getFields().size(); i++) {
274+
FieldVector vector = vectorSchemaRoot.getVector(i);
275+
String name = vector.getName();
276+
Object object = vector.getObject(0);
277+
resource.set(name, object);
278+
}
279+
return resource;
280+
}
281+
}
282+
}
263283
}

lib/src/main/java/io/cloudquery/internal/servers/plugin/v3/PluginServer.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22

33
import com.google.protobuf.ByteString;
44
import io.cloudquery.helper.ArrowHelper;
5+
import io.cloudquery.messages.WriteInsert;
6+
import io.cloudquery.messages.WriteMessage;
57
import io.cloudquery.messages.WriteMigrateTable;
68
import io.cloudquery.plugin.BackendOptions;
79
import io.cloudquery.plugin.NewClientOptions;
810
import io.cloudquery.plugin.Plugin;
911
import io.cloudquery.plugin.v3.PluginGrpc.PluginImplBase;
1012
import io.cloudquery.plugin.v3.Write;
1113
import io.cloudquery.plugin.v3.Write.MessageMigrateTable;
14+
import io.cloudquery.scalar.ValidationException;
1215
import io.cloudquery.schema.Table;
1316
import io.grpc.stub.StreamObserver;
1417
import java.io.IOException;
@@ -120,8 +123,10 @@ public void onNext(Write.Request request) {
120123
try {
121124
if (messageCase == Write.Request.MessageCase.MIGRATE_TABLE) {
122125
plugin.write(processMigrateTableRequest(request));
126+
} else if (messageCase == Write.Request.MessageCase.INSERT) {
127+
plugin.write(processInsertRequest(request));
123128
}
124-
} catch (IOException ex) {
129+
} catch (IOException | ValidationException ex) {
125130
onError(ex);
126131
}
127132
}
@@ -154,4 +159,11 @@ private WriteMigrateTable processMigrateTableRequest(Write.Request request) thro
154159
boolean migrateForce = request.getMigrateTable().getMigrateForce();
155160
return new WriteMigrateTable(ArrowHelper.decode(byteString), migrateForce);
156161
}
162+
163+
private WriteMessage processInsertRequest(Write.Request request)
164+
throws IOException, ValidationException {
165+
Write.MessageInsert insert = request.getInsert();
166+
ByteString record = insert.getRecord();
167+
return new WriteInsert(ArrowHelper.decodeResource(record));
168+
}
157169
}
Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
package io.cloudquery.memdb;
22

3+
import io.cloudquery.messages.WriteInsert;
34
import io.cloudquery.messages.WriteMessage;
45
import io.cloudquery.messages.WriteMigrateTable;
56
import io.cloudquery.schema.ClientMeta;
7+
import io.cloudquery.schema.Resource;
68
import io.cloudquery.schema.Table;
79
import io.cloudquery.schema.TableColumnChange;
810
import java.util.ArrayList;
911
import java.util.HashMap;
1012
import java.util.List;
1113
import java.util.Map;
1214
import java.util.concurrent.locks.ReentrantReadWriteLock;
13-
import org.apache.arrow.vector.VectorSchemaRoot;
1415

1516
public class MemDBClient implements ClientMeta {
1617
private static final String id = "memdb";
1718

1819
private ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
1920
private Map<String, Table> tables = new HashMap<>();
20-
private Map<String, List<VectorSchemaRoot>> memDB = new HashMap<>();
21+
private Map<String, List<Resource>> memDB = new HashMap<>();
2122

2223
public MemDBClient() {}
2324

@@ -28,34 +29,69 @@ public String getId() {
2829

2930
@Override
3031
public void write(WriteMessage message) {
31-
if (message instanceof WriteMigrateTable migrateTable) {
32-
migrate(migrateTable);
32+
lock.writeLock().lock();
33+
try {
34+
if (message instanceof WriteMigrateTable migrateTable) {
35+
migrate(migrateTable);
36+
}
37+
if (message instanceof WriteInsert insert) {
38+
insert(insert);
39+
}
40+
} finally {
41+
lock.writeLock().unlock();
3342
}
3443
}
3544

36-
public void close() {
37-
// do nothing
45+
private void insert(WriteInsert insert) {
46+
String tableName = insert.getResource().getTable().getName();
47+
Table table = tables.get(tableName);
48+
overwrite(table, insert.getResource());
3849
}
3950

40-
private void migrate(WriteMigrateTable migrateTable) {
41-
lock.writeLock().lock();
42-
try {
43-
Table table = migrateTable.getTable();
44-
String tableName = table.getName();
45-
if (!memDB.containsKey(tableName)) {
46-
memDB.put(tableName, new ArrayList<>());
47-
tables.put(tableName, table);
48-
return;
49-
}
51+
private void overwrite(Table table, Resource resource) {
52+
String tableName = table.getName();
53+
List<Integer> pkIndexes = table.primaryKeyIndexes();
54+
if (pkIndexes.isEmpty()) {
55+
memDB.get(tableName).add(resource);
56+
return;
57+
}
5058

51-
List<TableColumnChange> changes = table.getChanges(tables.get(tableName));
52-
if (changes.isEmpty()) {
59+
for (int i = 0; i < memDB.get(tableName).size(); i++) {
60+
boolean found = true;
61+
for (int pkIndex : pkIndexes) {
62+
String s1 = resource.getTable().getColumns().get(pkIndex).getName();
63+
String s2 = memDB.get(tableName).get(i).getTable().getColumns().get(pkIndex).getName();
64+
if (!s1.equals(s2)) {
65+
found = false;
66+
}
67+
}
68+
if (found) {
69+
memDB.get(tableName).remove(i);
70+
memDB.get(tableName).add(resource);
5371
return;
5472
}
73+
}
74+
memDB.get(tableName).add(resource);
75+
}
76+
77+
public void close() {
78+
// do nothing
79+
}
80+
81+
private void migrate(WriteMigrateTable migrateTable) {
82+
Table table = migrateTable.getTable();
83+
String tableName = table.getName();
84+
if (!memDB.containsKey(tableName)) {
5585
memDB.put(tableName, new ArrayList<>());
5686
tables.put(tableName, table);
57-
} finally {
58-
lock.writeLock().unlock();
87+
return;
88+
}
89+
90+
List<TableColumnChange> changes = table.getChanges(tables.get(tableName));
91+
if (changes.isEmpty()) {
92+
return;
5993
}
94+
memDB.put(tableName, new ArrayList<>());
95+
tables.put(tableName, table);
6096
}
6197
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package io.cloudquery.messages;
2+
3+
import io.cloudquery.schema.Resource;
4+
import lombok.AllArgsConstructor;
5+
import lombok.Getter;
6+
7+
@AllArgsConstructor
8+
@Getter
9+
public class WriteInsert extends WriteMessage {
10+
private Resource resource;
11+
}

lib/src/main/java/io/cloudquery/scalar/Scalar.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ public static Scalar<?> fromArrowType(ArrowType arrowType) {
105105
case Duration -> {
106106
return new Duration();
107107
}
108+
case List -> {
109+
return new JSON();
110+
}
108111
}
109112

110113
if (arrowType instanceof ArrowType.ExtensionType extensionType) {

lib/src/main/java/io/cloudquery/schema/Table.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import static io.cloudquery.schema.TableColumnChangeType.ADD;
44
import static io.cloudquery.schema.TableColumnChangeType.REMOVE;
55
import static io.cloudquery.schema.TableColumnChangeType.UPDATE;
6+
import static java.util.stream.Collectors.toList;
67

78
import io.cloudquery.glob.Glob;
89
import io.cloudquery.schema.Column.ColumnBuilder;
@@ -14,6 +15,7 @@
1415
import java.util.Map;
1516
import java.util.Optional;
1617
import java.util.function.Predicate;
18+
import java.util.stream.IntStream;
1719
import lombok.Builder;
1820
import lombok.Getter;
1921
import lombok.NonNull;
@@ -173,6 +175,13 @@ public List<String> primaryKeys() {
173175
return columns.stream().filter(Column::isPrimaryKey).map(Column::getName).toList();
174176
}
175177

178+
public List<Integer> primaryKeyIndexes() {
179+
return IntStream.range(0, columns.size())
180+
.filter(i -> columns.get(i).isPrimaryKey())
181+
.boxed()
182+
.collect(toList());
183+
}
184+
176185
private Optional<Table> filterDfs(
177186
boolean parentMatched,
178187
Predicate<Table> include,

lib/src/main/java/io/cloudquery/server/PluginServe.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package io.cloudquery.server;
22

33
import io.cloudquery.plugin.Plugin;
4+
import io.cloudquery.types.Extensions;
45
import lombok.AccessLevel;
56
import lombok.Builder;
67
import lombok.NonNull;
@@ -12,6 +13,7 @@ public class PluginServe {
1213
@Builder.Default private String[] args = new String[] {};
1314

1415
public int Serve() {
16+
Extensions.registerExtensions();
1517
return new CommandLine(new RootCommand()).addSubcommand(new ServeCommand(plugin)).execute(args);
1618
}
1719
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package io.cloudquery.types;
2+
3+
import org.apache.arrow.vector.types.pojo.ExtensionTypeRegistry;
4+
5+
public class Extensions {
6+
public static void registerExtensions() {
7+
ExtensionTypeRegistry.register(new UUIDType());
8+
ExtensionTypeRegistry.register(new JSONType());
9+
}
10+
11+
private Extensions() {}
12+
}

lib/src/main/java/io/cloudquery/types/UUIDType.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
package io.cloudquery.types;
22

3-
import static org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType;
4-
53
import java.nio.ByteBuffer;
64
import java.util.UUID;
75
import org.apache.arrow.memory.BufferAllocator;
@@ -61,6 +59,9 @@ public UUIDVector(String name, BufferAllocator allocator, FixedSizeBinaryVector
6159

6260
@Override
6361
public Object getObject(int index) {
62+
if (getUnderlyingVector().isSet(index) == 0) {
63+
return null;
64+
}
6465
final ByteBuffer bb = ByteBuffer.wrap(getUnderlyingVector().getObject(index));
6566
return new UUID(bb.getLong(), bb.getLong());
6667
}

lib/src/test/java/io/cloudquery/internal/servers/plugin/v3/PluginServerTest.java

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,28 @@
33
import static org.mockito.ArgumentMatchers.any;
44
import static org.mockito.Mockito.verify;
55

6+
import com.google.protobuf.ByteString;
67
import io.cloudquery.helper.ArrowHelper;
8+
import io.cloudquery.messages.WriteInsert;
79
import io.cloudquery.messages.WriteMigrateTable;
810
import io.cloudquery.plugin.Plugin;
911
import io.cloudquery.plugin.v3.PluginGrpc;
1012
import io.cloudquery.plugin.v3.PluginGrpc.PluginStub;
1113
import io.cloudquery.plugin.v3.Write;
14+
import io.cloudquery.plugin.v3.Write.MessageInsert;
15+
import io.cloudquery.scalar.ValidationException;
16+
import io.cloudquery.schema.Column;
17+
import io.cloudquery.schema.Resource;
1218
import io.cloudquery.schema.Table;
1319
import io.grpc.Server;
1420
import io.grpc.inprocess.InProcessChannelBuilder;
1521
import io.grpc.inprocess.InProcessServerBuilder;
1622
import io.grpc.stub.StreamObserver;
1723
import io.grpc.testing.GrpcCleanupRule;
1824
import java.io.IOException;
25+
import java.util.List;
1926
import java.util.concurrent.CountDownLatch;
27+
import org.apache.arrow.vector.types.pojo.ArrowType;
2028
import org.junit.Rule;
2129
import org.junit.jupiter.api.BeforeEach;
2230
import org.junit.jupiter.api.Test;
@@ -47,7 +55,7 @@ public void setUp() throws IOException {
4755
}
4856

4957
@Test
50-
public void shouldSendWriteMigrateTableMessage() throws IOException, InterruptedException {
58+
public void shouldSendWriteMigrateTableMessage() throws Exception {
5159
NullResponseStream<Write.Response> responseObserver = new NullResponseStream<>();
5260

5361
StreamObserver<Write.Request> writeService = pluginStub.write(responseObserver);
@@ -58,6 +66,18 @@ public void shouldSendWriteMigrateTableMessage() throws IOException, Interrupted
5866
verify(plugin).write(any(WriteMigrateTable.class));
5967
}
6068

69+
@Test
70+
public void shouldSendWriteInsertMessage() throws Exception {
71+
NullResponseStream<Write.Response> responseObserver = new NullResponseStream<>();
72+
73+
StreamObserver<Write.Request> writeService = pluginStub.write(responseObserver);
74+
writeService.onNext(generateInsertMessage());
75+
writeService.onCompleted();
76+
responseObserver.await();
77+
78+
verify(plugin).write(any(WriteInsert.class));
79+
}
80+
6181
private static Write.Request generateMigrateTableMessage() throws IOException {
6282
Table table = Table.builder().name("test").build();
6383
return Write.Request.newBuilder()
@@ -66,6 +86,16 @@ private static Write.Request generateMigrateTableMessage() throws IOException {
6686
.build();
6787
}
6888

89+
private Write.Request generateInsertMessage() throws IOException, ValidationException {
90+
Column column = Column.builder().name("test_column").type(ArrowType.Utf8.INSTANCE).build();
91+
Table table = Table.builder().name("test").columns(List.of(column)).build();
92+
Resource resource = Resource.builder().table(table).build();
93+
resource.set("test_column", "test_data");
94+
ByteString byteString = ArrowHelper.encode(resource);
95+
MessageInsert messageInsert = MessageInsert.newBuilder().setRecord(byteString).build();
96+
return Write.Request.newBuilder().setInsert(messageInsert).build();
97+
}
98+
6999
private static class NullResponseStream<T> implements StreamObserver<T> {
70100
private final CountDownLatch countDownLatch = new CountDownLatch(1);
71101

0 commit comments

Comments
 (0)