5
5
import functools
6
6
from multiprocessing import Process
7
7
import os
8
+ import re
8
9
import sys
9
10
import textwrap
10
11
from typing import Any , Callable , Dict , List , Sequence , Union
@@ -53,7 +54,7 @@ def subprocess_init(self, args: Union[str, Sequence], **kwargs) -> None:
53
54
setattr (subprocess .Popen , "__init__" , subprocess_init )
54
55
55
56
56
- def patch_multiprocessing (tracer : VizTracer ) -> None :
57
+ def patch_multiprocessing (tracer : VizTracer , args : List [ str ] ) -> None :
57
58
58
59
# For fork process
59
60
def func_after_fork (tracer : VizTracer ):
@@ -83,7 +84,7 @@ def get_command_line(**kwds) -> List[str]:
83
84
prog = textwrap .dedent (f"""
84
85
from multiprocessing.spawn import spawn_main;
85
86
from viztracer.patch import patch_spawned_process;
86
- patch_spawned_process({ tracer .init_kwargs } );
87
+ patch_spawned_process({ tracer .init_kwargs } , { args } );
87
88
spawn_main(%s)
88
89
""" )
89
90
prog %= ', ' .join ('%s=%r' % item for item in kwds .items ())
@@ -100,24 +101,27 @@ def __init__(
100
101
run : Callable ,
101
102
target : Callable ,
102
103
args : List [Any ],
103
- kwargs : Dict [str , Any ]):
104
+ kwargs : Dict [str , Any ],
105
+ cmdline_args : List [str ]):
104
106
self ._viztracer_kwargs = viztracer_kwargs
105
107
self ._run = run
106
108
self ._target = target
107
109
self ._args = args
108
110
self ._kwargs = kwargs
111
+ self ._cmdline_args = cmdline_args
109
112
self ._exiting = False
110
113
111
114
def run (self ) -> None :
112
115
import viztracer
113
116
114
117
tracer = viztracer .VizTracer (** self ._viztracer_kwargs )
118
+ install_all_hooks (tracer , self ._cmdline_args )
115
119
tracer .register_exit ()
116
120
tracer .start ()
117
121
self ._run ()
118
122
119
123
120
- def patch_spawned_process (viztracer_kwargs : Dict [str , Any ]):
124
+ def patch_spawned_process (viztracer_kwargs : Dict [str , Any ], cmdline_args : List [ str ] ):
121
125
from multiprocessing import reduction , process # type: ignore
122
126
from multiprocessing .spawn import prepare
123
127
import multiprocessing .spawn
@@ -130,7 +134,7 @@ def _main_3839(fd, parent_sentinel):
130
134
preparation_data = reduction .pickle .load (from_parent )
131
135
prepare (preparation_data )
132
136
self : Process = reduction .pickle .load (from_parent )
133
- sp = SpawnProcess (viztracer_kwargs , self .run , self ._target , self ._args , self ._kwargs )
137
+ sp = SpawnProcess (viztracer_kwargs , self .run , self ._target , self ._args , self ._kwargs , cmdline_args )
134
138
self .run = sp .run
135
139
finally :
136
140
del process .current_process ()._inheriting
@@ -144,7 +148,7 @@ def _main_3637(fd):
144
148
preparation_data = reduction .pickle .load (from_parent )
145
149
prepare (preparation_data )
146
150
self : Process = reduction .pickle .load (from_parent )
147
- sp = SpawnProcess (viztracer_kwargs , self .run , self ._target , self ._args , self ._kwargs )
151
+ sp = SpawnProcess (viztracer_kwargs , self .run , self ._target , self ._args , self ._kwargs , cmdline_args )
148
152
self .run = sp .run
149
153
finally :
150
154
del process .current_process ()._inheriting
@@ -154,3 +158,33 @@ def _main_3637(fd):
154
158
multiprocessing .spawn ._main = _main_3839 # type: ignore
155
159
else :
156
160
multiprocessing .spawn ._main = _main_3637 # type: ignore
161
+
162
+
163
+ def install_all_hooks (
164
+ tracer : VizTracer ,
165
+ args : List [str ],
166
+ patch_multiprocess : bool = True ) -> None :
167
+
168
+ # multiprocess hook
169
+ if patch_multiprocess :
170
+ patch_multiprocessing (tracer , args )
171
+ patch_subprocess (args + ["--subprocess_child" , "--dump_raw" , "-o" , tracer .output_file ])
172
+
173
+ # If we want to hook fork correctly with file waiter, we need to
174
+ # use os.register_at_fork to write the file, and make sure
175
+ # os.exec won't clear viztracer so that the file lives forever.
176
+ # This is basically equivalent to py3.8 + Linux
177
+ if hasattr (sys , "addaudithook" ):
178
+ if hasattr (os , "register_at_fork" ) and patch_multiprocess :
179
+ def audit_hook (event , _ ): # pragma: no cover
180
+ if event == "os.exec" :
181
+ tracer .exit_routine ()
182
+ sys .addaudithook (audit_hook ) # type: ignore
183
+ os .register_at_fork (after_in_child = lambda : tracer .label_file_to_write ()) # type: ignore
184
+ if tracer .log_audit is not None :
185
+ audit_regex_list = [re .compile (regex ) for regex in tracer .log_audit ]
186
+
187
+ def audit_hook (event , _ ): # pragma: no cover
188
+ if len (audit_regex_list ) == 0 or any ((regex .fullmatch (event ) for regex in audit_regex_list )):
189
+ tracer .log_instant (event , args = {"args" : [str (arg ) for arg in args ]})
190
+ sys .addaudithook (audit_hook ) # type: ignore
0 commit comments