Skip to content

chore: implment write insert logic #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ build

# Intellij
.idea
.cq
34 changes: 26 additions & 8 deletions lib/src/main/java/io/cloudquery/helper/ArrowHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import static java.util.Arrays.asList;

import com.google.protobuf.ByteString;
import io.cloudquery.scalar.ValidationException;
import io.cloudquery.schema.Column;
import io.cloudquery.schema.Resource;
import io.cloudquery.schema.Table;
Expand Down Expand Up @@ -221,14 +222,12 @@ public static Table fromArrowSchema(Schema schema) {
String constraintName = metaData.get(CQ_EXTENSION_CONSTRAINT_NAME);

TableBuilder tableBuilder =
Table.builder().name(name).constraintName(constraintName).columns(columns);

if (title != null) {
tableBuilder.title(title);
}
if (description != null) {
tableBuilder.description(description);
}
Table.builder()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍

.name(name)
.constraintName(constraintName)
.columns(columns)
.title(title)
.description(description);
if (parent != null) {
tableBuilder.parent(Table.builder().name(parent).build());
}
Expand Down Expand Up @@ -260,4 +259,23 @@ public static ByteString encode(Resource resource) throws IOException {
}
}
}

public static Resource decodeResource(ByteString byteString)
throws IOException, ValidationException {
try (BufferAllocator bufferAllocator = new RootAllocator()) {
try (ArrowStreamReader reader =
new ArrowStreamReader(byteString.newInput(), bufferAllocator)) {
VectorSchemaRoot vectorSchemaRoot = reader.getVectorSchemaRoot();
reader.loadNextBatch();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we would need to loop on loadNextBatch until it returns false, see https://arrow.apache.org/docs/java/reference/org/apache/arrow/vector/ipc/ArrowStreamReader.html#loadNextBatch--

Though maybe not an issue at the moment since we always send a single record batch?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the method signature implies only a single item Resource decodeResource(ByteString byteString), do you think I should generalise this to a list?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep it as is for now and change it if we run into issues

Resource resource =
Resource.builder().table(fromArrowSchema(vectorSchemaRoot.getSchema())).build();
for (int i = 0; i < vectorSchemaRoot.getSchema().getFields().size(); i++) {
FieldVector vector = vectorSchemaRoot.getVector(i);
// TODO: We currently only support a single row
resource.set(vector.getName(), vector.getObject(0));
}
return resource;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import com.google.protobuf.ByteString;
import io.cloudquery.helper.ArrowHelper;
import io.cloudquery.messages.WriteInsert;
import io.cloudquery.messages.WriteMessage;
import io.cloudquery.messages.WriteMigrateTable;
import io.cloudquery.plugin.BackendOptions;
import io.cloudquery.plugin.NewClientOptions;
import io.cloudquery.plugin.Plugin;
import io.cloudquery.plugin.v3.PluginGrpc.PluginImplBase;
import io.cloudquery.plugin.v3.Write;
import io.cloudquery.plugin.v3.Write.MessageMigrateTable;
import io.cloudquery.scalar.ValidationException;
import io.cloudquery.schema.Table;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
Expand Down Expand Up @@ -120,8 +123,10 @@ public void onNext(Write.Request request) {
try {
if (messageCase == Write.Request.MessageCase.MIGRATE_TABLE) {
plugin.write(processMigrateTableRequest(request));
} else if (messageCase == Write.Request.MessageCase.INSERT) {
plugin.write(processInsertRequest(request));
}
} catch (IOException ex) {
} catch (IOException | ValidationException ex) {
onError(ex);
}
}
Expand Down Expand Up @@ -154,4 +159,11 @@ private WriteMigrateTable processMigrateTableRequest(Write.Request request) thro
boolean migrateForce = request.getMigrateTable().getMigrateForce();
return new WriteMigrateTable(ArrowHelper.decode(byteString), migrateForce);
}

private WriteMessage processInsertRequest(Write.Request request)
throws IOException, ValidationException {
Write.MessageInsert insert = request.getInsert();
ByteString record = insert.getRecord();
return new WriteInsert(ArrowHelper.decodeResource(record));
}
}
76 changes: 56 additions & 20 deletions lib/src/main/java/io/cloudquery/memdb/MemDBClient.java
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
package io.cloudquery.memdb;

import io.cloudquery.messages.WriteInsert;
import io.cloudquery.messages.WriteMessage;
import io.cloudquery.messages.WriteMigrateTable;
import io.cloudquery.schema.ClientMeta;
import io.cloudquery.schema.Resource;
import io.cloudquery.schema.Table;
import io.cloudquery.schema.TableColumnChange;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.arrow.vector.VectorSchemaRoot;

public class MemDBClient implements ClientMeta {
private static final String id = "memdb";

private ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
private Map<String, Table> tables = new HashMap<>();
private Map<String, List<VectorSchemaRoot>> memDB = new HashMap<>();
private Map<String, List<Resource>> memDB = new HashMap<>();

public MemDBClient() {}

Expand All @@ -28,34 +29,69 @@ public String getId() {

@Override
public void write(WriteMessage message) {
if (message instanceof WriteMigrateTable migrateTable) {
migrate(migrateTable);
lock.writeLock().lock();
try {
if (message instanceof WriteMigrateTable migrateTable) {
migrate(migrateTable);
}
if (message instanceof WriteInsert insert) {
insert(insert);
}
} finally {
lock.writeLock().unlock();
}
}

public void close() {
// do nothing
private void insert(WriteInsert insert) {
String tableName = insert.getResource().getTable().getName();
Table table = tables.get(tableName);
overwrite(table, insert.getResource());
}

private void migrate(WriteMigrateTable migrateTable) {
lock.writeLock().lock();
try {
Table table = migrateTable.getTable();
String tableName = table.getName();
if (!memDB.containsKey(tableName)) {
memDB.put(tableName, new ArrayList<>());
tables.put(tableName, table);
return;
}
private void overwrite(Table table, Resource resource) {
String tableName = table.getName();
List<Integer> pkIndexes = table.primaryKeyIndexes();
if (pkIndexes.isEmpty()) {
memDB.get(tableName).add(resource);
return;
}

List<TableColumnChange> changes = table.getChanges(tables.get(tableName));
if (changes.isEmpty()) {
for (int i = 0; i < memDB.get(tableName).size(); i++) {
boolean found = true;
for (int pkIndex : pkIndexes) {
String s1 = resource.getTable().getColumns().get(pkIndex).getName();
String s2 = memDB.get(tableName).get(i).getTable().getColumns().get(pkIndex).getName();
if (!s1.equals(s2)) {
found = false;
}
}
if (found) {
memDB.get(tableName).remove(i);
memDB.get(tableName).add(resource);
return;
}
}
memDB.get(tableName).add(resource);
}

public void close() {
// do nothing
}

private void migrate(WriteMigrateTable migrateTable) {
Table table = migrateTable.getTable();
String tableName = table.getName();
if (!memDB.containsKey(tableName)) {
memDB.put(tableName, new ArrayList<>());
tables.put(tableName, table);
} finally {
lock.writeLock().unlock();
return;
}

List<TableColumnChange> changes = table.getChanges(tables.get(tableName));
if (changes.isEmpty()) {
return;
}
memDB.put(tableName, new ArrayList<>());
tables.put(tableName, table);
}
}
11 changes: 11 additions & 0 deletions lib/src/main/java/io/cloudquery/messages/WriteInsert.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package io.cloudquery.messages;

import io.cloudquery.schema.Resource;
import lombok.AllArgsConstructor;
import lombok.Getter;

@AllArgsConstructor
@Getter
public class WriteInsert extends WriteMessage {
private Resource resource;
}
3 changes: 3 additions & 0 deletions lib/src/main/java/io/cloudquery/scalar/Scalar.java
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ public static Scalar<?> fromArrowType(ArrowType arrowType) {
case Duration -> {
return new Duration();
}
case List -> {
return new JSON();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now we will map arrow types of List to the JSON extension type.

}
}

if (arrowType instanceof ArrowType.ExtensionType extensionType) {
Expand Down
9 changes: 9 additions & 0 deletions lib/src/main/java/io/cloudquery/schema/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import static io.cloudquery.schema.TableColumnChangeType.ADD;
import static io.cloudquery.schema.TableColumnChangeType.REMOVE;
import static io.cloudquery.schema.TableColumnChangeType.UPDATE;
import static java.util.stream.Collectors.toList;

import io.cloudquery.glob.Glob;
import io.cloudquery.schema.Column.ColumnBuilder;
Expand All @@ -14,6 +15,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.IntStream;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
Expand Down Expand Up @@ -173,6 +175,13 @@ public List<String> primaryKeys() {
return columns.stream().filter(Column::isPrimaryKey).map(Column::getName).toList();
}

public List<Integer> primaryKeyIndexes() {
return IntStream.range(0, columns.size())
.filter(i -> columns.get(i).isPrimaryKey())
.boxed()
.collect(toList());
}

private Optional<Table> filterDfs(
boolean parentMatched,
Predicate<Table> include,
Expand Down
2 changes: 2 additions & 0 deletions lib/src/main/java/io/cloudquery/server/PluginServe.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.cloudquery.server;

import io.cloudquery.plugin.Plugin;
import io.cloudquery.types.Extensions;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.NonNull;
Expand All @@ -12,6 +13,7 @@ public class PluginServe {
@Builder.Default private String[] args = new String[] {};

public int Serve() {
Extensions.registerExtensions();
return new CommandLine(new RootCommand()).addSubcommand(new ServeCommand(plugin)).execute(args);
}
}
12 changes: 12 additions & 0 deletions lib/src/main/java/io/cloudquery/types/Extensions.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package io.cloudquery.types;

import org.apache.arrow.vector.types.pojo.ExtensionTypeRegistry;

public class Extensions {
public static void registerExtensions() {
ExtensionTypeRegistry.register(new UUIDType());
ExtensionTypeRegistry.register(new JSONType());
}

private Extensions() {}
}
5 changes: 3 additions & 2 deletions lib/src/main/java/io/cloudquery/types/UUIDType.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package io.cloudquery.types;

import static org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType;

import java.nio.ByteBuffer;
import java.util.UUID;
import org.apache.arrow.memory.BufferAllocator;
Expand Down Expand Up @@ -61,6 +59,9 @@ public UUIDVector(String name, BufferAllocator allocator, FixedSizeBinaryVector

@Override
public Object getObject(int index) {
if (getUnderlyingVector().isSet(index) == 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

return null;
}
final ByteBuffer bb = ByteBuffer.wrap(getUnderlyingVector().getObject(index));
return new UUID(bb.getLong(), bb.getLong());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,28 @@
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;

import com.google.protobuf.ByteString;
import io.cloudquery.helper.ArrowHelper;
import io.cloudquery.messages.WriteInsert;
import io.cloudquery.messages.WriteMigrateTable;
import io.cloudquery.plugin.Plugin;
import io.cloudquery.plugin.v3.PluginGrpc;
import io.cloudquery.plugin.v3.PluginGrpc.PluginStub;
import io.cloudquery.plugin.v3.Write;
import io.cloudquery.plugin.v3.Write.MessageInsert;
import io.cloudquery.scalar.ValidationException;
import io.cloudquery.schema.Column;
import io.cloudquery.schema.Resource;
import io.cloudquery.schema.Table;
import io.grpc.Server;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcCleanupRule;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.junit.Rule;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -47,7 +55,7 @@ public void setUp() throws IOException {
}

@Test
public void shouldSendWriteMigrateTableMessage() throws IOException, InterruptedException {
public void shouldSendWriteMigrateTableMessage() throws Exception {
NullResponseStream<Write.Response> responseObserver = new NullResponseStream<>();

StreamObserver<Write.Request> writeService = pluginStub.write(responseObserver);
Expand All @@ -58,6 +66,18 @@ public void shouldSendWriteMigrateTableMessage() throws IOException, Interrupted
verify(plugin).write(any(WriteMigrateTable.class));
}

@Test
public void shouldSendWriteInsertMessage() throws Exception {
NullResponseStream<Write.Response> responseObserver = new NullResponseStream<>();

StreamObserver<Write.Request> writeService = pluginStub.write(responseObserver);
writeService.onNext(generateInsertMessage());
writeService.onCompleted();
responseObserver.await();

verify(plugin).write(any(WriteInsert.class));
}

private static Write.Request generateMigrateTableMessage() throws IOException {
Table table = Table.builder().name("test").build();
return Write.Request.newBuilder()
Expand All @@ -66,6 +86,16 @@ private static Write.Request generateMigrateTableMessage() throws IOException {
.build();
}

private Write.Request generateInsertMessage() throws IOException, ValidationException {
Column column = Column.builder().name("test_column").type(ArrowType.Utf8.INSTANCE).build();
Table table = Table.builder().name("test").columns(List.of(column)).build();
Resource resource = Resource.builder().table(table).build();
resource.set("test_column", "test_data");
ByteString byteString = ArrowHelper.encode(resource);
MessageInsert messageInsert = MessageInsert.newBuilder().setRecord(byteString).build();
return Write.Request.newBuilder().setInsert(messageInsert).build();
}

private static class NullResponseStream<T> implements StreamObserver<T> {
private final CountDownLatch countDownLatch = new CountDownLatch(1);

Expand Down