diff --git a/.gitignore b/.gitignore index 4c348d6..449ced5 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,4 @@ build # Intellij .idea +.cq diff --git a/lib/src/main/java/io/cloudquery/helper/ArrowHelper.java b/lib/src/main/java/io/cloudquery/helper/ArrowHelper.java index 3dc8ff7..4f0d882 100644 --- a/lib/src/main/java/io/cloudquery/helper/ArrowHelper.java +++ b/lib/src/main/java/io/cloudquery/helper/ArrowHelper.java @@ -3,6 +3,7 @@ import static java.util.Arrays.asList; import com.google.protobuf.ByteString; +import io.cloudquery.scalar.ValidationException; import io.cloudquery.schema.Column; import io.cloudquery.schema.Resource; import io.cloudquery.schema.Table; @@ -221,14 +222,12 @@ public static Table fromArrowSchema(Schema schema) { String constraintName = metaData.get(CQ_EXTENSION_CONSTRAINT_NAME); TableBuilder tableBuilder = - Table.builder().name(name).constraintName(constraintName).columns(columns); - - if (title != null) { - tableBuilder.title(title); - } - if (description != null) { - tableBuilder.description(description); - } + Table.builder() + .name(name) + .constraintName(constraintName) + .columns(columns) + .title(title) + .description(description); if (parent != null) { tableBuilder.parent(Table.builder().name(parent).build()); } @@ -260,4 +259,23 @@ public static ByteString encode(Resource resource) throws IOException { } } } + + public static Resource decodeResource(ByteString byteString) + throws IOException, ValidationException { + try (BufferAllocator bufferAllocator = new RootAllocator()) { + try (ArrowStreamReader reader = + new ArrowStreamReader(byteString.newInput(), bufferAllocator)) { + VectorSchemaRoot vectorSchemaRoot = reader.getVectorSchemaRoot(); + reader.loadNextBatch(); + Resource resource = + Resource.builder().table(fromArrowSchema(vectorSchemaRoot.getSchema())).build(); + for (int i = 0; i < vectorSchemaRoot.getSchema().getFields().size(); i++) { + FieldVector vector = vectorSchemaRoot.getVector(i); + // TODO: We currently only support a single row + resource.set(vector.getName(), vector.getObject(0)); + } + return resource; + } + } + } } diff --git a/lib/src/main/java/io/cloudquery/internal/servers/plugin/v3/PluginServer.java b/lib/src/main/java/io/cloudquery/internal/servers/plugin/v3/PluginServer.java index 43cb1a7..dff65f1 100644 --- a/lib/src/main/java/io/cloudquery/internal/servers/plugin/v3/PluginServer.java +++ b/lib/src/main/java/io/cloudquery/internal/servers/plugin/v3/PluginServer.java @@ -2,6 +2,8 @@ import com.google.protobuf.ByteString; import io.cloudquery.helper.ArrowHelper; +import io.cloudquery.messages.WriteInsert; +import io.cloudquery.messages.WriteMessage; import io.cloudquery.messages.WriteMigrateTable; import io.cloudquery.plugin.BackendOptions; import io.cloudquery.plugin.NewClientOptions; @@ -9,6 +11,7 @@ import io.cloudquery.plugin.v3.PluginGrpc.PluginImplBase; import io.cloudquery.plugin.v3.Write; import io.cloudquery.plugin.v3.Write.MessageMigrateTable; +import io.cloudquery.scalar.ValidationException; import io.cloudquery.schema.Table; import io.grpc.stub.StreamObserver; import java.io.IOException; @@ -120,8 +123,10 @@ public void onNext(Write.Request request) { try { if (messageCase == Write.Request.MessageCase.MIGRATE_TABLE) { plugin.write(processMigrateTableRequest(request)); + } else if (messageCase == Write.Request.MessageCase.INSERT) { + plugin.write(processInsertRequest(request)); } - } catch (IOException ex) { + } catch (IOException | ValidationException ex) { onError(ex); } } @@ -154,4 +159,11 @@ private WriteMigrateTable processMigrateTableRequest(Write.Request request) thro boolean migrateForce = request.getMigrateTable().getMigrateForce(); return new WriteMigrateTable(ArrowHelper.decode(byteString), migrateForce); } + + private WriteMessage processInsertRequest(Write.Request request) + throws IOException, ValidationException { + Write.MessageInsert insert = request.getInsert(); + ByteString record = insert.getRecord(); + return new WriteInsert(ArrowHelper.decodeResource(record)); + } } diff --git a/lib/src/main/java/io/cloudquery/memdb/MemDBClient.java b/lib/src/main/java/io/cloudquery/memdb/MemDBClient.java index f5e5283..064852e 100644 --- a/lib/src/main/java/io/cloudquery/memdb/MemDBClient.java +++ b/lib/src/main/java/io/cloudquery/memdb/MemDBClient.java @@ -1,8 +1,10 @@ package io.cloudquery.memdb; +import io.cloudquery.messages.WriteInsert; import io.cloudquery.messages.WriteMessage; import io.cloudquery.messages.WriteMigrateTable; import io.cloudquery.schema.ClientMeta; +import io.cloudquery.schema.Resource; import io.cloudquery.schema.Table; import io.cloudquery.schema.TableColumnChange; import java.util.ArrayList; @@ -10,14 +12,13 @@ import java.util.List; import java.util.Map; import java.util.concurrent.locks.ReentrantReadWriteLock; -import org.apache.arrow.vector.VectorSchemaRoot; public class MemDBClient implements ClientMeta { private static final String id = "memdb"; private ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); private Map tables = new HashMap<>(); - private Map> memDB = new HashMap<>(); + private Map> memDB = new HashMap<>(); public MemDBClient() {} @@ -28,34 +29,69 @@ public String getId() { @Override public void write(WriteMessage message) { - if (message instanceof WriteMigrateTable migrateTable) { - migrate(migrateTable); + lock.writeLock().lock(); + try { + if (message instanceof WriteMigrateTable migrateTable) { + migrate(migrateTable); + } + if (message instanceof WriteInsert insert) { + insert(insert); + } + } finally { + lock.writeLock().unlock(); } } - public void close() { - // do nothing + private void insert(WriteInsert insert) { + String tableName = insert.getResource().getTable().getName(); + Table table = tables.get(tableName); + overwrite(table, insert.getResource()); } - private void migrate(WriteMigrateTable migrateTable) { - lock.writeLock().lock(); - try { - Table table = migrateTable.getTable(); - String tableName = table.getName(); - if (!memDB.containsKey(tableName)) { - memDB.put(tableName, new ArrayList<>()); - tables.put(tableName, table); - return; - } + private void overwrite(Table table, Resource resource) { + String tableName = table.getName(); + List pkIndexes = table.primaryKeyIndexes(); + if (pkIndexes.isEmpty()) { + memDB.get(tableName).add(resource); + return; + } - List changes = table.getChanges(tables.get(tableName)); - if (changes.isEmpty()) { + for (int i = 0; i < memDB.get(tableName).size(); i++) { + boolean found = true; + for (int pkIndex : pkIndexes) { + String s1 = resource.getTable().getColumns().get(pkIndex).getName(); + String s2 = memDB.get(tableName).get(i).getTable().getColumns().get(pkIndex).getName(); + if (!s1.equals(s2)) { + found = false; + } + } + if (found) { + memDB.get(tableName).remove(i); + memDB.get(tableName).add(resource); return; } + } + memDB.get(tableName).add(resource); + } + + public void close() { + // do nothing + } + + private void migrate(WriteMigrateTable migrateTable) { + Table table = migrateTable.getTable(); + String tableName = table.getName(); + if (!memDB.containsKey(tableName)) { memDB.put(tableName, new ArrayList<>()); tables.put(tableName, table); - } finally { - lock.writeLock().unlock(); + return; + } + + List changes = table.getChanges(tables.get(tableName)); + if (changes.isEmpty()) { + return; } + memDB.put(tableName, new ArrayList<>()); + tables.put(tableName, table); } } diff --git a/lib/src/main/java/io/cloudquery/messages/WriteInsert.java b/lib/src/main/java/io/cloudquery/messages/WriteInsert.java new file mode 100644 index 0000000..58ed31a --- /dev/null +++ b/lib/src/main/java/io/cloudquery/messages/WriteInsert.java @@ -0,0 +1,11 @@ +package io.cloudquery.messages; + +import io.cloudquery.schema.Resource; +import lombok.AllArgsConstructor; +import lombok.Getter; + +@AllArgsConstructor +@Getter +public class WriteInsert extends WriteMessage { + private Resource resource; +} diff --git a/lib/src/main/java/io/cloudquery/scalar/Scalar.java b/lib/src/main/java/io/cloudquery/scalar/Scalar.java index c996079..c2680fd 100644 --- a/lib/src/main/java/io/cloudquery/scalar/Scalar.java +++ b/lib/src/main/java/io/cloudquery/scalar/Scalar.java @@ -105,6 +105,9 @@ public static Scalar fromArrowType(ArrowType arrowType) { case Duration -> { return new Duration(); } + case List -> { + return new JSON(); + } } if (arrowType instanceof ArrowType.ExtensionType extensionType) { diff --git a/lib/src/main/java/io/cloudquery/schema/Table.java b/lib/src/main/java/io/cloudquery/schema/Table.java index 6700784..7ee9185 100644 --- a/lib/src/main/java/io/cloudquery/schema/Table.java +++ b/lib/src/main/java/io/cloudquery/schema/Table.java @@ -3,6 +3,7 @@ import static io.cloudquery.schema.TableColumnChangeType.ADD; import static io.cloudquery.schema.TableColumnChangeType.REMOVE; import static io.cloudquery.schema.TableColumnChangeType.UPDATE; +import static java.util.stream.Collectors.toList; import io.cloudquery.glob.Glob; import io.cloudquery.schema.Column.ColumnBuilder; @@ -14,6 +15,7 @@ import java.util.Map; import java.util.Optional; import java.util.function.Predicate; +import java.util.stream.IntStream; import lombok.Builder; import lombok.Getter; import lombok.NonNull; @@ -173,6 +175,13 @@ public List primaryKeys() { return columns.stream().filter(Column::isPrimaryKey).map(Column::getName).toList(); } + public List primaryKeyIndexes() { + return IntStream.range(0, columns.size()) + .filter(i -> columns.get(i).isPrimaryKey()) + .boxed() + .collect(toList()); + } + private Optional filterDfs( boolean parentMatched, Predicate
include, diff --git a/lib/src/main/java/io/cloudquery/server/PluginServe.java b/lib/src/main/java/io/cloudquery/server/PluginServe.java index 3698af6..f308d92 100644 --- a/lib/src/main/java/io/cloudquery/server/PluginServe.java +++ b/lib/src/main/java/io/cloudquery/server/PluginServe.java @@ -1,6 +1,7 @@ package io.cloudquery.server; import io.cloudquery.plugin.Plugin; +import io.cloudquery.types.Extensions; import lombok.AccessLevel; import lombok.Builder; import lombok.NonNull; @@ -12,6 +13,7 @@ public class PluginServe { @Builder.Default private String[] args = new String[] {}; public int Serve() { + Extensions.registerExtensions(); return new CommandLine(new RootCommand()).addSubcommand(new ServeCommand(plugin)).execute(args); } } diff --git a/lib/src/main/java/io/cloudquery/types/Extensions.java b/lib/src/main/java/io/cloudquery/types/Extensions.java new file mode 100644 index 0000000..2c4c560 --- /dev/null +++ b/lib/src/main/java/io/cloudquery/types/Extensions.java @@ -0,0 +1,12 @@ +package io.cloudquery.types; + +import org.apache.arrow.vector.types.pojo.ExtensionTypeRegistry; + +public class Extensions { + public static void registerExtensions() { + ExtensionTypeRegistry.register(new UUIDType()); + ExtensionTypeRegistry.register(new JSONType()); + } + + private Extensions() {} +} diff --git a/lib/src/main/java/io/cloudquery/types/UUIDType.java b/lib/src/main/java/io/cloudquery/types/UUIDType.java index 6f7350d..cdc6b69 100644 --- a/lib/src/main/java/io/cloudquery/types/UUIDType.java +++ b/lib/src/main/java/io/cloudquery/types/UUIDType.java @@ -1,7 +1,5 @@ package io.cloudquery.types; -import static org.apache.arrow.vector.types.pojo.ArrowType.ExtensionType; - import java.nio.ByteBuffer; import java.util.UUID; import org.apache.arrow.memory.BufferAllocator; @@ -61,6 +59,9 @@ public UUIDVector(String name, BufferAllocator allocator, FixedSizeBinaryVector @Override public Object getObject(int index) { + if (getUnderlyingVector().isSet(index) == 0) { + return null; + } final ByteBuffer bb = ByteBuffer.wrap(getUnderlyingVector().getObject(index)); return new UUID(bb.getLong(), bb.getLong()); } diff --git a/lib/src/test/java/io/cloudquery/internal/servers/plugin/v3/PluginServerTest.java b/lib/src/test/java/io/cloudquery/internal/servers/plugin/v3/PluginServerTest.java index fdaad00..d42fef0 100644 --- a/lib/src/test/java/io/cloudquery/internal/servers/plugin/v3/PluginServerTest.java +++ b/lib/src/test/java/io/cloudquery/internal/servers/plugin/v3/PluginServerTest.java @@ -3,12 +3,18 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.verify; +import com.google.protobuf.ByteString; import io.cloudquery.helper.ArrowHelper; +import io.cloudquery.messages.WriteInsert; import io.cloudquery.messages.WriteMigrateTable; import io.cloudquery.plugin.Plugin; import io.cloudquery.plugin.v3.PluginGrpc; import io.cloudquery.plugin.v3.PluginGrpc.PluginStub; import io.cloudquery.plugin.v3.Write; +import io.cloudquery.plugin.v3.Write.MessageInsert; +import io.cloudquery.scalar.ValidationException; +import io.cloudquery.schema.Column; +import io.cloudquery.schema.Resource; import io.cloudquery.schema.Table; import io.grpc.Server; import io.grpc.inprocess.InProcessChannelBuilder; @@ -16,7 +22,9 @@ import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; import java.io.IOException; +import java.util.List; import java.util.concurrent.CountDownLatch; +import org.apache.arrow.vector.types.pojo.ArrowType; import org.junit.Rule; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -47,7 +55,7 @@ public void setUp() throws IOException { } @Test - public void shouldSendWriteMigrateTableMessage() throws IOException, InterruptedException { + public void shouldSendWriteMigrateTableMessage() throws Exception { NullResponseStream responseObserver = new NullResponseStream<>(); StreamObserver writeService = pluginStub.write(responseObserver); @@ -58,6 +66,18 @@ public void shouldSendWriteMigrateTableMessage() throws IOException, Interrupted verify(plugin).write(any(WriteMigrateTable.class)); } + @Test + public void shouldSendWriteInsertMessage() throws Exception { + NullResponseStream responseObserver = new NullResponseStream<>(); + + StreamObserver writeService = pluginStub.write(responseObserver); + writeService.onNext(generateInsertMessage()); + writeService.onCompleted(); + responseObserver.await(); + + verify(plugin).write(any(WriteInsert.class)); + } + private static Write.Request generateMigrateTableMessage() throws IOException { Table table = Table.builder().name("test").build(); return Write.Request.newBuilder() @@ -66,6 +86,16 @@ private static Write.Request generateMigrateTableMessage() throws IOException { .build(); } + private Write.Request generateInsertMessage() throws IOException, ValidationException { + Column column = Column.builder().name("test_column").type(ArrowType.Utf8.INSTANCE).build(); + Table table = Table.builder().name("test").columns(List.of(column)).build(); + Resource resource = Resource.builder().table(table).build(); + resource.set("test_column", "test_data"); + ByteString byteString = ArrowHelper.encode(resource); + MessageInsert messageInsert = MessageInsert.newBuilder().setRecord(byteString).build(); + return Write.Request.newBuilder().setInsert(messageInsert).build(); + } + private static class NullResponseStream implements StreamObserver { private final CountDownLatch countDownLatch = new CountDownLatch(1);