Skip to content

Commit

Permalink
check pid exists when waiting for child
Browse files Browse the repository at this point in the history
  • Loading branch information
TTianshun committed Dec 25, 2023
1 parent 68b7784 commit f8898c1
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 18 deletions.
11 changes: 10 additions & 1 deletion src/viztracer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,16 @@ def wait_children_finish(self) -> None:
try:
if any((f.endswith(".viztmp") for f in os.listdir(self.multiprocess_output_dir))):
same_line_print("Wait for child processes to finish, Ctrl+C to skip")
while any((f.endswith(".viztmp") for f in os.listdir(self.multiprocess_output_dir))):
while True:
remain_viztmp = [f for f in os.listdir(self.multiprocess_output_dir) if f.endswith(".viztmp")]
if not remain_viztmp:
break
remain_children = list(int(f[7:-12]) for f in remain_viztmp)
for pid in remain_children:
if pid_exists(pid):
break
else:
break
time.sleep(0.5)
except KeyboardInterrupt:
pass
Expand Down
72 changes: 58 additions & 14 deletions src/viztracer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
import os
import re
import sys
import ctypes
from typing import Union

# Windows macros
STILL_ACTIVE = 0x103
ERROR_ACCESS_DENIED = 0x5
PROCESS_QUERY_LIMITED_INFORMATION = 0x1000


def size_fmt(num: Union[int, float], suffix: str = 'B') -> str:
for unit in ['', 'Ki', 'Mi', 'Gi']:
Expand Down Expand Up @@ -101,7 +107,6 @@ def time_str_to_us(t_s: str) -> float:
# https://github.com/giampaolo/psutil
def pid_exists(pid):
"""Check whether pid exists in the current process table.
UNIX only.
"""
if pid < 0:
return False
Expand All @@ -110,19 +115,58 @@ def pid_exists(pid):
# in the process group of the calling process.
# On certain systems 0 is a valid PID but we have no way
# to know that in a portable fashion.
# On Windows, 0 is an idle process buw we don't need to
# check it here
raise ValueError('invalid PID 0')
try:
os.kill(pid, 0)
except OSError as err:
if err.errno == errno.ESRCH:
# ESRCH == No such process
return False
elif err.errno == errno.EPERM:
# EPERM clearly means there's a process to deny access to
# UNIX
if sys.platform != "win32":
try:
os.kill(pid, 0)
except OSError as err:
if err.errno == errno.ESRCH:
# ESRCH == No such process
return False
elif err.errno == errno.EPERM:
# EPERM clearly means there's a process to deny access to
return True
else: # pragma: no cover
# According to "man 2 kill" possible error values are
# (EINVAL, EPERM, ESRCH)
raise
else:
return True
else: # pragma: no cover
# According to "man 2 kill" possible error values are
# (EINVAL, EPERM, ESRCH)
raise
# Windows
else:
return True
kernel32 = ctypes.windll.kernel32

process = kernel32.OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, 0, pid)
if not process:
if kernel32.GetLastError() == ERROR_ACCESS_DENIED:
# Access is denied, which means there's a process.
# Usually it's impossible to run here in viztracer.
return True # pragma: no cover
else:
return False

exit_code = ctypes.c_ulong()
out = kernel32.GetExitCodeProcess(process, ctypes.byref(exit_code))
kernel32.CloseHandle(process)
# nonzero return value means the funtion succeeds
if out:
if exit_code.value == STILL_ACTIVE:
# According to documents of GetExitCodeProcess.
# If a thread returns STILL_ACTIVE (259) as an error code,
# then applications that test for that value could interpret
# it to mean that the thread is still running, and continue
# to test for the completion of the thread after the thread
# has terminated, which could put the application into an
# infinite loop.
return True
else:
return False
else: # pragma: no cover
if kernel32.GetLastError() == ERROR_ACCESS_DENIED:
# Access is denied, which means there's a process.
# Usually it's impossible to run here in viztracer.
return True
return False # pragma: no cover
44 changes: 44 additions & 0 deletions tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,47 @@ def test_escape_string(self):
expected_output_file="result.json",
script=issue285_code,
expected_stdout=".*Total Entries:.*")


wait_for_child = """
import time
import multiprocessing
def target():
time.sleep(3)
if __name__ == '__main__':
p = multiprocessing.Process(target=target)
p.start()
multiprocessing.process._children = set()
time.sleep(1)
"""

wait_for_terminated_child = """
import time
import os
import signal
import multiprocessing
def target():
time.sleep(3)
os.kill(os.getpid(), signal.SIGTERM)
if __name__ == '__main__':
p = multiprocessing.Process(target=target)
p.start()
multiprocessing.process._children = set()
time.sleep(1)
"""


class TestWaitForChild(CmdlineTmpl):
def test_child_process_exits_normally(self):
self.template(["viztracer", "-o", "result.json", "--dump_raw", "cmdline_test.py"],
expected_output_file="result.json", expected_stdout=r"Wait",
script=wait_for_child)

def test_child_process_exits_abnormally(self):
self.template(["viztracer", "-o", "result.json", "--dump_raw", "cmdline_test.py"],
expected_output_file="result.json", expected_stdout=r"Wait",
script=wait_for_terminated_child)
23 changes: 20 additions & 3 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@

import os
import sys
import unittest
import time
from multiprocessing import Process

import viztracer.util

from .base_tmpl import BaseTmpl


def target_child():
time.sleep(1)


class TestUtil(BaseTmpl):
def test_size_fmt(self):
size_fmt = viztracer.util.size_fmt
Expand All @@ -34,11 +39,23 @@ def test_time_str_to_us(self):
self.assertRaises(ValueError, time_str_to_us, "0.0.0")
self.assertRaises(ValueError, time_str_to_us, "invalid")

@unittest.skipIf(sys.platform == "win32", "pid_exists only works on Unix")
def test_pid_exists(self):
pid_exists = viztracer.util.pid_exists
self.assertFalse(pid_exists(-1))
self.assertTrue(pid_exists(1))
if sys.platform != "win32":
self.assertTrue(pid_exists(1))
self.assertTrue(pid_exists(os.getpid()))
with self.assertRaises(ValueError):
pid_exists(0)

# test child
p = Process(target=target_child)
p.start()
self.assertTrue(pid_exists(p.pid))
p.join()
self.assertFalse(pid_exists(p.pid))

# test a process that doesn't exist
# Windows pid starts from 4
if sys.platform == "win32":
self.assertFalse(pid_exists(2))

0 comments on commit f8898c1

Please sign in to comment.