Skip to content

Commit

Permalink
Merge pull request #2662 from cta-observatory/fix_tool_exit_order
Browse files Browse the repository at this point in the history
Fix tool exit order
  • Loading branch information
maxnoe authored Dec 4, 2024
2 parents db1722d + b345fc0 commit ee0a356
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 25 deletions.
2 changes: 2 additions & 0 deletions docs/changes/2662.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix the order in which ``Tool`` runs final operations to fix an issue
of provenance not being correctly recorded.
82 changes: 82 additions & 0 deletions src/ctapipe/core/tests/test_tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import json
import logging
import os
import signal
import sys
import tempfile
from multiprocessing import Barrier, Process
from pathlib import Path

import pytest
Expand Down Expand Up @@ -487,3 +490,82 @@ class MyTool(Tool):
assert len(inputs) == 1
assert inputs[0]["role"] == "Tool Configuration"
assert inputs[0]["url"] == str(config_path)


@pytest.mark.parametrize(
("exit_code", "expected_status"),
[
(0, "completed"),
(None, "completed"),
(1, "error"),
(2, "error"),
],
)
def test_exit_status(exit_code, expected_status, tmp_path, provenance):
"""check that the config is correctly in the provenance"""

class MyTool(Tool):
exit_code = Int(allow_none=True, default_value=None).tag(config=True)

def start(self):
if self.exit_code is None:
return

if self.exit_code == 0:
sys.exit(0)

if self.exit_code == 1:
raise ValueError("Some error happened")

class CustomError(ValueError):
exit_code = self.exit_code

raise CustomError("Some error with specific code happened")

provenance_path = tmp_path / "provlog.json"
run_tool(
MyTool(exit_code=exit_code),
[f"--provenance-log={provenance_path}"],
raises=False,
)

activities = json.loads(provenance_path.read_text())
assert len(activities) == 1
provlog = activities[0]
assert provlog["status"] == expected_status


class InterruptTestTool(Tool):
name = "test-interrupt"

def __init__(self, barrier):
super().__init__()
self.barrier = barrier

def start(self):
self.barrier.wait()
signal.pause()


def test_exit_status_interrupted(tmp_path, provenance):
"""check that the config is correctly in the provenance"""

# to make sure we only kill the process once it is running
barrier = Barrier(2)
tool = InterruptTestTool(barrier)

provenance_path = tmp_path / "provlog.json"
args = [f"--provenance-log={provenance_path}", "--log-level=INFO"]
process = Process(target=run_tool, args=(tool, args), kwargs=dict(raises=False))
process.start()
barrier.wait()

# process.terminate()
os.kill(process.pid, signal.SIGINT)
process.join()

activities = json.loads(provenance_path.read_text())
assert len(activities) == 1
provlog = activities[0]
assert provlog["activity_name"] == InterruptTestTool.name
assert provlog["status"] == "interrupted"
38 changes: 15 additions & 23 deletions src/ctapipe/core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def run(self, argv=None, raises=False):
# return codes are taken from:
# https://tldp.org/LDP/abs/html/exitcodes.html

status = "completed"
exit_status = 0
current_exception = None

Expand All @@ -430,51 +431,42 @@ def run(self, argv=None, raises=False):

self.start()
self.finish()
self.log.info("Finished: %s", self.name)
Provenance().finish_activity(activity_name=self.name)
except (ToolConfigurationError, TraitError) as err:
current_exception = err
self.log.error("%s", err)
self.log.error("Use --help for more info")
exit_status = 2 # wrong cmd line parameter
Provenance().finish_activity(
activity_name=self.name, status="error", exit_code=exit_status
)
status = "error"
except KeyboardInterrupt:
self.log.warning("WAS INTERRUPTED BY CTRL-C")
exit_status = 130 # Script terminated by Control-C
Provenance().finish_activity(
activity_name=self.name, status="interrupted", exit_code=exit_status
)
status = "interrupted"
except Exception as err:
current_exception = err
exit_status = getattr(err, "exit_code", 1)
status = "error"
self.log.exception("Caught unexpected exception: %s", err)
Provenance().finish_activity(
activity_name=self.name, status="error", exit_code=exit_status
)
except SystemExit as err:
exit_status = err.code
if exit_status == 0:
# Finish normally
Provenance().finish_activity(activity_name=self.name)
else:
# Finish with error
if exit_status != 0:
status = "error"
current_exception = err
self.log.critical(
"Caught SystemExit with exit code %s", exit_status
)
Provenance().finish_activity(
activity_name=self.name,
status="error",
exit_code=exit_status,
)
finally:
if not {"-h", "--help", "--help-all"}.intersection(self.argv):
self.write_provenance()
if raises and current_exception:
self.write_provenance()
raise current_exception

Provenance().finish_activity(
activity_name=self.name, status=status, exit_code=exit_status
)

if not {"-h", "--help", "--help-all"}.intersection(self.argv):
self.write_provenance()

self.log.info("Finished %s", self.name)
self.exit(exit_status)

def write_provenance(self):
Expand Down
20 changes: 18 additions & 2 deletions src/ctapipe/tools/tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Test ctapipe-process on a few different use cases
"""

import json
from subprocess import CalledProcessError

import astropy.units as u
Expand Down Expand Up @@ -160,17 +161,20 @@ def test_stage1_datalevels(tmp_path):
assert isinstance(tool.event_source, DummyEventSource)


def test_stage_2_from_simtel(tmp_path):
def test_stage_2_from_simtel(tmp_path, provenance):
"""check we can go to DL2 geometry from simtel file"""
config = resource_file("stage2_config.json")
output = tmp_path / "test_stage2_from_simtel.DL2.h5"

provenance_log = tmp_path / "provenance.log"
input_path = get_dataset_path("gamma_prod5.simtel.zst")
run_tool(
ProcessorTool(),
argv=[
f"--config={config}",
"--input=dataset://gamma_prod5.simtel.zst",
f"--input={input_path}",
f"--output={output}",
f"--provenance-log={provenance_log}",
"--overwrite",
],
cwd=tmp_path,
Expand All @@ -190,6 +194,18 @@ def test_stage_2_from_simtel(tmp_path):
assert dl2["HillasReconstructor_telescopes"].dtype == np.bool_
assert dl2["HillasReconstructor_telescopes"].shape[1] == len(subarray)

activities = json.loads(provenance_log.read_text())
assert len(activities) == 1

activity = activities[0]
assert activity["status"] == "completed"
assert len(activity["input"]) == 2
assert activity["input"][0]["url"] == str(config)
assert activity["input"][1]["url"] == str(input_path)

assert len(activity["output"]) == 1
assert activity["output"][0]["url"] == str(output)


def test_stage_2_from_dl1_images(tmp_path, dl1_image_file):
"""check we can go to DL2 geometry from DL1 images"""
Expand Down

0 comments on commit ee0a356

Please sign in to comment.