From 8457c949ca223513221be30cf78bd08467ec1dc7 Mon Sep 17 00:00:00 2001 From: Wey Gu Date: Wed, 6 Sep 2023 14:04:48 +0800 Subject: [PATCH] fix e2e test --- examples/spark_engine.ipynb | 3 +- ng_ai/nebula_writer.py | 3 ++ .../notebooks/spark_engine_e2e.ipynb | 46 +++++++++++++++++++ .../integration/spark_engine_cases/writer.py | 3 +- tests/integration/test_e2e_spark_engine.py | 1 - 5 files changed, 51 insertions(+), 5 deletions(-) diff --git a/examples/spark_engine.ipynb b/examples/spark_engine.ipynb index 1480eaa..1852875 100644 --- a/examples/spark_engine.ipynb +++ b/examples/spark_engine.ipynb @@ -628,7 +628,7 @@ "outputs": [], "source": [ "writer = NebulaWriter(\n", - " data=df_result, sink=\"nebulagraph_vertex\", config=config, engine=\"spark\"\n", + " data=df_result, sink=\"nebulagraph_edge\", config=config, engine=\"spark\"\n", ")\n", "\n", "# map column louvain into property cluster_id\n", @@ -636,7 +636,6 @@ "\n", "writer.set_options(\n", " space=\"basketballplayer\",\n", - " type=\"edge\",\n", " edge_type=\"jaccard_similarity\",\n", " src_id=\"srcId\",\n", " dst_id=\"dstId\",\n", diff --git a/ng_ai/nebula_writer.py b/ng_ai/nebula_writer.py index 6f07787..3559232 100644 --- a/ng_ai/nebula_writer.py +++ b/ng_ai/nebula_writer.py @@ -320,6 +320,9 @@ def set_options(self, **kwargs): # case switch based on sink if self.sink in SPARK_NEBULA_SINKS: + kwargs["type"] = ( + "vertex" if self.sink == "nebulagraph_vertex" else "edge" + ) self._set_options_with_nebula(**kwargs) elif self.sink in SPARK_FILE_SINKS: self._set_options_with_file(**kwargs) diff --git a/tests/integration/notebooks/spark_engine_e2e.ipynb b/tests/integration/notebooks/spark_engine_e2e.ipynb index 8cda942..819c556 100644 --- a/tests/integration/notebooks/spark_engine_e2e.ipynb +++ b/tests/integration/notebooks/spark_engine_e2e.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "id": "8a6e2678", "metadata": {}, @@ -42,6 +43,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "170f8c38", "metadata": {}, @@ -70,6 +72,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "6051622d", "metadata": {}, @@ -116,6 +119,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "id": "311b49c4", "metadata": {}, @@ -153,6 +157,7 @@ "properties = {\"lpa\": \"cluster_id\"}\n", "\n", "writer.set_options(\n", + " space=\"basketballplayer\",\n", " tag=\"label_propagation\",\n", " vid_field=\"_id\",\n", " properties=properties,\n", @@ -162,6 +167,47 @@ "# write back to NebulaGraph\n", "writer.write()" ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "cbe44f42", + "metadata": {}, + "source": [ + "## Test Writer insert mode in sink: `nebulagraph_edge`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "79068641", + "metadata": {}, + "outputs": [], + "source": [ + "df_result = df.algo.jaccard_similarity()\n", + "\n", + "writer = NebulaWriter(\n", + " data=df_result, sink=\"nebulagraph_edge\", config=config, engine=\"spark\"\n", + ")\n", + "\n", + "# map column louvain into property cluster_id\n", + "properties = {\"similarity\": \"similarity\"}\n", + "\n", + "writer.set_options(\n", + " space=\"basketballplayer\",\n", + " edge_type=\"jaccard_similarity\",\n", + " src_id=\"srcId\",\n", + " dst_id=\"dstId\",\n", + " src_id_policy=\"\",\n", + " dst_id_policy=\"\",\n", + " properties=properties,\n", + " batch_size=256,\n", + " write_mode=\"insert\",\n", + ")\n", + "\n", + "# write back to NebulaGraph\n", + "writer.write()" + ] } ], "metadata": { diff --git a/tests/integration/spark_engine_cases/writer.py b/tests/integration/spark_engine_cases/writer.py index 7983f87..94e06b0 100644 --- a/tests/integration/spark_engine_cases/writer.py +++ b/tests/integration/spark_engine_cases/writer.py @@ -43,7 +43,7 @@ df_result = df.algo.jaccard() writer = NebulaWriter( - data=df_result, sink="nebulagraph_vertex", config=config, engine="spark" + data=df_result, sink="nebulagraph_edge", config=config, engine="spark" ) # map column louvain into property cluster_id @@ -51,7 +51,6 @@ writer.set_options( space="basketballplayer", - type="edge", edge_type="jaccard_similarity", src_id="srcId", dst_id="dstId", diff --git a/tests/integration/test_e2e_spark_engine.py b/tests/integration/test_e2e_spark_engine.py index a13c861..fdbd6ee 100644 --- a/tests/integration/test_e2e_spark_engine.py +++ b/tests/integration/test_e2e_spark_engine.py @@ -107,7 +107,6 @@ def test_label_propagation_spark_engine_writer(): "MATCH (v:player) RETURN v.label_propagation.cluster_id LIMIT 1" ) print(result) - connection_pool.close() assert result.is_succeeded(), f"ERROR during query NebulaGraph: {result}" assert (