From 196867df14e45e609ba665011fc57db7aaa87343 Mon Sep 17 00:00:00 2001 From: Martin Norbury Date: Tue, 22 Aug 2023 14:24:46 +0100 Subject: [PATCH] chore: adding supporting decode for write logic refs: #85 --- .../io/cloudquery/helper/ArrowHelper.java | 52 ++++++++++- .../servers/plugin/v3/PluginServer.java | 26 +++++- .../main/java/io/cloudquery/memdb/MemDB.java | 5 +- .../java/io/cloudquery/memdb/MemDBClient.java | 43 +++++++++ .../io/cloudquery/messages/WriteMessage.java | 3 + .../messages/WriteMigrateTable.java | 12 +++ .../java/io/cloudquery/plugin/Plugin.java | 3 +- .../java/io/cloudquery/schema/ClientMeta.java | 4 + .../io/cloudquery/helper/ArrowHelperTest.java | 85 ++++++++++++++++++ .../servers/plugin/v3/PluginServerTest.java | 87 +++++++++++++++++++ 10 files changed, 310 insertions(+), 10 deletions(-) create mode 100644 lib/src/main/java/io/cloudquery/messages/WriteMessage.java create mode 100644 lib/src/main/java/io/cloudquery/messages/WriteMigrateTable.java create mode 100644 lib/src/test/java/io/cloudquery/helper/ArrowHelperTest.java create mode 100644 lib/src/test/java/io/cloudquery/internal/servers/plugin/v3/PluginServerTest.java diff --git a/lib/src/main/java/io/cloudquery/helper/ArrowHelper.java b/lib/src/main/java/io/cloudquery/helper/ArrowHelper.java index f3bc626..f68aec2 100644 --- a/lib/src/main/java/io/cloudquery/helper/ArrowHelper.java +++ b/lib/src/main/java/io/cloudquery/helper/ArrowHelper.java @@ -5,20 +5,29 @@ import com.google.protobuf.ByteString; import io.cloudquery.schema.Column; import io.cloudquery.schema.Table; +import io.cloudquery.schema.Table.TableBuilder; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.channels.Channels; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.ipc.ArrowStreamWriter; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; public class ArrowHelper { + public static final String CQ_TABLE_NAME = "cq:table_name"; + public static final String CQ_TABLE_TITLE = "cq:table_title"; + public static final String CQ_TABLE_DESCRIPTION = "cq:table_description"; + public static final String CQ_TABLE_DEPENDS_ON = "cq:table_depends_on"; + public static ByteString encode(Table table) throws IOException { try (BufferAllocator bufferAllocator = new RootAllocator()) { Schema schema = toArrowSchema(table); @@ -34,6 +43,15 @@ public static ByteString encode(Table table) throws IOException { } } + public static Table decode(ByteString byteString) throws IOException { + try (BufferAllocator bufferAllocator = new RootAllocator()) { + try (ArrowReader reader = new ArrowStreamReader(byteString.newInput(), bufferAllocator)) { + VectorSchemaRoot vectorSchemaRoot = reader.getVectorSchemaRoot(); + return fromArrowSchema(vectorSchemaRoot.getSchema()); + } + } + } + public static Schema toArrowSchema(Table table) { List columns = table.getColumns(); Field[] fields = new Field[columns.size()]; @@ -43,16 +61,42 @@ public static Schema toArrowSchema(Table table) { fields[i] = field; } Map metadata = new HashMap<>(); - metadata.put("cq:table_name", table.getName()); + metadata.put(CQ_TABLE_NAME, table.getName()); if (table.getTitle() != null) { - metadata.put("cq:table_title", table.getTitle()); + metadata.put(CQ_TABLE_TITLE, table.getTitle()); } if (table.getDescription() != null) { - metadata.put("cq:table_description", table.getDescription()); + metadata.put(CQ_TABLE_DESCRIPTION, table.getDescription()); } if (table.getParent() != null) { - metadata.put("cq:table_depends_on", table.getParent().getName()); + metadata.put(CQ_TABLE_DEPENDS_ON, table.getParent().getName()); } return new Schema(asList(fields), metadata); } + + public static Table fromArrowSchema(Schema schema) { + List columns = new ArrayList<>(); + for (Field field : schema.getFields()) { + columns.add(Column.builder().name(field.getName()).type(field.getType()).build()); + } + + Map metaData = schema.getCustomMetadata(); + String name = metaData.get(CQ_TABLE_NAME); + String title = metaData.get(CQ_TABLE_TITLE); + String description = metaData.get(CQ_TABLE_DESCRIPTION); + String parent = metaData.get(CQ_TABLE_DEPENDS_ON); + + TableBuilder tableBuilder = Table.builder().name(name).columns(columns); + if (title != null) { + tableBuilder.title(title); + } + if (description != null) { + tableBuilder.description(description); + } + if (parent != null) { + tableBuilder.parent(Table.builder().name(parent).build()); + } + + return tableBuilder.build(); + } } diff --git a/lib/src/main/java/io/cloudquery/internal/servers/plugin/v3/PluginServer.java b/lib/src/main/java/io/cloudquery/internal/servers/plugin/v3/PluginServer.java index 3a1ed6c..0439fb5 100644 --- a/lib/src/main/java/io/cloudquery/internal/servers/plugin/v3/PluginServer.java +++ b/lib/src/main/java/io/cloudquery/internal/servers/plugin/v3/PluginServer.java @@ -2,13 +2,16 @@ import com.google.protobuf.ByteString; import io.cloudquery.helper.ArrowHelper; +import io.cloudquery.messages.WriteMigrateTable; import io.cloudquery.plugin.BackendOptions; import io.cloudquery.plugin.NewClientOptions; import io.cloudquery.plugin.Plugin; import io.cloudquery.plugin.v3.PluginGrpc.PluginImplBase; import io.cloudquery.plugin.v3.Write; +import io.cloudquery.plugin.v3.Write.MessageMigrateTable; import io.cloudquery.schema.Table; import io.grpc.stub.StreamObserver; +import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -107,13 +110,23 @@ public void read( @Override public StreamObserver write(StreamObserver responseObserver) { - plugin.write(); return new StreamObserver<>() { @Override - public void onNext(Write.Request request) {} + public void onNext(Write.Request request) { + Write.Request.MessageCase messageCase = request.getMessageCase(); + try { + if (messageCase == Write.Request.MessageCase.MIGRATE_TABLE) { + plugin.write(processMigrateTableRequest(request)); + } + } catch (IOException ex) { + onError(ex); + } + } @Override - public void onError(Throwable t) {} + public void onError(Throwable t) { + responseObserver.onError(t); + } @Override public void onCompleted() { @@ -131,4 +144,11 @@ public void close( responseObserver.onNext(io.cloudquery.plugin.v3.Close.Response.newBuilder().build()); responseObserver.onCompleted(); } + + private WriteMigrateTable processMigrateTableRequest(Write.Request request) throws IOException { + MessageMigrateTable migrateTable = request.getMigrateTable(); + ByteString byteString = migrateTable.getTable(); + boolean migrateForce = request.getMigrateTable().getMigrateForce(); + return new WriteMigrateTable(ArrowHelper.decode(byteString), migrateForce); + } } diff --git a/lib/src/main/java/io/cloudquery/memdb/MemDB.java b/lib/src/main/java/io/cloudquery/memdb/MemDB.java index dbb0348..593cf83 100644 --- a/lib/src/main/java/io/cloudquery/memdb/MemDB.java +++ b/lib/src/main/java/io/cloudquery/memdb/MemDB.java @@ -1,5 +1,6 @@ package io.cloudquery.memdb; +import io.cloudquery.messages.WriteMessage; import io.cloudquery.plugin.BackendOptions; import io.cloudquery.plugin.ClientNotInitializedException; import io.cloudquery.plugin.NewClientOptions; @@ -93,8 +94,8 @@ public void read() { } @Override - public void write() { - throw new UnsupportedOperationException("Unimplemented method 'Write'"); + public void write(WriteMessage message) { + client.write(message); } @Override diff --git a/lib/src/main/java/io/cloudquery/memdb/MemDBClient.java b/lib/src/main/java/io/cloudquery/memdb/MemDBClient.java index 9a640b5..f5e5283 100644 --- a/lib/src/main/java/io/cloudquery/memdb/MemDBClient.java +++ b/lib/src/main/java/io/cloudquery/memdb/MemDBClient.java @@ -1,10 +1,24 @@ package io.cloudquery.memdb; +import io.cloudquery.messages.WriteMessage; +import io.cloudquery.messages.WriteMigrateTable; import io.cloudquery.schema.ClientMeta; +import io.cloudquery.schema.Table; +import io.cloudquery.schema.TableColumnChange; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import org.apache.arrow.vector.VectorSchemaRoot; public class MemDBClient implements ClientMeta { private static final String id = "memdb"; + private ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); + private Map tables = new HashMap<>(); + private Map> memDB = new HashMap<>(); + public MemDBClient() {} @Override @@ -12,7 +26,36 @@ public String getId() { return id; } + @Override + public void write(WriteMessage message) { + if (message instanceof WriteMigrateTable migrateTable) { + migrate(migrateTable); + } + } + public void close() { // do nothing } + + private void migrate(WriteMigrateTable migrateTable) { + lock.writeLock().lock(); + try { + Table table = migrateTable.getTable(); + String tableName = table.getName(); + if (!memDB.containsKey(tableName)) { + memDB.put(tableName, new ArrayList<>()); + tables.put(tableName, table); + return; + } + + List changes = table.getChanges(tables.get(tableName)); + if (changes.isEmpty()) { + return; + } + memDB.put(tableName, new ArrayList<>()); + tables.put(tableName, table); + } finally { + lock.writeLock().unlock(); + } + } } diff --git a/lib/src/main/java/io/cloudquery/messages/WriteMessage.java b/lib/src/main/java/io/cloudquery/messages/WriteMessage.java new file mode 100644 index 0000000..81e6129 --- /dev/null +++ b/lib/src/main/java/io/cloudquery/messages/WriteMessage.java @@ -0,0 +1,3 @@ +package io.cloudquery.messages; + +public abstract class WriteMessage {} diff --git a/lib/src/main/java/io/cloudquery/messages/WriteMigrateTable.java b/lib/src/main/java/io/cloudquery/messages/WriteMigrateTable.java new file mode 100644 index 0000000..bf8123b --- /dev/null +++ b/lib/src/main/java/io/cloudquery/messages/WriteMigrateTable.java @@ -0,0 +1,12 @@ +package io.cloudquery.messages; + +import io.cloudquery.schema.Table; +import lombok.AllArgsConstructor; +import lombok.Getter; + +@AllArgsConstructor +@Getter +public class WriteMigrateTable extends WriteMessage { + private Table table; + private boolean migrateForce; +} diff --git a/lib/src/main/java/io/cloudquery/plugin/Plugin.java b/lib/src/main/java/io/cloudquery/plugin/Plugin.java index f04623d..5c02701 100644 --- a/lib/src/main/java/io/cloudquery/plugin/Plugin.java +++ b/lib/src/main/java/io/cloudquery/plugin/Plugin.java @@ -1,5 +1,6 @@ package io.cloudquery.plugin; +import io.cloudquery.messages.WriteMessage; import io.cloudquery.schema.ClientMeta; import io.cloudquery.schema.SchemaException; import io.cloudquery.schema.Table; @@ -40,7 +41,7 @@ public abstract void sync( public abstract void read(); - public abstract void write(); + public abstract void write(WriteMessage message); public abstract void close(); } diff --git a/lib/src/main/java/io/cloudquery/schema/ClientMeta.java b/lib/src/main/java/io/cloudquery/schema/ClientMeta.java index cf0b64a..b0e382a 100644 --- a/lib/src/main/java/io/cloudquery/schema/ClientMeta.java +++ b/lib/src/main/java/io/cloudquery/schema/ClientMeta.java @@ -1,5 +1,9 @@ package io.cloudquery.schema; +import io.cloudquery.messages.WriteMessage; + public interface ClientMeta { String getId(); + + void write(WriteMessage message); } diff --git a/lib/src/test/java/io/cloudquery/helper/ArrowHelperTest.java b/lib/src/test/java/io/cloudquery/helper/ArrowHelperTest.java new file mode 100644 index 0000000..cdb041f --- /dev/null +++ b/lib/src/test/java/io/cloudquery/helper/ArrowHelperTest.java @@ -0,0 +1,85 @@ +package io.cloudquery.helper; + +import static io.cloudquery.helper.ArrowHelper.CQ_TABLE_DEPENDS_ON; +import static io.cloudquery.helper.ArrowHelper.CQ_TABLE_DESCRIPTION; +import static io.cloudquery.helper.ArrowHelper.CQ_TABLE_NAME; +import static io.cloudquery.helper.ArrowHelper.CQ_TABLE_TITLE; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import com.google.protobuf.ByteString; +import io.cloudquery.schema.Column; +import io.cloudquery.schema.Table; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.jupiter.api.Test; + +public class ArrowHelperTest { + + public static final Table TEST_TABLE = + Table.builder() + .name("table1") + .description("A simple test table") + .title("Test table title") + .parent(Table.builder().name("parent").build()) + .columns( + List.of( + Column.builder().name("column1").type(ArrowType.Utf8.INSTANCE).build(), + Column.builder().name("column2").type(ArrowType.Utf8.INSTANCE).build())) + .build(); + + @Test + public void testToArrowSchema() { + Schema arrowSchema = ArrowHelper.toArrowSchema(TEST_TABLE); + + assertEquals(arrowSchema.getFields().get(0).getName(), "column1"); + assertEquals(arrowSchema.getFields().get(1).getName(), "column2"); + + assertEquals( + arrowSchema.getCustomMetadata(), + Map.of( + CQ_TABLE_NAME, "table1", + CQ_TABLE_DESCRIPTION, "A simple test table", + CQ_TABLE_TITLE, "Test table title", + CQ_TABLE_DEPENDS_ON, "parent")); + } + + @Test + public void testFromArrowSchema() { + List fields = + List.of( + Field.nullable("column1", ArrowType.Utf8.INSTANCE), + Field.nullable("column2", ArrowType.Utf8.INSTANCE)); + + Schema schema = new Schema(fields, Map.of(CQ_TABLE_NAME, "table1")); + + Table table = ArrowHelper.fromArrowSchema(schema); + + assertEquals(table.getName(), "table1"); + + for (int i = 0; i < table.getColumns().size(); i++) { + Column column = table.getColumns().get(i); + assertEquals(column.getName(), fields.get(i).getName()); + assertEquals(column.getType(), fields.get(i).getType()); + } + } + + @Test + public void testRoundTrip() throws IOException { + ByteString byteString = ArrowHelper.encode(TEST_TABLE); + Table table = ArrowHelper.decode(byteString); + + assertEquals(table.getName(), TEST_TABLE.getName()); + assertEquals(table.getDescription(), TEST_TABLE.getDescription()); + assertEquals(table.getTitle(), TEST_TABLE.getTitle()); + assertEquals(table.getParent().getName(), TEST_TABLE.getParent().getName()); + + for (int i = 0; i < TEST_TABLE.getColumns().size(); i++) { + assertEquals(TEST_TABLE.getColumns().get(i).getName(), table.getColumns().get(i).getName()); + assertEquals(TEST_TABLE.getColumns().get(i).getType(), table.getColumns().get(i).getType()); + } + } +} diff --git a/lib/src/test/java/io/cloudquery/internal/servers/plugin/v3/PluginServerTest.java b/lib/src/test/java/io/cloudquery/internal/servers/plugin/v3/PluginServerTest.java new file mode 100644 index 0000000..fdaad00 --- /dev/null +++ b/lib/src/test/java/io/cloudquery/internal/servers/plugin/v3/PluginServerTest.java @@ -0,0 +1,87 @@ +package io.cloudquery.internal.servers.plugin.v3; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; + +import io.cloudquery.helper.ArrowHelper; +import io.cloudquery.messages.WriteMigrateTable; +import io.cloudquery.plugin.Plugin; +import io.cloudquery.plugin.v3.PluginGrpc; +import io.cloudquery.plugin.v3.PluginGrpc.PluginStub; +import io.cloudquery.plugin.v3.Write; +import io.cloudquery.schema.Table; +import io.grpc.Server; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.GrpcCleanupRule; +import java.io.IOException; +import java.util.concurrent.CountDownLatch; +import org.junit.Rule; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +@ExtendWith(MockitoExtension.class) +public class PluginServerTest { + + @Mock private Plugin plugin; + + @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); + + private PluginStub pluginStub; + + @BeforeEach + public void setUp() throws IOException { + PluginServer pluginServer = new PluginServer(plugin); + + String generatedName = InProcessServerBuilder.generateName(); + Server server = InProcessServerBuilder.forName(generatedName).addService(pluginServer).build(); + server.start(); + + InProcessChannelBuilder inProcessChannelBuilder = + InProcessChannelBuilder.forName(generatedName).directExecutor(); + pluginStub = PluginGrpc.newStub(grpcCleanupRule.register(inProcessChannelBuilder.build())); + } + + @Test + public void shouldSendWriteMigrateTableMessage() throws IOException, InterruptedException { + NullResponseStream responseObserver = new NullResponseStream<>(); + + StreamObserver writeService = pluginStub.write(responseObserver); + writeService.onNext(generateMigrateTableMessage()); + writeService.onCompleted(); + responseObserver.await(); + + verify(plugin).write(any(WriteMigrateTable.class)); + } + + private static Write.Request generateMigrateTableMessage() throws IOException { + Table table = Table.builder().name("test").build(); + return Write.Request.newBuilder() + .setMigrateTable( + Write.MessageMigrateTable.newBuilder().setTable(ArrowHelper.encode(table)).build()) + .build(); + } + + private static class NullResponseStream implements StreamObserver { + private final CountDownLatch countDownLatch = new CountDownLatch(1); + + @Override + public void onNext(T value) {} + + @Override + public void onError(Throwable t) {} + + @Override + public void onCompleted() { + countDownLatch.countDown(); + } + + public void await() throws InterruptedException { + countDownLatch.await(); + } + } +}