Skip to content

Commit

Permalink
natlab: Add generic ssh connection class
Browse files Browse the repository at this point in the history
  • Loading branch information
lcruz99 committed Mar 7, 2025
1 parent 04c7046 commit f3403e6
Show file tree
Hide file tree
Showing 12 changed files with 175 additions and 235 deletions.
40 changes: 16 additions & 24 deletions nat-lab/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from utils.process import ProcessExecError
from utils.router import IPStack
from utils.tcpdump import make_tcpdump, make_local_tcpdump
from utils.vm import mac_vm_util, windows_vm_util

DERP_SERVER_1_ADDR = "http://10.0.10.1:8765"
DERP_SERVER_2_ADDR = "http://10.0.10.2:8765"
Expand Down Expand Up @@ -266,27 +265,16 @@ async def perform_pretest_cleanups():


async def _copy_vm_binaries(tag: ConnectionTag):
if tag in [ConnectionTag.WINDOWS_VM_1, ConnectionTag.WINDOWS_VM_2]:
try:
print(f"copying for {tag}")
async with windows_vm_util.new_connection(
LAN_ADDR_MAP[tag], copy_binaries=True, reenable_nat=True
):
pass
except OSError as e:
if os.environ.get("GITLAB_CI"):
raise e
print(e)
elif tag is ConnectionTag.MAC_VM:
try:
async with mac_vm_util.new_connection(
copy_binaries=True, reenable_nat=True
):
pass
except OSError as e:
if os.environ.get("GITLAB_CI"):
raise e
print(e)
try:
print(f"copying for {tag}")
async with SshConnection.new_connection(
LAN_ADDR_MAP[tag], tag, copy_binaries=True, reenable_nat=True
):
pass
except OSError as e:
if os.environ.get("GITLAB_CI"):
raise e
print(e)


async def _copy_vm_binaries_if_needed(items):
Expand Down Expand Up @@ -358,7 +346,9 @@ async def collect_kernel_logs(items, suffix):
for item in items:
if any(mark.name == "mac" for mark in item.own_markers):
try:
async with mac_vm_util.new_connection() as conn:
async with SshConnection.new_connection(
LAN_ADDR_MAP[ConnectionTag.MAC_VM], ConnectionTag.MAC_VM
) as conn:
await _save_macos_logs(conn, suffix)
except OSError as e:
if os.environ.get("GITLAB_CI"):
Expand Down Expand Up @@ -465,7 +455,9 @@ async def collect_mac_diagnostic_reports():
return
print("Collect mac diagnostic reports")
try:
async with mac_vm_util.new_connection() as connection:
async with SshConnection.new_connection(
LAN_ADDR_MAP[ConnectionTag.MAC_VM], ConnectionTag.MAC_VM
) as connection:
await connection.download(
"/Library/Logs/DiagnosticReports", "logs/system_diagnostic_reports"
)
Expand Down
23 changes: 8 additions & 15 deletions nat-lab/tests/telio.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ async def on_stderr(stderr: str) -> None:
f"Test cleanup: Exception while flushing logs: {e}",
)

await self.get_proxy().shutdown(self._connection.target_name())
await self.get_proxy().shutdown(self._connection.tag.name)
else:
print(
datetime.now(),
Expand Down Expand Up @@ -1083,14 +1083,7 @@ async def _save_logs(self) -> None:

system_log_content = await self.get_system_log()

if self._connection.target_os == TargetOS.Linux:
process = self._connection.create_process(["cat", "/etc/hostname"])
await process.execute()
container_id = process.get_stdout().strip()
else:
container_id = str(self._connection.target_os.name)

filename = container_id + ".log"
filename = self._connection.tag.name.lower() + ".log"
if len(filename.encode("utf-8")) > 256:
filename = f"{filename[:251]}.log"

Expand All @@ -1117,7 +1110,9 @@ async def _save_logs(self) -> None:
file_name = os.path.basename(trace_path)
os.rename(
os.path.join(log_dir, file_name),
os.path.join(log_dir, f"{container_id}-{file_name}"),
os.path.join(
log_dir, f"{self._connection.tag.name.lower()}-{file_name}"
),
)

async def save_moose_db(self) -> None:
Expand Down Expand Up @@ -1150,9 +1145,7 @@ async def save_mac_network_info(self) -> None:

network_info_info = await self.get_network_info()

container_id = str(self._connection.target_os.name)

filename = container_id + "_network_info.log"
filename = self._connection.tag.name.lower() + "_network_info.log"
if len(filename.encode("utf-8")) > 256:
filename = f"{filename[:251]}.log"

Expand Down Expand Up @@ -1220,7 +1213,7 @@ async def collect_core_dumps(self):

# if we collected some core dumps, copy them
if isinstance(self._connection, DockerConnection) and should_copy_coredumps:
container_name = self._connection.container_name()
container_name = container_id(self._connection.tag)
test_name = get_current_test_case_and_parameters()[0] or ""
for i, file_path in enumerate(dump_files):
file_name = file_path.rsplit("/", 1)[-1]
Expand Down Expand Up @@ -1250,7 +1243,7 @@ async def find_files(connection, where, name_pattern):
def copy_file(from_connection, from_path, destination_path):
"""Copy a file from within the docker container connection to the destination path"""
if isinstance(from_connection, DockerConnection):
container_name = from_connection.container_name()
container_name = container_id(from_connection.tag)

file_name = os.path.basename(from_path)
core_dump_destination = os.path.join(destination_path, file_name)
Expand Down
2 changes: 1 addition & 1 deletion nat-lab/tests/test_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ async def test_batching(
gateway_container_names = [container_id(conn_tag) for conn_tag in gateways]
conns = [client.get_connection() for client in clients]
node_container_names = [
conn.container_name()
container_id(conn.tag)
for conn in conns
if isinstance(conn, DockerConnection)
]
Expand Down
4 changes: 2 additions & 2 deletions nat-lab/tests/test_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def test_get_network_interface_tunnel_keys(adapter_type, name) -> None:
connection = await exit_stack.enter_async_context(
new_connection_raw(ConnectionTag.WINDOWS_VM_1)
)
assert [] == await _get_network_interface_tunnel_keys(connection)
assert [] == await get_network_interface_tunnel_keys(connection)
_env = await exit_stack.enter_async_context(
setup_environment(
exit_stack,
Expand All @@ -40,7 +40,7 @@ async def test_get_network_interface_tunnel_keys(adapter_type, name) -> None:

# This function is used during test startup to remove interfaces
# that might have managed to survive the end of the previous test.
keys = await _get_network_interface_tunnel_keys(connection)
keys = await get_network_interface_tunnel_keys(connection)
assert [
"HKEY_LOCAL_MACHINE\\SYSTEM\\CurrentControlSet\\Control\\Class\\{4d36e972-e325-11ce-bfc1-08002be10318}\\0006"
] == keys
Expand Down
4 changes: 2 additions & 2 deletions nat-lab/tests/test_lana.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,12 @@ async def wait_for_event_dump(
events = fetch_moose_events(events_path)
if len(events) == nr_events:
print(
f"Found db from {connection.target_name()} with the expected {nr_events} events."
f"Found db from {connection.tag.name} with the expected {nr_events} events."
)
return events
await asyncio.sleep(DEFAULT_CHECK_INTERVAL)
print(
f"Failed looking db from {connection.target_name()}, expected {nr_events} but"
f"Failed looking db from {connection.tag.name}, expected {nr_events} but"
f" {len(events)} were found."
)
return None
Expand Down
15 changes: 11 additions & 4 deletions nat-lab/tests/utils/connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@ def local():

class Connection(ABC):
_target_os: Optional[TargetOS]
_tag: Optional[ConnectionTag]

def __init__(self, target_os: TargetOS) -> None:
def __init__(self, target_os: TargetOS, tag: ConnectionTag) -> None:
self._target_os = target_os
self._tag = tag

@abstractmethod
def create_process(
Expand All @@ -90,9 +92,14 @@ def target_os(self) -> TargetOS:
def target_os(self, target_os: TargetOS) -> None:
self._target_os = target_os

@abstractmethod
def target_name(self) -> str:
pass
@property
def tag(self) -> ConnectionTag:
assert self._tag
return self._tag

@tag.setter
def tag(self, tag: ConnectionTag) -> None:
self._tag = tag

@abstractmethod
async def download(self, remote_path: str, local_path: str) -> None:
Expand Down
54 changes: 33 additions & 21 deletions nat-lab/tests/utils/connection/docker_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from aiodocker.containers import DockerContainer
from asyncio import to_thread
from config import LINUX_INTERFACE_NAME
from contextlib import asynccontextmanager
from datetime import datetime
from subprocess import run
from typing import List, Type
from typing import List, Type, Dict, AsyncIterator
from typing_extensions import Self
from utils.process import Process, DockerProcess

Expand Down Expand Up @@ -82,47 +83,52 @@

class DockerConnection(Connection):
_container: DockerContainer
_name: str

def __init__(self, container: DockerContainer, container_name: str):
super().__init__(TargetOS.Linux)
self._name = container_name
def __init__(self, container: DockerContainer, tag: ConnectionTag):
super().__init__(TargetOS.Linux, tag)
self._container = container

@classmethod
async def new(cls: Type[Self], docker: Docker, container_name: str) -> Self:
new_docker_conn = cls(
await docker.containers.get(container_name), container_name
)
await new_docker_conn.restore_ip_tables()
await new_docker_conn.clean_interface()
async def __aenter__(self):
await self.restore_ip_tables()
await self.clean_interface()
await setup_ephemeral_ports(self)
return self

return new_docker_conn
async def __aexit__(self, *exc_details):
await self.restore_ip_tables()
await self.clean_interface()
return self

def container_name(self) -> str:
return self._name

def target_name(self) -> str:
return self.container_name()
@classmethod
@asynccontextmanager
async def new_connection(
cls: Type[Self], docker: Docker, tag: ConnectionTag
) -> AsyncIterator["DockerConnection"]:
async with cls(
await docker.containers.get(container_id(tag)), tag
) as connection:
yield connection

async def download(self, remote_path: str, local_path: str) -> None:
def aux():
run(["docker", "cp", self._name + ":" + remote_path, local_path])
run(
["docker", "cp", container_id(self.tag) + ":" + remote_path, local_path]
)

await to_thread(aux)

def create_process(
self, command: List[str], kill_id=None, term_type=None
) -> "Process":
process = DockerProcess(
self._container, self.container_name(), command, kill_id
self._container, container_id(self.tag), command, kill_id
)
print(
datetime.now(),
"Executing",
command,
"on",
self._name,
self.tag.name,
"with Kill ID:",
process.get_kill_id(),
)
Expand Down Expand Up @@ -161,3 +167,9 @@ async def clean_interface(self) -> None:
).execute()
except:
pass # Most of the time there will be no interface to be deleted


def container_id(tag: ConnectionTag) -> str:
if tag in DOCKER_SERVICE_IDS:
return f"nat-lab-{DOCKER_SERVICE_IDS[tag]}-1"
assert False, f"tag {tag} not a docker container"
Loading

0 comments on commit f3403e6

Please sign in to comment.