Skip to content

Commit

Permalink
feat: Add support for OpenSSH forwarding Unix sockets (#550)
Browse files Browse the repository at this point in the history
OpenSSH clients can forward Unix sockets instead of TCP ports with the
same `-L` flag using a slightly modified argument. This patch should
allow users to take advantage of the full syntax. From the man pages:

```
     -L [bind_address:]port:host:hostport
     -L [bind_address:]port:remote_socket
     -L local_socket:host:hostport
     -L local_socket:remote_socket
```

This patch also adds a protection for the feature.
  • Loading branch information
Lihu Ben-Ezri-Ravin authored Jul 23, 2021
1 parent c95b8ee commit bcf7078
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 24 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
1.8.0
-----

* SSHMachine: support forwarding Unix sockets in ``.tunnel()`` (`#550 <https://github.com/tomerfiliba/plumbum/pull/550>`_)


1.7.0
-----

Expand Down
20 changes: 14 additions & 6 deletions plumbum/machines/ssh_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,12 @@ def tunnel(
r"""Creates an SSH tunnel from the TCP port (``lport``) of the local machine
(``lhost``, defaults to ``"localhost"``, but it can be any IP you can ``bind()``)
to the remote TCP port (``dport``) of the destination machine (``dhost``, defaults
to ``"localhost"``, which means *this remote machine*). The returned
:class:`SshTunnel <plumbum.machines.remote.SshTunnel>` object can be used as a
*context-manager*.
to ``"localhost"``, which means *this remote machine*). This function also
supports Unix sockets, in which case the local socket should be passed in as
``lport`` and the local bind address should be ``None``. The same can be done
for a remote socket, by following the same pattern with ``dport`` and ``dhost``.
The returned :class:`SshTunnel <plumbum.machines.remote.SshTunnel>` object can
be used as a *context-manager*.
The more conventional use case is the following::
Expand Down Expand Up @@ -263,12 +266,17 @@ def tunnel(
rem = SshMachine("megazord")
with rem.tunnel(1234, 5678):
with rem.tunnel(1234, "/var/lib/mysql/mysql.sock", dhost=None):
sock = socket.socket()
sock.connect(("localhost", 1234))
# sock is now tunneled to megazord:5678
# sock is now tunneled to the MySQL socket on megazord
"""
ssh_opts = ["-L", "[{}]:{}:[{}]:{}".format(lhost, lport, dhost, dport)]
formatted_lhost = "" if lhost is None else "[{}]:".format(lhost)
formatted_dhost = "" if dhost is None else "[{}]:".format(dhost)
ssh_opts = [
"-L",
"{}{}:{}{}".format(formatted_lhost, lport, formatted_dhost, dport),
]
proc = self.popen((), ssh_opts=ssh_opts, new_session=True)
return SshTunnel(
ShellSession(
Expand Down
54 changes: 36 additions & 18 deletions tests/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,11 +267,26 @@ def test_copy(self):


class BaseRemoteMachineTest(object):
TUNNEL_PROG = r"""import sys, socket
TUNNEL_PROG_AF_INET = r"""import sys, socket
s = socket.socket()
s.bind(("", 0))
s.listen(1)
sys.stdout.write("{0}\n".format( s.getsockname()[1]))
sys.stdout.write("{0}\n".format(s.getsockname()[1]))
sys.stdout.flush()
s2, _ = s.accept()
data = s2.recv(100)
s2.send(b"hello " + data)
s2.close()
s.close()
"""

TUNNEL_PROG_AF_UNIX = r"""import sys, socket, tempfile
s = socket.socket(family=socket.AF_UNIX)
socket_location = tempfile.NamedTemporaryFile()
socket_location.close()
s.bind(socket_location.name)
s.listen(1)
sys.stdout.write("{0}\n".format(s.getsockname()))
sys.stdout.flush()
s2, _ = s.accept()
data = s2.recv(100)
Expand Down Expand Up @@ -438,23 +453,26 @@ def _connect(self):
return SshMachine(TEST_HOST)

def test_tunnel(self):
with self._connect() as rem:
p = (rem.python["-u"] << self.TUNNEL_PROG).popen()
try:
port = int(p.stdout.readline().decode("ascii").strip())
except ValueError:
print(p.communicate())
raise

with rem.tunnel(12222, port) as tun:
s = socket.socket()
s.connect(("localhost", 12222))
s.send(six.b("world"))
data = s.recv(100)
s.close()
for tunnel_prog in (self.TUNNEL_PROG_AF_INET, self.TUNNEL_PROG_AF_UNIX):
with self._connect() as rem:
p = (rem.python["-u"] << tunnel_prog).popen()
port_or_socket = p.stdout.readline().decode("ascii").strip()
try:
port_or_socket = int(port_or_socket)
dhost = "localhost"
except ValueError:
dhost = None

with rem.tunnel(12222, port_or_socket, dhost=dhost) as tun:
s = socket.socket()
s.connect(("localhost", 12222))
s.send(six.b("world"))
data = s.recv(100)
s.close()

print(p.communicate())
assert data == b"hello world"
print(p.communicate())
assert data == b"hello world"

def test_get(self):
with self._connect() as rem:
Expand Down Expand Up @@ -525,7 +543,7 @@ def _connect(self):

def test_tunnel(self):
with self._connect() as rem:
p = rem.python["-c", self.TUNNEL_PROG].popen()
p = rem.python["-c", self.TUNNEL_PROG_AF_INET].popen()
try:
port = int(p.stdout.readline().strip())
except ValueError:
Expand Down

0 comments on commit bcf7078

Please sign in to comment.