diff --git a/alchemiscale/storage/models.py b/alchemiscale/storage/models.py index c9b000b8..3dc69e0d 100644 --- a/alchemiscale/storage/models.py +++ b/alchemiscale/storage/models.py @@ -8,12 +8,12 @@ from copy import copy from datetime import datetime from enum import Enum -from typing import Union, Dict, Optional +from typing import Union, Optional, List from uuid import uuid4 import hashlib -from pydantic import BaseModel, Field +from pydantic import BaseModel from gufe.tokenization import GufeTokenizable, GufeKey from ..models import ScopedKey, Scope @@ -143,6 +143,96 @@ def _defaults(cls): return super()._defaults() +# TODO: fill in docstrings +class TaskRestartPattern(GufeTokenizable): + """A pattern to compare returned Task tracebacks to. + + Attributes + ---------- + pattern: str + A regular expression pattern that can match to returned tracebacks of errored Tasks. + max_retries: int + The number of times the pattern can trigger a restart for a Task. + taskhub_sk: str + The TaskHub the pattern is bound to. This is needed to properly set a unique Gufe key. + """ + + pattern: str + max_retries: int + taskhub_sk: str + + def __init__( + self, pattern: str, max_retries: int, taskhub_scoped_key: Union[str, ScopedKey] + ): + + if not isinstance(pattern, str) or pattern == "": + raise ValueError("`pattern` must be a non-empty string") + + self.pattern = pattern + + if not isinstance(max_retries, int) or max_retries <= 0: + raise ValueError("`max_retries` must have a positive integer value.") + self.max_retries = max_retries + + self.taskhub_scoped_key = str(taskhub_scoped_key) + + def _gufe_tokenize(self): + key_string = self.pattern + self.taskhub_scoped_key + return hashlib.md5(key_string.encode()).hexdigest() + + @classmethod + def _defaults(cls): + raise NotImplementedError + + @classmethod + def _from_dict(cls, dct): + return cls(**dct) + + def _to_dict(self): + return { + "pattern": self.pattern, + "max_retries": self.max_retries, + "taskhub_scoped_key": self.taskhub_scoped_key, + } + + # TODO: should this also compare taskhub scoped keys? + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + return self.pattern == other.pattern + + +class Traceback(GufeTokenizable): + + def __init__(self, tracebacks: List[str]): + value_error = ValueError( + "`tracebacks` must be a non-empty list of string values" + ) + if not isinstance(tracebacks, list) or tracebacks == []: + raise value_error + else: + # in the case where tracebacks is not an iterable, this will raise a TypeError + all_string_values = all([isinstance(value, str) for value in tracebacks]) + if not all_string_values or "" in tracebacks: + raise value_error + + self.tracebacks = tracebacks + + def _gufe_tokenize(self): + return hashlib.md5(str(self.tracebacks).encode()).hexdigest() + + @classmethod + def _defaults(cls): + raise NotImplementedError + + @classmethod + def _from_dict(cls, dct): + return Traceback(**dct) + + def _to_dict(self): + return {"tracebacks": self.tracebacks} + + class TaskHub(GufeTokenizable): """ diff --git a/alchemiscale/storage/statestore.py b/alchemiscale/storage/statestore.py index 1ffd4f4a..07d05d02 100644 --- a/alchemiscale/storage/statestore.py +++ b/alchemiscale/storage/statestore.py @@ -9,7 +9,7 @@ from contextlib import contextmanager import json from functools import lru_cache -from typing import Dict, List, Optional, Union, Tuple +from typing import Dict, List, Optional, Union, Tuple, Set import weakref import numpy as np @@ -24,10 +24,11 @@ ComputeServiceRegistration, NetworkMark, NetworkStateEnum, + ProtocolDAGResultRef, Task, TaskHub, + TaskRestartPattern, TaskStatusEnum, - ProtocolDAGResultRef, ) from ..strategies import Strategy from ..models import Scope, ScopedKey @@ -1405,30 +1406,51 @@ def action_tasks( # so we can properly return `None` if needed task_map = {str(task): None for task in tasks} - q = f""" + query_safe_task_list = [str(task) for task in tasks if task] + + q = """ // get our TaskHub - UNWIND {cypher_list_from_scoped_keys(tasks)} AS task_sk - MATCH (th:TaskHub {{_scoped_key: "{taskhub}"}})-[:PERFORMS]->(an:AlchemicalNetwork) + UNWIND $query_safe_task_list AS task_sk + MATCH (th:TaskHub {_scoped_key: $taskhub_scoped_key})-[:PERFORMS]->(an:AlchemicalNetwork) // get the task we want to add to the hub; check that it connects to same network - MATCH (task:Task {{_scoped_key: task_sk}})-[:PERFORMS]->(tf:Transformation|NonTransformation)<-[:DEPENDS_ON]-(an) + MATCH (task:Task {_scoped_key: task_sk})-[:PERFORMS]->(:Transformation|NonTransformation)<-[:DEPENDS_ON]-(an) // only proceed for cases where task is not already actioned on hub // and where the task is either in 'waiting', 'running', or 'error' status WITH th, an, task WHERE NOT (th)-[:ACTIONS]->(task) - AND task.status IN ['{TaskStatusEnum.waiting.value}', '{TaskStatusEnum.running.value}', '{TaskStatusEnum.error.value}'] + AND task.status IN [$waiting, $running, $error] // create the connection - CREATE (th)-[ar:ACTIONS {{weight: 0.5}}]->(task) + CREATE (th)-[ar:ACTIONS {weight: 0.5}]->(task) // set the task property to the scoped key of the Task // this is a convenience for when we have to loop over relationships in Python SET ar.task = task._scoped_key + // we want to preserve the list of tasks for the return, so we need to make a subquery + // since the subsequent WHERE clause could reduce the records in task + WITH task, th + CALL { + WITH task, th + MATCH (trp: TaskRestartPattern)-[:ENFORCES]->(th) + WHERE NOT (trp)-[:APPLIES]->(task) + + CREATE (trp)-[:APPLIES {num_retries: 0}]->(task) + } + RETURN task """ - results = self.execute_query(q) + + results = self.execute_query( + q, + query_safe_task_list=query_safe_task_list, + waiting=TaskStatusEnum.waiting.value, + running=TaskStatusEnum.running.value, + error=TaskStatusEnum.error.value, + taskhub_scoped_key=str(taskhub), + ) # update our map with the results, leaving None for tasks that aren't found for task_record in results.records: @@ -1581,14 +1603,24 @@ def cancel_tasks( """ canceled_sks = [] with self.transaction() as tx: - for t in tasks: - q = f""" + for task in tasks: + query = """ // get our task hub, as well as the task :ACTIONS relationship we want to remove - MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[ar:ACTIONS]->(task:Task {{_scoped_key: '{t}'}}) + MATCH (th:TaskHub {_scoped_key: $taskhub_scoped_key})-[ar:ACTIONS]->(task:Task {_scoped_key: $task_scoped_key}) DELETE ar + + WITH task + CALL { + WITH task + MATCH (task)<-[applies:APPLIES]-(:TaskRestartPattern) + DELETE applies + } + RETURN task """ - _task = tx.run(q).to_eager_result() + _task = tx.run( + query, taskhub_scoped_key=str(taskhub), task_scoped_key=str(task) + ).to_eager_result() if _task.records: sk = _task.records[0].data()["task"]["_scoped_key"] @@ -2552,7 +2584,9 @@ def set_task_complete( // if we changed the status to complete, // drop all ACTIONS relationships OPTIONAL MATCH (t_)<-[ar:ACTIONS]-(th:TaskHub) + OPTIONAL MATCH (t_)<-[applies:APPLIES]-(:TaskRestartPattern) DELETE ar + DELETE applies WITH scoped_key, t, t_ @@ -2635,9 +2669,11 @@ def set_task_invalid( OPTIONAL MATCH (t_)<-[ar:ACTIONS]-(th:TaskHub) OPTIONAL MATCH (extends_task)<-[are:ACTIONS]-(th:TaskHub) + OPTIONAL MATCH (t_)<-[applies:APPLIES]-(:TaskRestartPattern) DELETE ar DELETE are + DELETE applies WITH scoped_key, t, t_ @@ -2685,9 +2721,11 @@ def set_task_deleted( OPTIONAL MATCH (t_)<-[ar:ACTIONS]-(th:TaskHub) OPTIONAL MATCH (extends_task)<-[are:ACTIONS]-(th:TaskHub) + OPTIONAL MATCH (t_)<-[applies:APPLIES]-(:TaskRestartPattern) DELETE ar DELETE are + DELETE applies WITH scoped_key, t, t_ @@ -2703,6 +2741,143 @@ def err_msg(t, status): return self._set_task_status(tasks, q, err_msg, raise_error=raise_error) + ## task restart policy + + # TODO: fill in docstring + def add_task_restart_patterns( + self, taskhub: ScopedKey, patterns: List[str], number_of_retries: int + ): + """Add a list of restart policy patterns to a `TaskHub` along with the number of retries allowed. + + Parameters + ---------- + + + Raises + ------ + """ + + # get taskhub node + q = """ + MATCH (th:TaskHub {`_scoped_key`: $taskhub}) + RETURN th + """ + results = self.execute_query(q, taskhub=str(taskhub)) + ## raise error if taskhub not found + + if not results.records: + raise KeyError("No such TaskHub in the database") + + record_data = results.records[0]["th"] + taskhub_node = record_data_to_node(record_data) + scope = taskhub.scope + + with self.transaction() as tx: + actioned_tasks_query = """ + MATCH (taskhub: TaskHub {`_scoped_key`: $taskhub_scoped_key})-[:ACTIONS]->(task: Task) + RETURN task + """ + + subgraph = Subgraph() + + actioned_task_nodes = [] + + for actioned_tasks_record in ( + tx.run(actioned_tasks_query, taskhub_scoped_key=str(taskhub)) + .to_eager_result() + .records + ): + actioned_task_nodes.append( + record_data_to_node(actioned_tasks_record["task"]) + ) + + for pattern in patterns: + task_restart_pattern = TaskRestartPattern( + pattern, + max_retries=number_of_retries, + taskhub_scoped_key=str(taskhub), + ) + + _, task_restart_pattern_node, scoped_key = self._gufe_to_subgraph( + task_restart_pattern.to_shallow_dict(), + labels=["GufeTokenizable", task_restart_pattern.__class__.__name__], + gufe_key=task_restart_pattern.key, + scope=scope, + ) + + subgraph |= Relationship.type("ENFORCES")( + task_restart_pattern_node, + taskhub_node, + _org=scope.org, + _campaign=scope.campaign, + _project=scope.project, + ) + + for actioned_task_node in actioned_task_nodes: + subgraph |= Relationship.type("APPLIES")( + task_restart_pattern_node, + actioned_task_node, + num_retries=0, + ) + merge_subgraph(tx, subgraph, "GufeTokenizable", "_scoped_key") + + # TODO: fill in docstring + def remove_task_restart_patterns(self, taskhub: ScopedKey, patterns: List[str]): + q = """ + UNWIND $patterns AS pattern + + MATCH (trp: TaskRestartPattern {pattern: pattern, taskhub_scoped_key: $taskhub_scoped_key}) + + DETACH DELETE trp + """ + + self.execute_query(q, patterns=patterns, taskhub_scoped_key=str(taskhub)) + + # TODO: fill in docstring + def set_task_restart_patterns_max_retries( + self, + taskhub_scoped_key: Union[ScopedKey, str], + patterns: List[str], + max_retries: int, + ): + query = """ + UNWIND $patterns AS pattern + MATCH (trp: TaskRestartPattern {pattern: pattern, taskhub_scoped_key: $taskhub_scoped_key}) + SET trp.max_retries = $max_retries + """ + + self.execute_query( + query, + patterns=patterns, + taskhub_scoped_key=str(taskhub_scoped_key), + max_retries=max_retries, + ) + + # TODO: fill in docstring + def get_task_restart_patterns( + self, taskhubs: List[ScopedKey] + ) -> Dict[ScopedKey, Set[Tuple[str, int]]]: + + q = """ + UNWIND $taskhub_scoped_keys as taskhub_scoped_key + MATCH (trp: TaskRestartPattern)-[ENFORCES]->(th: TaskHub {`_scoped_key`: taskhub_scoped_key}) + RETURN th, trp + """ + + records = self.execute_query( + q, taskhub_scoped_keys=list(map(str, taskhubs)) + ).records + + data = {taskhub: set() for taskhub in taskhubs} + + for record in records: + pattern = record["trp"]["pattern"] + max_retries = record["trp"]["max_retries"] + taskhub_sk = ScopedKey.from_str(record["th"]["_scoped_key"]) + data[taskhub_sk].add((pattern, max_retries)) + + return data + ## authentication def create_credentialed_entity(self, entity: CredentialedEntity): diff --git a/alchemiscale/tests/integration/interface/client/test_client.py b/alchemiscale/tests/integration/interface/client/test_client.py index ceb968f4..c39ce4f8 100644 --- a/alchemiscale/tests/integration/interface/client/test_client.py +++ b/alchemiscale/tests/integration/interface/client/test_client.py @@ -2123,3 +2123,31 @@ def test_get_task_failures( # TODO: can we mix in a success in here somewhere? # not possible with current BrokenProtocol, unfortunately + + # TaskRestartPolicy client methods + + @pytest.mark.xfail(raises=NotImplementedError) + def test_add_task_restart_policy_patterns(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_get_task_restart_policy_patterns(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_remove_task_restart_policy_patterns(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_clear_task_restart_policy_patterns(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_task_resolve_restarts( + self, + scope_test, + n4js_preloaded, + user_client: client.AlchemiscaleClient, + network_tyk2_failure, + ): + raise NotImplementedError diff --git a/alchemiscale/tests/integration/storage/test_statestore.py b/alchemiscale/tests/integration/storage/test_statestore.py index 2632524b..c7901840 100644 --- a/alchemiscale/tests/integration/storage/test_statestore.py +++ b/alchemiscale/tests/integration/storage/test_statestore.py @@ -3,6 +3,7 @@ from typing import List, Dict from pathlib import Path from itertools import chain +from collections import defaultdict import pytest from gufe import AlchemicalNetwork @@ -1095,6 +1096,65 @@ def test_action_task(self, n4js: Neo4jStore, network_tyk2, scope_test): task_sks_fail = n4js.action_tasks(task_sks, taskhub_sk2) assert all([i is None for i in task_sks_fail]) + # test for APPLIES relationship between an ACTIONED task and a TaskRestartPattern + + ## create a restart pattern, should already create APPLIES relationships with those + ## already actioned + n4js.add_task_restart_patterns(taskhub_sk, ["test_pattern"], 5) + + query = """ + MATCH (:TaskRestartPattern)-[applies:APPLIES]->(Task)<-[:ACTIONS]-(:TaskHub {`_scoped_key`: $taskhub_scoped_key}) + // change this so that later tests can show the value was not overwritten + SET applies.num_retries = 1 + RETURN count(applies) AS applies_count + """ + + ## sanity check that this number makes sense + applies_count = n4js.execute_query( + query, taskhub_scoped_key=str(taskhub_sk) + ).records[0]["applies_count"] + + assert applies_count == 10 + + # create 10 more tasks and action them + task_sks = n4js.create_tasks([transformation_sk] * 10) + n4js.action_tasks(task_sks, taskhub_sk) + + assert len(n4js.get_taskhub_actioned_tasks([taskhub_sk])[0]) == 20 + + # same as above query without the set num_retries = 1 + query = """ + MATCH (:TaskRestartPattern)-[applies:APPLIES]->(:Task)<-[:ACTIONS]-(:TaskHub {`_scoped_key`: $taskhub_scoped_key}) + RETURN count(applies) AS applies_count + """ + + applies_count = n4js.execute_query( + query, taskhub_scoped_key=str(taskhub_sk) + ).records[0]["applies_count"] + + query = """ + MATCH (:TaskRestartPattern)-[applies:APPLIES]->(:Task) + RETURN applies + """ + + results = n4js.execute_query(query) + + count_0, count_1 = 0, 0 + for count in map( + lambda record: record["applies"]["num_retries"], results.records + ): + match count: + case 0: + count_0 += 1 + case 1: + count_1 += 1 + case _: + raise AssertionError( + "Unexpected count value found in num_retries field" + ) + + assert count_0 == count_1 == 10 + def test_action_task_other_statuses( self, n4js: Neo4jStore, network_tyk2, scope_test ): @@ -1210,6 +1270,14 @@ def test_cancel_task(self, n4js, network_tyk2, scope_test): # cancel the second and third task we created canceled = n4js.cancel_tasks(task_sks[1:3], taskhub_sk) + # cancel a fake task + fake_canceled = n4js.cancel_tasks( + [ScopedKey.from_str("Task-FAKE-test_org-test_campaign-test_project")], + taskhub_sk, + ) + + assert fake_canceled[0] is None + # check that the hub has the contents we expect q = f"""MATCH (tq:TaskHub {{_scoped_key: '{taskhub_sk}'}})-[:ACTIONS]->(task:Task) return task @@ -1223,6 +1291,31 @@ def test_cancel_task(self, n4js, network_tyk2, scope_test): actioned ) - set(canceled) + # create a TaskRestartPattern + n4js.add_task_restart_patterns(taskhub_sk, ["Test pattern"], 1) + + query = """ + MATCH (:TaskHub {`_scoped_key`: $taskhub_scoped_key})<-[:ENFORCES]-(:TaskRestartPattern)-[applies:APPLIES]->(:Task) + RETURN count(applies) AS applies_count + """ + + assert ( + n4js.execute_query(query, taskhub_scoped_key=str(taskhub_sk)).records[0][ + "applies_count" + ] + == 8 + ) + + # cancel the fourth and fifth task we created + canceled = n4js.cancel_tasks(task_sks[3:5], taskhub_sk) + + assert ( + n4js.execute_query(query, taskhub_scoped_key=str(taskhub_sk)).records[0][ + "applies_count" + ] + == 6 + ) + def test_get_taskhub_tasks(self, n4js, network_tyk2, scope_test): an = network_tyk2 network_sk, taskhub_sk, _ = n4js.assemble_network(an, scope_test) @@ -1851,6 +1944,295 @@ def test_get_task_failures( assert pdr_ref_sk in failure_pdr_ref_sks assert pdr_ref2_sk in failure_pdr_ref_sks + ### task restart policies + + class TestTaskRestartPolicy: + + @pytest.mark.parametrize("status", ("complete", "invalid", "deleted")) + def test_task_status_change(self, n4js, network_tyk2, scope_test, status): + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + f"_test_task_status_change" + ) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + transformation = list(an.edges)[0] + transformation_scoped_key = n4js.get_scoped_key(transformation, scope_test) + task_scoped_keys = n4js.create_tasks([transformation_scoped_key]) + n4js.action_tasks(task_scoped_keys, taskhub_scoped_key) + + n4js.add_task_restart_patterns(taskhub_scoped_key, ["Test pattern"], 10) + + query = """ + MATCH (:TaskRestartPattern)-[:APPLIES]->(task:Task {`_scoped_key`: $task_scoped_key})<-[:ACTIONS]-(:TaskHub {`_scoped_key`: $taskhub_scoped_key}) + RETURN task + """ + + results = n4js.execute_query( + query, + task_scoped_key=str(task_scoped_keys[0]), + taskhub_scoped_key=str(taskhub_scoped_key), + ) + + assert len(results.records) == 1 + + target_method = { + "complete": n4js.set_task_complete, + "invalid": n4js.set_task_invalid, + "deleted": n4js.set_task_deleted, + } + + if status == "complete": + n4js.set_task_running(task_scoped_keys) + + assert target_method[status](task_scoped_keys)[0] is not None + + query = """ + MATCH (:TaskRestartPattern)-[:APPLIES]->(task:Task) + RETURN task + """ + + results = n4js.execute_query( + query, + task_scoped_key=str(task_scoped_keys[0]), + taskhub_scoped_key=str(taskhub_scoped_key), + ) + + assert len(results.records) == 0 + + def test_add_task_restart_patterns(self, n4js, network_tyk2, scope_test): + # create three new alchemical networks (and taskhubs) + taskhub_sks = [] + for network_index in range(3): + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + + f"_test_add_task_restart_patterns_{network_index}" + ) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + + # don't action tasks on every network, take every other + if network_index % 2 == 0: + transformation = list(an.edges)[0] + transformation_sk = n4js.get_scoped_key(transformation, scope_test) + task_sks = n4js.create_tasks([transformation_sk] * 3) + n4js.action_tasks(task_sks, taskhub_scoped_key) + + taskhub_sks.append(taskhub_scoped_key) + # test a shared pattern with and without shared number of restarts + # this will create 6 unique patterns + for network_index in range(3): + taskhub_scoped_key = taskhub_sks[network_index] + n4js.add_task_restart_patterns( + taskhub_scoped_key, ["shared_pattern_and_restarts.+"], 5 + ) + n4js.add_task_restart_patterns( + taskhub_scoped_key, + ["shared_pattern_and_different_restarts.+"], + network_index + 1, + ) + + q = """UNWIND $taskhub_sks AS taskhub_sk + MATCH (trp: TaskRestartPattern)-[:ENFORCES]->(th: TaskHub {`_scoped_key`: taskhub_sk}) RETURN trp, th + """ + + taskhub_sks = list(map(str, taskhub_sks)) + records = n4js.execute_query(q, taskhub_sks=taskhub_sks).records + + assert len(records) == 6 + + taskhub_scoped_key_set = set() + taskrestartpattern_scoped_key_set = set() + + for record in records: + taskhub_scoped_key = ScopedKey.from_str(record["th"]["_scoped_key"]) + taskrestartpattern_scoped_key = ScopedKey.from_str( + record["trp"]["_scoped_key"] + ) + + taskhub_scoped_key_set.add(taskhub_scoped_key) + taskrestartpattern_scoped_key_set.add(taskrestartpattern_scoped_key) + + assert len(taskhub_scoped_key_set) == 3 + assert len(taskrestartpattern_scoped_key_set) == 6 + + # check that the applies relationships were correctly added + + ## first check that the number of applies relationships is correct and + ## that the number of retries is zero + applies_query = """ + MATCH (trp: TaskRestartPattern)-[app:APPLIES {num_retries: 0}]->(task: Task)<-[:ACTIONS]-(th: TaskHub) + RETURN th, count(app) AS num_applied + """ + + records = n4js.execute_query(applies_query).records + + ### one record per taskhub, each with six num_applied + assert len(records) == 2 + assert records[0]["num_applied"] == records[1]["num_applied"] == 6 + + applies_nonzero_retries = """ + MATCH (trp: TaskRestartPattern)-[app:APPLIES]->(task: Task)<-[:ACTIONS]-(th: TaskHub) + WHERE app.num_retries <> 0 + RETURN th, count(app) AS num_applied + """ + assert len(n4js.execute_query(applies_nonzero_retries).records) == 0 + + def test_remove_task_restart_patterns(self, n4js, network_tyk2, scope_test): + + # collect what we expect `get_task_restart_patterns` to return + expected_results = defaultdict(set) + + # create three new alchemical networks (and taskhubs) + taskhub_sks = [] + for network_index in range(3): + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + + f"_test_remove_task_restart_patterns_{network_index}" + ) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + taskhub_sks.append(taskhub_scoped_key) + + # test a shared pattern with and without shared number of restarts + # this will create 6 unique patterns + for network_index in range(3): + taskhub_scoped_key = taskhub_sks[network_index] + n4js.add_task_restart_patterns( + taskhub_scoped_key, ["shared_pattern_and_restarts.+"], 5 + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_restarts.+", 5) + ) + + n4js.add_task_restart_patterns( + taskhub_scoped_key, + ["shared_pattern_and_different_restarts.+"], + network_index + 1, + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_different_restarts.+", network_index + 1) + ) + + # remove both patterns enforcing the first taskhub at the same time, two patterns + target_taskhub = taskhub_sks[0] + target_patterns = [] + + for pattern, _ in expected_results[target_taskhub]: + target_patterns.append(pattern) + + expected_results[target_taskhub].clear() + + n4js.remove_task_restart_patterns(target_taskhub, target_patterns) + assert expected_results == n4js.get_task_restart_patterns(taskhub_sks) + + # remove both patterns enforcing the second taskhub one at a time, two patterns + target_taskhub = taskhub_sks[1] + # pointer to underlying set, pops will update comparison data structure + target_patterns = expected_results[target_taskhub] + + pattern, _ = target_patterns.pop() + n4js.remove_task_restart_patterns(target_taskhub, [pattern]) + assert expected_results == n4js.get_task_restart_patterns(taskhub_sks) + + pattern, _ = target_patterns.pop() + n4js.remove_task_restart_patterns(target_taskhub, [pattern]) + assert expected_results == n4js.get_task_restart_patterns(taskhub_sks) + + def test_set_task_restart_patterns_max_retries( + self, n4js, network_tyk2, scope_test + ): + network_name = ( + network_tyk2.name + "_test_set_task_restart_patterns_max_retries" + ) + an = network_tyk2.copy_with_replacements(name=network_name) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + + pattern_data = [("pattern_1", 5), ("pattern_2", 5), ("pattern_3", 5)] + + n4js.add_task_restart_patterns( + taskhub_scoped_key, + patterns=[data[0] for data in pattern_data], + number_of_retries=5, + ) + + expected_results = {taskhub_scoped_key: set(pattern_data)} + + assert expected_results == n4js.get_task_restart_patterns( + [taskhub_scoped_key] + ) + + # reflect changing just one max_retry + new_pattern_1_tuple = ("pattern_1", 1) + + expected_results[taskhub_scoped_key].remove(pattern_data[0]) + expected_results[taskhub_scoped_key].add(new_pattern_1_tuple) + + n4js.set_task_restart_patterns_max_retries( + taskhub_scoped_key, new_pattern_1_tuple[0], new_pattern_1_tuple[1] + ) + + assert expected_results == n4js.get_task_restart_patterns( + [taskhub_scoped_key] + ) + + # reflect changing more than one at a time + new_pattern_2_tuple = ("pattern_2", 2) + new_pattern_3_tuple = ("pattern_3", 2) + + expected_results[taskhub_scoped_key].remove(pattern_data[1]) + expected_results[taskhub_scoped_key].add(new_pattern_2_tuple) + + expected_results[taskhub_scoped_key].remove(pattern_data[2]) + expected_results[taskhub_scoped_key].add(new_pattern_3_tuple) + + n4js.set_task_restart_patterns_max_retries( + taskhub_scoped_key, [new_pattern_2_tuple[0], new_pattern_3_tuple[0]], 2 + ) + + assert expected_results == n4js.get_task_restart_patterns( + [taskhub_scoped_key] + ) + + def test_get_task_restart_patterns(self, n4js, network_tyk2, scope_test): + # create three new alchemical networks (and taskhubs) + taskhub_sks = [] + for network_index in range(3): + an = network_tyk2.copy_with_replacements( + name=network_tyk2.name + + f"_test_add_task_restart_patterns_{network_index}" + ) + _, taskhub_scoped_key, _ = n4js.assemble_network(an, scope_test) + taskhub_sks.append(taskhub_scoped_key) + + expected_results = defaultdict(set) + # test a shared pattern with and without shared number of restarts + # this will create 6 unique patterns + for network_index in range(3): + taskhub_scoped_key = taskhub_sks[network_index] + n4js.add_task_restart_patterns( + taskhub_scoped_key, ["shared_pattern_and_restarts.+"], 5 + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_restarts.+", 5) + ) + n4js.add_task_restart_patterns( + taskhub_scoped_key, + ["shared_pattern_and_different_restarts.+"], + network_index + 1, + ) + expected_results[taskhub_scoped_key].add( + ("shared_pattern_and_different_restarts.+", network_index + 1) + ) + + taskhub_grouped_patterns = n4js.get_task_restart_patterns(taskhub_sks) + + assert taskhub_grouped_patterns == expected_results + + @pytest.mark.xfail(raises=NotImplementedError) + def test_task_actioning_applies_relationship(self): + raise NotImplementedError + + @pytest.mark.xfail(raises=NotImplementedError) + def test_task_deaction_applies_relationship(self): + raise NotImplementedError + ### authentication @pytest.mark.parametrize( diff --git a/alchemiscale/tests/unit/test_storage_models.py b/alchemiscale/tests/unit/test_storage_models.py index 36678b9a..55dc872f 100644 --- a/alchemiscale/tests/unit/test_storage_models.py +++ b/alchemiscale/tests/unit/test_storage_models.py @@ -1,6 +1,11 @@ import pytest -from alchemiscale.storage.models import NetworkStateEnum, NetworkMark +from alchemiscale.storage.models import ( + NetworkStateEnum, + NetworkMark, + TaskRestartPattern, + Traceback, +) from alchemiscale import ScopedKey @@ -38,3 +43,149 @@ def test_suggested_states_message(self): assert len(suggested_states) == len(NetworkStateEnum) for state in suggested_states: NetworkStateEnum(state) + + +class TestTaskRestartPattern(object): + + pattern_value_error = "`pattern` must be a non-empty string" + max_retries_value_error = "`max_retries` must have a positive integer value." + + def test_empty_pattern(self): + with pytest.raises(ValueError, match=self.pattern_value_error): + _ = TaskRestartPattern( + "", 3, "FakeScopedKey-1234-fake_org-fake_campaign-fake_project" + ) + + def test_non_string_pattern(self): + with pytest.raises(ValueError, match=self.pattern_value_error): + _ = TaskRestartPattern( + None, 3, "FakeScopedKey-1234-fake_org-fake_campaign-fake_project" + ) + + with pytest.raises(ValueError, match=self.pattern_value_error): + _ = TaskRestartPattern( + [], 3, "FakeScopedKey-1234-fake_org-fake_campaign-fake_project" + ) + + def test_non_positive_max_retries(self): + + with pytest.raises(ValueError, match=self.max_retries_value_error): + TaskRestartPattern( + "Example pattern", + 0, + "FakeScopedKey-1234-fake_org-fake_campaign-fake_project", + ) + + with pytest.raises(ValueError, match=self.max_retries_value_error): + TaskRestartPattern( + "Example pattern", + -1, + "FakeScopedKey-1234-fake_org-fake_campaign-fake_project", + ) + + def test_non_int_max_retries(self): + with pytest.raises(ValueError, match=self.max_retries_value_error): + TaskRestartPattern( + "Example pattern", + 4.0, + "FakeScopedKey-1234-fake_org-fake_campaign-fake_project", + ) + + def test_to_dict(self): + trp = TaskRestartPattern( + "Example pattern", + 3, + "FakeScopedKey-1234-fake_org-fake_campaign-fake_project", + ) + dict_trp = trp.to_dict() + + assert len(dict_trp.keys()) == 6 + + assert dict_trp.pop("__qualname__") == "TaskRestartPattern" + assert dict_trp.pop("__module__") == "alchemiscale.storage.models" + assert ( + dict_trp.pop("taskhub_scoped_key") + == "FakeScopedKey-1234-fake_org-fake_campaign-fake_project" + ) + + # light test of the version key + try: + dict_trp.pop(":version:") + except KeyError: + raise AssertionError("expected to find :version:") + + expected = {"pattern": "Example pattern", "max_retries": 3} + + assert expected == dict_trp + + def test_from_dict(self): + + original_pattern = "Example pattern" + original_max_retries = 3 + original_taskhub_scoped_key = ( + "FakeScopedKey-1234-fake_org-fake_campaign-fake_project" + ) + + trp_orig = TaskRestartPattern( + original_pattern, original_max_retries, original_taskhub_scoped_key + ) + trp_dict = trp_orig.to_dict() + trp_reconstructed: TaskRestartPattern = TaskRestartPattern.from_dict(trp_dict) + + assert trp_reconstructed.pattern == original_pattern + assert trp_reconstructed.max_retries == original_max_retries + assert trp_reconstructed.taskhub_scoped_key == original_taskhub_scoped_key + + +class TestTraceback(object): + + valid_entry = ["traceback1", "traceback2", "traceback3"] + tracebacks_value_error = "`tracebacks` must be a non-empty list of string values" + + def test_empty_string_element(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback(self.valid_entry + [""]) + + def test_non_list_parameter(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback(None) + + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback(100) + + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback("not a list, but still an iterable that yields strings") + + def test_list_non_string_elements(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback(self.valid_entry + [None]) + + def test_empty_list(self): + with pytest.raises(ValueError, match=self.tracebacks_value_error): + Traceback([]) + + def test_to_dict(self): + tb = Traceback(self.valid_entry) + tb_dict = tb.to_dict() + + assert len(tb_dict) == 4 + + assert tb_dict.pop("__qualname__") == "Traceback" + assert tb_dict.pop("__module__") == "alchemiscale.storage.models" + + # light test of the version key + try: + tb_dict.pop(":version:") + except KeyError: + raise AssertionError("expected to find :version:") + + expected = {"tracebacks": self.valid_entry} + + assert expected == tb_dict + + def test_from_dict(self): + tb_orig = Traceback(self.valid_entry) + tb_dict = tb_orig.to_dict() + tb_reconstructed: TaskRestartPattern = TaskRestartPattern.from_dict(tb_dict) + + assert tb_reconstructed.tracebacks == self.valid_entry