Skip to content

Commit

Permalink
fix e2e test
Browse files Browse the repository at this point in the history
  • Loading branch information
wey-gu committed Sep 6, 2023
1 parent 2977ec0 commit 8457c94
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 5 deletions.
3 changes: 1 addition & 2 deletions examples/spark_engine.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -628,15 +628,14 @@
"outputs": [],
"source": [
"writer = NebulaWriter(\n",
" data=df_result, sink=\"nebulagraph_vertex\", config=config, engine=\"spark\"\n",
" data=df_result, sink=\"nebulagraph_edge\", config=config, engine=\"spark\"\n",
")\n",
"\n",
"# map column louvain into property cluster_id\n",
"properties = {\"similarity\": \"similarity\"}\n",
"\n",
"writer.set_options(\n",
" space=\"basketballplayer\",\n",
" type=\"edge\",\n",
" edge_type=\"jaccard_similarity\",\n",
" src_id=\"srcId\",\n",
" dst_id=\"dstId\",\n",
Expand Down
3 changes: 3 additions & 0 deletions ng_ai/nebula_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,9 @@ def set_options(self, **kwargs):

# case switch based on sink
if self.sink in SPARK_NEBULA_SINKS:
kwargs["type"] = (
"vertex" if self.sink == "nebulagraph_vertex" else "edge"
)
self._set_options_with_nebula(**kwargs)
elif self.sink in SPARK_FILE_SINKS:
self._set_options_with_file(**kwargs)
Expand Down
46 changes: 46 additions & 0 deletions tests/integration/notebooks/spark_engine_e2e.ipynb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "8a6e2678",
"metadata": {},
Expand Down Expand Up @@ -42,6 +43,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "170f8c38",
"metadata": {},
Expand Down Expand Up @@ -70,6 +72,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "6051622d",
"metadata": {},
Expand Down Expand Up @@ -116,6 +119,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "311b49c4",
"metadata": {},
Expand Down Expand Up @@ -153,6 +157,7 @@
"properties = {\"lpa\": \"cluster_id\"}\n",
"\n",
"writer.set_options(\n",
" space=\"basketballplayer\",\n",
" tag=\"label_propagation\",\n",
" vid_field=\"_id\",\n",
" properties=properties,\n",
Expand All @@ -162,6 +167,47 @@
"# write back to NebulaGraph\n",
"writer.write()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "cbe44f42",
"metadata": {},
"source": [
"## Test Writer insert mode in sink: `nebulagraph_edge`"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "79068641",
"metadata": {},
"outputs": [],
"source": [
"df_result = df.algo.jaccard_similarity()\n",
"\n",
"writer = NebulaWriter(\n",
" data=df_result, sink=\"nebulagraph_edge\", config=config, engine=\"spark\"\n",
")\n",
"\n",
"# map column louvain into property cluster_id\n",
"properties = {\"similarity\": \"similarity\"}\n",
"\n",
"writer.set_options(\n",
" space=\"basketballplayer\",\n",
" edge_type=\"jaccard_similarity\",\n",
" src_id=\"srcId\",\n",
" dst_id=\"dstId\",\n",
" src_id_policy=\"\",\n",
" dst_id_policy=\"\",\n",
" properties=properties,\n",
" batch_size=256,\n",
" write_mode=\"insert\",\n",
")\n",
"\n",
"# write back to NebulaGraph\n",
"writer.write()"
]
}
],
"metadata": {
Expand Down
3 changes: 1 addition & 2 deletions tests/integration/spark_engine_cases/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,14 @@
df_result = df.algo.jaccard()

writer = NebulaWriter(
data=df_result, sink="nebulagraph_vertex", config=config, engine="spark"
data=df_result, sink="nebulagraph_edge", config=config, engine="spark"
)

# map column louvain into property cluster_id
properties = {"similarity": "similarity"}

writer.set_options(
space="basketballplayer",
type="edge",
edge_type="jaccard_similarity",
src_id="srcId",
dst_id="dstId",
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_e2e_spark_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def test_label_propagation_spark_engine_writer():
"MATCH (v:player) RETURN v.label_propagation.cluster_id LIMIT 1"
)
print(result)
connection_pool.close()

assert result.is_succeeded(), f"ERROR during query NebulaGraph: {result}"
assert (
Expand Down

0 comments on commit 8457c94

Please sign in to comment.