From bcf70787968c60ef4302c82678d6dc50a61eb9f1 Mon Sep 17 00:00:00 2001 From: Lihu Ben-Ezri-Ravin Date: Fri, 23 Jul 2021 17:20:28 -0400 Subject: [PATCH] feat: Add support for OpenSSH forwarding Unix sockets (#550) OpenSSH clients can forward Unix sockets instead of TCP ports with the same `-L` flag using a slightly modified argument. This patch should allow users to take advantage of the full syntax. From the man pages: ``` -L [bind_address:]port:host:hostport -L [bind_address:]port:remote_socket -L local_socket:host:hostport -L local_socket:remote_socket ``` This patch also adds a protection for the feature. --- CHANGELOG.rst | 6 ++++ plumbum/machines/ssh_machine.py | 20 ++++++++---- tests/test_remote.py | 54 ++++++++++++++++++++++----------- 3 files changed, 56 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6452d4c7b..f48a30250 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,3 +1,9 @@ +1.8.0 +----- + +* SSHMachine: support forwarding Unix sockets in ``.tunnel()`` (`#550 `_) + + 1.7.0 ----- diff --git a/plumbum/machines/ssh_machine.py b/plumbum/machines/ssh_machine.py index 8250f9a9e..dc99b3d62 100644 --- a/plumbum/machines/ssh_machine.py +++ b/plumbum/machines/ssh_machine.py @@ -225,9 +225,12 @@ def tunnel( r"""Creates an SSH tunnel from the TCP port (``lport``) of the local machine (``lhost``, defaults to ``"localhost"``, but it can be any IP you can ``bind()``) to the remote TCP port (``dport``) of the destination machine (``dhost``, defaults - to ``"localhost"``, which means *this remote machine*). The returned - :class:`SshTunnel ` object can be used as a - *context-manager*. + to ``"localhost"``, which means *this remote machine*). This function also + supports Unix sockets, in which case the local socket should be passed in as + ``lport`` and the local bind address should be ``None``. The same can be done + for a remote socket, by following the same pattern with ``dport`` and ``dhost``. + The returned :class:`SshTunnel ` object can + be used as a *context-manager*. The more conventional use case is the following:: @@ -263,12 +266,17 @@ def tunnel( rem = SshMachine("megazord") - with rem.tunnel(1234, 5678): + with rem.tunnel(1234, "/var/lib/mysql/mysql.sock", dhost=None): sock = socket.socket() sock.connect(("localhost", 1234)) - # sock is now tunneled to megazord:5678 + # sock is now tunneled to the MySQL socket on megazord """ - ssh_opts = ["-L", "[{}]:{}:[{}]:{}".format(lhost, lport, dhost, dport)] + formatted_lhost = "" if lhost is None else "[{}]:".format(lhost) + formatted_dhost = "" if dhost is None else "[{}]:".format(dhost) + ssh_opts = [ + "-L", + "{}{}:{}{}".format(formatted_lhost, lport, formatted_dhost, dport), + ] proc = self.popen((), ssh_opts=ssh_opts, new_session=True) return SshTunnel( ShellSession( diff --git a/tests/test_remote.py b/tests/test_remote.py index c2a1c2075..ec16c5e03 100644 --- a/tests/test_remote.py +++ b/tests/test_remote.py @@ -267,11 +267,26 @@ def test_copy(self): class BaseRemoteMachineTest(object): - TUNNEL_PROG = r"""import sys, socket + TUNNEL_PROG_AF_INET = r"""import sys, socket s = socket.socket() s.bind(("", 0)) s.listen(1) -sys.stdout.write("{0}\n".format( s.getsockname()[1])) +sys.stdout.write("{0}\n".format(s.getsockname()[1])) +sys.stdout.flush() +s2, _ = s.accept() +data = s2.recv(100) +s2.send(b"hello " + data) +s2.close() +s.close() +""" + + TUNNEL_PROG_AF_UNIX = r"""import sys, socket, tempfile +s = socket.socket(family=socket.AF_UNIX) +socket_location = tempfile.NamedTemporaryFile() +socket_location.close() +s.bind(socket_location.name) +s.listen(1) +sys.stdout.write("{0}\n".format(s.getsockname())) sys.stdout.flush() s2, _ = s.accept() data = s2.recv(100) @@ -438,23 +453,26 @@ def _connect(self): return SshMachine(TEST_HOST) def test_tunnel(self): - with self._connect() as rem: - p = (rem.python["-u"] << self.TUNNEL_PROG).popen() - try: - port = int(p.stdout.readline().decode("ascii").strip()) - except ValueError: - print(p.communicate()) - raise - with rem.tunnel(12222, port) as tun: - s = socket.socket() - s.connect(("localhost", 12222)) - s.send(six.b("world")) - data = s.recv(100) - s.close() + for tunnel_prog in (self.TUNNEL_PROG_AF_INET, self.TUNNEL_PROG_AF_UNIX): + with self._connect() as rem: + p = (rem.python["-u"] << tunnel_prog).popen() + port_or_socket = p.stdout.readline().decode("ascii").strip() + try: + port_or_socket = int(port_or_socket) + dhost = "localhost" + except ValueError: + dhost = None + + with rem.tunnel(12222, port_or_socket, dhost=dhost) as tun: + s = socket.socket() + s.connect(("localhost", 12222)) + s.send(six.b("world")) + data = s.recv(100) + s.close() - print(p.communicate()) - assert data == b"hello world" + print(p.communicate()) + assert data == b"hello world" def test_get(self): with self._connect() as rem: @@ -525,7 +543,7 @@ def _connect(self): def test_tunnel(self): with self._connect() as rem: - p = rem.python["-c", self.TUNNEL_PROG].popen() + p = rem.python["-c", self.TUNNEL_PROG_AF_INET].popen() try: port = int(p.stdout.readline().strip()) except ValueError: