Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/fix oneof #63

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ dist/
.pytest_cache
/public
/build
.venv/
48 changes: 35 additions & 13 deletions ml_pipeline_engine/dag/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,15 @@ def _get_reduced_dag(
dest: NodeId,
is_recurrent: bool = False,
is_oneof: bool = False,
graph=None,
) -> DiGraph:
"""
Get filtered and connected subgraph
"""

if graph is None:
graph = self.dag.graph

def _filter(u: str, v: str) -> bool:
"""
Delete edges with EdgeField.case_branch from subgraph_view
Expand All @@ -251,7 +255,7 @@ def _filter(u: str, v: str) -> bool:
u - Node
v - Node Edge
"""
return not self.dag.graph.edges[u, v].get(EdgeField.case_branch)
return not graph.edges[u, v].get(EdgeField.case_branch)

def _filter_node(u: str) -> bool:
"""
Expand All @@ -260,10 +264,10 @@ def _filter_node(u: str) -> bool:
Args:
u - Node
"""
return not self.dag.graph.nodes[u].get(NodeField.is_oneof_child)
return not graph.nodes[u].get(NodeField.is_oneof_child)

return get_connected_subgraph(
dag=nx.subgraph_view(self.dag.graph, filter_edge=_filter, filter_node=_filter_node),
dag=nx.subgraph_view(graph, filter_edge=_filter, filter_node=_filter_node),
source=source,
dest=dest,
is_recurrent=is_recurrent,
Expand All @@ -278,10 +282,15 @@ def _get_reduced_dag_input_one_of(
"""
Get the subgraph for the OneOf subgraph
"""

return get_connected_subgraph(
nx.subgraph_view(self.dag.graph), source, dest, is_oneof=True,
)
# Get all ancestors of dest (nodes required to compute dest)
subdag_nodes = nx.ancestors(self.dag.graph, dest)
subdag_nodes.add(dest)
# Build subgraph including only these nodes
subdag = self.dag.graph.subgraph(subdag_nodes).copy()
# Return connected subgraph
attrs = {dest: {NodeField.is_oneof_child: False}}
nx.set_node_attributes(subdag, attrs)
return get_connected_subgraph(subdag, source, dest, is_oneof=True)

def _add_case_result(self, switch_node_id: NodeId) -> None:
"""
Expand Down Expand Up @@ -327,6 +336,9 @@ async def _execute_node(
try:
logger.info('Preparing node for the execution, node_id=%s', node_id)

if self._node_storage.exists_node_unhandled_error(node_id):
raise self._node_storage.get_node_unhandled_error(node_id)

result = await self.__execute_node(
node_id=node_id,
force_default=force_default,
Expand Down Expand Up @@ -544,7 +556,6 @@ async def _run_oneof(self, node_id: NodeId) -> t.Any:
"""
Run OneOf subgraph. Returns the first non-error result or specific error
"""

logger.debug('Prepare OneOf DAG node_id=%s', node_id)

for idx, subgraph_node_id in enumerate(self.dag.graph.nodes[node_id][NodeField.oneof_nodes]):
Expand All @@ -556,6 +567,9 @@ async def _run_oneof(self, node_id: NodeId) -> t.Any:

logger.debug('Prepare [%s]%s to start. OneOf result node %s', idx, oneof_dag, node_id)

oneof_dag = self._get_reduced_dag(source=self.dag.input_node,
dest=subgraph_node_id, graph=oneof_dag, is_oneof=True)

self._create_task(coro=self._run_dag(dag=oneof_dag), name=str(oneof_dag))

await self._lock_manager.wait_for_condition(
Expand All @@ -571,20 +585,28 @@ async def _run_oneof(self, node_id: NodeId) -> t.Any:

if not self.__has_subgraph_error(oneof_dag):

# The node_id is a synthetic node and cannot be executed anywhere. Hence, we should copy the
# result of the last successful subgraph and unlock everything related to the synthetic node.
# Copy the result of the successful subgraph to the synthetic node
self._node_storage.copy_node_result(subgraph_node_id, node_id)

# Unlock the synthetic node and its descendants
await self.__unlock_itself(node_id)
await self.__unlock_descendants(node_id)
await self.__unlock_run_method()

logger.debug('The %s has been succeeded', oneof_dag)
return

await self.__raise_exc(
OneOfDoesNotHaveResultError(node_id),
)
# All branches have errors
error = OneOfDoesNotHaveResultError(node_id)
self._node_storage.set_node_result(node_id, error)

# Unlock the synthetic node and its descendants
await self.__unlock_itself(node_id)
await self.__unlock_descendants(node_id)
await self.__unlock_run_method()

successors = tuple(self.dag.graph.successors(node_id))
self._node_storage.set_node_unhandled_error(successors[0], error)

async def _run_switch(self, dag: DiGraph, node_id: NodeId) -> t.Any:
"""
Expand Down
11 changes: 11 additions & 0 deletions ml_pipeline_engine/dag/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,21 @@ class DAGNodeStorage:
switch_results: HiddenDict = field(default_factory=HiddenDict)
recurrent_subgraph: HiddenDict = field(default_factory=HiddenDict)
waiting_list: HiddenDict = field(default_factory=HiddenDict)
node_unhandled_errors: HiddenDict = field(default_factory=HiddenDict)

def set_node_result(self, node_id: NodeId, data: t.Any) -> None:
self.node_results.set(node_id, data)

def set_node_unhandled_error(self, node_id: NodeId, data: t.Any) -> None:
self.node_unhandled_errors.set(node_id, data)

def get_node_unhandled_error(self, node_id: NodeId, with_hidden: bool = False) -> t.Any:
return self.node_unhandled_errors.get(node_id, with_hidden)

def exists_node_unhandled_error(self, node_id: NodeId, with_hidden: bool = False) -> bool:
unhandled_error = self.get_node_unhandled_error(node_id, with_hidden)
return isinstance(unhandled_error, Exception)

def get_node_result(self, node_id: NodeId, with_hidden: bool = False) -> t.Any:
return self.node_results.get(node_id, with_hidden)

Expand Down
2 changes: 1 addition & 1 deletion ml_pipeline_viewer/src/theme.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@ export const GlobalStyle = createGlobalStyle`
margin: 0;
min-width: 1200px;
}

`
Empty file added requirements.txt
Empty file.
68 changes: 68 additions & 0 deletions tests/dag/oneof/test_nested_input_one_first_level_failure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import typing as t

from ml_pipeline_engine.dag_builders.annotation.marks import Input
from ml_pipeline_engine.dag_builders.annotation.marks import InputOneOf
from ml_pipeline_engine.node import ProcessorBase
from ml_pipeline_engine.types import PipelineChartLike
from ml_pipeline_engine.node import ProcessorBase


class SomeInput(ProcessorBase):
name = 'input'

def process(self, base_num: int, other_num: int) -> dict:
return {
'base_num': base_num,
'other_num': other_num,
}

class SomeFeature0(ProcessorBase):
name = 'some_feature0'

async def process(self, ds_value: Input(SomeInput)) -> int:
return ds_value

class FirstDataSource(ProcessorBase):
name = 'some_data_source'

def process(self, _: Input(SomeInput), inp: Input(SomeFeature0)) -> int:
raise Exception

class SecondDataSource(ProcessorBase):
name = 'some_data_source_second'

def process(self, _: Input(SomeInput)) -> int:
return 2

class SomeFeature(ProcessorBase):
name = 'some_feature'

def process(self, ds_value: InputOneOf([FirstDataSource, SecondDataSource]), inp: Input(SomeInput)) -> int:
return ds_value

class SomeFeature2(ProcessorBase):
name = 'some_feature2'

async def process(self, ds_value: Input(SomeFeature)) -> int:
return ds_value

class FallbackFeature(ProcessorBase):
name = 'fallback_feature'

def process(self) -> int:
return 125

class SomeVectorizer(ProcessorBase):
name = 'vectorizer'

def process(self, feature_value: InputOneOf([SomeFeature, FallbackFeature])) -> int:
return feature_value + 20


async def test_nested_input_one_of_first_level_failure_dag(
build_chart: t.Callable[..., PipelineChartLike],
) -> None:
chart = build_chart(input_node=SomeInput, output_node=SomeVectorizer)
result = await chart.run(input_kwargs=dict(base_num=10, other_num=5))

assert result.value == 22
68 changes: 68 additions & 0 deletions tests/dag/oneof/test_nested_input_one_first_level_success.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import typing as t

from ml_pipeline_engine.dag_builders.annotation.marks import Input
from ml_pipeline_engine.dag_builders.annotation.marks import InputOneOf
from ml_pipeline_engine.node import ProcessorBase
from ml_pipeline_engine.types import PipelineChartLike
from ml_pipeline_engine.node import ProcessorBase


class SomeInput(ProcessorBase):
name = 'input'

def process(self, base_num: int, other_num: int) -> dict:
return {
'base_num': base_num,
'other_num': other_num,
}

class SomeFeature0(ProcessorBase):
name = 'some_feature0'

async def process(self, ds_value: Input(SomeInput)) -> int:
return ds_value

class FirstDataSource(ProcessorBase):
name = 'some_data_source'

def process(self, _: Input(SomeInput), inp: Input(SomeFeature0)) -> int:
return 1

class SecondDataSource(ProcessorBase):
name = 'some_data_source_second'

def process(self, _: Input(SomeInput)) -> int:
return 2

class SomeFeature(ProcessorBase):
name = 'some_feature'

def process(self, ds_value: InputOneOf([FirstDataSource, SecondDataSource]), inp: Input(SomeInput)) -> int:
return ds_value

class SomeFeature2(ProcessorBase):
name = 'some_feature2'

async def process(self, ds_value: Input(SomeFeature)) -> int:
return ds_value

class FallbackFeature(ProcessorBase):
name = 'fallback_feature'

def process(self) -> int:
return 125

class SomeVectorizer(ProcessorBase):
name = 'vectorizer'

def process(self, feature_value: InputOneOf([SomeFeature, FallbackFeature])) -> int:
return feature_value + 20


async def test_nested_input_one_of_first_level_success_dag(
build_chart: t.Callable[..., PipelineChartLike],
) -> None:
chart = build_chart(input_node=SomeInput, output_node=SomeVectorizer)
result = await chart.run(input_kwargs=dict(base_num=10, other_num=5))

assert result.value == 21
85 changes: 85 additions & 0 deletions tests/dag/oneof/test_nested_input_one_only_necessary_branches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import typing as t

from ml_pipeline_engine.dag_builders.annotation.marks import Input
from ml_pipeline_engine.dag_builders.annotation.marks import InputOneOf
from ml_pipeline_engine.node import ProcessorBase
from ml_pipeline_engine.types import PipelineChartLike
from ml_pipeline_engine.node import ProcessorBase

count_calls = 0

class SomeInput(ProcessorBase):
name = 'input'

def process(self, base_num: int, other_num: int) -> dict:
return {
'base_num': base_num,
'other_num': other_num,
}

class SomeFeature0(ProcessorBase):
name = 'some_feature0'

async def process(self, ds_value: Input(SomeInput)) -> int:
return ds_value

class FirstDataSource(ProcessorBase):
name = 'some_data_source'

def process(self, _: Input(SomeInput), inp: Input(SomeFeature0)) -> int:
global count_calls
count_calls += 1
return 1

class SecondDataSource(ProcessorBase):
name = 'some_data_source_second'

def process(self, _: Input(SomeInput)) -> int:
global count_calls
count_calls += 1
return 2

class SomeFeature(ProcessorBase):
name = 'some_feature'

def process(self, ds_value: InputOneOf([FirstDataSource, SecondDataSource]), inp: Input(SomeInput)) -> int:
global count_calls
count_calls += 1
return ds_value

class SomeFeature2(ProcessorBase):
name = 'some_feature2'

async def process(self, ds_value: Input(SomeFeature)) -> int:
return ds_value

class FallbackFeature(ProcessorBase):
name = 'fallback_feature'

def process(self) -> int:
global count_calls
count_calls += 1
return 125

class SomeVectorizer(ProcessorBase):
name = 'vectorizer'

def process(self, feature_value: InputOneOf([SomeFeature, FallbackFeature])) -> int:
return feature_value + 20

class SomeMLModel(ProcessorBase):
name = 'some_model'

def process(self, vec_value: Input(SomeVectorizer)) -> float:
return (vec_value + 30) / 100


async def test_nested_input_one_of_only_necessary_branches_dag(
build_chart: t.Callable[..., PipelineChartLike],
) -> None:
chart = build_chart(input_node=SomeInput, output_node=SomeMLModel)
result = await chart.run(input_kwargs=dict(base_num=10, other_num=5))

global count_calls

assert count_calls == 2
Loading