Skip to content

Commit 621a97b

Browse files
Make nested multiprocess work (#264)
* Make nested multiprocess work
1 parent 19889f3 commit 621a97b

File tree

5 files changed

+130
-57
lines changed

5 files changed

+130
-57
lines changed

src/viztracer/main.py

Lines changed: 15 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import multiprocessing.util # type: ignore
1111
import os
1212
import platform
13-
import re
1413
import shutil
1514
import signal
1615
import sys
@@ -23,7 +22,7 @@
2322
from . import __version__
2423
from .attach_process.add_code_to_python_process import run_python_code # type: ignore
2524
from .code_monkey import CodeMonkey
26-
from .patch import patch_multiprocessing, patch_subprocess
25+
from .patch import install_all_hooks
2726
from .report_builder import ReportBuilder
2827
from .util import time_str_to_us, color_print, same_line_print, pid_exists
2928
from .viztracer import VizTracer
@@ -153,13 +152,6 @@ def create_parser(self) -> argparse.ArgumentParser:
153152
help="time you want to trace the process")
154153
return parser
155154

156-
@property
157-
def is_main_process(self) -> bool:
158-
options = self.options
159-
if options.subprocess_child:
160-
return False # pragma: no cover
161-
return self.parent_pid == os.getpid()
162-
163155
def load_config_file(self, filename: str = ".viztracerrc") -> argparse.Namespace:
164156
ret = argparse.Namespace()
165157
if os.path.exists(filename):
@@ -262,6 +254,7 @@ def parse(self, argv: List[str]) -> VizProcedureResult:
262254
"log_gc": options.log_gc,
263255
"log_sparse": options.log_sparse,
264256
"log_async": options.log_async,
257+
"log_audit": options.log_audit,
265258
"vdb": options.vdb,
266259
"pid_suffix": True,
267260
"file_info": False,
@@ -296,45 +289,6 @@ def search_file(self, file_name: str) -> Optional[str]:
296289

297290
return None
298291

299-
def install_all_hooks(self, tracer: VizTracer) -> None:
300-
options = self.options
301-
302-
# multiprocess hook
303-
if not options.ignore_multiprocess:
304-
patch_multiprocessing(tracer)
305-
if not options.subprocess_child:
306-
patch_subprocess(self.args + ["--subprocess_child", "--dump_raw", "-o", tracer.output_file])
307-
308-
# If we want to hook fork correctly with file waiter, we need to
309-
# use os.register_at_fork to write the file, and make sure
310-
# os.exec won't clear viztracer so that the file lives forever.
311-
# This is basically equivalent to py3.8 + Linux
312-
if hasattr(sys, "addaudithook"):
313-
if hasattr(os, "register_at_fork"):
314-
def audit_hook(event, args): # pragma: no cover
315-
if event == "os.exec":
316-
tracer.exit_routine()
317-
sys.addaudithook(audit_hook) # type: ignore
318-
os.register_at_fork(after_in_child=lambda: tracer.label_file_to_write()) # type: ignore
319-
if options.log_audit is not None:
320-
audit_regex_list = [re.compile(regex) for regex in options.log_audit]
321-
322-
def audit_hook(event, args): # pragma: no cover
323-
if len(audit_regex_list) == 0 or any((regex.fullmatch(event) for regex in audit_regex_list)):
324-
tracer.log_instant(event, args={"args": [str(arg) for arg in args]})
325-
sys.addaudithook(audit_hook) # type: ignore
326-
327-
# SIGTERM hook
328-
def term_handler(signalnum, frame):
329-
sys.exit(0)
330-
331-
if self.is_main_process:
332-
signal.signal(signal.SIGTERM, term_handler)
333-
multiprocessing.util.Finalize(self, self.exit_routine, exitpriority=-1)
334-
else:
335-
signal.signal(signal.SIGTERM, term_handler)
336-
multiprocessing.util.Finalize(tracer, tracer.exit_routine, exitpriority=-1)
337-
338292
def run(self) -> VizProcedureResult:
339293
if self.options.version:
340294
return self.show_version()
@@ -365,7 +319,19 @@ def run_code(self, code: Any, global_dict: Dict[str, Any]) -> VizProcedureResult
365319
tracer = VizTracer(**self.init_kwargs)
366320
self.tracer = tracer
367321

368-
self.install_all_hooks(tracer)
322+
install_all_hooks(tracer,
323+
self.args,
324+
patch_multiprocess=not options.ignore_multiprocess)
325+
326+
def term_handler(signalnum, frame):
327+
sys.exit(0)
328+
329+
signal.signal(signal.SIGTERM, term_handler)
330+
331+
if options.subprocess_child:
332+
multiprocessing.util.Finalize(tracer, tracer.exit_routine, exitpriority=-1)
333+
else:
334+
multiprocessing.util.Finalize(self, self.exit_routine, exitpriority=-1)
369335

370336
if not options.log_sparse:
371337
tracer.start()

src/viztracer/patch.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import functools
66
from multiprocessing import Process
77
import os
8+
import re
89
import sys
910
import textwrap
1011
from typing import Any, Callable, Dict, List, Sequence, Union
@@ -53,7 +54,7 @@ def subprocess_init(self, args: Union[str, Sequence], **kwargs) -> None:
5354
setattr(subprocess.Popen, "__init__", subprocess_init)
5455

5556

56-
def patch_multiprocessing(tracer: VizTracer) -> None:
57+
def patch_multiprocessing(tracer: VizTracer, args: List[str]) -> None:
5758

5859
# For fork process
5960
def func_after_fork(tracer: VizTracer):
@@ -83,7 +84,7 @@ def get_command_line(**kwds) -> List[str]:
8384
prog = textwrap.dedent(f"""
8485
from multiprocessing.spawn import spawn_main;
8586
from viztracer.patch import patch_spawned_process;
86-
patch_spawned_process({tracer.init_kwargs});
87+
patch_spawned_process({tracer.init_kwargs}, {args});
8788
spawn_main(%s)
8889
""")
8990
prog %= ', '.join('%s=%r' % item for item in kwds.items())
@@ -100,24 +101,27 @@ def __init__(
100101
run: Callable,
101102
target: Callable,
102103
args: List[Any],
103-
kwargs: Dict[str, Any]):
104+
kwargs: Dict[str, Any],
105+
cmdline_args: List[str]):
104106
self._viztracer_kwargs = viztracer_kwargs
105107
self._run = run
106108
self._target = target
107109
self._args = args
108110
self._kwargs = kwargs
111+
self._cmdline_args = cmdline_args
109112
self._exiting = False
110113

111114
def run(self) -> None:
112115
import viztracer
113116

114117
tracer = viztracer.VizTracer(**self._viztracer_kwargs)
118+
install_all_hooks(tracer, self._cmdline_args)
115119
tracer.register_exit()
116120
tracer.start()
117121
self._run()
118122

119123

120-
def patch_spawned_process(viztracer_kwargs: Dict[str, Any]):
124+
def patch_spawned_process(viztracer_kwargs: Dict[str, Any], cmdline_args: List[str]):
121125
from multiprocessing import reduction, process # type: ignore
122126
from multiprocessing.spawn import prepare
123127
import multiprocessing.spawn
@@ -130,7 +134,7 @@ def _main_3839(fd, parent_sentinel):
130134
preparation_data = reduction.pickle.load(from_parent)
131135
prepare(preparation_data)
132136
self: Process = reduction.pickle.load(from_parent)
133-
sp = SpawnProcess(viztracer_kwargs, self.run, self._target, self._args, self._kwargs)
137+
sp = SpawnProcess(viztracer_kwargs, self.run, self._target, self._args, self._kwargs, cmdline_args)
134138
self.run = sp.run
135139
finally:
136140
del process.current_process()._inheriting
@@ -144,7 +148,7 @@ def _main_3637(fd):
144148
preparation_data = reduction.pickle.load(from_parent)
145149
prepare(preparation_data)
146150
self: Process = reduction.pickle.load(from_parent)
147-
sp = SpawnProcess(viztracer_kwargs, self.run, self._target, self._args, self._kwargs)
151+
sp = SpawnProcess(viztracer_kwargs, self.run, self._target, self._args, self._kwargs, cmdline_args)
148152
self.run = sp.run
149153
finally:
150154
del process.current_process()._inheriting
@@ -154,3 +158,33 @@ def _main_3637(fd):
154158
multiprocessing.spawn._main = _main_3839 # type: ignore
155159
else:
156160
multiprocessing.spawn._main = _main_3637 # type: ignore
161+
162+
163+
def install_all_hooks(
164+
tracer: VizTracer,
165+
args: List[str],
166+
patch_multiprocess: bool = True) -> None:
167+
168+
# multiprocess hook
169+
if patch_multiprocess:
170+
patch_multiprocessing(tracer, args)
171+
patch_subprocess(args + ["--subprocess_child", "--dump_raw", "-o", tracer.output_file])
172+
173+
# If we want to hook fork correctly with file waiter, we need to
174+
# use os.register_at_fork to write the file, and make sure
175+
# os.exec won't clear viztracer so that the file lives forever.
176+
# This is basically equivalent to py3.8 + Linux
177+
if hasattr(sys, "addaudithook"):
178+
if hasattr(os, "register_at_fork") and patch_multiprocess:
179+
def audit_hook(event, _): # pragma: no cover
180+
if event == "os.exec":
181+
tracer.exit_routine()
182+
sys.addaudithook(audit_hook) # type: ignore
183+
os.register_at_fork(after_in_child=lambda: tracer.label_file_to_write()) # type: ignore
184+
if tracer.log_audit is not None:
185+
audit_regex_list = [re.compile(regex) for regex in tracer.log_audit]
186+
187+
def audit_hook(event, _): # pragma: no cover
188+
if len(audit_regex_list) == 0 or any((regex.fullmatch(event) for regex in audit_regex_list)):
189+
tracer.log_instant(event, args={"args": [str(arg) for arg in args]})
190+
sys.addaudithook(audit_hook) # type: ignore

src/viztracer/viztracer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(self,
3232
log_gc: bool = False,
3333
log_sparse: bool = False,
3434
log_async: bool = False,
35+
log_audit: Optional[Sequence[str]] = None,
3536
vdb: bool = False,
3637
pid_suffix: bool = False,
3738
file_info: bool = True,
@@ -66,6 +67,7 @@ def __init__(self,
6667
self.output_file = output_file
6768
self.system_print = None
6869
self.log_sparse = log_sparse
70+
self.log_audit = log_audit
6971
self.dump_raw = dump_raw
7072
self.sanitize_function_name = sanitize_function_name
7173
self.minimize_memory = minimize_memory
@@ -112,6 +114,7 @@ def init_kwargs(self) -> Dict:
112114
"log_gc": self.log_gc,
113115
"log_sparse": self.log_sparse,
114116
"log_async": self.log_async,
117+
"log_audit": self.log_audit,
115118
"vdb": self.vdb,
116119
"pid_suffix": self.pid_suffix,
117120
"min_duration": self.min_duration,

tests/test_multiprocess.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
from .cmdline_tmpl import CmdlineTmpl
1212

1313

14+
file_grandparent = """
15+
import subprocess
16+
subprocess.run(["python", "parent.py"])
17+
"""
18+
19+
1420
file_parent = """
1521
import subprocess
1622
subprocess.run(["python", "child.py"])
@@ -82,6 +88,33 @@ def f():
8288
time.sleep(0.1)
8389
"""
8490

91+
file_nested_multiprocessing = """
92+
import multiprocessing
93+
from multiprocessing import Process
94+
import time
95+
96+
97+
def fib(n):
98+
if n < 2:
99+
return 1
100+
return fib(n-1) + fib(n-2)
101+
102+
def f():
103+
fib(5)
104+
105+
def spawn():
106+
p = Process(target=f)
107+
p.start()
108+
p.join()
109+
110+
if __name__ == "__main__":
111+
fib(2)
112+
p = Process(target=spawn)
113+
p.start()
114+
p.join()
115+
time.sleep(0.1)
116+
"""
117+
85118
file_multiprocessing_overload_run = """
86119
import multiprocessing
87120
from multiprocessing import Process
@@ -196,6 +229,30 @@ def test_child_process(self):
196229
expected_output_file=None)
197230
self.assertEqual(len(os.listdir(tmpdir)), 1)
198231

232+
def test_nested(self):
233+
def check_func(data):
234+
pids = set()
235+
for entry in data["traceEvents"]:
236+
pids.add(entry["pid"])
237+
self.assertEqual(len(pids), 5)
238+
with open("parent.py", "w") as f:
239+
f.write(file_parent)
240+
self.template(["viztracer", "-o", "result.json", "cmdline_test.py"],
241+
expected_output_file="result.json", script=file_grandparent, check_func=check_func)
242+
os.remove("parent.py")
243+
244+
def test_nested_multiproessing(self):
245+
def check_func(data):
246+
pids = set()
247+
for entry in data["traceEvents"]:
248+
pids.add(entry["pid"])
249+
self.assertEqual(len(pids), 3)
250+
with open("parent.py", "w") as f:
251+
f.write(file_multiprocessing)
252+
self.template(["viztracer", "-o", "result.json", "cmdline_test.py"],
253+
expected_output_file="result.json", script=file_grandparent, check_func=check_func)
254+
os.remove("parent.py")
255+
199256
@unittest.skipIf(sys.platform == "win32", "Can't get anything on Windows with SIGTERM")
200257
def test_term(self):
201258
with tempfile.TemporaryDirectory() as tmpdir:
@@ -248,6 +305,19 @@ def check_func(data):
248305
check_func=check_func,
249306
concurrency="multiprocessing")
250307

308+
def test_nested_multiprosessing(self):
309+
def check_func(data):
310+
pids = set()
311+
for entry in data["traceEvents"]:
312+
pids.add(entry["pid"])
313+
self.assertEqual(len(pids), 3)
314+
315+
self.template(["viztracer", "-o", "result.json", "cmdline_test.py"],
316+
expected_output_file="result.json",
317+
script=file_nested_multiprocessing,
318+
check_func=check_func,
319+
concurrency="multiprocessing")
320+
251321
def test_multiprocessing_entry_limit(self):
252322
result = self.template(["viztracer", "-o", "result.json", "--tracer_entries", "10", "cmdline_test.py"],
253323
expected_output_file="result.json",
@@ -263,7 +333,7 @@ def check_func(data):
263333
pids.add(entry["pid"])
264334
self.assertEqual(len(pids), 1)
265335

266-
self.template(["viztracer", "-o", "result.json", "--ignore_multiproces", "cmdline_test.py"],
336+
self.template(["viztracer", "-o", "result.json", "--ignore_multiprocess", "cmdline_test.py"],
267337
expected_output_file="result.json",
268338
script=file_multiprocessing,
269339
check_func=check_func,

tests/test_patch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
set_spawning_popen(None)
3333
child_r, parent_w = os.pipe()
3434
35-
patch_spawned_process({'output_file': "$tmpdir/result.json", 'pid_suffix': True})
35+
patch_spawned_process({'output_file': "$tmpdir/result.json", 'pid_suffix': True}, [])
3636
pid = os.getpid()
3737
3838
assert multiprocessing.spawn._main.__qualname__ == "_main"

0 commit comments

Comments
 (0)