Skip to content

Commit

Permalink
Updated docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
ianmkenney committed Oct 7, 2024
1 parent fdc25a7 commit 51194ff
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 20 deletions.
4 changes: 0 additions & 4 deletions alchemiscale/interface/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,6 @@ def get_task_status(
return status[0].value


# TODO docstring
@router.post("/networks/{network_scoped_key}/restartpolicy/add")
def add_task_restart_patterns(
network_scoped_key: str,
Expand All @@ -961,7 +960,6 @@ def add_task_restart_patterns(
n4js.add_task_restart_patterns(taskhub_scoped_key, patterns, number_of_retries)


# TODO docstring
@router.post("/networks/{network_scoped_key}/restartpolicy/remove")
def remove_task_restart_patterns(
network_scoped_key: str,
Expand All @@ -974,7 +972,6 @@ def remove_task_restart_patterns(
n4js.remove_task_restart_patterns(taskhub_scoped_key, patterns)


# TODO: docstring
@router.get("/networks/{network_scoped_key}/restartpolicy/clear")
def clear_task_restart_patterns(
network_scoped_key: str,
Expand All @@ -987,7 +984,6 @@ def clear_task_restart_patterns(
return [network_scoped_key]


# TODO docstring
@router.post("/bulk/networks/restartpolicy/get")
def get_task_restart_patterns(
*,
Expand Down
12 changes: 10 additions & 2 deletions alchemiscale/storage/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ def _defaults(cls):
return super()._defaults()


# TODO: fill in docstrings
class TaskRestartPattern(GufeTokenizable):
"""A pattern to compare returned Task tracebacks to.
Expand Down Expand Up @@ -202,8 +201,17 @@ def __eq__(self, other):
return self.pattern == other.pattern


# TODO: docstrings
class Tracebacks(GufeTokenizable):
"""
Attributes
----------
tracebacks: list[str]
The tracebacks returned with the ProtocolUnitFailures.
source_keys:list[ScopedKey]
The ScopedKeys of the Protocols that failed.
failure_keys: list[ScopedKey]
The ScopedKeys of the ProtocolUnitFailures.
"""

def __init__(
self, tracebacks: List[str], source_keys: List[str], failure_keys: List[str]
Expand Down
80 changes: 66 additions & 14 deletions alchemiscale/storage/statestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -2846,18 +2846,25 @@ def err_msg(t, status):

## task restart policy

# TODO: fill in docstring
def add_task_restart_patterns(
self, taskhub: ScopedKey, patterns: List[str], number_of_retries: int
self, taskhub: ScopedKey, patterns: list[str], number_of_retries: int
):
"""Add a list of restart policy patterns to a `TaskHub` along with the number of retries allowed.
Parameters
----------
taskhub : ScopedKey
TaskHub for the restart patterns to enforce.
patterns: list[str]
Regular expression patterns that will be compared to tracebacks returned by ProtocolUnitFailures.
number_of_retries: int
The number of times the given patterns will apply to a single Task, attempts to restart beyond
this value will result in a canceled Task with an error status.
Raises
------
KeyError
Raised when the provided TaskHub ScopedKey cannot be associated with a TaskHub in the database.
"""

# get taskhub node
Expand All @@ -2866,8 +2873,8 @@ def add_task_restart_patterns(
RETURN th
"""
results = self.execute_query(q, taskhub=str(taskhub))
## raise error if taskhub not found


# raise error if taskhub not found
if not results.records:
raise KeyError("No such TaskHub in the database")

Expand Down Expand Up @@ -2935,8 +2942,16 @@ def add_task_restart_patterns(

self.resolve_task_restarts(actioned_task_scoped_keys, tx=tx)

# TODO: fill in docstring
def remove_task_restart_patterns(self, taskhub: ScopedKey, patterns: List[str]):
def remove_task_restart_patterns(self, taskhub: ScopedKey, patterns: list[str]):
"""Remove a list of restart patterns enforcing a TaskHub from the database.
Parameters
----------
taskhub: ScopedKey
The ScopedKey of the TaskHub that the patterns enforce.
patterns: list[str]
The patterns to remove. Patterns not enforcing the TaskHub are ignored.
"""
q = """
UNWIND $patterns AS pattern
Expand All @@ -2948,19 +2963,36 @@ def remove_task_restart_patterns(self, taskhub: ScopedKey, patterns: List[str]):
self.execute_query(q, patterns=patterns, taskhub_scoped_key=str(taskhub))

def clear_task_restart_patterns(self, taskhub: ScopedKey):
"""Clear all restart patterns from a TaskHub.
Parameters
----------
taskhub: ScopedKey
The ScopedKey of the TaskHub to clear of restart patterns.
"""
q = """
MATCH (trp: TaskRestartPattern {taskhub_scoped_key: $taskhub_scoped_key})
DETACH DELETE trp
"""
self.execute_query(q, taskhub_scoped_key=str(taskhub))

# TODO: fill in docstring
def set_task_restart_patterns_max_retries(
self,
taskhub_scoped_key: Union[ScopedKey, str],
patterns: List[str],
taskhub_scoped_key: ScopedKey,
patterns: list[str],
max_retries: int,
):
"""Set the maximum number of retries of a pattern enforcing a TaskHub.
Parameters
----------
taskhub_scoped_key: ScopedKey
The ScopedKey of the TaskHub that the patterns enforce.
patterns: list[str]
The patterns to change the maximum retries value for.
max_retries: int
The new maximum retries value.
"""
query = """
UNWIND $patterns AS pattern
MATCH (trp: TaskRestartPattern {pattern: pattern, taskhub_scoped_key: $taskhub_scoped_key})
Expand All @@ -2974,11 +3006,24 @@ def set_task_restart_patterns_max_retries(
max_retries=max_retries,
)

# TODO: fill in docstring
# TODO: validation of taskhubs variable, will fail in weird ways if not enforced
def get_task_restart_patterns(
self, taskhubs: List[ScopedKey]
) -> Dict[ScopedKey, Set[Tuple[str, int]]]:
self, taskhubs: list[ScopedKey]
) -> dict[ScopedKey, set[tuple[str, int]]]:
"""For a list of TaskHub ScopedKeys, get the associated restart patterns along with the maximum number of retries for each pattern.
Parameters
----------
taskhubs: list[ScopedKey]
The ScopedKeys of the TaskHubs to get the restart patterns of.
Returns
-------
dict[ScopedKey, set[tuple[str, int]]]
A dictionary containing whose keys are the ScopedKeys of the TaskHubs provided and whose
values are a set of tuples containing the patterns enforcing each TaskHub along with their
associated maximum number of retries.
"""

q = """
UNWIND $taskhub_scoped_keys as taskhub_scoped_key
Expand All @@ -3002,9 +3047,16 @@ def get_task_restart_patterns(

return data

# TODO: docstrings
@chainable
def resolve_task_restarts(self, task_scoped_keys: Iterable[ScopedKey], *, tx=None):
"""Determine whether or not Tasks need to be restarted or canceled and perform that action.
Parameters
----------
task_scoped_keys: Iterable[ScopedKey]
An iterable of Task ScopedKeys that need to be resolved. Tasks without the error status
are filtered out and ignored.
"""

# Given the scoped keys of a list of Tasks, find all tasks that have an
# error status and have a TaskRestartPattern applied. A subquery is executed
Expand Down

0 comments on commit 51194ff

Please sign in to comment.