diff --git a/sshd/enter.py b/sshd/enter.py index d19e3203a..1ce1b3922 100755 --- a/sshd/enter.py +++ b/sshd/enter.py @@ -6,7 +6,8 @@ import shlex import sys import time - +import signal +import threading import docker import redis @@ -36,6 +37,14 @@ def get_docker_client(user_id): docker_client = docker.DockerClient(base_url=docker_host, tls=False) return docker_host, docker_client, is_mac +def kill_exec_on_container_death(container, exec_pid): + container.wait(condition="not-running") + try: + os.kill(exec_pid, signal.SIGTERM) + time.sleep(0.5) + except ProcessLookupError: + pass + def main(): original_command = os.getenv("SSH_ORIGINAL_COMMAND") @@ -86,8 +95,8 @@ def print(*args, **kwargs): attempts = 0 print("\r", " " * 80, "\rConnected!") - - if not os.fork(): + child_pid = os.fork(); + if not child_pid: ssh_entrypoint = "/run/dojo/bin/ssh-entrypoint" if is_mac: cmd = f"/bin/bash -c {shlex.quote(original_command)}" if original_command else "zsh -i" @@ -114,6 +123,13 @@ def print(*args, **kwargs): ) else: + runtime = (container.attrs or {}).get("HostConfig",{}).get("Runtime") + is_kata = runtime == "io.containerd.run.kata.v2" + if is_kata: + monitor_thread = threading.Thread(target=kill_exec_on_container_death, + args=(container,child_pid), + daemon=True) + monitor_thread.start() _, status = os.wait() if simple or status == 0: break