From 38a108b8bc4d0468430c0f1349195f8f55b84c86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?An=C5=BEe=20Stari=C4=8D?= Date: Wed, 1 Nov 2023 18:24:30 +0100 Subject: [PATCH] Fix stalling in Pipeline command (#632) * Fix stalling in Pipeline command When a command (other than the first) in a pipeline wrote more than 64k to the stderr, and the output was consumed with the iter_lines function, the whole pipeline stalled. Fixed by reading the output of all commands where either stdout or stderr was set to PIPE. * disable new tests on windows, increase timeout * black & flake8 * trying to figure out what the extra items are * test_pipelines: ignore empty lines somehow stderr includes an empty line when run on pypy * skip new tests on windows * Apply suggestions from code review --------- Co-authored-by: Henry Schreiner --- plumbum/commands/base.py | 2 - plumbum/commands/processes.py | 41 ++++++++++++++--- tests/test_pipelines.py | 87 +++++++++++++++++++++++++++++++++++ 3 files changed, 122 insertions(+), 8 deletions(-) create mode 100644 tests/test_pipelines.py diff --git a/plumbum/commands/base.py b/plumbum/commands/base.py index 52c0f2626..3172c526e 100644 --- a/plumbum/commands/base.py +++ b/plumbum/commands/base.py @@ -398,8 +398,6 @@ def popen(self, args=(), **kwargs): dstproc = self.dstcmd.popen(**kwargs) # allow p1 to receive a SIGPIPE if p2 exits srcproc.stdout.close() - if srcproc.stderr is not None: - dstproc.stderr = srcproc.stderr if srcproc.stdin and src_kwargs.get("stdin") != PIPE: srcproc.stdin.close() dstproc.srcproc = srcproc diff --git a/plumbum/commands/processes.py b/plumbum/commands/processes.py index 802ede4d2..89c276342 100644 --- a/plumbum/commands/processes.py +++ b/plumbum/commands/processes.py @@ -18,14 +18,44 @@ def _check_process(proc, retcode, timeout, stdout, stderr): return proc.returncode, stdout, stderr +def _get_piped_streams(proc): + """Get a list of all valid standard streams for proc that were opened with PIPE option. + + If proc was started from a Pipeline command, this function assumes it will have a + "srcproc" member pointing to the previous command in the pipeline. That link will + be used to traverse all started processes started from the pipeline, the list will + include stdout/stderr streams opened as PIPE for all commands in the pipeline. + If that was not the case, some processes could write to pipes no one reads from + which would result in process stalling after the pipe's buffer is filled. + + Streams that were closed (because they were redirected to the input of a subsequent command) + are not included in the result + """ + streams = [] + + def add_stream(type_, stream): + if stream is None or stream.closed: + return + streams.append((type_, stream)) + + while proc: + add_stream(1, proc.stderr) + add_stream(0, proc.stdout) + proc = getattr(proc, "srcproc", None) + + return streams + + def _iter_lines_posix(proc, decode, linesize, line_timeout=None): from selectors import EVENT_READ, DefaultSelector + streams = _get_piped_streams(proc) + # Python 3.4+ implementation def selector(): sel = DefaultSelector() - sel.register(proc.stdout, EVENT_READ, 0) - sel.register(proc.stderr, EVENT_READ, 1) + for stream_type, stream in streams: + sel.register(stream, EVENT_READ, stream_type) while True: ready = sel.select(line_timeout) if not ready and line_timeout: @@ -41,10 +71,9 @@ def selector(): yield ret if proc.poll() is not None: break - for line in proc.stdout: - yield 0, decode(line) - for line in proc.stderr: - yield 1, decode(line) + for stream_type, stream in streams: + for line in stream: + yield stream_type, decode(line) def _iter_lines_win32(proc, decode, linesize, line_timeout=None): diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py new file mode 100644 index 000000000..8ea47e2d4 --- /dev/null +++ b/tests/test_pipelines.py @@ -0,0 +1,87 @@ +from typing import List, Tuple + +import pytest + +import plumbum +from plumbum._testtools import skip_on_windows +from plumbum.commands import BaseCommand + + +@skip_on_windows +@pytest.mark.timeout(3) +def test_draining_stderr(generate_cmd, process_cmd): + stdout, stderr = get_output_with_iter_lines( + generate_cmd | process_cmd | process_cmd + ) + expected_output = {f"generated {i}" for i in range(5000)} + expected_output.update(f"consumed {i}" for i in range(5000)) + assert set(stderr) - expected_output == set() + assert len(stderr) == 15000 + assert len(stdout) == 5000 + + +@skip_on_windows +@pytest.mark.timeout(3) +def test_draining_stderr_with_stderr_redirect(tmp_path, generate_cmd, process_cmd): + stdout, stderr = get_output_with_iter_lines( + generate_cmd | (process_cmd >= str(tmp_path / "output.txt")) | process_cmd + ) + expected_output = {f"generated {i}" for i in range(5000)} + expected_output.update(f"consumed {i}" for i in range(5000)) + assert set(stderr) - expected_output == set() + assert len(stderr) == 10000 + assert len(stdout) == 5000 + + +@skip_on_windows +@pytest.mark.timeout(3) +def test_draining_stderr_with_stdout_redirect(tmp_path, generate_cmd, process_cmd): + stdout, stderr = get_output_with_iter_lines( + generate_cmd | process_cmd | process_cmd > str(tmp_path / "output.txt") + ) + expected_output = {f"generated {i}" for i in range(5000)} + expected_output.update(f"consumed {i}" for i in range(5000)) + assert set(stderr) - expected_output == set() + assert len(stderr) == 15000 + assert len(stdout) == 0 + + +@pytest.fixture() +def generate_cmd(tmp_path): + generate = tmp_path / "generate.py" + generate.write_text( + """\ +import sys +for i in range(5000): + print("generated", i, file=sys.stderr) + print(i) +""" + ) + return plumbum.local["python"][generate] + + +@pytest.fixture() +def process_cmd(tmp_path): + process = tmp_path / "process.py" + process.write_text( + """\ +import sys +for line in sys.stdin: + i = line.strip() + print("consumed", i, file=sys.stderr) + print(i) +""" + ) + return plumbum.local["python"][process] + + +def get_output_with_iter_lines(cmd: BaseCommand) -> Tuple[List[str], List[str]]: + stderr, stdout = [], [] + proc = cmd.popen() + for stdout_line, stderr_line in proc.iter_lines(retcode=[0, None]): + if stderr_line: + stderr.append(stderr_line) + if stdout_line: + stdout.append(stdout_line) + proc.wait() + return stdout, stderr