Skip to content

Commit

Permalink
add option to change terminal type for remote process
Browse files Browse the repository at this point in the history
  • Loading branch information
gytsto committed Dec 6, 2024
1 parent b75b367 commit e3fd5c5
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 26 deletions.
4 changes: 3 additions & 1 deletion nat-lab/tests/utils/connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def __init__(self, target_os: TargetOS) -> None:
self._target_os = target_os

@abstractmethod
def create_process(self, command: List[str], kill_id=None) -> "Process":
def create_process(
self, command: List[str], kill_id=None, term_type=None
) -> "Process":
pass

@property
Expand Down
4 changes: 3 additions & 1 deletion nat-lab/tests/utils/connection/docker_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def aux():

await to_thread(aux)

def create_process(self, command: List[str], kill_id=None) -> "Process":
def create_process(
self, command: List[str], kill_id=None, term_type=None
) -> "Process":
process = DockerProcess(
self._container, self.container_name(), command, kill_id
)
Expand Down
8 changes: 6 additions & 2 deletions nat-lab/tests/utils/connection/ssh_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def __init__(
self._connection = connection
self._target_os = target_os

def create_process(self, command: List[str], kill_id=None) -> "Process":
def create_process(
self, command: List[str], kill_id=None, term_type=None
) -> "Process":
print(datetime.now(), "Executing", command, "on", self.target_os)
if self._target_os == TargetOS.Windows:
escape_argument = cmd_exe_escape.escape_argument
Expand All @@ -32,7 +34,9 @@ def create_process(self, command: List[str], kill_id=None) -> "Process":
else:
assert False, f"not supported target_os '{self._target_os}'"

return SshProcess(self._connection, self._vm_name, command, escape_argument)
return SshProcess(
self._connection, self._vm_name, command, escape_argument, term_type
)

async def get_ip_address(self) -> tuple[str, str]:
ip = self._connection._host # pylint: disable=protected-access
Expand Down
4 changes: 2 additions & 2 deletions nat-lab/tests/utils/connection_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class ConnectionTag(Enum):
ConnectionTag.DOCKER_DERP_1: "derp-01",
ConnectionTag.DOCKER_DERP_2: "derp-02",
ConnectionTag.DOCKER_DERP_3: "derp-03",
ConnectionTag.DOCKER_DNS_SERVER_1: "dns-server-01",
ConnectionTag.DOCKER_DNS_SERVER_2: "dns-server-02",
ConnectionTag.DOCKER_DNS_SERVER_1: "dns-server-1",
ConnectionTag.DOCKER_DNS_SERVER_2: "dns-server-2",
}


Expand Down
7 changes: 6 additions & 1 deletion nat-lab/tests/utils/process/ssh_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ class SshProcess(Process):
_stdin: Optional[asyncssh.SSHWriter]
_process: Optional[asyncssh.SSHClientProcess]
_running: bool
_term_type: Optional[str]

def __init__(
self,
ssh_connection: asyncssh.SSHClientConnection,
vm_name: str,
command: List[str],
escape_argument: Callable[[str], str],
term_type: Optional[str] = None,
) -> None:
self._ssh_connection = ssh_connection
self._vm_name = vm_name
Expand All @@ -36,6 +38,7 @@ def __init__(
self._escape_argument = escape_argument
self._process = None
self._running = False
self._term_type = term_type

async def execute(
self,
Expand All @@ -48,7 +51,9 @@ async def execute(
escaped = [self._escape_argument(arg) for arg in self._command]
command_str = " ".join(escaped)

self._process = await self._ssh_connection.create_process(command_str)
self._process = await self._ssh_connection.create_process(
command_str, term_type=self._term_type
)
self._running = True
self._stdin = self._process.stdin
self._stdin_ready.set()
Expand Down
54 changes: 35 additions & 19 deletions nat-lab/tests/utils/tcpdump.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
import os
from asyncio import Event
from asyncio import Event, wait_for
from config import WINDUMP_BINARY_WINDOWS
from contextlib import asynccontextmanager, AsyncExitStack
from typing import AsyncIterator, Optional, List
from typing import AsyncIterator, Optional
from utils.connection import TargetOS, Connection
from utils.output_notifier import OutputNotifier
from utils.process import Process
from utils.testing import get_current_test_log_path

PCAP_FILE_PATH = {
TargetOS.Linux: "/dump.pcap",
TargetOS.Mac: "/tmp/dump.pcap",
TargetOS.Mac: "/var/root/dump.pcap",
TargetOS.Windows: "C:\\workspace\\dump.pcap",
}


class TcpDump:
interfaces: Optional[List[str]]
interfaces: Optional[list[str]]
connection: Connection
process: Process
stdout: str
Expand All @@ -28,8 +28,9 @@ class TcpDump:
def __init__(
self,
connection: Connection,
filters: List[str],
interfaces: Optional[List[str]] = None,
flags: Optional[list[str]] = None,
expressions: Optional[list[str]] = None,
interfaces: Optional[list[str]] = None,
output_file: Optional[str] = None,
count: Optional[int] = None,
) -> None:
Expand All @@ -44,34 +45,40 @@ def __init__(

self.output_notifier.notify_output("listening on", self.start_event)

command = [
self.get_tcpdump_binary(connection.target_os),
"-l",
]
command = [self.get_tcpdump_binary(connection.target_os), "-n"]

if self.output_file:
command += ["-w", self.output_file]
else:
command += ["-w", PCAP_FILE_PATH[self.connection.target_os]]

if self.interfaces:
for interface in self.interfaces:
command += ["-i", interface]
command += ["-i", ",".join(self.interfaces)]
else:
if self.connection.target_os != TargetOS.Windows:
command += ["-i", "any"]
else:
command += ["-i", "1", "-i", "2"]

if self.count:
command += ["-c", str(self.count)]

if flags:
command += flags

if self.connection.target_os != TargetOS.Windows:
command += ["--immediate-mode"]
command += ["port not 22"]
else:
command += ["not port 22"]

if self.count:
command += ["-c", self.count]

command += filters
if expressions:
command += expressions

self.process = self.connection.create_process(command)
self.process = self.connection.create_process(
command,
term_type="xterm" if self.connection.target_os == TargetOS.Mac else None,
)

@staticmethod
def get_tcpdump_binary(target_os: TargetOS) -> str:
Expand All @@ -90,10 +97,12 @@ def get_stderr(self) -> str:
return self.stderr

async def on_stdout(self, output: str) -> None:
print(f"tcpdump: {output}")
self.stdout += output
await self.output_notifier.handle_output(output)

async def on_stderr(self, output: str) -> None:
print(f"tcpdump err: {output}")
self.stderr += output
await self.output_notifier.handle_output(output)

Expand All @@ -107,7 +116,7 @@ async def execute(self) -> None:
@asynccontextmanager
async def run(self) -> AsyncIterator["TcpDump"]:
async with self.process.run(self.on_stdout, self.on_stderr, True):
await self.start_event.wait()
await wait_for(self.start_event.wait(), 10)
yield self


Expand All @@ -132,7 +141,7 @@ async def make_tcpdump(
try:
async with AsyncExitStack() as exit_stack:
for conn in connection_list:
await exit_stack.enter_async_context(TcpDump(conn, ["-U"]).run())
await exit_stack.enter_async_context(TcpDump(conn).run())
yield
finally:
if download:
Expand All @@ -143,3 +152,10 @@ async def make_tcpdump(
store_in if store_in else log_dir, conn.target_name()
)
await conn.download(PCAP_FILE_PATH[conn.target_os], path)

if conn.target_os != TargetOS.Windows:
await conn.create_process(
["rm", "-f", PCAP_FILE_PATH[conn.target_os]]
).execute()
else:
await conn.create_process(["del", PCAP_FILE_PATH[conn.target_os]]).execute()

0 comments on commit e3fd5c5

Please sign in to comment.