diff --git a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py index c31e36ac8..6f3b6f6ac 100644 --- a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py +++ b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/client.py @@ -1,5 +1,7 @@ +import os import shlex import subprocess +import tempfile from dataclasses import dataclass from urllib.parse import urlparse @@ -45,6 +47,7 @@ def run(self, direct, args): # Get SSH command and default username from driver ssh_command = self.call("get_ssh_command") default_username = self.call("get_default_username") + ssh_identity = self.call("get_ssh_identity") if direct: # Use direct TCP address @@ -56,7 +59,7 @@ def run(self, direct, args): if not host or not port: raise ValueError(f"Invalid address format: {address}") self.logger.debug(f"Using direct TCP connection for SSH - host: {host}, port: {port}") - return self._run_ssh_local(host, port, ssh_command, default_username, args) + return self._run_ssh_local(host, port, ssh_command, default_username, ssh_identity, args) except (DriverMethodNotImplemented, ValueError) as e: self.logger.error(f"Direct address connection failed ({e}), falling back to SSH port forwarding") return self.run(False, args) @@ -69,27 +72,61 @@ def run(self, direct, args): host = addr[0] port = addr[1] self.logger.debug(f"SSH port forward established - host: {host}, port: {port}") - return self._run_ssh_local(host, port, ssh_command, default_username, args) + return self._run_ssh_local(host, port, ssh_command, default_username, ssh_identity, args) - def _run_ssh_local(self, host, port, ssh_command, default_username, args): + def _run_ssh_local(self, host, port, ssh_command, default_username, ssh_identity, args): """Run SSH command with the given host, port, and arguments""" - # Build SSH command arguments - ssh_args = self._build_ssh_command_args(ssh_command, port, default_username, args) - - # Separate SSH options from command arguments - ssh_options, command_args = self._separate_ssh_options_and_command_args(args) - - # Build final SSH command - ssh_args = self._build_final_ssh_command(ssh_args, ssh_options, host, command_args) - - # Execute the command - return self._execute_ssh_command(ssh_args) + # Create temporary identity file if needed + identity_file = None + temp_file = None + if ssh_identity: + try: + temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='_ssh_key') + temp_file.write(ssh_identity) + temp_file.close() + # Set proper permissions (600) for SSH key + os.chmod(temp_file.name, 0o600) + identity_file = temp_file.name + self.logger.debug(f"Created temporary identity file: {identity_file}") + except Exception as e: + self.logger.error(f"Failed to create temporary identity file: {e}") + if temp_file: + try: + os.unlink(temp_file.name) + except Exception: + pass + raise - def _build_ssh_command_args(self, ssh_command, port, default_username, args): + try: + # Build SSH command arguments + ssh_args = self._build_ssh_command_args(ssh_command, port, default_username, identity_file, args) + + # Separate SSH options from command arguments + ssh_options, command_args = self._separate_ssh_options_and_command_args(args) + + # Build final SSH command + ssh_args = self._build_final_ssh_command(ssh_args, ssh_options, host, command_args) + + # Execute the command + return self._execute_ssh_command(ssh_args) + finally: + # Clean up temporary identity file + if identity_file: + try: + os.unlink(identity_file) + self.logger.debug(f"Cleaned up temporary identity file: {identity_file}") + except Exception as e: + self.logger.warning(f"Failed to clean up temporary identity file {identity_file}: {e}") + + def _build_ssh_command_args(self, ssh_command, port, default_username, identity_file, args): """Build initial SSH command arguments""" # Split the SSH command into individual arguments ssh_args = shlex.split(ssh_command) + # Add identity file if provided + if identity_file: + ssh_args.extend(["-i", identity_file]) + # Add port if specified if port and port != 22: ssh_args.extend(["-p", str(port)]) diff --git a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver.py b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver.py index 657e4e8f8..ec5597ca9 100644 --- a/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver.py +++ b/packages/jumpstarter-driver-ssh/jumpstarter_driver_ssh/driver.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from pathlib import Path from jumpstarter.common.exceptions import ConfigurationError from jumpstarter.driver import Driver, export @@ -10,6 +11,8 @@ class SSHWrapper(Driver): default_username: str = "" ssh_command: str = "ssh -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o LogLevel=ERROR" + ssh_identity: str | None = None + ssh_identity_file: str | None = None def __post_init__(self): if hasattr(super(), "__post_init__"): @@ -18,6 +21,16 @@ def __post_init__(self): if "tcp" not in self.children: raise ConfigurationError("'tcp' child is required via ref, or directly as a TcpNetwork driver instance") + if self.ssh_identity and self.ssh_identity_file: + raise ConfigurationError("Cannot specify both ssh_identity and ssh_identity_file") + + # If ssh_identity_file is provided, read it into ssh_identity + if self.ssh_identity_file: + try: + self.ssh_identity = Path(self.ssh_identity_file).read_text() + except Exception as e: + raise ConfigurationError(f"Failed to read ssh_identity_file '{self.ssh_identity_file}': {e}") from None + @classmethod def client(cls) -> str: return "jumpstarter_driver_ssh.client.SSHWrapperClient" @@ -31,3 +44,8 @@ def get_default_username(self): def get_ssh_command(self): """Get the SSH command to use""" return self.ssh_command + + @export + def get_ssh_identity(self): + """Get the SSH identity key content""" + return self.ssh_identity