Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 52 additions & 15 deletions packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
import shlex
import subprocess
import tempfile
from dataclasses import dataclass
from urllib.parse import urlparse

Expand Down Expand Up @@ -45,6 +47,7 @@ def run(self, direct, args):
# Get SSH command and default username from driver
ssh_command = self.call("get_ssh_command")
default_username = self.call("get_default_username")
ssh_identity = self.call("get_ssh_identity")

if direct:
# Use direct TCP address
Expand All @@ -56,7 +59,7 @@ def run(self, direct, args):
if not host or not port:
raise ValueError(f"Invalid address format: {address}")
self.logger.debug(f"Using direct TCP connection for SSH - host: {host}, port: {port}")
return self._run_ssh_local(host, port, ssh_command, default_username, args)
return self._run_ssh_local(host, port, ssh_command, default_username, ssh_identity, args)
except (DriverMethodNotImplemented, ValueError) as e:
self.logger.error(f"Direct address connection failed ({e}), falling back to SSH port forwarding")
return self.run(False, args)
Expand All @@ -69,27 +72,61 @@ def run(self, direct, args):
host = addr[0]
port = addr[1]
self.logger.debug(f"SSH port forward established - host: {host}, port: {port}")
return self._run_ssh_local(host, port, ssh_command, default_username, args)
return self._run_ssh_local(host, port, ssh_command, default_username, ssh_identity, args)

def _run_ssh_local(self, host, port, ssh_command, default_username, args):
def _run_ssh_local(self, host, port, ssh_command, default_username, ssh_identity, args):
"""Run SSH command with the given host, port, and arguments"""
# Build SSH command arguments
ssh_args = self._build_ssh_command_args(ssh_command, port, default_username, args)

# Separate SSH options from command arguments
ssh_options, command_args = self._separate_ssh_options_and_command_args(args)

# Build final SSH command
ssh_args = self._build_final_ssh_command(ssh_args, ssh_options, host, command_args)

# Execute the command
return self._execute_ssh_command(ssh_args)
# Create temporary identity file if needed
identity_file = None
temp_file = None
if ssh_identity:
try:
temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_ssh_key')
temp_file.write(ssh_identity)
temp_file.close()
# Set proper permissions (600) for SSH key
os.chmod(temp_file.name, 0o600)
identity_file = temp_file.name
self.logger.debug(f"Created temporary identity file: {identity_file}")
except Exception as e:
self.logger.error(f"Failed to create temporary identity file: {e}")
if temp_file:
try:
os.unlink(temp_file.name)
except Exception:
pass
raise

def _build_ssh_command_args(self, ssh_command, port, default_username, args):
try:
# Build SSH command arguments
ssh_args = self._build_ssh_command_args(ssh_command, port, default_username, identity_file, args)

# Separate SSH options from command arguments
ssh_options, command_args = self._separate_ssh_options_and_command_args(args)

# Build final SSH command
ssh_args = self._build_final_ssh_command(ssh_args, ssh_options, host, command_args)

# Execute the command
return self._execute_ssh_command(ssh_args)
finally:
# Clean up temporary identity file
if identity_file:
try:
os.unlink(identity_file)
self.logger.debug(f"Cleaned up temporary identity file: {identity_file}")
except Exception as e:
self.logger.warning(f"Failed to clean up temporary identity file {identity_file}: {e}")

def _build_ssh_command_args(self, ssh_command, port, default_username, identity_file, args):
"""Build initial SSH command arguments"""
# Split the SSH command into individual arguments
ssh_args = shlex.split(ssh_command)

# Add identity file if provided
if identity_file:
ssh_args.extend(["-i", identity_file])

# Add port if specified
if port and port != 22:
ssh_args.extend(["-p", str(port)])
Expand Down
18 changes: 18 additions & 0 deletions packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from pathlib import Path

from jumpstarter.common.exceptions import ConfigurationError
from jumpstarter.driver import Driver, export
Expand All @@ -10,6 +11,8 @@ class SSHWrapper(Driver):

default_username: str = ""
ssh_command: str = "ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o LogLevel=ERROR"
ssh_identity: str | None = None
ssh_identity_file: str | None = None

def __post_init__(self):
if hasattr(super(), "__post_init__"):
Expand All @@ -18,6 +21,16 @@ def __post_init__(self):
if "tcp" not in self.children:
raise ConfigurationError("'tcp' child is required via ref, or directly as a TcpNetwork driver instance")

if self.ssh_identity and self.ssh_identity_file:
raise ConfigurationError("Cannot specify both ssh_identity and ssh_identity_file")

# If ssh_identity_file is provided, read it into ssh_identity
if self.ssh_identity_file:
try:
self.ssh_identity = Path(self.ssh_identity_file).read_text()
except Exception as e:
raise ConfigurationError(f"Failed to read ssh_identity_file '{self.ssh_identity_file}': {e}") from None

@classmethod
def client(cls) -> str:
return "jumpstarter_driver_ssh.client.SSHWrapperClient"
Expand All @@ -31,3 +44,8 @@ def get_default_username(self):
def get_ssh_command(self):
"""Get the SSH command to use"""
return self.ssh_command

@export
def get_ssh_identity(self):
"""Get the SSH identity key content"""
return self.ssh_identity
Loading