Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

check pid exists when waiting for child #388

Merged
merged 3 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/viztracer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import threading
import time
import types
import re
from typing import Any, Dict, List, Optional, Tuple

from viztracer.vcompressor import VCompressor
Expand Down Expand Up @@ -590,7 +591,18 @@ def wait_children_finish(self) -> None:
try:
if any((f.endswith(".viztmp") for f in os.listdir(self.multiprocess_output_dir))):
same_line_print("Wait for child processes to finish, Ctrl+C to skip")
while any((f.endswith(".viztmp") for f in os.listdir(self.multiprocess_output_dir))):
while True:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic here can be simplified. You don't need to check if remain_viztmp is empty if you need to iterate through it. Do the for loop directly and your remaining logic will take care of it.

int(f[7:-12]) is a horrible magic number. Yit would be better to use regular expression here because it's self-explainable. Also, if there happen to be a file that does not match the pattern, this implementation will raise an index error or a integer conversion error, which would be a huge confusion to users. For regular expression, if it does not match, you can issue a warning (might be something we change in other places but forgot to match here).

You probably do not need a remain_children variable, you can iterate through remain_viztmp directly.

remain_viztmp = [f for f in os.listdir(self.multiprocess_output_dir) if f.endswith(".viztmp")]
for viztmp_file in remain_viztmp:
match = re.search(r'result_(\d+).json.viztmp', viztmp_file)
if match:
pid = int(match.group(1))
if pid_exists(pid):
break
else: # pragma: no cover
color_print("WARNING", f"Unknown viztmp file {viztmp_file}")
else:
break
time.sleep(0.5)
except KeyboardInterrupt:
pass
Expand Down
74 changes: 59 additions & 15 deletions src/viztracer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
import os
import re
import sys
import ctypes
from typing import Union

# Windows macros
STILL_ACTIVE = 0x103
ERROR_ACCESS_DENIED = 0x5
PROCESS_QUERY_LIMITED_INFORMATION = 0x1000


def size_fmt(num: Union[int, float], suffix: str = 'B') -> str:
for unit in ['', 'Ki', 'Mi', 'Gi']:
Expand Down Expand Up @@ -101,7 +107,6 @@ def time_str_to_us(t_s: str) -> float:
# https://github.com/giampaolo/psutil
def pid_exists(pid):
"""Check whether pid exists in the current process table.
UNIX only.
"""
if pid < 0:
return False
Expand All @@ -110,19 +115,58 @@ def pid_exists(pid):
# in the process group of the calling process.
# On certain systems 0 is a valid PID but we have no way
# to know that in a portable fashion.
# On Windows, 0 is an idle process buw we don't need to
# check it here
raise ValueError('invalid PID 0')
try:
os.kill(pid, 0)
except OSError as err:
if err.errno == errno.ESRCH:
# ESRCH == No such process
return False
elif err.errno == errno.EPERM:
# EPERM clearly means there's a process to deny access to
return True
else: # pragma: no cover
# According to "man 2 kill" possible error values are
# (EINVAL, EPERM, ESRCH)
raise
if sys.platform == "win32":
# Windows
kernel32 = ctypes.windll.kernel32

process = kernel32.OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, 0, pid)
if not process:
if kernel32.GetLastError() == ERROR_ACCESS_DENIED:
# Access is denied, which means there's a process.
# Usually it's impossible to run here in viztracer.
return True # pragma: no cover
else:
return False

exit_code = ctypes.c_ulong()
out = kernel32.GetExitCodeProcess(process, ctypes.byref(exit_code))
kernel32.CloseHandle(process)
# nonzero return value means the funtion succeeds
if out:
if exit_code.value == STILL_ACTIVE:
# According to documents of GetExitCodeProcess.
# If a thread returns STILL_ACTIVE (259) as an error code,
# then applications that test for that value could interpret
# it to mean that the thread is still running, and continue
# to test for the completion of the thread after the thread
# has terminated, which could put the application into an
# infinite loop.
return True
else:
return False
else: # pragma: no cover
if kernel32.GetLastError() == ERROR_ACCESS_DENIED:
# Access is denied, which means there's a process.
# Usually it's impossible to run here in viztracer.
return True
return False # pragma: no cover
else:
return True
# UNIX
try:
os.kill(pid, 0)
except OSError as err:
if err.errno == errno.ESRCH:
# ESRCH == No such process
return False
elif err.errno == errno.EPERM:
# EPERM clearly means there's a process to deny access to
return True
else: # pragma: no cover
# According to "man 2 kill" possible error values are
# (EINVAL, EPERM, ESRCH)
raise
else:
return True
50 changes: 50 additions & 0 deletions tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,53 @@ def test_escape_string(self):
expected_output_file="result.json",
script=issue285_code,
expected_stdout=".*Total Entries:.*")


wait_for_child = """
import time
import multiprocessing

def target():
time.sleep(3)

if __name__ == '__main__':
p = multiprocessing.Process(target=target)
p.start()
# The main process will join the child in multiprocessing.process._children.
# This is a hack to make sure the main process won't join the child process,
# so we can test the VizUI.wait_children_finish function
multiprocessing.process._children = set()
gaogaotiantian marked this conversation as resolved.
Show resolved Hide resolved
time.sleep(1)
"""

wait_for_terminated_child = """
import time
import os
import signal
import multiprocessing

def target():
time.sleep(3)
os.kill(os.getpid(), signal.SIGTERM)

if __name__ == '__main__':
p = multiprocessing.Process(target=target)
p.start()
# The main process will join the child in multiprocessing.process._children.
# This is a hack to make sure the main process won't join the child process,
# so we can test the VizUI.wait_children_finish function
multiprocessing.process._children = set()
time.sleep(1)
"""


class TestWaitForChild(CmdlineTmpl):
def test_child_process_exits_normally(self):
self.template(["viztracer", "-o", "result.json", "cmdline_test.py"],
expected_output_file="result.json", expected_stdout=r"Wait",
script=wait_for_child)

def test_child_process_exits_abnormally(self):
self.template(["viztracer", "-o", "result.json", "cmdline_test.py"],
expected_output_file="result.json", expected_stdout=r"Wait",
script=wait_for_terminated_child)
23 changes: 20 additions & 3 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@

import os
import sys
import unittest
import time
from multiprocessing import Process

import viztracer.util

from .base_tmpl import BaseTmpl


def target_child():
time.sleep(1)


class TestUtil(BaseTmpl):
def test_size_fmt(self):
size_fmt = viztracer.util.size_fmt
Expand All @@ -34,11 +39,23 @@ def test_time_str_to_us(self):
self.assertRaises(ValueError, time_str_to_us, "0.0.0")
self.assertRaises(ValueError, time_str_to_us, "invalid")

@unittest.skipIf(sys.platform == "win32", "pid_exists only works on Unix")
def test_pid_exists(self):
pid_exists = viztracer.util.pid_exists
self.assertFalse(pid_exists(-1))
self.assertTrue(pid_exists(1))
if sys.platform != "win32":
self.assertTrue(pid_exists(1))
self.assertTrue(pid_exists(os.getpid()))
with self.assertRaises(ValueError):
pid_exists(0)

# test child
p = Process(target=target_child)
p.start()
self.assertTrue(pid_exists(p.pid))
p.join()
self.assertFalse(pid_exists(p.pid))

# test a process that doesn't exist
# Windows pid starts from 4
if sys.platform == "win32":
self.assertFalse(pid_exists(2))
Loading