diff --git a/src/viztracer/main.py b/src/viztracer/main.py index b7691108..f2dcbe90 100644 --- a/src/viztracer/main.py +++ b/src/viztracer/main.py @@ -16,6 +16,7 @@ import threading import time import types +import re from typing import Any, Dict, List, Optional, Tuple from viztracer.vcompressor import VCompressor @@ -590,7 +591,18 @@ def wait_children_finish(self) -> None: try: if any((f.endswith(".viztmp") for f in os.listdir(self.multiprocess_output_dir))): same_line_print("Wait for child processes to finish, Ctrl+C to skip") - while any((f.endswith(".viztmp") for f in os.listdir(self.multiprocess_output_dir))): + while True: + remain_viztmp = [f for f in os.listdir(self.multiprocess_output_dir) if f.endswith(".viztmp")] + for viztmp_file in remain_viztmp: + match = re.search(r'result_(\d+).json.viztmp', viztmp_file) + if match: + pid = int(match.group(1)) + if pid_exists(pid): + break + else: # pragma: no cover + color_print("WARNING", f"Unknown viztmp file {viztmp_file}") + else: + break time.sleep(0.5) except KeyboardInterrupt: pass diff --git a/src/viztracer/util.py b/src/viztracer/util.py index 286214fe..875d65c7 100644 --- a/src/viztracer/util.py +++ b/src/viztracer/util.py @@ -6,8 +6,14 @@ import os import re import sys +import ctypes from typing import Union +# Windows macros +STILL_ACTIVE = 0x103 +ERROR_ACCESS_DENIED = 0x5 +PROCESS_QUERY_LIMITED_INFORMATION = 0x1000 + def size_fmt(num: Union[int, float], suffix: str = 'B') -> str: for unit in ['', 'Ki', 'Mi', 'Gi']: @@ -101,7 +107,6 @@ def time_str_to_us(t_s: str) -> float: # https://github.com/giampaolo/psutil def pid_exists(pid): """Check whether pid exists in the current process table. - UNIX only. """ if pid < 0: return False @@ -110,19 +115,58 @@ def pid_exists(pid): # in the process group of the calling process. # On certain systems 0 is a valid PID but we have no way # to know that in a portable fashion. + # On Windows, 0 is an idle process buw we don't need to + # check it here raise ValueError('invalid PID 0') - try: - os.kill(pid, 0) - except OSError as err: - if err.errno == errno.ESRCH: - # ESRCH == No such process - return False - elif err.errno == errno.EPERM: - # EPERM clearly means there's a process to deny access to - return True - else: # pragma: no cover - # According to "man 2 kill" possible error values are - # (EINVAL, EPERM, ESRCH) - raise + if sys.platform == "win32": + # Windows + kernel32 = ctypes.windll.kernel32 + + process = kernel32.OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, 0, pid) + if not process: + if kernel32.GetLastError() == ERROR_ACCESS_DENIED: + # Access is denied, which means there's a process. + # Usually it's impossible to run here in viztracer. + return True # pragma: no cover + else: + return False + + exit_code = ctypes.c_ulong() + out = kernel32.GetExitCodeProcess(process, ctypes.byref(exit_code)) + kernel32.CloseHandle(process) + # nonzero return value means the funtion succeeds + if out: + if exit_code.value == STILL_ACTIVE: + # According to documents of GetExitCodeProcess. + # If a thread returns STILL_ACTIVE (259) as an error code, + # then applications that test for that value could interpret + # it to mean that the thread is still running, and continue + # to test for the completion of the thread after the thread + # has terminated, which could put the application into an + # infinite loop. + return True + else: + return False + else: # pragma: no cover + if kernel32.GetLastError() == ERROR_ACCESS_DENIED: + # Access is denied, which means there's a process. + # Usually it's impossible to run here in viztracer. + return True + return False # pragma: no cover else: - return True + # UNIX + try: + os.kill(pid, 0) + except OSError as err: + if err.errno == errno.ESRCH: + # ESRCH == No such process + return False + elif err.errno == errno.EPERM: + # EPERM clearly means there's a process to deny access to + return True + else: # pragma: no cover + # According to "man 2 kill" possible error values are + # (EINVAL, EPERM, ESRCH) + raise + else: + return True diff --git a/tests/test_regression.py b/tests/test_regression.py index 84cf27c6..82481fe8 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -384,3 +384,53 @@ def test_escape_string(self): expected_output_file="result.json", script=issue285_code, expected_stdout=".*Total Entries:.*") + + +wait_for_child = """ +import time +import multiprocessing + +def target(): + time.sleep(3) + +if __name__ == '__main__': + p = multiprocessing.Process(target=target) + p.start() + # The main process will join the child in multiprocessing.process._children. + # This is a hack to make sure the main process won't join the child process, + # so we can test the VizUI.wait_children_finish function + multiprocessing.process._children = set() + time.sleep(1) +""" + +wait_for_terminated_child = """ +import time +import os +import signal +import multiprocessing + +def target(): + time.sleep(3) + os.kill(os.getpid(), signal.SIGTERM) + +if __name__ == '__main__': + p = multiprocessing.Process(target=target) + p.start() + # The main process will join the child in multiprocessing.process._children. + # This is a hack to make sure the main process won't join the child process, + # so we can test the VizUI.wait_children_finish function + multiprocessing.process._children = set() + time.sleep(1) +""" + + +class TestWaitForChild(CmdlineTmpl): + def test_child_process_exits_normally(self): + self.template(["viztracer", "-o", "result.json", "cmdline_test.py"], + expected_output_file="result.json", expected_stdout=r"Wait", + script=wait_for_child) + + def test_child_process_exits_abnormally(self): + self.template(["viztracer", "-o", "result.json", "cmdline_test.py"], + expected_output_file="result.json", expected_stdout=r"Wait", + script=wait_for_terminated_child) diff --git a/tests/test_util.py b/tests/test_util.py index ba462506..51a5861d 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -3,13 +3,18 @@ import os import sys -import unittest +import time +from multiprocessing import Process import viztracer.util from .base_tmpl import BaseTmpl +def target_child(): + time.sleep(1) + + class TestUtil(BaseTmpl): def test_size_fmt(self): size_fmt = viztracer.util.size_fmt @@ -34,11 +39,23 @@ def test_time_str_to_us(self): self.assertRaises(ValueError, time_str_to_us, "0.0.0") self.assertRaises(ValueError, time_str_to_us, "invalid") - @unittest.skipIf(sys.platform == "win32", "pid_exists only works on Unix") def test_pid_exists(self): pid_exists = viztracer.util.pid_exists self.assertFalse(pid_exists(-1)) - self.assertTrue(pid_exists(1)) + if sys.platform != "win32": + self.assertTrue(pid_exists(1)) self.assertTrue(pid_exists(os.getpid())) with self.assertRaises(ValueError): pid_exists(0) + + # test child + p = Process(target=target_child) + p.start() + self.assertTrue(pid_exists(p.pid)) + p.join() + self.assertFalse(pid_exists(p.pid)) + + # test a process that doesn't exist + # Windows pid starts from 4 + if sys.platform == "win32": + self.assertFalse(pid_exists(2))