Skip to content

Commit

Permalink
Add __mp_main__ as a duplicate for __main__ for pickle to work (#423)
Browse files Browse the repository at this point in the history
  • Loading branch information
gaogaotiantian authored Apr 6, 2024
1 parent 3f6322e commit 9f17472
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/viztracer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ def run_string(self) -> VizProcedureResult:
setattr(main_mod, "__file__", "<string>")
setattr(main_mod, "__builtins__", globals()["__builtins__"])

sys.modules["__main__"] = main_mod
# __mp_main__ should be a duplicate of __main__ for pickle
sys.modules["__main__"] = sys.modules["__mp_main__"] = main_mod
code = compile(cmd_string, "<string>", "exec")
sys.argv = ["-c"] + self.command[:]
return self.run_code(code, main_mod.__dict__)
Expand Down Expand Up @@ -440,7 +441,8 @@ def run_command(self) -> VizProcedureResult:
setattr(main_mod, "__file__", os.path.abspath(file_name))
setattr(main_mod, "__builtins__", globals()["__builtins__"])

sys.modules["__main__"] = main_mod
# __mp_main__ should be a duplicate of __main__ for pickle
sys.modules["__main__"] = sys.modules["__mp_main__"] = main_mod
code = compile(code_string, os.path.abspath(file_name), "exec")
sys.path.insert(0, os.path.dirname(file_name))
sys.argv = command[:]
Expand Down
28 changes: 28 additions & 0 deletions tests/test_multiprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,20 @@ def f(x):
gc.enable()
"""

file_pool_with_pickle = """
from multiprocessing import get_context
class Bar:
pass
def foo(args):
return Bar()
if __name__ == '__main__':
with get_context('spawn').Pool(1) as pool:
_ = list(pool.imap_unordered(foo, [1]))
"""

file_loky = """
from loky import get_reusable_executor
import time
Expand Down Expand Up @@ -428,6 +442,20 @@ def check_func(data):
if not os.getenv("COVERAGE_RUN"):
raise e

@unittest.skipIf("win32" in sys.platform, "Does not support Windows")
def test_multiprocessing_pool_with_pickle(self):
def check_func(data):
pids = set()
for entry in data["traceEvents"]:
pids.add(entry["pid"])
self.assertGreater(len(pids), 1)

self.template(["viztracer", "-o", "result.json", "cmdline_test.py"],
expected_output_file="result.json",
script=file_pool_with_pickle,
check_func=check_func,
concurrency="multiprocessing")

def test_multiprosessing_stack_depth(self):
def check_func(data):
for entry in data["traceEvents"]:
Expand Down

0 comments on commit 9f17472

Please sign in to comment.