Skip to content

chore: adding supporting decode for write logic #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 48 additions & 4 deletions lib/src/main/java/io/cloudquery/helper/ArrowHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,29 @@
import com.google.protobuf.ByteString;
import io.cloudquery.schema.Column;
import io.cloudquery.schema.Table;
import io.cloudquery.schema.Table.TableBuilder;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
import org.apache.arrow.vector.ipc.ArrowStreamReader;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;

public class ArrowHelper {
public static final String CQ_TABLE_NAME = "cq:table_name";
public static final String CQ_TABLE_TITLE = "cq:table_title";
public static final String CQ_TABLE_DESCRIPTION = "cq:table_description";
public static final String CQ_TABLE_DEPENDS_ON = "cq:table_depends_on";

public static ByteString encode(Table table) throws IOException {
try (BufferAllocator bufferAllocator = new RootAllocator()) {
Schema schema = toArrowSchema(table);
Expand All @@ -34,6 +43,15 @@ public static ByteString encode(Table table) throws IOException {
}
}

public static Table decode(ByteString byteString) throws IOException {
try (BufferAllocator bufferAllocator = new RootAllocator()) {
try (ArrowReader reader = new ArrowStreamReader(byteString.newInput(), bufferAllocator)) {
VectorSchemaRoot vectorSchemaRoot = reader.getVectorSchemaRoot();
return fromArrowSchema(vectorSchemaRoot.getSchema());
}
}
}

public static Schema toArrowSchema(Table table) {
List<Column> columns = table.getColumns();
Field[] fields = new Field[columns.size()];
Expand All @@ -43,16 +61,42 @@ public static Schema toArrowSchema(Table table) {
fields[i] = field;
}
Map<String, String> metadata = new HashMap<>();
metadata.put("cq:table_name", table.getName());
metadata.put(CQ_TABLE_NAME, table.getName());
if (table.getTitle() != null) {
metadata.put("cq:table_title", table.getTitle());
metadata.put(CQ_TABLE_TITLE, table.getTitle());
}
if (table.getDescription() != null) {
metadata.put("cq:table_description", table.getDescription());
metadata.put(CQ_TABLE_DESCRIPTION, table.getDescription());
}
if (table.getParent() != null) {
metadata.put("cq:table_depends_on", table.getParent().getName());
metadata.put(CQ_TABLE_DEPENDS_ON, table.getParent().getName());
}
return new Schema(asList(fields), metadata);
}

public static Table fromArrowSchema(Schema schema) {
List<Column> columns = new ArrayList<>();
for (Field field : schema.getFields()) {
columns.add(Column.builder().name(field.getName()).type(field.getType()).build());
}

Map<String, String> metaData = schema.getCustomMetadata();
String name = metaData.get(CQ_TABLE_NAME);
String title = metaData.get(CQ_TABLE_TITLE);
String description = metaData.get(CQ_TABLE_DESCRIPTION);
String parent = metaData.get(CQ_TABLE_DEPENDS_ON);

TableBuilder tableBuilder = Table.builder().name(name).columns(columns);
if (title != null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[non blocking] I think we can skip these null checks and pass everything to builder. If the values are null it won't change anything on the resulted table, WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that makes sense. I'll modify as part of the next PR.

tableBuilder.title(title);
}
if (description != null) {
tableBuilder.description(description);
}
if (parent != null) {
tableBuilder.parent(Table.builder().name(parent).build());
}

return tableBuilder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

import com.google.protobuf.ByteString;
import io.cloudquery.helper.ArrowHelper;
import io.cloudquery.messages.WriteMigrateTable;
import io.cloudquery.plugin.BackendOptions;
import io.cloudquery.plugin.NewClientOptions;
import io.cloudquery.plugin.Plugin;
import io.cloudquery.plugin.v3.PluginGrpc.PluginImplBase;
import io.cloudquery.plugin.v3.Write;
import io.cloudquery.plugin.v3.Write.MessageMigrateTable;
import io.cloudquery.schema.Table;
import io.grpc.stub.StreamObserver;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

Expand Down Expand Up @@ -107,13 +110,23 @@ public void read(

@Override
public StreamObserver<Write.Request> write(StreamObserver<Write.Response> responseObserver) {
plugin.write();
return new StreamObserver<>() {
@Override
public void onNext(Write.Request request) {}
public void onNext(Write.Request request) {
Write.Request.MessageCase messageCase = request.getMessageCase();
try {
if (messageCase == Write.Request.MessageCase.MIGRATE_TABLE) {
plugin.write(processMigrateTableRequest(request));
}
} catch (IOException ex) {
onError(ex);
}
}

@Override
public void onError(Throwable t) {}
public void onError(Throwable t) {
responseObserver.onError(t);
}

@Override
public void onCompleted() {
Expand All @@ -131,4 +144,11 @@ public void close(
responseObserver.onNext(io.cloudquery.plugin.v3.Close.Response.newBuilder().build());
responseObserver.onCompleted();
}

private WriteMigrateTable processMigrateTableRequest(Write.Request request) throws IOException {
MessageMigrateTable migrateTable = request.getMigrateTable();
ByteString byteString = migrateTable.getTable();
boolean migrateForce = request.getMigrateTable().getMigrateForce();
return new WriteMigrateTable(ArrowHelper.decode(byteString), migrateForce);
}
}
5 changes: 3 additions & 2 deletions lib/src/main/java/io/cloudquery/memdb/MemDB.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.cloudquery.memdb;

import io.cloudquery.messages.WriteMessage;
import io.cloudquery.plugin.BackendOptions;
import io.cloudquery.plugin.ClientNotInitializedException;
import io.cloudquery.plugin.NewClientOptions;
Expand Down Expand Up @@ -93,8 +94,8 @@ public void read() {
}

@Override
public void write() {
throw new UnsupportedOperationException("Unimplemented method 'Write'");
public void write(WriteMessage message) {
client.write(message);
}

@Override
Expand Down
43 changes: 43 additions & 0 deletions lib/src/main/java/io/cloudquery/memdb/MemDBClient.java
Original file line number Diff line number Diff line change
@@ -1,18 +1,61 @@
package io.cloudquery.memdb;

import io.cloudquery.messages.WriteMessage;
import io.cloudquery.messages.WriteMigrateTable;
import io.cloudquery.schema.ClientMeta;
import io.cloudquery.schema.Table;
import io.cloudquery.schema.TableColumnChange;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.apache.arrow.vector.VectorSchemaRoot;

public class MemDBClient implements ClientMeta {
private static final String id = "memdb";

private ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
private Map<String, Table> tables = new HashMap<>();
private Map<String, List<VectorSchemaRoot>> memDB = new HashMap<>();

public MemDBClient() {}

@Override
public String getId() {
return id;
}

@Override
public void write(WriteMessage message) {
if (message instanceof WriteMigrateTable migrateTable) {
migrate(migrateTable);
}
}

public void close() {
// do nothing
}

private void migrate(WriteMigrateTable migrateTable) {
lock.writeLock().lock();
try {
Table table = migrateTable.getTable();
String tableName = table.getName();
if (!memDB.containsKey(tableName)) {
memDB.put(tableName, new ArrayList<>());
tables.put(tableName, table);
return;
}

List<TableColumnChange> changes = table.getChanges(tables.get(tableName));
if (changes.isEmpty()) {
return;
}
memDB.put(tableName, new ArrayList<>());
tables.put(tableName, table);
} finally {
lock.writeLock().unlock();
}
}
}
3 changes: 3 additions & 0 deletions lib/src/main/java/io/cloudquery/messages/WriteMessage.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package io.cloudquery.messages;

public abstract class WriteMessage {}
12 changes: 12 additions & 0 deletions lib/src/main/java/io/cloudquery/messages/WriteMigrateTable.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package io.cloudquery.messages;

import io.cloudquery.schema.Table;
import lombok.AllArgsConstructor;
import lombok.Getter;

@AllArgsConstructor
@Getter
public class WriteMigrateTable extends WriteMessage {
private Table table;
private boolean migrateForce;
}
3 changes: 2 additions & 1 deletion lib/src/main/java/io/cloudquery/plugin/Plugin.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.cloudquery.plugin;

import io.cloudquery.messages.WriteMessage;
import io.cloudquery.schema.ClientMeta;
import io.cloudquery.schema.SchemaException;
import io.cloudquery.schema.Table;
Expand Down Expand Up @@ -40,7 +41,7 @@ public abstract void sync(

public abstract void read();

public abstract void write();
public abstract void write(WriteMessage message);

public abstract void close();
}
4 changes: 4 additions & 0 deletions lib/src/main/java/io/cloudquery/schema/ClientMeta.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package io.cloudquery.schema;

import io.cloudquery.messages.WriteMessage;

public interface ClientMeta {
String getId();

void write(WriteMessage message);
}
85 changes: 85 additions & 0 deletions lib/src/test/java/io/cloudquery/helper/ArrowHelperTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package io.cloudquery.helper;

import static io.cloudquery.helper.ArrowHelper.CQ_TABLE_DEPENDS_ON;
import static io.cloudquery.helper.ArrowHelper.CQ_TABLE_DESCRIPTION;
import static io.cloudquery.helper.ArrowHelper.CQ_TABLE_NAME;
import static io.cloudquery.helper.ArrowHelper.CQ_TABLE_TITLE;
import static org.junit.jupiter.api.Assertions.assertEquals;

import com.google.protobuf.ByteString;
import io.cloudquery.schema.Column;
import io.cloudquery.schema.Table;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.Test;

public class ArrowHelperTest {

public static final Table TEST_TABLE =
Table.builder()
.name("table1")
.description("A simple test table")
.title("Test table title")
.parent(Table.builder().name("parent").build())
.columns(
List.of(
Column.builder().name("column1").type(ArrowType.Utf8.INSTANCE).build(),
Column.builder().name("column2").type(ArrowType.Utf8.INSTANCE).build()))
.build();

@Test
public void testToArrowSchema() {
Schema arrowSchema = ArrowHelper.toArrowSchema(TEST_TABLE);

assertEquals(arrowSchema.getFields().get(0).getName(), "column1");
assertEquals(arrowSchema.getFields().get(1).getName(), "column2");

assertEquals(
arrowSchema.getCustomMetadata(),
Map.of(
CQ_TABLE_NAME, "table1",
CQ_TABLE_DESCRIPTION, "A simple test table",
CQ_TABLE_TITLE, "Test table title",
CQ_TABLE_DEPENDS_ON, "parent"));
}

@Test
public void testFromArrowSchema() {
List<Field> fields =
List.of(
Field.nullable("column1", ArrowType.Utf8.INSTANCE),
Field.nullable("column2", ArrowType.Utf8.INSTANCE));

Schema schema = new Schema(fields, Map.of(CQ_TABLE_NAME, "table1"));

Table table = ArrowHelper.fromArrowSchema(schema);

assertEquals(table.getName(), "table1");

for (int i = 0; i < table.getColumns().size(); i++) {
Column column = table.getColumns().get(i);
assertEquals(column.getName(), fields.get(i).getName());
assertEquals(column.getType(), fields.get(i).getType());
}
}

@Test
public void testRoundTrip() throws IOException {
ByteString byteString = ArrowHelper.encode(TEST_TABLE);
Table table = ArrowHelper.decode(byteString);

assertEquals(table.getName(), TEST_TABLE.getName());
assertEquals(table.getDescription(), TEST_TABLE.getDescription());
assertEquals(table.getTitle(), TEST_TABLE.getTitle());
assertEquals(table.getParent().getName(), TEST_TABLE.getParent().getName());

for (int i = 0; i < TEST_TABLE.getColumns().size(); i++) {
assertEquals(TEST_TABLE.getColumns().get(i).getName(), table.getColumns().get(i).getName());
assertEquals(TEST_TABLE.getColumns().get(i).getType(), table.getColumns().get(i).getType());
}
}
}
Loading