Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement task restart policies #280

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
94 changes: 92 additions & 2 deletions alchemiscale/storage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from copy import copy
from datetime import datetime
from enum import Enum
from typing import Union, Dict, Optional
from typing import Union, Optional, List
from uuid import uuid4
import hashlib


from pydantic import BaseModel, Field
from pydantic import BaseModel
from gufe.tokenization import GufeTokenizable, GufeKey

from ..models import ScopedKey, Scope
Expand Down Expand Up @@ -143,6 +143,96 @@ def _defaults(cls):
return super()._defaults()


# TODO: fill in docstrings
class TaskRestartPattern(GufeTokenizable):
"""A pattern to compare returned Task tracebacks to.

Attributes
----------
pattern: str
A regular expression pattern that can match to returned tracebacks of errored Tasks.
max_retries: int
The number of times the pattern can trigger a restart for a Task.
taskhub_sk: str
The TaskHub the pattern is bound to. This is needed to properly set a unique Gufe key.
"""

pattern: str
max_retries: int
taskhub_sk: str

def __init__(
self, pattern: str, max_retries: int, taskhub_scoped_key: Union[str, ScopedKey]
):

if not isinstance(pattern, str) or pattern == "":
raise ValueError("`pattern` must be a non-empty string")

self.pattern = pattern

if not isinstance(max_retries, int) or max_retries <= 0:
raise ValueError("`max_retries` must have a positive integer value.")
self.max_retries = max_retries

self.taskhub_scoped_key = str(taskhub_scoped_key)

def _gufe_tokenize(self):
key_string = self.pattern + self.taskhub_scoped_key
return hashlib.md5(key_string.encode()).hexdigest()

@classmethod
def _defaults(cls):
raise NotImplementedError

@classmethod
def _from_dict(cls, dct):
return cls(**dct)

def _to_dict(self):
return {
"pattern": self.pattern,
"max_retries": self.max_retries,
"taskhub_scoped_key": self.taskhub_scoped_key,
}

# TODO: should this also compare taskhub scoped keys?
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return self.pattern == other.pattern


class Traceback(GufeTokenizable):

def __init__(self, tracebacks: List[str]):
value_error = ValueError(
"`tracebacks` must be a non-empty list of string values"
)
if not isinstance(tracebacks, list) or tracebacks == []:
raise value_error
else:
# in the case where tracebacks is not an iterable, this will raise a TypeError
all_string_values = all([isinstance(value, str) for value in tracebacks])
if not all_string_values or "" in tracebacks:
raise value_error

self.tracebacks = tracebacks

def _gufe_tokenize(self):
return hashlib.md5(str(self.tracebacks).encode()).hexdigest()

@classmethod
def _defaults(cls):
raise NotImplementedError

@classmethod
def _from_dict(cls, dct):
return Traceback(**dct)

def _to_dict(self):
return {"tracebacks": self.tracebacks}


class TaskHub(GufeTokenizable):
"""

Expand Down
201 changes: 188 additions & 13 deletions alchemiscale/storage/statestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from contextlib import contextmanager
import json
from functools import lru_cache
from typing import Dict, List, Optional, Union, Tuple
from typing import Dict, List, Optional, Union, Tuple, Set
import weakref
import numpy as np

Expand All @@ -24,10 +24,11 @@
ComputeServiceRegistration,
NetworkMark,
NetworkStateEnum,
ProtocolDAGResultRef,
Task,
TaskHub,
TaskRestartPattern,
TaskStatusEnum,
ProtocolDAGResultRef,
)
from ..strategies import Strategy
from ..models import Scope, ScopedKey
Expand Down Expand Up @@ -1405,30 +1406,51 @@ def action_tasks(
# so we can properly return `None` if needed
task_map = {str(task): None for task in tasks}

q = f"""
query_safe_task_list = [str(task) for task in tasks if task]

q = """
// get our TaskHub
UNWIND {cypher_list_from_scoped_keys(tasks)} AS task_sk
MATCH (th:TaskHub {{_scoped_key: "{taskhub}"}})-[:PERFORMS]->(an:AlchemicalNetwork)
UNWIND $query_safe_task_list AS task_sk
MATCH (th:TaskHub {_scoped_key: $taskhub_scoped_key})-[:PERFORMS]->(an:AlchemicalNetwork)

// get the task we want to add to the hub; check that it connects to same network
MATCH (task:Task {{_scoped_key: task_sk}})-[:PERFORMS]->(tf:Transformation|NonTransformation)<-[:DEPENDS_ON]-(an)
MATCH (task:Task {_scoped_key: task_sk})-[:PERFORMS]->(:Transformation|NonTransformation)<-[:DEPENDS_ON]-(an)

// only proceed for cases where task is not already actioned on hub
// and where the task is either in 'waiting', 'running', or 'error' status
WITH th, an, task
WHERE NOT (th)-[:ACTIONS]->(task)
AND task.status IN ['{TaskStatusEnum.waiting.value}', '{TaskStatusEnum.running.value}', '{TaskStatusEnum.error.value}']
AND task.status IN [$waiting, $running, $error]

// create the connection
CREATE (th)-[ar:ACTIONS {{weight: 0.5}}]->(task)
CREATE (th)-[ar:ACTIONS {weight: 0.5}]->(task)

// set the task property to the scoped key of the Task
// this is a convenience for when we have to loop over relationships in Python
SET ar.task = task._scoped_key

// we want to preserve the list of tasks for the return, so we need to make a subquery
// since the subsequent WHERE clause could reduce the records in task
WITH task, th
CALL {
WITH task, th
MATCH (trp: TaskRestartPattern)-[:ENFORCES]->(th)
WHERE NOT (trp)-[:APPLIES]->(task)

CREATE (trp)-[:APPLIES {num_retries: 0}]->(task)
}

RETURN task
"""
results = self.execute_query(q)

results = self.execute_query(
q,
query_safe_task_list=query_safe_task_list,
waiting=TaskStatusEnum.waiting.value,
running=TaskStatusEnum.running.value,
error=TaskStatusEnum.error.value,
taskhub_scoped_key=str(taskhub),
)

# update our map with the results, leaving None for tasks that aren't found
for task_record in results.records:
Expand Down Expand Up @@ -1581,14 +1603,24 @@ def cancel_tasks(
"""
canceled_sks = []
with self.transaction() as tx:
for t in tasks:
q = f"""
for task in tasks:
query = """
// get our task hub, as well as the task :ACTIONS relationship we want to remove
MATCH (th:TaskHub {{_scoped_key: '{taskhub}'}})-[ar:ACTIONS]->(task:Task {{_scoped_key: '{t}'}})
MATCH (th:TaskHub {_scoped_key: $taskhub_scoped_key})-[ar:ACTIONS]->(task:Task {_scoped_key: $task_scoped_key})
DELETE ar

WITH task
CALL {
WITH task
MATCH (task)<-[applies:APPLIES]-(:TaskRestartPattern)
DELETE applies
}

RETURN task
"""
_task = tx.run(q).to_eager_result()
_task = tx.run(
query, taskhub_scoped_key=str(taskhub), task_scoped_key=str(task)
).to_eager_result()

if _task.records:
sk = _task.records[0].data()["task"]["_scoped_key"]
Expand Down Expand Up @@ -2552,7 +2584,9 @@ def set_task_complete(
// if we changed the status to complete,
// drop all ACTIONS relationships
OPTIONAL MATCH (t_)<-[ar:ACTIONS]-(th:TaskHub)
OPTIONAL MATCH (t_)<-[applies:APPLIES]-(:TaskRestartPattern)
DELETE ar
DELETE applies

WITH scoped_key, t, t_

Expand Down Expand Up @@ -2635,9 +2669,11 @@ def set_task_invalid(

OPTIONAL MATCH (t_)<-[ar:ACTIONS]-(th:TaskHub)
OPTIONAL MATCH (extends_task)<-[are:ACTIONS]-(th:TaskHub)
OPTIONAL MATCH (t_)<-[applies:APPLIES]-(:TaskRestartPattern)

DELETE ar
DELETE are
DELETE applies

WITH scoped_key, t, t_

Expand Down Expand Up @@ -2685,9 +2721,11 @@ def set_task_deleted(

OPTIONAL MATCH (t_)<-[ar:ACTIONS]-(th:TaskHub)
OPTIONAL MATCH (extends_task)<-[are:ACTIONS]-(th:TaskHub)
OPTIONAL MATCH (t_)<-[applies:APPLIES]-(:TaskRestartPattern)

DELETE ar
DELETE are
DELETE applies

WITH scoped_key, t, t_

Expand All @@ -2703,6 +2741,143 @@ def err_msg(t, status):

return self._set_task_status(tasks, q, err_msg, raise_error=raise_error)

## task restart policy

# TODO: fill in docstring
def add_task_restart_patterns(
self, taskhub: ScopedKey, patterns: List[str], number_of_retries: int
):
"""Add a list of restart policy patterns to a `TaskHub` along with the number of retries allowed.

Parameters
----------


Raises
------
"""

# get taskhub node
q = """
MATCH (th:TaskHub {`_scoped_key`: $taskhub})
RETURN th
"""
results = self.execute_query(q, taskhub=str(taskhub))
## raise error if taskhub not found

if not results.records:
raise KeyError("No such TaskHub in the database")

record_data = results.records[0]["th"]
taskhub_node = record_data_to_node(record_data)
scope = taskhub.scope

with self.transaction() as tx:
actioned_tasks_query = """
MATCH (taskhub: TaskHub {`_scoped_key`: $taskhub_scoped_key})-[:ACTIONS]->(task: Task)
RETURN task
"""

subgraph = Subgraph()

actioned_task_nodes = []

for actioned_tasks_record in (
tx.run(actioned_tasks_query, taskhub_scoped_key=str(taskhub))
.to_eager_result()
.records
):
actioned_task_nodes.append(
record_data_to_node(actioned_tasks_record["task"])
)

for pattern in patterns:
task_restart_pattern = TaskRestartPattern(
pattern,
max_retries=number_of_retries,
taskhub_scoped_key=str(taskhub),
)

_, task_restart_pattern_node, scoped_key = self._gufe_to_subgraph(
task_restart_pattern.to_shallow_dict(),
labels=["GufeTokenizable", task_restart_pattern.__class__.__name__],
gufe_key=task_restart_pattern.key,
scope=scope,
)

subgraph |= Relationship.type("ENFORCES")(
task_restart_pattern_node,
taskhub_node,
_org=scope.org,
_campaign=scope.campaign,
_project=scope.project,
)

for actioned_task_node in actioned_task_nodes:
subgraph |= Relationship.type("APPLIES")(
task_restart_pattern_node,
actioned_task_node,
num_retries=0,
)
merge_subgraph(tx, subgraph, "GufeTokenizable", "_scoped_key")

# TODO: fill in docstring
def remove_task_restart_patterns(self, taskhub: ScopedKey, patterns: List[str]):
q = """
UNWIND $patterns AS pattern

MATCH (trp: TaskRestartPattern {pattern: pattern, taskhub_scoped_key: $taskhub_scoped_key})

DETACH DELETE trp
"""

self.execute_query(q, patterns=patterns, taskhub_scoped_key=str(taskhub))

# TODO: fill in docstring
def set_task_restart_patterns_max_retries(
self,
taskhub_scoped_key: Union[ScopedKey, str],
patterns: List[str],
max_retries: int,
):
query = """
UNWIND $patterns AS pattern
MATCH (trp: TaskRestartPattern {pattern: pattern, taskhub_scoped_key: $taskhub_scoped_key})
SET trp.max_retries = $max_retries
"""

self.execute_query(
query,
patterns=patterns,
taskhub_scoped_key=str(taskhub_scoped_key),
max_retries=max_retries,
)

# TODO: fill in docstring
def get_task_restart_patterns(
self, taskhubs: List[ScopedKey]
) -> Dict[ScopedKey, Set[Tuple[str, int]]]:

q = """
UNWIND $taskhub_scoped_keys as taskhub_scoped_key
MATCH (trp: TaskRestartPattern)-[ENFORCES]->(th: TaskHub {`_scoped_key`: taskhub_scoped_key})
RETURN th, trp
"""

records = self.execute_query(
q, taskhub_scoped_keys=list(map(str, taskhubs))
).records

data = {taskhub: set() for taskhub in taskhubs}

for record in records:
pattern = record["trp"]["pattern"]
max_retries = record["trp"]["max_retries"]
taskhub_sk = ScopedKey.from_str(record["th"]["_scoped_key"])
data[taskhub_sk].add((pattern, max_retries))

return data

## authentication

def create_credentialed_entity(self, entity: CredentialedEntity):
Expand Down
Loading