From f3403e614e97ba6a5913370110f0e73f701b8eba Mon Sep 17 00:00:00 2001 From: Luis Cruz Date: Fri, 7 Mar 2025 10:55:37 +0000 Subject: [PATCH] natlab: Add generic ssh connection class --- nat-lab/tests/conftest.py | 40 ++++----- nat-lab/tests/telio.py | 23 ++--- nat-lab/tests/test_batching.py | 2 +- nat-lab/tests/test_cleanup.py | 4 +- nat-lab/tests/test_lana.py | 4 +- nat-lab/tests/utils/connection/connection.py | 15 +++- .../utils/connection/docker_connection.py | 54 +++++++----- .../tests/utils/connection/ssh_connection.py | 87 ++++++++++++++++--- nat-lab/tests/utils/connection_util.py | 67 +++++--------- nat-lab/tests/utils/tcpdump.py | 2 +- nat-lab/tests/utils/vm/mac_vm_util.py | 53 +---------- nat-lab/tests/utils/vm/windows_vm_util.py | 59 +------------ 12 files changed, 175 insertions(+), 235 deletions(-) diff --git a/nat-lab/tests/conftest.py b/nat-lab/tests/conftest.py index 0f2d97229..762957fc0 100644 --- a/nat-lab/tests/conftest.py +++ b/nat-lab/tests/conftest.py @@ -18,7 +18,6 @@ from utils.process import ProcessExecError from utils.router import IPStack from utils.tcpdump import make_tcpdump, make_local_tcpdump -from utils.vm import mac_vm_util, windows_vm_util DERP_SERVER_1_ADDR = "http://10.0.10.1:8765" DERP_SERVER_2_ADDR = "http://10.0.10.2:8765" @@ -266,27 +265,16 @@ async def perform_pretest_cleanups(): async def _copy_vm_binaries(tag: ConnectionTag): - if tag in [ConnectionTag.WINDOWS_VM_1, ConnectionTag.WINDOWS_VM_2]: - try: - print(f"copying for {tag}") - async with windows_vm_util.new_connection( - LAN_ADDR_MAP[tag], copy_binaries=True, reenable_nat=True - ): - pass - except OSError as e: - if os.environ.get("GITLAB_CI"): - raise e - print(e) - elif tag is ConnectionTag.MAC_VM: - try: - async with mac_vm_util.new_connection( - copy_binaries=True, reenable_nat=True - ): - pass - except OSError as e: - if os.environ.get("GITLAB_CI"): - raise e - print(e) + try: + print(f"copying for {tag}") + async with SshConnection.new_connection( + LAN_ADDR_MAP[tag], tag, copy_binaries=True, reenable_nat=True + ): + pass + except OSError as e: + if os.environ.get("GITLAB_CI"): + raise e + print(e) async def _copy_vm_binaries_if_needed(items): @@ -358,7 +346,9 @@ async def collect_kernel_logs(items, suffix): for item in items: if any(mark.name == "mac" for mark in item.own_markers): try: - async with mac_vm_util.new_connection() as conn: + async with SshConnection.new_connection( + LAN_ADDR_MAP[ConnectionTag.MAC_VM], ConnectionTag.MAC_VM + ) as conn: await _save_macos_logs(conn, suffix) except OSError as e: if os.environ.get("GITLAB_CI"): @@ -465,7 +455,9 @@ async def collect_mac_diagnostic_reports(): return print("Collect mac diagnostic reports") try: - async with mac_vm_util.new_connection() as connection: + async with SshConnection.new_connection( + LAN_ADDR_MAP[ConnectionTag.MAC_VM], ConnectionTag.MAC_VM + ) as connection: await connection.download( "/Library/Logs/DiagnosticReports", "logs/system_diagnostic_reports" ) diff --git a/nat-lab/tests/telio.py b/nat-lab/tests/telio.py index 3879f9af4..2c1f08ba0 100644 --- a/nat-lab/tests/telio.py +++ b/nat-lab/tests/telio.py @@ -508,7 +508,7 @@ async def on_stderr(stderr: str) -> None: f"Test cleanup: Exception while flushing logs: {e}", ) - await self.get_proxy().shutdown(self._connection.target_name()) + await self.get_proxy().shutdown(self._connection.tag.name) else: print( datetime.now(), @@ -1083,14 +1083,7 @@ async def _save_logs(self) -> None: system_log_content = await self.get_system_log() - if self._connection.target_os == TargetOS.Linux: - process = self._connection.create_process(["cat", "/etc/hostname"]) - await process.execute() - container_id = process.get_stdout().strip() - else: - container_id = str(self._connection.target_os.name) - - filename = container_id + ".log" + filename = self._connection.tag.name.lower() + ".log" if len(filename.encode("utf-8")) > 256: filename = f"{filename[:251]}.log" @@ -1117,7 +1110,9 @@ async def _save_logs(self) -> None: file_name = os.path.basename(trace_path) os.rename( os.path.join(log_dir, file_name), - os.path.join(log_dir, f"{container_id}-{file_name}"), + os.path.join( + log_dir, f"{self._connection.tag.name.lower()}-{file_name}" + ), ) async def save_moose_db(self) -> None: @@ -1150,9 +1145,7 @@ async def save_mac_network_info(self) -> None: network_info_info = await self.get_network_info() - container_id = str(self._connection.target_os.name) - - filename = container_id + "_network_info.log" + filename = self._connection.tag.name.lower() + "_network_info.log" if len(filename.encode("utf-8")) > 256: filename = f"{filename[:251]}.log" @@ -1220,7 +1213,7 @@ async def collect_core_dumps(self): # if we collected some core dumps, copy them if isinstance(self._connection, DockerConnection) and should_copy_coredumps: - container_name = self._connection.container_name() + container_name = container_id(self._connection.tag) test_name = get_current_test_case_and_parameters()[0] or "" for i, file_path in enumerate(dump_files): file_name = file_path.rsplit("/", 1)[-1] @@ -1250,7 +1243,7 @@ async def find_files(connection, where, name_pattern): def copy_file(from_connection, from_path, destination_path): """Copy a file from within the docker container connection to the destination path""" if isinstance(from_connection, DockerConnection): - container_name = from_connection.container_name() + container_name = container_id(from_connection.tag) file_name = os.path.basename(from_path) core_dump_destination = os.path.join(destination_path, file_name) diff --git a/nat-lab/tests/test_batching.py b/nat-lab/tests/test_batching.py index 23721ddfb..c50a1e30d 100644 --- a/nat-lab/tests/test_batching.py +++ b/nat-lab/tests/test_batching.py @@ -161,7 +161,7 @@ async def test_batching( gateway_container_names = [container_id(conn_tag) for conn_tag in gateways] conns = [client.get_connection() for client in clients] node_container_names = [ - conn.container_name() + container_id(conn.tag) for conn in conns if isinstance(conn, DockerConnection) ] diff --git a/nat-lab/tests/test_cleanup.py b/nat-lab/tests/test_cleanup.py index fc0694b23..c20b2c039 100644 --- a/nat-lab/tests/test_cleanup.py +++ b/nat-lab/tests/test_cleanup.py @@ -21,7 +21,7 @@ async def test_get_network_interface_tunnel_keys(adapter_type, name) -> None: connection = await exit_stack.enter_async_context( new_connection_raw(ConnectionTag.WINDOWS_VM_1) ) - assert [] == await _get_network_interface_tunnel_keys(connection) + assert [] == await get_network_interface_tunnel_keys(connection) _env = await exit_stack.enter_async_context( setup_environment( exit_stack, @@ -40,7 +40,7 @@ async def test_get_network_interface_tunnel_keys(adapter_type, name) -> None: # This function is used during test startup to remove interfaces # that might have managed to survive the end of the previous test. - keys = await _get_network_interface_tunnel_keys(connection) + keys = await get_network_interface_tunnel_keys(connection) assert [ "HKEY_LOCAL_MACHINE\\SYSTEM\\CurrentControlSet\\Control\\Class\\{4d36e972-e325-11ce-bfc1-08002be10318}\\0006" ] == keys diff --git a/nat-lab/tests/test_lana.py b/nat-lab/tests/test_lana.py index 0b9d9536b..05bdf29d5 100644 --- a/nat-lab/tests/test_lana.py +++ b/nat-lab/tests/test_lana.py @@ -202,12 +202,12 @@ async def wait_for_event_dump( events = fetch_moose_events(events_path) if len(events) == nr_events: print( - f"Found db from {connection.target_name()} with the expected {nr_events} events." + f"Found db from {connection.tag.name} with the expected {nr_events} events." ) return events await asyncio.sleep(DEFAULT_CHECK_INTERVAL) print( - f"Failed looking db from {connection.target_name()}, expected {nr_events} but" + f"Failed looking db from {connection.tag.name}, expected {nr_events} but" f" {len(events)} were found." ) return None diff --git a/nat-lab/tests/utils/connection/connection.py b/nat-lab/tests/utils/connection/connection.py index 59ba3a7f1..e1106c1d2 100644 --- a/nat-lab/tests/utils/connection/connection.py +++ b/nat-lab/tests/utils/connection/connection.py @@ -71,9 +71,11 @@ def local(): class Connection(ABC): _target_os: Optional[TargetOS] + _tag: Optional[ConnectionTag] - def __init__(self, target_os: TargetOS) -> None: + def __init__(self, target_os: TargetOS, tag: ConnectionTag) -> None: self._target_os = target_os + self._tag = tag @abstractmethod def create_process( @@ -90,9 +92,14 @@ def target_os(self) -> TargetOS: def target_os(self, target_os: TargetOS) -> None: self._target_os = target_os - @abstractmethod - def target_name(self) -> str: - pass + @property + def tag(self) -> ConnectionTag: + assert self._tag + return self._tag + + @tag.setter + def tag(self, tag: ConnectionTag) -> None: + self._tag = tag @abstractmethod async def download(self, remote_path: str, local_path: str) -> None: diff --git a/nat-lab/tests/utils/connection/docker_connection.py b/nat-lab/tests/utils/connection/docker_connection.py index 61322df56..b243b4b5f 100644 --- a/nat-lab/tests/utils/connection/docker_connection.py +++ b/nat-lab/tests/utils/connection/docker_connection.py @@ -3,9 +3,10 @@ from aiodocker.containers import DockerContainer from asyncio import to_thread from config import LINUX_INTERFACE_NAME +from contextlib import asynccontextmanager from datetime import datetime from subprocess import run -from typing import List, Type +from typing import List, Type, Dict, AsyncIterator from typing_extensions import Self from utils.process import Process, DockerProcess @@ -82,32 +83,37 @@ class DockerConnection(Connection): _container: DockerContainer - _name: str - def __init__(self, container: DockerContainer, container_name: str): - super().__init__(TargetOS.Linux) - self._name = container_name + def __init__(self, container: DockerContainer, tag: ConnectionTag): + super().__init__(TargetOS.Linux, tag) self._container = container - @classmethod - async def new(cls: Type[Self], docker: Docker, container_name: str) -> Self: - new_docker_conn = cls( - await docker.containers.get(container_name), container_name - ) - await new_docker_conn.restore_ip_tables() - await new_docker_conn.clean_interface() + async def __aenter__(self): + await self.restore_ip_tables() + await self.clean_interface() + await setup_ephemeral_ports(self) + return self - return new_docker_conn + async def __aexit__(self, *exc_details): + await self.restore_ip_tables() + await self.clean_interface() + return self - def container_name(self) -> str: - return self._name - - def target_name(self) -> str: - return self.container_name() + @classmethod + @asynccontextmanager + async def new_connection( + cls: Type[Self], docker: Docker, tag: ConnectionTag + ) -> AsyncIterator["DockerConnection"]: + async with cls( + await docker.containers.get(container_id(tag)), tag + ) as connection: + yield connection async def download(self, remote_path: str, local_path: str) -> None: def aux(): - run(["docker", "cp", self._name + ":" + remote_path, local_path]) + run( + ["docker", "cp", container_id(self.tag) + ":" + remote_path, local_path] + ) await to_thread(aux) @@ -115,14 +121,14 @@ def create_process( self, command: List[str], kill_id=None, term_type=None ) -> "Process": process = DockerProcess( - self._container, self.container_name(), command, kill_id + self._container, container_id(self.tag), command, kill_id ) print( datetime.now(), "Executing", command, "on", - self._name, + self.tag.name, "with Kill ID:", process.get_kill_id(), ) @@ -161,3 +167,9 @@ async def clean_interface(self) -> None: ).execute() except: pass # Most of the time there will be no interface to be deleted + + +def container_id(tag: ConnectionTag) -> str: + if tag in DOCKER_SERVICE_IDS: + return f"nat-lab-{DOCKER_SERVICE_IDS[tag]}-1" + assert False, f"tag {tag} not a docker container" diff --git a/nat-lab/tests/utils/connection/ssh_connection.py b/nat-lab/tests/utils/connection/ssh_connection.py index c82ac1b31..5c122ce64 100644 --- a/nat-lab/tests/utils/connection/ssh_connection.py +++ b/nat-lab/tests/utils/connection/ssh_connection.py @@ -1,6 +1,8 @@ import asyncssh import shlex import subprocess +import utils.vm.mac_vm_util as utils_mac +import utils.vm.windows_vm_util as utils_win from .connection import Connection, TargetOS, ConnectionTag, setup_ephemeral_ports from contextlib import asynccontextmanager from datetime import datetime @@ -11,42 +13,93 @@ class SshConnection(Connection): _connection: asyncssh.SSHClientConnection - _vm_name: str - _target_os: TargetOS def __init__( self, connection: asyncssh.SSHClientConnection, - vm_name: str, - target_os: TargetOS, + tag: ConnectionTag, ): - super().__init__(target_os) - self._vm_name = vm_name + if tag in [ConnectionTag.WINDOWS_VM_1, ConnectionTag.WINDOWS_VM_2]: + target_os = TargetOS.Windows + elif tag is ConnectionTag.MAC_VM: + target_os = TargetOS.Mac + else: + assert False, format( + "Can't create ssh connection for the provided tag: %s", tag.name + ) + + super().__init__(target_os, tag) self._connection = connection - self._target_os = target_os + + async def __aenter__(self): + await setup_ephemeral_ports(self) + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + @classmethod + @asynccontextmanager + async def new_connection( + cls, + ip: str, + tag: ConnectionTag, + copy_binaries: bool = False, + reenable_nat=False, + ) -> AsyncIterator["SshConnection"]: + if reenable_nat: + subprocess.check_call(["sudo", "bash", "vm_nat.sh", "disable"]) + subprocess.check_call(["sudo", "bash", "vm_nat.sh", "enable"]) + + ssh_options = asyncssh.SSHClientConnectionOptions( + encryption_algs=[ + "aes128-gcm@openssh.com", + "aes256-ctr", + "aes192-ctr", + "aes128-ctr", + ], + compression_algs=None, + ) + + async with asyncssh.connect( + ip, + username="root", + password="vagrant", # Hardcoded password for transient VM used in tests + known_hosts=None, + options=ssh_options, + ) as ssh_connection: + async with cls(ssh_connection, tag) as connection: + if copy_binaries: + await connection.copy_binaries() + + if connection.target_os is TargetOS.Windows: + keys = await utils_win.get_network_interface_tunnel_keys(connection) + for key in keys: + await connection.create_process( + ["reg", "delete", key, "/F"] + ).execute() + + yield connection def create_process( self, command: List[str], kill_id=None, term_type=None ) -> "Process": print(datetime.now(), "Executing", command, "on", self.target_os) - if self._target_os == TargetOS.Windows: + if self.target_os == TargetOS.Windows: escape_argument = cmd_exe_escape.escape_argument - elif self._target_os in [TargetOS.Linux, TargetOS.Mac]: + elif self.target_os in [TargetOS.Linux, TargetOS.Mac]: escape_argument = shlex.quote else: - assert False, f"not supported target_os '{self._target_os}'" + assert False, f"not supported target_os '{self.target_os}'" return SshProcess( - self._connection, self._vm_name, command, escape_argument, term_type + self._connection, self.tag.name, command, escape_argument, term_type ) async def get_ip_address(self) -> tuple[str, str]: ip = self._connection._host # pylint: disable=protected-access return (ip, ip) - def target_name(self) -> str: - return str(self._target_os) - async def download(self, remote_path: str, local_path: str) -> None: """Copy file from 'remote_path' on the node connected via this connection, to local directory 'local_path'""" try: @@ -57,3 +110,9 @@ async def download(self, remote_path: str, local_path: str) -> None: if "No such file or directory" in e.reason: return raise e + + async def copy_binaries(self) -> None: + if self.target_os is TargetOS.Windows: + await utils_win.copy_binaries(self._connection, self) + elif self.target_os is TargetOS.Mac: + await utils_mac.copy_binaries(self._connection, self) diff --git a/nat-lab/tests/utils/connection_util.py b/nat-lab/tests/utils/connection_util.py index f2930a59e..13ffb0b3f 100644 --- a/nat-lab/tests/utils/connection_util.py +++ b/nat-lab/tests/utils/connection_util.py @@ -58,40 +58,24 @@ def get_uniffi_path(connection: Connection) -> str: assert False, f"target_os not supported '{target_os}'" -@asynccontextmanager -async def connection_setup( - connection: Connection, tag: ConnectionTag -) -> AsyncIterator[Connection]: - await setup_ephemeral_ports(connection, tag) - try: - yield connection - finally: - if isinstance(connection, DockerConnection): - await connection.restore_ip_tables() - await connection.clean_interface() - - @asynccontextmanager async def new_connection_raw( tag: ConnectionTag, ) -> AsyncIterator[Connection]: if tag in DOCKER_SERVICE_IDS: async with Docker() as docker: - connection: Connection = await DockerConnection.new( - docker, container_id(tag) - ) - async with connection_setup(connection, tag) as conn: - yield conn - - elif tag in [ConnectionTag.WINDOWS_VM_1, ConnectionTag.WINDOWS_VM_2]: - async with windows_vm_util.new_connection(LAN_ADDR_MAP[tag]) as connection: - async with connection_setup(connection, tag) as conn: - yield conn - - elif tag == ConnectionTag.MAC_VM: - async with mac_vm_util.new_connection() as connection: - async with connection_setup(connection, tag) as conn: - yield conn + async with DockerConnection.new_connection(docker, tag) as connection: + try: + yield connection + finally: + pass + + elif is_tag_valid_for_ssh_connection(tag): + async with SshConnection.new_connection(LAN_ADDR_MAP[tag], tag) as connection: + try: + yield connection + finally: + pass else: assert False, f"tag {tag} not supported" @@ -117,22 +101,23 @@ async def new_connection_manager_by_tag( tag: ConnectionTag, conn_tracker_config: Optional[List[ConnTrackerEventsValidator]] = None, ) -> AsyncIterator[ConnectionManager]: - # pylint: disable-next=contextmanager-generator-missing-cleanup async with new_connection_raw(tag) as connection: network_switcher = await create_network_switcher(tag, connection) async with network_switcher.switch_to_primary_network(): if tag in DOCKER_GW_MAP: - # pylint: disable-next=contextmanager-generator-missing-cleanup async with new_connection_raw(DOCKER_GW_MAP[tag]) as gw_connection: async with ConnectionTracker( gw_connection, conn_tracker_config ).run() as conn_tracker: - yield ConnectionManager( - connection, - gw_connection, - network_switcher, - conn_tracker, - ) + try: + yield ConnectionManager( + connection, + gw_connection, + network_switcher, + conn_tracker, + ) + finally: + pass else: async with ConnectionTracker( connection, conn_tracker_config @@ -147,14 +132,12 @@ async def new_connection_with_conn_tracker( tag: ConnectionTag, conn_tracker_config: Optional[List[ConnTrackerEventsValidator]], ) -> AsyncIterator[Tuple[Connection, ConnectionTracker]]: - # pylint: disable-next=contextmanager-generator-missing-cleanup async with new_connection_manager_by_tag(tag, conn_tracker_config) as conn_manager: yield (conn_manager.connection, conn_manager.tracker) @asynccontextmanager async def new_connection_by_tag(tag: ConnectionTag) -> AsyncIterator[Connection]: - # pylint: disable-next=contextmanager-generator-missing-cleanup async with new_connection_manager_by_tag(tag, None) as conn_manager: yield conn_manager.connection @@ -165,7 +148,6 @@ async def new_connection_with_node_tracker( conn_tracker_config: Optional[List[ConnTrackerEventsValidator]], ) -> AsyncIterator[Tuple[Connection, ConnectionTracker]]: if tag in DOCKER_SERVICE_IDS: - # pylint: disable-next=contextmanager-generator-missing-cleanup async with new_connection_raw(tag) as connection: network_switcher = await create_network_switcher(tag, connection) async with network_switcher.switch_to_primary_network(): @@ -173,17 +155,10 @@ async def new_connection_with_node_tracker( connection, conn_tracker_config ).run() as conn_tracker: yield (connection, conn_tracker) - else: assert False, f"tag {tag} not supported with node tracker" -def container_id(tag: ConnectionTag) -> str: - if tag in DOCKER_SERVICE_IDS: - return f"nat-lab-{DOCKER_SERVICE_IDS[tag]}-1" - assert False, f"tag {tag} not a docker container" - - def convert_port_to_integer(port: Union[str, int, None]) -> int: if isinstance(port, int): return port diff --git a/nat-lab/tests/utils/tcpdump.py b/nat-lab/tests/utils/tcpdump.py index ea5f2772d..d179f3b59 100644 --- a/nat-lab/tests/utils/tcpdump.py +++ b/nat-lab/tests/utils/tcpdump.py @@ -245,7 +245,7 @@ async def make_tcpdump( continue path = find_unique_path_for_tcpdump( - store_in if store_in else log_dir, conn.target_name() + store_in if store_in else log_dir, conn.tag.name ) await conn.download(PCAP_FILE_PATH[conn.target_os], path) diff --git a/nat-lab/tests/utils/vm/mac_vm_util.py b/nat-lab/tests/utils/vm/mac_vm_util.py index b7dec7eb7..d6a3b23db 100644 --- a/nat-lab/tests/utils/vm/mac_vm_util.py +++ b/nat-lab/tests/utils/vm/mac_vm_util.py @@ -1,61 +1,14 @@ import asyncssh import os -import subprocess -from config import ( - get_root_path, - LIBTELIO_BINARY_PATH_MAC_VM, - MAC_VM_IP, - UNIFFI_PATH_MAC_VM, -) -from contextlib import asynccontextmanager -from typing import AsyncIterator -from utils.connection import Connection, SshConnection, TargetOS +from config import get_root_path, LIBTELIO_BINARY_PATH_MAC_VM, UNIFFI_PATH_MAC_VM +from utils.connection import Connection from utils.process import ProcessExecError VM_TCLI_DIR = LIBTELIO_BINARY_PATH_MAC_VM VM_UNIFFI_DIR = UNIFFI_PATH_MAC_VM -@asynccontextmanager -async def new_connection( - ip: str = MAC_VM_IP, - copy_binaries: bool = False, - reenable_nat=False, -) -> AsyncIterator[Connection]: - if reenable_nat: - subprocess.check_call(["sudo", "bash", "vm_nat.sh", "disable"]) - subprocess.check_call(["sudo", "bash", "vm_nat.sh", "enable"]) - - # Speedup large file transfer: https://github.com/ronf/asyncssh/issues/374 - ssh_options = asyncssh.SSHClientConnectionOptions( - encryption_algs=[ - "aes128-gcm@openssh.com", - "aes256-ctr", - "aes192-ctr", - "aes128-ctr", - ], - compression_algs=None, - ) - - async with asyncssh.connect( - ip, - username="root", - password="vagrant", # NOTE: this is hardcoded password for transient vm existing only during the tests - known_hosts=None, - options=ssh_options, - ) as ssh_connection: - connection = SshConnection(ssh_connection, "Mac", TargetOS.Mac) - - if copy_binaries: - await _copy_binaries(ssh_connection, connection) - - try: - yield connection - finally: - pass - - -async def _copy_binaries( +async def copy_binaries( ssh_connection: asyncssh.SSHClientConnection, connection: Connection ) -> None: for directory in [VM_TCLI_DIR, VM_UNIFFI_DIR]: diff --git a/nat-lab/tests/utils/vm/windows_vm_util.py b/nat-lab/tests/utils/vm/windows_vm_util.py index a0084acf0..65f5b94b3 100644 --- a/nat-lab/tests/utils/vm/windows_vm_util.py +++ b/nat-lab/tests/utils/vm/windows_vm_util.py @@ -1,70 +1,19 @@ import asyncssh import os -import subprocess from config import ( get_root_path, LIBTELIO_BINARY_PATH_WINDOWS_VM, UNIFFI_PATH_WINDOWS_VM, - WINDOWS_1_VM_IP, - WINDOWS_2_VM_IP, ) -from contextlib import asynccontextmanager from datetime import datetime -from typing import AsyncIterator, List -from utils.connection import Connection, SshConnection, TargetOS +from typing import List +from utils.connection import Connection from utils.process import ProcessExecError VM_TCLI_DIR = LIBTELIO_BINARY_PATH_WINDOWS_VM VM_UNIFFI_DIR = UNIFFI_PATH_WINDOWS_VM VM_SYSTEM32 = "C:\\Windows\\System32" -NAME = { - WINDOWS_1_VM_IP: "Windows-1", - WINDOWS_2_VM_IP: "Windows-2", -} - - -@asynccontextmanager -async def new_connection( - ip: str = WINDOWS_1_VM_IP, - copy_binaries: bool = False, - reenable_nat=False, -) -> AsyncIterator[Connection]: - if reenable_nat: - subprocess.check_call(["sudo", "bash", "vm_nat.sh", "disable"]) - subprocess.check_call(["sudo", "bash", "vm_nat.sh", "enable"]) - - # Speedup large file transfer: https://github.com/ronf/asyncssh/issues/374 - ssh_options = asyncssh.SSHClientConnectionOptions( - encryption_algs=[ - "aes128-gcm@openssh.com", - "aes256-ctr", - "aes192-ctr", - "aes128-ctr", - ], - compression_algs=None, - ) - - async with asyncssh.connect( - ip, - username="vagrant", - password="vagrant", # NOTE: this is hardcoded password for transient vm existing only during the tests - known_hosts=None, - options=ssh_options, - ) as ssh_connection: - connection = SshConnection(ssh_connection, NAME[ip], TargetOS.Windows) - - keys = await _get_network_interface_tunnel_keys(connection) - for key in keys: - await connection.create_process(["reg", "delete", key, "/F"]).execute() - - if copy_binaries: - await _copy_binaries(ssh_connection, connection) - try: - yield connection - finally: - pass - def _file_copy_progress_handler( srcpath, dstpath, bytes_copied, total, file_copy_progress_buffer @@ -112,7 +61,7 @@ async def _copy_file_with_progress_handler( raise e -async def _copy_binaries( +async def copy_binaries( ssh_connection: asyncssh.SSHClientConnection, connection: Connection ) -> None: for directory in [VM_TCLI_DIR, VM_UNIFFI_DIR]: @@ -156,7 +105,7 @@ async def _copy_binaries( await _copy_file_with_progress_handler(ssh_connection, src, dst, allow_missing) -async def _get_network_interface_tunnel_keys(connection): +async def get_network_interface_tunnel_keys(connection): result = await connection.create_process([ "reg", "query",