From 7f752b32bc2fd5b3f9aaa93128a91b30202e3c44 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 16 Jul 2024 08:15:23 -0700 Subject: [PATCH 01/12] Added placeholder tests for proposed methods * Test: test_add_task_restart_policy_patterns * Test: test_get_task_restart_policy_patterns * Test: test_remove_task_restart_policy_patterns * Test: test_clear_task_restart_policy_patterns * Test: test_task_resolve_restarts --- .../interface/client/test_client.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/alchemiscale/tests/integration/interface/client/test_client.py b/alchemiscale/tests/integration/interface/client/test_client.py index ceb968f4..c39ce4f8 100644 --- a/alchemiscale/tests/integration/interface/client/test_client.py +++ b/alchemiscale/tests/integration/interface/client/test_client.py @@ -2123,3 +2123,31 @@ def test_get_task_failures( # TODO: can we mix in a success in here somewhere? # not possible with current BrokenProtocol, unfortunately + + # TaskRestartPolicy client methods + + @pytest.mark.xfail(raises=NotImplementedError) + def test_add_task_restart_policy_patterns(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_get_task_restart_policy_patterns(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_remove_task_restart_policy_patterns(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_clear_task_restart_policy_patterns(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_task_resolve_restarts( + self, + scope_test, + n4js_preloaded, + user_client: client.AlchemiscaleClient, + network_tyk2_failure, + ): + raise NotImplementedError From dd8f0e967ebfd313f2a815fd5a8576b8a1e86552 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 16 Jul 2024 08:19:30 -0700 Subject: [PATCH 02/12] Added models for new node types * TaskRestartPattern * TaskRestartPolicy * TaskHistory --- alchemiscale/storage/models.py | 80 +++++++++++++++++++++++++++++++++- 1 file changed, 78 insertions(+), 2 deletions(-) diff --git a/alchemiscale/storage/models.py b/alchemiscale/storage/models.py index c9b000b8..25a4c3d3 100644 --- a/alchemiscale/storage/models.py +++ b/alchemiscale/storage/models.py @@ -8,12 +8,12 @@ from copy import copy from datetime import datetime from enum import Enum -from typing import Union, Dict, Optional +from typing import Union, Optional, List from uuid import uuid4 import hashlib -from pydantic import BaseModel, Field +from pydantic import BaseModel from gufe.tokenization import GufeTokenizable, GufeKey from ..models import ScopedKey, Scope @@ -143,6 +143,82 @@ def _defaults(cls): return super()._defaults() +# TODO: fill in docstrings +class TaskRestartPattern(GufeTokenizable): + """A pattern to compare returned Task tracebacks to. + + Attributes + ---------- + pattern: str + A regular expression pattern that can match to returned tracebacks of errored Tasks. + retry_count: int + The number of times the pattern can trigger a restart for a Task. + """ + + pattern: str + retry_count: int + + def __init__(self, pattern: str): + self.pattern = pattern + + def _gufe_tokenize(self): + return hashlib.md5(self.pattern).hexdigest() + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return self.pattern == other.pattern + + +# TODO: fill in docstrings +class TaskRestartPolicy(GufeTokenizable): + """Restart policy that enforces a TaskHub. + + Attributes + ---------- + taskhub: str + ScopedKey of the TaskHub this TaskRestartPolicy enforces. + """ + + taskhub: str + + def __init__(self, taskhub: ScopedKey): + self.taskhub = taskhub + + def _gufe_tokenize(self): + return hashlib.md5( + self.__class__.__qualname__ + str(self.taskhub), usedforsecurity=False + ).hexdigest() + + +# TODO: fill in docstrings +class TaskHistory(GufeTokenizable): + """History attached to a `Task`. + + Attributes + ---------- + task: str + ScopedKey of the Task this TaskHistory corresponds to. + tracebacks: List[str] + The history of tracebacks returned with the newest entries appearing at the end of the list. + times_restarted: int + The number of times the task has bee + """ + + task: str + tracebacks: list + times_restarted: int + + def __init__(self, task: ScopedKey, tracebacks: List[str]): + self.task = task + self.tracebacks = tracebacks + + def _gufe_tokenize(self): + return hashlib.md5( + self.__class__.__qualname__ + str(self.task), usedforsecurity=False + ).hexdigest() + + class TaskHub(GufeTokenizable): """ From da17e45913e3bc55498012323cebbe6d4ee2ebd4 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 17 Jul 2024 14:59:31 -0700 Subject: [PATCH 03/12] Updated new GufeTokenizable models in statestore * Removed TaskRestartPolicy and TaskHistory * Added Traceback --- alchemiscale/storage/models.py | 77 +++++++++++++----------------- alchemiscale/storage/statestore.py | 16 ++++++- 2 files changed, 47 insertions(+), 46 deletions(-) diff --git a/alchemiscale/storage/models.py b/alchemiscale/storage/models.py index 25a4c3d3..b9090160 100644 --- a/alchemiscale/storage/models.py +++ b/alchemiscale/storage/models.py @@ -151,18 +151,34 @@ class TaskRestartPattern(GufeTokenizable): ---------- pattern: str A regular expression pattern that can match to returned tracebacks of errored Tasks. - retry_count: int + max_retries: int The number of times the pattern can trigger a restart for a Task. """ pattern: str - retry_count: int + max_retries: int - def __init__(self, pattern: str): + def __init__(self, pattern: str, max_retries: int): self.pattern = pattern + if not isinstance(max_retries, int) or max_retries <= 0: + raise ValueError("`max_retries` must have a positive integer value.") + self.max_retries = max_retries + + # TODO: these hashes can overlap across TaskHubs def _gufe_tokenize(self): - return hashlib.md5(self.pattern).hexdigest() + return hashlib.md5(self.pattern.encode()).hexdigest() + + @classmethod + def _defaults(cls): + raise NotImplementedError + + @classmethod + def _from_dict(cls, dct): + return cls(**dct) + + def _to_dict(self): + return {"pattern": self.pattern, "max_retries": self.max_retries} def __eq__(self, other): if not isinstance(other, self.__class__): @@ -170,53 +186,24 @@ def __eq__(self, other): return self.pattern == other.pattern -# TODO: fill in docstrings -class TaskRestartPolicy(GufeTokenizable): - """Restart policy that enforces a TaskHub. - - Attributes - ---------- - taskhub: str - ScopedKey of the TaskHub this TaskRestartPolicy enforces. - """ - - taskhub: str +class Traceback(GufeTokenizable): - def __init__(self, taskhub: ScopedKey): - self.taskhub = taskhub + def __init__(self, tracebacks: List[str]): + self.tracebacks = tracebacks def _gufe_tokenize(self): - return hashlib.md5( - self.__class__.__qualname__ + str(self.taskhub), usedforsecurity=False - ).hexdigest() - + return hashlib.md5(str(self.tracebacks).encode()).hexdigest() -# TODO: fill in docstrings -class TaskHistory(GufeTokenizable): - """History attached to a `Task`. - - Attributes - ---------- - task: str - ScopedKey of the Task this TaskHistory corresponds to. - tracebacks: List[str] - The history of tracebacks returned with the newest entries appearing at the end of the list. - times_restarted: int - The number of times the task has bee - """ - - task: str - tracebacks: list - times_restarted: int + @classmethod + def _defaults(cls): + raise NotImplementedError - def __init__(self, task: ScopedKey, tracebacks: List[str]): - self.task = task - self.tracebacks = tracebacks + @classmethod + def _from_dict(cls, dct): + return Traceback(**dct) - def _gufe_tokenize(self): - return hashlib.md5( - self.__class__.__qualname__ + str(self.task), usedforsecurity=False - ).hexdigest() + def _to_dict(self): + return {"tracebacks": self.tracebacks} class TaskHub(GufeTokenizable): diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 1ffd4f4a..3ec0aa5e 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -24,10 +24,10 @@ ComputeServiceRegistration, NetworkMark, NetworkStateEnum, + ProtocolDAGResultRef, Task, TaskHub, TaskStatusEnum, - ProtocolDAGResultRef, ) from ..strategies import Strategy from ..models import Scope, ScopedKey @@ -2703,6 +2703,20 @@ def err_msg(t, status): return self._set_task_status(tasks, q, err_msg, raise_error=raise_error) + ## task restart policy + + # TODO: fill in docstring + def add_task_restart_policy_patterns( + self, taskhub: ScopedKey, patterns: List[str], number_of_retries: int + ): + """Add a list of restart policy patterns to a `TaskHub` along with the number of retries allowed. + + Parameters + ---------- + + """ + raise NotImplementedError + ## authentication def create_credentialed_entity(self, entity: CredentialedEntity): From b7f63d4909e9e5ee3852a0dc7efa1e29e6327566 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 17 Jul 2024 15:41:08 -0700 Subject: [PATCH 04/12] Added placeholder unit tests for new models --- .../tests/unit/test_storage_models.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/alchemiscale/tests/unit/test_storage_models.py b/alchemiscale/tests/unit/test_storage_models.py index 36678b9a..68c9b8c7 100644 --- a/alchemiscale/tests/unit/test_storage_models.py +++ b/alchemiscale/tests/unit/test_storage_models.py @@ -38,3 +38,37 @@ def test_suggested_states_message(self): assert len(suggested_states) == len(NetworkStateEnum) for state in suggested_states: NetworkStateEnum(state) + + +class TestTaskRestartPattern(object): + + @pytest.mark.xfail(raises=NotImplementedError) + def test_empty_pattern(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_negative_max_retries(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_non_int_max_retries(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_to_dict(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_from_dict(self): + raise NotImplementedError + + +class TestTraceback(object): + + @pytest.mark.xfail(raises=NotImplementedError) + def test_to_dict(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_from_dict(self): + raise NotImplementedError From 6a167f13cd532b51c0f41a2741a947b5c073946d Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Thu, 18 Jul 2024 13:40:28 -0700 Subject: [PATCH 05/12] Added validation and unit tests for storage models * TaskReturnPattern: Confirm that the input pattern is a string type and that it is not empty. * Traceback: Confirm that the input is a list of strings and that none of them are empty. --- alchemiscale/storage/models.py | 15 +++ .../tests/unit/test_storage_models.py | 117 +++++++++++++++--- 2 files changed, 116 insertions(+), 16 deletions(-) diff --git a/alchemiscale/storage/models.py b/alchemiscale/storage/models.py index b9090160..fae7af93 100644 --- a/alchemiscale/storage/models.py +++ b/alchemiscale/storage/models.py @@ -159,6 +159,10 @@ class TaskRestartPattern(GufeTokenizable): max_retries: int def __init__(self, pattern: str, max_retries: int): + + if not isinstance(pattern, str) or pattern == "": + raise ValueError("`pattern` must be a non-empty string") + self.pattern = pattern if not isinstance(max_retries, int) or max_retries <= 0: @@ -189,6 +193,17 @@ def __eq__(self, other): class Traceback(GufeTokenizable): def __init__(self, tracebacks: List[str]): + value_error = ValueError( + "`tracebacks` must be a non-empty list of string values" + ) + if not isinstance(tracebacks, list) or tracebacks == []: + raise value_error + else: + # in the case where tracebacks is not an iterable, this will raise a TypeError + all_string_values = all([isinstance(value, str) for value in tracebacks]) + if not all_string_values or "" in tracebacks: + raise value_error + self.tracebacks = tracebacks def _gufe_tokenize(self): diff --git a/alchemiscale/tests/unit/test_storage_models.py b/alchemiscale/tests/unit/test_storage_models.py index 68c9b8c7..02fe188e 100644 --- a/alchemiscale/tests/unit/test_storage_models.py +++ b/alchemiscale/tests/unit/test_storage_models.py @@ -1,6 +1,11 @@ import pytest -from alchemiscale.storage.models import NetworkStateEnum, NetworkMark +from alchemiscale.storage.models import ( + NetworkStateEnum, + NetworkMark, + TaskRestartPattern, + Traceback, +) from alchemiscale import ScopedKey @@ -42,33 +47,113 @@ def test_suggested_states_message(self): class TestTaskRestartPattern(object): - @pytest.mark.xfail(raises=NotImplementedError) + pattern_value_error = "`pattern` must be a non-empty string" + max_retries_value_error = "`max_retries` must have a positive integer value." + def test_empty_pattern(self): - raise NotImplementedError + with pytest.raises(ValueError, match=self.pattern_value_error): + _ = TaskRestartPattern("", 3) + + def test_non_string_pattern(self): + with pytest.raises(ValueError, match=self.pattern_value_error): + _ = TaskRestartPattern(None, 3) + + with pytest.raises(ValueError, match=self.pattern_value_error): + _ = TaskRestartPattern([], 3) + + def test_non_positive_max_retries(self): - @pytest.mark.xfail(raises=NotImplementedError) - def test_negative_max_retries(self): - raise NotImplementedError + with pytest.raises(ValueError, match=self.max_retries_value_error): + TaskRestartPattern("Example pattern", 0) + + with pytest.raises(ValueError, match=self.max_retries_value_error): + TaskRestartPattern("Example pattern", -1) - @pytest.mark.xfail(raises=NotImplementedError) def test_non_int_max_retries(self): - raise NotImplementedError + with pytest.raises(ValueError, match=self.max_retries_value_error): + TaskRestartPattern("Example pattern", 4.0) - @pytest.mark.xfail(raises=NotImplementedError) def test_to_dict(self): - raise NotImplementedError + trp = TaskRestartPattern("Example pattern", 3) + dict_trp = trp.to_dict() + + assert len(dict_trp.keys()) == 5 + + assert dict_trp.pop("__qualname__") == "TaskRestartPattern" + assert dict_trp.pop("__module__") == "alchemiscale.storage.models" + + # light test of the version key + try: + dict_trp.pop(":version:") + except KeyError: + raise AssertionError("expected to find :version:") + + expected = {"pattern": "Example pattern", "max_retries": 3} + + assert expected == dict_trp - @pytest.mark.xfail(raises=NotImplementedError) def test_from_dict(self): - raise NotImplementedError + + original_pattern = "Example pattern" + original_max_retries = 3 + + trp_orig = TaskRestartPattern(original_pattern, original_max_retries) + trp_dict = trp_orig.to_dict() + trp_reconstructed: TaskRestartPattern = TaskRestartPattern.from_dict(trp_dict) + + assert trp_reconstructed.pattern == original_pattern + assert trp_reconstructed.max_retries == original_max_retries class TestTraceback(object): - @pytest.mark.xfail(raises=NotImplementedError) + valid_entry = ["traceback1", "traceback2", "traceback3"] + tracebacks_value_error = "`tracebacks` must be a non-empty list of string values" + + def test_empty_string_element(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback(self.valid_entry + [""]) + + def test_non_list_parameter(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback(None) + + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback(100) + + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback("not a list, but still an iterable that yields strings") + + def test_list_non_string_elements(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback(self.valid_entry + [None]) + + def test_empty_list(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback([]) + def test_to_dict(self): - raise NotImplementedError + tb = Traceback(self.valid_entry) + tb_dict = tb.to_dict() + + assert len(tb_dict) == 4 + + assert tb_dict.pop("__qualname__") == "Traceback" + assert tb_dict.pop("__module__") == "alchemiscale.storage.models" + + # light test of the version key + try: + tb_dict.pop(":version:") + except KeyError: + raise AssertionError("expected to find :version:") + + expected = {"tracebacks": self.valid_entry} + + assert expected == tb_dict - @pytest.mark.xfail(raises=NotImplementedError) def test_from_dict(self): - raise NotImplementedError + tb_orig = Traceback(self.valid_entry) + tb_dict = tb_orig.to_dict() + tb_reconstructed: TaskRestartPattern = TaskRestartPattern.from_dict(tb_dict) + + assert tb_reconstructed.tracebacks == self.valid_entry From a10e2355196debf7ba3ebc466040c35193374541 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 22 Jul 2024 14:16:17 -0700 Subject: [PATCH 06/12] Added `taskhub_sk` to `TaskRestartPattern` Similar to `TaskHub`s, the `TaskRestartPattern` needs additonal hashed data to uniquely identify it as a Neo4j node (via the gufe key). The unit tests have been updated to reflect this change. --- alchemiscale/storage/models.py | 20 ++++++-- .../tests/unit/test_storage_models.py | 50 +++++++++++++++---- 2 files changed, 57 insertions(+), 13 deletions(-) diff --git a/alchemiscale/storage/models.py b/alchemiscale/storage/models.py index fae7af93..3dc69e0d 100644 --- a/alchemiscale/storage/models.py +++ b/alchemiscale/storage/models.py @@ -153,12 +153,17 @@ class TaskRestartPattern(GufeTokenizable): A regular expression pattern that can match to returned tracebacks of errored Tasks. max_retries: int The number of times the pattern can trigger a restart for a Task. + taskhub_sk: str + The TaskHub the pattern is bound to. This is needed to properly set a unique Gufe key. """ pattern: str max_retries: int + taskhub_sk: str - def __init__(self, pattern: str, max_retries: int): + def __init__( + self, pattern: str, max_retries: int, taskhub_scoped_key: Union[str, ScopedKey] + ): if not isinstance(pattern, str) or pattern == "": raise ValueError("`pattern` must be a non-empty string") @@ -169,9 +174,11 @@ def __init__(self, pattern: str, max_retries: int): raise ValueError("`max_retries` must have a positive integer value.") self.max_retries = max_retries - # TODO: these hashes can overlap across TaskHubs + self.taskhub_scoped_key = str(taskhub_scoped_key) + def _gufe_tokenize(self): - return hashlib.md5(self.pattern.encode()).hexdigest() + key_string = self.pattern + self.taskhub_scoped_key + return hashlib.md5(key_string.encode()).hexdigest() @classmethod def _defaults(cls): @@ -182,8 +189,13 @@ def _from_dict(cls, dct): return cls(**dct) def _to_dict(self): - return {"pattern": self.pattern, "max_retries": self.max_retries} + return { + "pattern": self.pattern, + "max_retries": self.max_retries, + "taskhub_scoped_key": self.taskhub_scoped_key, + } + # TODO: should this also compare taskhub scoped keys? def __eq__(self, other): if not isinstance(other, self.__class__): return False diff --git a/alchemiscale/tests/unit/test_storage_models.py b/alchemiscale/tests/unit/test_storage_models.py index 02fe188e..55dc872f 100644 --- a/alchemiscale/tests/unit/test_storage_models.py +++ b/alchemiscale/tests/unit/test_storage_models.py @@ -52,35 +52,61 @@ class TestTaskRestartPattern(object): def test_empty_pattern(self): with pytest.raises(ValueError, match=self.pattern_value_error): - _ = TaskRestartPattern("", 3) + _ = TaskRestartPattern( + "", 3, "FakeScopedKey-1234-fake_org-fake_campaign-fake_project" + ) def test_non_string_pattern(self): with pytest.raises(ValueError, match=self.pattern_value_error): - _ = TaskRestartPattern(None, 3) + _ = TaskRestartPattern( + None, 3, "FakeScopedKey-1234-fake_org-fake_campaign-fake_project" + ) with pytest.raises(ValueError, match=self.pattern_value_error): - _ = TaskRestartPattern([], 3) + _ = TaskRestartPattern( + [], 3, "FakeScopedKey-1234-fake_org-fake_campaign-fake_project" + ) def test_non_positive_max_retries(self): with pytest.raises(ValueError, match=self.max_retries_value_error): - TaskRestartPattern("Example pattern", 0) + TaskRestartPattern( + "Example pattern", + 0, + "FakeScopedKey-1234-fake_org-fake_campaign-fake_project", + ) with pytest.raises(ValueError, match=self.max_retries_value_error): - TaskRestartPattern("Example pattern", -1) + TaskRestartPattern( + "Example pattern", + -1, + "FakeScopedKey-1234-fake_org-fake_campaign-fake_project", + ) def test_non_int_max_retries(self): with pytest.raises(ValueError, match=self.max_retries_value_error): - TaskRestartPattern("Example pattern", 4.0) + TaskRestartPattern( + "Example pattern", + 4.0, + "FakeScopedKey-1234-fake_org-fake_campaign-fake_project", + ) def test_to_dict(self): - trp = TaskRestartPattern("Example pattern", 3) + trp = TaskRestartPattern( + "Example pattern", + 3, + "FakeScopedKey-1234-fake_org-fake_campaign-fake_project", + ) dict_trp = trp.to_dict() - assert len(dict_trp.keys()) == 5 + assert len(dict_trp.keys()) == 6 assert dict_trp.pop("__qualname__") == "TaskRestartPattern" assert dict_trp.pop("__module__") == "alchemiscale.storage.models" + assert ( + dict_trp.pop("taskhub_scoped_key") + == "FakeScopedKey-1234-fake_org-fake_campaign-fake_project" + ) # light test of the version key try: @@ -96,13 +122,19 @@ def test_from_dict(self): original_pattern = "Example pattern" original_max_retries = 3 + original_taskhub_scoped_key = ( + "FakeScopedKey-1234-fake_org-fake_campaign-fake_project" + ) - trp_orig = TaskRestartPattern(original_pattern, original_max_retries) + trp_orig = TaskRestartPattern( + original_pattern, original_max_retries, original_taskhub_scoped_key + ) trp_dict = trp_orig.to_dict() trp_reconstructed: TaskRestartPattern = TaskRestartPattern.from_dict(trp_dict) assert trp_reconstructed.pattern == original_pattern assert trp_reconstructed.max_retries == original_max_retries + assert trp_reconstructed.taskhub_scoped_key == original_taskhub_scoped_key class TestTraceback(object): From b99d8ef3a2f83fd68b3e756d59f6fd6b2db92345 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 22 Jul 2024 14:30:07 -0700 Subject: [PATCH 07/12] Added `statestore` methods for restart patterns `statestore` methods have been added to modify the database state: * add_task_restart_patterns * remove_task_restart_patterns * get_task_restart_patterns Tests were added for each method in the integration tests for the statestore. --- alchemiscale/storage/statestore.py | 84 +++++++++- .../integration/storage/test_statestore.py | 158 ++++++++++++++++++ 2 files changed, 240 insertions(+), 2 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 3ec0aa5e..74390bdf 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -27,6 +27,7 @@ ProtocolDAGResultRef, Task, TaskHub, + TaskRestartPattern, TaskStatusEnum, ) from ..strategies import Strategy @@ -2706,7 +2707,7 @@ def err_msg(t, status): ## task restart policy # TODO: fill in docstring - def add_task_restart_policy_patterns( + def add_task_restart_patterns( self, taskhub: ScopedKey, patterns: List[str], number_of_retries: int ): """Add a list of restart policy patterns to a `TaskHub` along with the number of retries allowed. @@ -2714,8 +2715,87 @@ def add_task_restart_policy_patterns( Parameters ---------- + + Raises + ------ """ - raise NotImplementedError + + # get taskhub node + q = """ + MATCH (th:TaskHub {`_scoped_key`: $taskhub}) + RETURN th + """ + results = self.execute_query(q, taskhub=str(taskhub)) + ## raise error if taskhub not found + + if not results.records: + raise KeyError("No such TaskHub in the database") + + record_data = results.records[0]["th"] + taskhub_node = record_data_to_node(record_data) + scope = taskhub.scope + + subgraph = Subgraph() + + for pattern in patterns: + task_restart_pattern = TaskRestartPattern( + pattern, + max_retries=number_of_retries, + taskhub_scoped_key=str(taskhub), + ) + + _, task_restart_policy_node, scoped_key = self._gufe_to_subgraph( + task_restart_pattern.to_shallow_dict(), + labels=["GufeTokenizable", task_restart_pattern.__class__.__name__], + gufe_key=task_restart_pattern.key, + scope=scope, + ) + + subgraph |= Relationship.type("ENFORCES")( + task_restart_policy_node, + taskhub_node, + _org=scope.org, + _campaign=scope.campaign, + _project=scope.project, + ) + + with self.transaction() as tx: + merge_subgraph(tx, subgraph, "GufeTokenizable", "_scoped_key") + + # TODO: fill in docstring + def remove_task_restart_patterns(self, taskhub: ScopedKey, patterns: List[str]): + q = """ + UNWIND $patterns AS pattern + + MATCH (trp: TaskRestartPattern {pattern: pattern, taskhub_scoped_key: $taskhub_scoped_key}) + + DETACH DELETE trp + """ + + self.execute_query(q, patterns=patterns, taskhub_scoped_key=str(taskhub)) + + # TODO: fill in docstring + def get_task_restart_patterns(self, taskhubs: List[ScopedKey]): + + q = """ + UNWIND $taskhub_scoped_keys as taskhub_scoped_key + MATCH (trp: TaskRestartPattern)-[ENFORCES]->(th: TaskHub {`_scoped_key`: taskhub_scoped_key}) + RETURN th, trp + """ + + records = self.execute_query( + q, taskhub_scoped_keys=list(map(str, taskhubs)) + ).records + + data = {taskhub: set() for taskhub in taskhubs} + + for record in records: + pattern = record["trp"]["pattern"] + max_retries = record["trp"]["max_retries"] + taskhub_sk = ScopedKey.from_str(record["th"]["_scoped_key"]) + data[taskhub_sk].add((pattern, max_retries)) + + return dict(data) ## authentication diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index 2632524b..1b96dfda 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -3,6 +3,7 @@ from typing import List, Dict from pathlib import Path from itertools import chain +from collections import defaultdict import pytest from gufe import AlchemicalNetwork @@ -1851,6 +1852,163 @@ def test_get_task_failures( assert pdr_ref_sk in failure_pdr_ref_sks assert pdr_ref2_sk in failure_pdr_ref_sks + ### task restart policies + + def test_add_task_restart_patterns(self, n4js, network_tyk2, scope_test): + # create three new alchemical networks (and taskhubs) + taskhub_sks = [] + for network_index in range(3): + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + + f"_test_add_task_restart_patterns_{network_index}" + ) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + taskhub_sks.append(taskhub_scoped_key) + + # test a shared pattern with and without shared number of restarts + # this will create 6 unique patterns + for network_index in range(3): + taskhub_scoped_key = taskhub_sks[network_index] + n4js.add_task_restart_patterns( + taskhub_scoped_key, ["shared_pattern_and_restarts.+"], 5 + ) + n4js.add_task_restart_patterns( + taskhub_scoped_key, + ["shared_pattern_and_different_restarts.+"], + network_index + 1, + ) + + q = """UNWIND $taskhub_sks AS taskhub_sk + MATCH (trp: TaskRestartPattern)-[ENFORCES]->(th: TaskHub {`_scoped_key`: taskhub_sk}) RETURN trp, th + """ + + taskhub_sks = list(map(str, taskhub_sks)) + records = n4js.execute_query(q, taskhub_sks=taskhub_sks).records + + assert len(records) == 6 + + taskhub_scoped_key_set = set() + taskrestartpattern_scoped_key_set = set() + + for record in records: + taskhub_scoped_key = ScopedKey.from_str(record["th"]["_scoped_key"]) + taskrestartpattern_scoped_key = ScopedKey.from_str( + record["trp"]["_scoped_key"] + ) + + taskhub_scoped_key_set.add(taskhub_scoped_key) + taskrestartpattern_scoped_key_set.add(taskrestartpattern_scoped_key) + + assert len(taskhub_scoped_key_set) == 3 + assert len(taskrestartpattern_scoped_key_set) == 6 + + def test_remove_task_restart_patterns(self, n4js, network_tyk2, scope_test): + + # collect what we expect `get_task_restart_patterns` to return + expected_results = defaultdict(set) + + # create three new alchemical networks (and taskhubs) + taskhub_sks = [] + for network_index in range(3): + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + + f"_test_remove_task_restart_patterns_{network_index}" + ) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + taskhub_sks.append(taskhub_scoped_key) + + # test a shared pattern with and without shared number of restarts + # this will create 6 unique patterns + for network_index in range(3): + taskhub_scoped_key = taskhub_sks[network_index] + n4js.add_task_restart_patterns( + taskhub_scoped_key, ["shared_pattern_and_restarts.+"], 5 + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_restarts.+", 5) + ) + + n4js.add_task_restart_patterns( + taskhub_scoped_key, + ["shared_pattern_and_different_restarts.+"], + network_index + 1, + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_different_restarts.+", network_index + 1) + ) + + # remove both patterns enforcing the first taskhub at the same time, two patterns + target_taskhub = taskhub_sks[0] + target_patterns = [] + + for pattern, _ in expected_results[target_taskhub]: + target_patterns.append(pattern) + + expected_results[target_taskhub].clear() + + n4js.remove_task_restart_patterns(target_taskhub, target_patterns) + assert expected_results == n4js.get_task_restart_patterns(taskhub_sks) + + # remove both patterns enforcing the second taskhub one at a time, two patterns + target_taskhub = taskhub_sks[1] + # pointer to underlying set, pops will update comparison data structure + target_patterns = expected_results[target_taskhub] + + pattern, _ = target_patterns.pop() + n4js.remove_task_restart_patterns(target_taskhub, [pattern]) + assert expected_results == n4js.get_task_restart_patterns(taskhub_sks) + + pattern, _ = target_patterns.pop() + n4js.remove_task_restart_patterns(target_taskhub, [pattern]) + assert expected_results == n4js.get_task_restart_patterns(taskhub_sks) + + def test_get_task_restart_patterns(self, n4js, network_tyk2, scope_test): + # create three new alchemical networks (and taskhubs) + taskhub_sks = [] + for network_index in range(3): + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + + f"_test_add_task_restart_patterns_{network_index}" + ) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + taskhub_sks.append(taskhub_scoped_key) + + expected_results = defaultdict(set) + # test a shared pattern with and without shared number of restarts + # this will create 6 unique patterns + for network_index in range(3): + taskhub_scoped_key = taskhub_sks[network_index] + n4js.add_task_restart_patterns( + taskhub_scoped_key, ["shared_pattern_and_restarts.+"], 5 + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_restarts.+", 5) + ) + n4js.add_task_restart_patterns( + taskhub_scoped_key, + ["shared_pattern_and_different_restarts.+"], + network_index + 1, + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_different_restarts.+", network_index + 1) + ) + + taskhub_grouped_patterns = n4js.get_task_restart_patterns(taskhub_sks) + + assert taskhub_grouped_patterns == expected_results + + @pytest.mark.xfail(raises=NotImplementedError) + def test_task_actioning_applies_relationship(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_add_restart_applies_relationship(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_task_deaction_applies_relationship(self): + raise NotImplementedError + ### authentication @pytest.mark.parametrize( From 39f986888909d19c9dd0c999a72650ae6a8b01fd Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 24 Jul 2024 17:06:42 -0700 Subject: [PATCH 08/12] Added APPLIES relationship when adding pattern The `add_task_restart_patterns` method now establishes the APPLIES relationship between the each new pattern and all Tasks ACTIONED on the corresponding TaskHub. Added testing for creation of the APPLIES relationship, asserting the number of created connections over multiple TaskHubs and Tasks. Further subdivided the test classes. Additionally added a `set_task_restart_patterns_max_retries` method for updating the max_retries of a TaskRestartPattern. --- alchemiscale/storage/statestore.py | 95 ++++-- .../integration/storage/test_statestore.py | 320 +++++++++++------- 2 files changed, 272 insertions(+), 143 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 74390bdf..598bf2e6 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -9,7 +9,7 @@ from contextlib import contextmanager import json from functools import lru_cache -from typing import Dict, List, Optional, Union, Tuple +from typing import Dict, List, Optional, Union, Tuple, Set import weakref import numpy as np @@ -2735,31 +2735,56 @@ def add_task_restart_patterns( taskhub_node = record_data_to_node(record_data) scope = taskhub.scope - subgraph = Subgraph() + with self.transaction() as tx: + actioned_tasks_query = """ + MATCH (taskhub: TaskHub {`_scoped_key`: $taskhub_scoped_key})-[:ACTIONS]->(task: Task) + RETURN task + """ - for pattern in patterns: - task_restart_pattern = TaskRestartPattern( - pattern, - max_retries=number_of_retries, - taskhub_scoped_key=str(taskhub), - ) + subgraph = Subgraph() - _, task_restart_policy_node, scoped_key = self._gufe_to_subgraph( - task_restart_pattern.to_shallow_dict(), - labels=["GufeTokenizable", task_restart_pattern.__class__.__name__], - gufe_key=task_restart_pattern.key, - scope=scope, - ) + actioned_task_nodes = [] - subgraph |= Relationship.type("ENFORCES")( - task_restart_policy_node, - taskhub_node, - _org=scope.org, - _campaign=scope.campaign, - _project=scope.project, - ) + for actioned_tasks_record in ( + tx.run(actioned_tasks_query, taskhub_scoped_key=str(taskhub)) + .to_eager_result() + .records + ): + actioned_task_nodes.append( + record_data_to_node(actioned_tasks_record["task"]) + ) - with self.transaction() as tx: + for pattern in patterns: + task_restart_pattern = TaskRestartPattern( + pattern, + max_retries=number_of_retries, + taskhub_scoped_key=str(taskhub), + ) + + _, task_restart_pattern_node, scoped_key = self._gufe_to_subgraph( + task_restart_pattern.to_shallow_dict(), + labels=["GufeTokenizable", task_restart_pattern.__class__.__name__], + gufe_key=task_restart_pattern.key, + scope=scope, + ) + + subgraph |= Relationship.type("ENFORCES")( + task_restart_pattern_node, + taskhub_node, + _org=scope.org, + _campaign=scope.campaign, + _project=scope.project, + ) + + for actioned_task_node in actioned_task_nodes: + subgraph |= Relationship.type("APPLIES")( + task_restart_pattern_node, + actioned_task_node, + num_retries=0, + _org=scope.org, + _campaign=scope.campaign, + _project=scope.project, + ) merge_subgraph(tx, subgraph, "GufeTokenizable", "_scoped_key") # TODO: fill in docstring @@ -2775,7 +2800,29 @@ def remove_task_restart_patterns(self, taskhub: ScopedKey, patterns: List[str]): self.execute_query(q, patterns=patterns, taskhub_scoped_key=str(taskhub)) # TODO: fill in docstring - def get_task_restart_patterns(self, taskhubs: List[ScopedKey]): + def set_task_restart_patterns_max_retries( + self, + taskhub_scoped_key: Union[ScopedKey, str], + patterns: List[str], + max_retries: int, + ): + query = """ + UNWIND $patterns AS pattern + MATCH (trp: TaskRestartPattern {pattern: pattern, taskhub_scoped_key: $taskhub_scoped_key}) + SET trp.max_retries = $max_retries + """ + + self.execute_query( + query, + patterns=patterns, + taskhub_scoped_key=str(taskhub_scoped_key), + max_retries=max_retries, + ) + + # TODO: fill in docstring + def get_task_restart_patterns( + self, taskhubs: List[ScopedKey] + ) -> Dict[ScopedKey, Set[Tuple[str, int]]]: q = """ UNWIND $taskhub_scoped_keys as taskhub_scoped_key @@ -2795,7 +2842,7 @@ def get_task_restart_patterns(self, taskhubs: List[ScopedKey]): taskhub_sk = ScopedKey.from_str(record["th"]["_scoped_key"]) data[taskhub_sk].add((pattern, max_retries)) - return dict(data) + return data ## authentication diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index 1b96dfda..fa8f7e0d 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -1854,160 +1854,242 @@ def test_get_task_failures( ### task restart policies - def test_add_task_restart_patterns(self, n4js, network_tyk2, scope_test): - # create three new alchemical networks (and taskhubs) - taskhub_sks = [] - for network_index in range(3): - an = network_tyk2.copy_with_replacements( - name=network_tyk2.name - + f"_test_add_task_restart_patterns_{network_index}" - ) - _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) - taskhub_sks.append(taskhub_scoped_key) + class TestTaskRestartPolicy: + + def test_add_task_restart_patterns(self, n4js, network_tyk2, scope_test): + # create three new alchemical networks (and taskhubs) + taskhub_sks = [] + for network_index in range(3): + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + + f"_test_add_task_restart_patterns_{network_index}" + ) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + + # don't action tasks on every network, take every other + if network_index % 2 == 0: + transformation = list(an.edges)[0] + transformation_sk = n4js.get_scoped_key(transformation, scope_test) + task_sks = n4js.create_tasks([transformation_sk] * 3) + n4js.action_tasks(task_sks, taskhub_scoped_key) + + taskhub_sks.append(taskhub_scoped_key) + # test a shared pattern with and without shared number of restarts + # this will create 6 unique patterns + for network_index in range(3): + taskhub_scoped_key = taskhub_sks[network_index] + n4js.add_task_restart_patterns( + taskhub_scoped_key, ["shared_pattern_and_restarts.+"], 5 + ) + n4js.add_task_restart_patterns( + taskhub_scoped_key, + ["shared_pattern_and_different_restarts.+"], + network_index + 1, + ) - # test a shared pattern with and without shared number of restarts - # this will create 6 unique patterns - for network_index in range(3): - taskhub_scoped_key = taskhub_sks[network_index] - n4js.add_task_restart_patterns( - taskhub_scoped_key, ["shared_pattern_and_restarts.+"], 5 - ) - n4js.add_task_restart_patterns( - taskhub_scoped_key, - ["shared_pattern_and_different_restarts.+"], - network_index + 1, - ) + q = """UNWIND $taskhub_sks AS taskhub_sk + MATCH (trp: TaskRestartPattern)-[:ENFORCES]->(th: TaskHub {`_scoped_key`: taskhub_sk}) RETURN trp, th + """ - q = """UNWIND $taskhub_sks AS taskhub_sk - MATCH (trp: TaskRestartPattern)-[ENFORCES]->(th: TaskHub {`_scoped_key`: taskhub_sk}) RETURN trp, th - """ + taskhub_sks = list(map(str, taskhub_sks)) + records = n4js.execute_query(q, taskhub_sks=taskhub_sks).records - taskhub_sks = list(map(str, taskhub_sks)) - records = n4js.execute_query(q, taskhub_sks=taskhub_sks).records + assert len(records) == 6 - assert len(records) == 6 + taskhub_scoped_key_set = set() + taskrestartpattern_scoped_key_set = set() - taskhub_scoped_key_set = set() - taskrestartpattern_scoped_key_set = set() + for record in records: + taskhub_scoped_key = ScopedKey.from_str(record["th"]["_scoped_key"]) + taskrestartpattern_scoped_key = ScopedKey.from_str( + record["trp"]["_scoped_key"] + ) - for record in records: - taskhub_scoped_key = ScopedKey.from_str(record["th"]["_scoped_key"]) - taskrestartpattern_scoped_key = ScopedKey.from_str( - record["trp"]["_scoped_key"] - ) + taskhub_scoped_key_set.add(taskhub_scoped_key) + taskrestartpattern_scoped_key_set.add(taskrestartpattern_scoped_key) - taskhub_scoped_key_set.add(taskhub_scoped_key) - taskrestartpattern_scoped_key_set.add(taskrestartpattern_scoped_key) + assert len(taskhub_scoped_key_set) == 3 + assert len(taskrestartpattern_scoped_key_set) == 6 - assert len(taskhub_scoped_key_set) == 3 - assert len(taskrestartpattern_scoped_key_set) == 6 + # check that the applies relationships were correctly added - def test_remove_task_restart_patterns(self, n4js, network_tyk2, scope_test): + ## first check that the number of applies relationships is correct and + ## that the number of retries is zero + applies_query = """ + MATCH (trp: TaskRestartPattern)-[app:APPLIES {num_retries: 0}]->(task: Task)<-[:ACTIONS]-(th: TaskHub) + RETURN th, count(app) AS num_applied + """ - # collect what we expect `get_task_restart_patterns` to return - expected_results = defaultdict(set) + records = n4js.execute_query(applies_query).records - # create three new alchemical networks (and taskhubs) - taskhub_sks = [] - for network_index in range(3): - an = network_tyk2.copy_with_replacements( - name=network_tyk2.name - + f"_test_remove_task_restart_patterns_{network_index}" + ### one record per taskhub, each with six num_applied + assert len(records) == 2 + assert records[0]["num_applied"] == records[1]["num_applied"] == 6 + + applies_nonzero_retries = """ + MATCH (trp: TaskRestartPattern)-[app:APPLIES]->(task: Task)<-[:ACTIONS]-(th: TaskHub) + WHERE app.num_retries <> 0 + RETURN th, count(app) AS num_applied + """ + assert len(n4js.execute_query(applies_nonzero_retries).records) == 0 + + def test_remove_task_restart_patterns(self, n4js, network_tyk2, scope_test): + + # collect what we expect `get_task_restart_patterns` to return + expected_results = defaultdict(set) + + # create three new alchemical networks (and taskhubs) + taskhub_sks = [] + for network_index in range(3): + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + + f"_test_remove_task_restart_patterns_{network_index}" + ) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + taskhub_sks.append(taskhub_scoped_key) + + # test a shared pattern with and without shared number of restarts + # this will create 6 unique patterns + for network_index in range(3): + taskhub_scoped_key = taskhub_sks[network_index] + n4js.add_task_restart_patterns( + taskhub_scoped_key, ["shared_pattern_and_restarts.+"], 5 + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_restarts.+", 5) + ) + + n4js.add_task_restart_patterns( + taskhub_scoped_key, + ["shared_pattern_and_different_restarts.+"], + network_index + 1, + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_different_restarts.+", network_index + 1) + ) + + # remove both patterns enforcing the first taskhub at the same time, two patterns + target_taskhub = taskhub_sks[0] + target_patterns = [] + + for pattern, _ in expected_results[target_taskhub]: + target_patterns.append(pattern) + + expected_results[target_taskhub].clear() + + n4js.remove_task_restart_patterns(target_taskhub, target_patterns) + assert expected_results == n4js.get_task_restart_patterns(taskhub_sks) + + # remove both patterns enforcing the second taskhub one at a time, two patterns + target_taskhub = taskhub_sks[1] + # pointer to underlying set, pops will update comparison data structure + target_patterns = expected_results[target_taskhub] + + pattern, _ = target_patterns.pop() + n4js.remove_task_restart_patterns(target_taskhub, [pattern]) + assert expected_results == n4js.get_task_restart_patterns(taskhub_sks) + + pattern, _ = target_patterns.pop() + n4js.remove_task_restart_patterns(target_taskhub, [pattern]) + assert expected_results == n4js.get_task_restart_patterns(taskhub_sks) + + def test_set_task_restart_patterns_max_retries( + self, n4js, network_tyk2, scope_test + ): + network_name = ( + network_tyk2.name + "_test_set_task_restart_patterns_max_retries" ) + an = network_tyk2.copy_with_replacements(name=network_name) _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) - taskhub_sks.append(taskhub_scoped_key) - # test a shared pattern with and without shared number of restarts - # this will create 6 unique patterns - for network_index in range(3): - taskhub_scoped_key = taskhub_sks[network_index] - n4js.add_task_restart_patterns( - taskhub_scoped_key, ["shared_pattern_and_restarts.+"], 5 - ) - expected_results[taskhub_scoped_key].add( - ("shared_pattern_and_restarts.+", 5) - ) + pattern_data = [("pattern_1", 5), ("pattern_2", 5), ("pattern_3", 5)] n4js.add_task_restart_patterns( taskhub_scoped_key, - ["shared_pattern_and_different_restarts.+"], - network_index + 1, + patterns=[data[0] for data in pattern_data], + number_of_retries=5, ) - expected_results[taskhub_scoped_key].add( - ("shared_pattern_and_different_restarts.+", network_index + 1) + + expected_results = {taskhub_scoped_key: set(pattern_data)} + + assert expected_results == n4js.get_task_restart_patterns( + [taskhub_scoped_key] ) - # remove both patterns enforcing the first taskhub at the same time, two patterns - target_taskhub = taskhub_sks[0] - target_patterns = [] + # reflect changing just one max_retry + new_pattern_1_tuple = ("pattern_1", 1) - for pattern, _ in expected_results[target_taskhub]: - target_patterns.append(pattern) + expected_results[taskhub_scoped_key].remove(pattern_data[0]) + expected_results[taskhub_scoped_key].add(new_pattern_1_tuple) - expected_results[target_taskhub].clear() + n4js.set_task_restart_patterns_max_retries( + taskhub_scoped_key, new_pattern_1_tuple[0], new_pattern_1_tuple[1] + ) - n4js.remove_task_restart_patterns(target_taskhub, target_patterns) - assert expected_results == n4js.get_task_restart_patterns(taskhub_sks) + assert expected_results == n4js.get_task_restart_patterns( + [taskhub_scoped_key] + ) - # remove both patterns enforcing the second taskhub one at a time, two patterns - target_taskhub = taskhub_sks[1] - # pointer to underlying set, pops will update comparison data structure - target_patterns = expected_results[target_taskhub] + # reflect changing more than one at a time + new_pattern_2_tuple = ("pattern_2", 2) + new_pattern_3_tuple = ("pattern_3", 2) - pattern, _ = target_patterns.pop() - n4js.remove_task_restart_patterns(target_taskhub, [pattern]) - assert expected_results == n4js.get_task_restart_patterns(taskhub_sks) + expected_results[taskhub_scoped_key].remove(pattern_data[1]) + expected_results[taskhub_scoped_key].add(new_pattern_2_tuple) - pattern, _ = target_patterns.pop() - n4js.remove_task_restart_patterns(target_taskhub, [pattern]) - assert expected_results == n4js.get_task_restart_patterns(taskhub_sks) + expected_results[taskhub_scoped_key].remove(pattern_data[2]) + expected_results[taskhub_scoped_key].add(new_pattern_3_tuple) - def test_get_task_restart_patterns(self, n4js, network_tyk2, scope_test): - # create three new alchemical networks (and taskhubs) - taskhub_sks = [] - for network_index in range(3): - an = network_tyk2.copy_with_replacements( - name=network_tyk2.name - + f"_test_add_task_restart_patterns_{network_index}" + n4js.set_task_restart_patterns_max_retries( + taskhub_scoped_key, [new_pattern_2_tuple[0], new_pattern_3_tuple[0]], 2 ) - _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) - taskhub_sks.append(taskhub_scoped_key) - expected_results = defaultdict(set) - # test a shared pattern with and without shared number of restarts - # this will create 6 unique patterns - for network_index in range(3): - taskhub_scoped_key = taskhub_sks[network_index] - n4js.add_task_restart_patterns( - taskhub_scoped_key, ["shared_pattern_and_restarts.+"], 5 - ) - expected_results[taskhub_scoped_key].add( - ("shared_pattern_and_restarts.+", 5) - ) - n4js.add_task_restart_patterns( - taskhub_scoped_key, - ["shared_pattern_and_different_restarts.+"], - network_index + 1, - ) - expected_results[taskhub_scoped_key].add( - ("shared_pattern_and_different_restarts.+", network_index + 1) + assert expected_results == n4js.get_task_restart_patterns( + [taskhub_scoped_key] ) - taskhub_grouped_patterns = n4js.get_task_restart_patterns(taskhub_sks) + def test_get_task_restart_patterns(self, n4js, network_tyk2, scope_test): + # create three new alchemical networks (and taskhubs) + taskhub_sks = [] + for network_index in range(3): + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + + f"_test_add_task_restart_patterns_{network_index}" + ) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + taskhub_sks.append(taskhub_scoped_key) + + expected_results = defaultdict(set) + # test a shared pattern with and without shared number of restarts + # this will create 6 unique patterns + for network_index in range(3): + taskhub_scoped_key = taskhub_sks[network_index] + n4js.add_task_restart_patterns( + taskhub_scoped_key, ["shared_pattern_and_restarts.+"], 5 + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_restarts.+", 5) + ) + n4js.add_task_restart_patterns( + taskhub_scoped_key, + ["shared_pattern_and_different_restarts.+"], + network_index + 1, + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_different_restarts.+", network_index + 1) + ) - assert taskhub_grouped_patterns == expected_results + taskhub_grouped_patterns = n4js.get_task_restart_patterns(taskhub_sks) - @pytest.mark.xfail(raises=NotImplementedError) - def test_task_actioning_applies_relationship(self): - raise NotImplementedError + assert taskhub_grouped_patterns == expected_results - @pytest.mark.xfail(raises=NotImplementedError) - def test_add_restart_applies_relationship(self): - raise NotImplementedError + @pytest.mark.xfail(raises=NotImplementedError) + def test_task_actioning_applies_relationship(self): + raise NotImplementedError - @pytest.mark.xfail(raises=NotImplementedError) - def test_task_deaction_applies_relationship(self): - raise NotImplementedError + @pytest.mark.xfail(raises=NotImplementedError) + def test_task_deaction_applies_relationship(self): + raise NotImplementedError ### authentication From 988155f36c227912a615ee9b2241ee172319ef35 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Fri, 26 Jul 2024 12:19:16 -0700 Subject: [PATCH 09/12] Establish APPLIES when actioning a Task "actioning" a Task on a TaskHub with preexisting TaskRestartPatterns created the APPLIES relationship between them with a num_retries value of 0. This behavior is tested in the test_action_task function in the statestore. --- alchemiscale/storage/statestore.py | 39 +++++++++--- .../integration/storage/test_statestore.py | 59 +++++++++++++++++++ 2 files changed, 91 insertions(+), 7 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 598bf2e6..3689b8fd 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -1406,30 +1406,54 @@ def action_tasks( # so we can properly return `None` if needed task_map = {str(task): None for task in tasks} - q = f""" + query_safe_task_list = [str(task) for task in tasks if task] + + q = """ // get our TaskHub - UNWIND {cypher_list_from_scoped_keys(tasks)} AS task_sk - MATCH (th:TaskHub {{_scoped_key: "{taskhub}"}})-[:PERFORMS]->(an:AlchemicalNetwork) + UNWIND $query_safe_task_list AS task_sk + MATCH (th:TaskHub {_scoped_key: $taskhub_scoped_key})-[:PERFORMS]->(an:AlchemicalNetwork) // get the task we want to add to the hub; check that it connects to same network - MATCH (task:Task {{_scoped_key: task_sk}})-[:PERFORMS]->(tf:Transformation|NonTransformation)<-[:DEPENDS_ON]-(an) + MATCH (task:Task {_scoped_key: task_sk})-[:PERFORMS]->(:Transformation|NonTransformation)<-[:DEPENDS_ON]-(an) // only proceed for cases where task is not already actioned on hub // and where the task is either in 'waiting', 'running', or 'error' status WITH th, an, task WHERE NOT (th)-[:ACTIONS]->(task) - AND task.status IN ['{TaskStatusEnum.waiting.value}', '{TaskStatusEnum.running.value}', '{TaskStatusEnum.error.value}'] + AND task.status IN [$waiting, $running, $error] // create the connection - CREATE (th)-[ar:ACTIONS {{weight: 0.5}}]->(task) + CREATE (th)-[ar:ACTIONS {weight: 0.5}]->(task) // set the task property to the scoped key of the Task // this is a convenience for when we have to loop over relationships in Python SET ar.task = task._scoped_key + // we want to preserve the list of tasks for the return, so we need to make a subquery + // since the subsequent WHERE clause could reduce the records in task + WITH task, th + CALL { + WITH task, th + MATCH (trp: TaskRestartPattern)-[:ENFORCES]->(th) + WHERE NOT (trp)-[:APPLIES]->(task) + + CREATE (trp)-[:APPLIES {num_retries: 0, `_campaign`: $campaign, `_org`: $org, `_project`: $project}]->(task) + } + RETURN task """ - results = self.execute_query(q) + + results = self.execute_query( + q, + query_safe_task_list=query_safe_task_list, + waiting=TaskStatusEnum.waiting.value, + running=TaskStatusEnum.running.value, + error=TaskStatusEnum.error.value, + taskhub_scoped_key=str(taskhub), + campaign=taskhub.campaign, + org=taskhub.org, + project=taskhub.project, + ) # update our map with the results, leaving None for tasks that aren't found for task_record in results.records: @@ -1587,6 +1611,7 @@ def cancel_tasks( // get our task hub, as well as the task :ACTIONS relationship we want to remove MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[ar:ACTIONS]->(task:Task {{_scoped_key: '{t}'}}) DELETE ar + RETURN task """ _task = tx.run(q).to_eager_result() diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index fa8f7e0d..f79a1c4a 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -1096,6 +1096,65 @@ def test_action_task(self, n4js: Neo4jStore, network_tyk2, scope_test): task_sks_fail = n4js.action_tasks(task_sks, taskhub_sk2) assert all([i is None for i in task_sks_fail]) + # test for APPLIES relationship between an ACTIONED task and a TaskRestartPattern + + ## create a restart pattern, should already create APPLIES relationships with those + ## already actioned + n4js.add_task_restart_patterns(taskhub_sk, ["test_pattern"], 5) + + query = """ + MATCH (:TaskRestartPattern)-[applies:APPLIES]->(Task)<-[:ACTIONS]-(:TaskHub {`_scoped_key`: $taskhub_scoped_key}) + // change this so that later tests can show the value was not overwritten + SET applies.num_retries = 1 + RETURN count(applies) AS applies_count + """ + + ## sanity check that this number makes sense + applies_count = n4js.execute_query( + query, taskhub_scoped_key=str(taskhub_sk) + ).records[0]["applies_count"] + + assert applies_count == 10 + + # create 10 more tasks and action them + task_sks = n4js.create_tasks([transformation_sk] * 10) + n4js.action_tasks(task_sks, taskhub_sk) + + assert len(n4js.get_taskhub_actioned_tasks([taskhub_sk])[0]) == 20 + + # same as above query without the set num_retries = 1 + query = """ + MATCH (:TaskRestartPattern)-[applies:APPLIES]->(:Task)<-[:ACTIONS]-(:TaskHub {`_scoped_key`: $taskhub_scoped_key}) + RETURN count(applies) AS applies_count + """ + + applies_count = n4js.execute_query( + query, taskhub_scoped_key=str(taskhub_sk) + ).records[0]["applies_count"] + + query = """ + MATCH (:TaskRestartPattern)-[applies:APPLIES]->(:Task) + RETURN applies + """ + + results = n4js.execute_query(query) + + count_0, count_1 = 0, 0 + for count in map( + lambda record: record["applies"]["num_retries"], results.records + ): + match count: + case 0: + count_0 += 1 + case 1: + count_1 += 1 + case _: + raise AssertionError( + "Unexpected count value found in num_retries field" + ) + + assert count_0 == count_1 == 10 + def test_action_task_other_statuses( self, n4js: Neo4jStore, network_tyk2, scope_test ): From d3f25f885cc1e13c5ff0f16328ba5cdb7128450e Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Fri, 26 Jul 2024 13:52:21 -0700 Subject: [PATCH 10/12] Canceling a Task removes the APPLIES relationship When an actioned Task is canceled and also has an APPLIES relationship with a TaskRestartPattern, APPLIES is removed between the two nodes. Removed org, project, and campaign fields since they are not necessary for the APPLIES relationship. --- alchemiscale/storage/statestore.py | 25 +++++++------- .../integration/storage/test_statestore.py | 33 +++++++++++++++++++ 2 files changed, 47 insertions(+), 11 deletions(-) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 3689b8fd..3f9a11d0 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -1437,7 +1437,7 @@ def action_tasks( MATCH (trp: TaskRestartPattern)-[:ENFORCES]->(th) WHERE NOT (trp)-[:APPLIES]->(task) - CREATE (trp)-[:APPLIES {num_retries: 0, `_campaign`: $campaign, `_org`: $org, `_project`: $project}]->(task) + CREATE (trp)-[:APPLIES {num_retries: 0}]->(task) } RETURN task @@ -1450,9 +1450,6 @@ def action_tasks( running=TaskStatusEnum.running.value, error=TaskStatusEnum.error.value, taskhub_scoped_key=str(taskhub), - campaign=taskhub.campaign, - org=taskhub.org, - project=taskhub.project, ) # update our map with the results, leaving None for tasks that aren't found @@ -1606,15 +1603,24 @@ def cancel_tasks( """ canceled_sks = [] with self.transaction() as tx: - for t in tasks: - q = f""" + for task in tasks: + query = """ // get our task hub, as well as the task :ACTIONS relationship we want to remove - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[ar:ACTIONS]->(task:Task {{_scoped_key: '{t}'}}) + MATCH (th:TaskHub {_scoped_key: $taskhub_scoped_key})-[ar:ACTIONS]->(task:Task {_scoped_key: $task_scoped_key}) DELETE ar + WITH task + CALL { + WITH task + MATCH (task)<-[applies:APPLIES]-(:TaskRestartPattern) + DELETE applies + } + RETURN task """ - _task = tx.run(q).to_eager_result() + _task = tx.run( + query, taskhub_scoped_key=str(taskhub), task_scoped_key=str(task) + ).to_eager_result() if _task.records: sk = _task.records[0].data()["task"]["_scoped_key"] @@ -2806,9 +2812,6 @@ def add_task_restart_patterns( task_restart_pattern_node, actioned_task_node, num_retries=0, - _org=scope.org, - _campaign=scope.campaign, - _project=scope.project, ) merge_subgraph(tx, subgraph, "GufeTokenizable", "_scoped_key") diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index f79a1c4a..2f5acf03 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -1270,6 +1270,14 @@ def test_cancel_task(self, n4js, network_tyk2, scope_test): # cancel the second and third task we created canceled = n4js.cancel_tasks(task_sks[1:3], taskhub_sk) + # cancel a fake task + fake_canceled = n4js.cancel_tasks( + [ScopedKey.from_str("Task-FAKE-test_org-test_campaign-test_project")], + taskhub_sk, + ) + + assert fake_canceled[0] is None + # check that the hub has the contents we expect q = f"""MATCH (tq:TaskHub {{_scoped_key: '{taskhub_sk}'}})-[:ACTIONS]->(task:Task) return task @@ -1283,6 +1291,31 @@ def test_cancel_task(self, n4js, network_tyk2, scope_test): actioned ) - set(canceled) + # create a TaskRestartPattern + n4js.add_task_restart_patterns(taskhub_sk, ["Test pattern"], 1) + + query = """ + MATCH (:TaskHub {`_scoped_key`: $taskhub_scoped_key})<-[:ENFORCES]-(:TaskRestartPattern)-[applies:APPLIES]->(:Task) + RETURN count(applies) AS applies_count + """ + + assert ( + n4js.execute_query(query, taskhub_scoped_key=str(taskhub_sk)).records[0][ + "applies_count" + ] + == 8 + ) + + # cancel the fourth and fifth task we created + canceled = n4js.cancel_tasks(task_sks[3:5], taskhub_sk) + + assert ( + n4js.execute_query(query, taskhub_scoped_key=str(taskhub_sk)).records[0][ + "applies_count" + ] + == 6 + ) + def test_get_taskhub_tasks(self, n4js, network_tyk2, scope_test): an = network_tyk2 network_sk, taskhub_sk, _ = n4js.assemble_network(an, scope_test) From 510ae664d243c0e736412a557222e089e7e231ae Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Thu, 1 Aug 2024 09:13:29 -0700 Subject: [PATCH 11/12] Task status changes affect APPLIES relationship Setting an actioned Task status to the following statuses now removes the APPLIES relationship from attached TaskRestartPatterns: * complete * invalid * deleted NOTE: tests have not been added for this yet --- alchemiscale/storage/statestore.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 3f9a11d0..07d05d02 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -2584,7 +2584,9 @@ def set_task_complete( // if we changed the status to complete, // drop all ACTIONS relationships OPTIONAL MATCH (t_)<-[ar:ACTIONS]-(th:TaskHub) + OPTIONAL MATCH (t_)<-[applies:APPLIES]-(:TaskRestartPattern) DELETE ar + DELETE applies WITH scoped_key, t, t_ @@ -2667,9 +2669,11 @@ def set_task_invalid( OPTIONAL MATCH (t_)<-[ar:ACTIONS]-(th:TaskHub) OPTIONAL MATCH (extends_task)<-[are:ACTIONS]-(th:TaskHub) + OPTIONAL MATCH (t_)<-[applies:APPLIES]-(:TaskRestartPattern) DELETE ar DELETE are + DELETE applies WITH scoped_key, t, t_ @@ -2717,9 +2721,11 @@ def set_task_deleted( OPTIONAL MATCH (t_)<-[ar:ACTIONS]-(th:TaskHub) OPTIONAL MATCH (extends_task)<-[are:ACTIONS]-(th:TaskHub) + OPTIONAL MATCH (t_)<-[applies:APPLIES]-(:TaskRestartPattern) DELETE ar DELETE are + DELETE applies WITH scoped_key, t, t_ From 2310fd575ce560d34408affecdfff3afd3519322 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Sun, 4 Aug 2024 14:06:43 -0700 Subject: [PATCH 12/12] Tests for Task status change on APPLIES Confirming that changing the status of an actioned Task to any of the following removes the APPLIES relationship: * complete * invalid * deleted --- .../integration/storage/test_statestore.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index 2f5acf03..c7901840 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -1948,6 +1948,56 @@ def test_get_task_failures( class TestTaskRestartPolicy: + @pytest.mark.parametrize("status", ("complete", "invalid", "deleted")) + def test_task_status_change(self, n4js, network_tyk2, scope_test, status): + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + f"_test_task_status_change" + ) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + transformation = list(an.edges)[0] + transformation_scoped_key = n4js.get_scoped_key(transformation, scope_test) + task_scoped_keys = n4js.create_tasks([transformation_scoped_key]) + n4js.action_tasks(task_scoped_keys, taskhub_scoped_key) + + n4js.add_task_restart_patterns(taskhub_scoped_key, ["Test pattern"], 10) + + query = """ + MATCH (:TaskRestartPattern)-[:APPLIES]->(task:Task {`_scoped_key`: $task_scoped_key})<-[:ACTIONS]-(:TaskHub {`_scoped_key`: $taskhub_scoped_key}) + RETURN task + """ + + results = n4js.execute_query( + query, + task_scoped_key=str(task_scoped_keys[0]), + taskhub_scoped_key=str(taskhub_scoped_key), + ) + + assert len(results.records) == 1 + + target_method = { + "complete": n4js.set_task_complete, + "invalid": n4js.set_task_invalid, + "deleted": n4js.set_task_deleted, + } + + if status == "complete": + n4js.set_task_running(task_scoped_keys) + + assert target_method[status](task_scoped_keys)[0] is not None + + query = """ + MATCH (:TaskRestartPattern)-[:APPLIES]->(task:Task) + RETURN task + """ + + results = n4js.execute_query( + query, + task_scoped_key=str(task_scoped_keys[0]), + taskhub_scoped_key=str(taskhub_scoped_key), + ) + + assert len(results.records) == 0 + def test_add_task_restart_patterns(self, n4js, network_tyk2, scope_test): # create three new alchemical networks (and taskhubs) taskhub_sks = []