Skip to content

Commit

Permalink
Replace ps | grep with psutil in SSHAttach (#2029)
Browse files Browse the repository at this point in the history
Fixes: #2019
  • Loading branch information
un-def authored Nov 25, 2024
1 parent 3eb746e commit 0ba717a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 45 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def get_long_description():
"alembic-postgresql-enum",
"asyncpg",
"jinja2",
"psutil",
]

AWS_DEPS = [
Expand Down
66 changes: 21 additions & 45 deletions src/dstack/_internal/core/services/ssh/attach.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import atexit
import re
import subprocess
import time
from pathlib import Path
from typing import Optional

from dstack._internal.compat import IS_WINDOWS
import psutil

from dstack._internal.core.errors import SSHError
from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.services.configs import ConfigManager
Expand All @@ -19,6 +19,9 @@
update_ssh_config,
)

# ssh -L option format: [bind_address:]port:host:hostport
_SSH_TUNNEL_REGEX = re.compile(r"(?:[\w.-]+:)?(?P<local_port>\d+):localhost:(?P<remote_port>\d+)")


class SSHAttach:
@classmethod
Expand All @@ -27,51 +30,24 @@ def get_control_sock_path(cls, run_name: str) -> Path:

@classmethod
def reuse_ports_lock(cls, run_name: str) -> Optional[PortsLock]:
if not get_ssh_client_info().supports_control_socket:
ssh_client_info = get_ssh_client_info()
if not ssh_client_info.supports_control_socket:
raise SSHError("Unsupported SSH client")
ssh_exe = str(ssh_client_info.path)
control_sock_path = normalize_path(cls.get_control_sock_path(run_name))
filter_prefix: str
output: bytes
if IS_WINDOWS:
filter_prefix = "powershell"
# This script always returns zero exit status.
output = subprocess.check_output(
[
"powershell",
"-c",
f"""Get-CimInstance Win32_Process \
-filter "commandline like '%-S {control_sock_path}%'" \
| select -ExpandProperty CommandLine \
""",
]
)
else:
filter_prefix = "grep"
ps = subprocess.Popen(("ps", "-A", "-o", "command"), stdout=subprocess.PIPE)
cp = subprocess.run(
["grep", "-F", "--", f"-S {control_sock_path}"],
stdin=ps.stdout,
stdout=subprocess.PIPE,
)
ps.wait()
# From grep man page: "the exit status is 0 if a line is selected,
# 1 if no lines were selected, and 2 if an error occurred".
if cp.returncode == 1:
return None
if cp.returncode != 0:
raise SSHError(
f"Unexpected grep exit status {cp.returncode} while searching for ssh process"
)
output = cp.stdout
commands = list(
filter(lambda s: not s.startswith(filter_prefix), output.decode().strip().split("\n"))
)
if commands:
port_pattern = r"-L (?:[\w.-]+:)?(\d+):localhost:(\d+)"
matches = re.findall(port_pattern, commands[0])
return PortsLock(
{int(target_port): int(local_port) for local_port, target_port in matches}
)
for process in psutil.process_iter(["cmdline"]):
cmdline = process.info["cmdline"]
if not (cmdline and cmdline[0] == ssh_exe and control_sock_path in cmdline):
continue
port_mapping: dict[int, int] = {}
cmdline_iter = iter(cmdline)
for arg in cmdline_iter:
if arg != "-L" or not (next_arg := next(cmdline_iter, None)):
continue
if match := _SSH_TUNNEL_REGEX.fullmatch(next_arg):
local_port, remote_port = match.group("local_port", "remote_port")
port_mapping[int(remote_port)] = int(local_port)
return PortsLock(port_mapping)
return None

def __init__(
Expand Down

0 comments on commit 0ba717a

Please sign in to comment.