Skip to content

Commit 131a772

Browse files
Merge pull request #34 from project-codeflare/singleton_bug
2 parents 53f5542 + b625129 commit 131a772

File tree

3 files changed

+55
-3
lines changed

3 files changed

+55
-3
lines changed

codeflare/pipelines/Datamodel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,9 @@ def get_post_edges(self, node: Node):
837837
:return: Outgoing edges for the node
838838
"""
839839
post_edges = []
840-
post_nodes = self.__post_graph__[node]
840+
post_nodes = []
841+
if node in self.__post_graph__.keys():
842+
post_nodes = self.__post_graph__[node]
841843
# Empty post
842844
if not post_nodes:
843845
post_edges.append(Edge(node, None))

codeflare/pipelines/Runtime.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ class ExecutionType(Enum):
5858
"""
5959
FIT = 0,
6060
PREDICT = 1,
61-
SCORE = 2
61+
SCORE = 2,
62+
TRANSFORM = 3
6263

6364

6465
@ray.remote
@@ -140,7 +141,10 @@ def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref:
140141
res_Xref = ray.put(estimator.transform(X))
141142
result = dm.XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, prev_node_ptr, [xy_ref])
142143
return result
143-
144+
elif mode == ExecutionType.TRANSFORM:
145+
res_Xref = ray.put(estimator.fit_transform(X))
146+
result = dm.XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, prev_node_ptr, [xy_ref])
147+
return result
144148

145149
def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType, is_outputNode):
146150
"""
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import pytest
2+
import ray
3+
import pandas as pd
4+
import numpy as np
5+
import sklearn.base as base
6+
from sklearn.preprocessing import MinMaxScaler
7+
import codeflare.pipelines.Datamodel as dm
8+
import codeflare.pipelines.Runtime as rt
9+
from codeflare.pipelines.Datamodel import Xy
10+
from codeflare.pipelines.Datamodel import XYRef
11+
from codeflare.pipelines.Runtime import ExecutionType
12+
13+
def test_singleton():
14+
15+
ray.shutdown()
16+
ray.init()
17+
18+
## prepare the data
19+
X = np.random.randint(0,100,size=(10000, 4))
20+
y = np.random.randint(0,2,size=(10000, 1))
21+
22+
## initialize codeflare pipeline by first creating the nodes
23+
pipeline = dm.Pipeline()
24+
node_a = dm.EstimatorNode('a', MinMaxScaler())
25+
pipeline.add_node(node_a)
26+
27+
pipeline_input = dm.PipelineInput()
28+
xy = dm.Xy(X,y)
29+
pipeline_input.add_xy_arg(node_a, xy)
30+
31+
## execute the codeflare pipeline
32+
pipeline_output = rt.execute_pipeline(pipeline, ExecutionType.TRANSFORM, pipeline_input)
33+
34+
## retrieve node e
35+
node_a_output = pipeline_output.get_xyrefs(node_a)
36+
Xout = ray.get(node_a_output[0].get_Xref())
37+
yout = ray.get(node_a_output[0].get_yref())
38+
39+
assert Xout.shape[0] == 10000
40+
assert yout.shape[0] == 10000
41+
42+
ray.shutdown()
43+
44+
if __name__ == "__main__":
45+
sys.exit(pytest.main(["-v", __file__]))
46+

0 commit comments

Comments
 (0)