Skip to content

Commit

Permalink
fix(pro): wait for outer tunnels to exit and clean up properly
Browse files Browse the repository at this point in the history
  • Loading branch information
pascalbreuninger committed Nov 19, 2024
1 parent 57ae622 commit 4856dbd
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 16 deletions.
3 changes: 1 addition & 2 deletions cmd/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ type SSHCmd struct {
SetEnvVars []string

Stdio bool
JumpContainer bool
AgentForwarding bool
GPGAgentForwarding bool
GitSSHSignatureForwarding bool
Expand Down Expand Up @@ -436,7 +435,7 @@ func (cmd *SSHCmd) startTunnel(ctx context.Context, devPodConfig *config.Config,
if cmd.Proxy {
go func() {
if err := cmd.startRunnerServices(ctx, devPodConfig, containerClient, log); err != nil {
log.Error(err)
log.Debug(err)
}
}()
}
Expand Down
24 changes: 16 additions & 8 deletions pkg/inject/inject.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import (
"io"
"os"
"strings"
"sync"
"time"

"github.com/loft-sh/devpod/pkg/command"
"github.com/loft-sh/devpod/pkg/util"
"github.com/loft-sh/log"
perrors "github.com/pkg/errors"
)
Expand Down Expand Up @@ -114,7 +116,8 @@ func InjectAndExecute(
case err = <-execErrChan:
result = <-injectChan
case result = <-injectChan:
// we don't wait for the command termination here and will just retry on error
// give exec some time to properly terminate and clean up
util.WaitForChan(execErrChan, 2*time.Second)
}

// prefer result error
Expand All @@ -126,7 +129,7 @@ func InjectAndExecute(
return result.wasExecuted, nil
}

log.Debugf("Rerun command as binary was injected")
log.Debug("Rerun command as binary was injected")
delayedStderr.Start()
return true, exec(ctx, scriptParams.Command, stdin, stdout, delayedStderr)
}
Expand Down Expand Up @@ -281,14 +284,19 @@ func readLine(reader io.Reader) (string, error) {
}

func pipe(toStdin io.Writer, fromStdin io.Reader, toStdout io.Writer, fromStdout io.Reader) error {
errChan := make(chan error, 2)
var err error
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
_, err := io.Copy(toStdout, fromStdout)
errChan <- err
defer wg.Done()
_, err = io.Copy(toStdout, fromStdout)
}()
wg.Add(1)
go func() {
_, err := io.Copy(toStdin, fromStdin)
errChan <- err
defer wg.Done()
_, err = io.Copy(toStdin, fromStdin)
}()
return <-errChan

wg.Wait()
return err
}
7 changes: 4 additions & 3 deletions pkg/ssh/server/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (s *Server) handler(sess ssh.Session) {
var err error
if isPty {
s.log.Debugf("Execute SSH server PTY command: %s", strings.Join(cmd.Args, " "))
err = HandlePTY(sess, ptyReq, winCh, cmd, nil)
err = s.HandlePTY(sess, ptyReq, winCh, cmd, nil)
} else {
s.log.Debugf("Execute SSH server command: %s", strings.Join(cmd.Args, " "))
err = s.HandleNonPTY(sess, cmd)
Expand Down Expand Up @@ -201,16 +201,17 @@ func (s *Server) HandleNonPTY(sess ssh.Session, cmd *exec.Cmd) (err error) {
}
}()

waitGroup.Wait()
// order is important here!
err = cmd.Wait()
waitGroup.Wait()
if err != nil {
return err
}

return nil
}

func HandlePTY(
func (s *Server) HandlePTY(
sess ssh.Session,
ptyReq ssh.Pty,
winCh <-chan ssh.Window,
Expand Down
6 changes: 3 additions & 3 deletions pkg/tunnel/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/loft-sh/devpod/pkg/config"
"github.com/loft-sh/devpod/pkg/provider"
devssh "github.com/loft-sh/devpod/pkg/ssh"
"github.com/loft-sh/devpod/pkg/util"
"github.com/loft-sh/log"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -128,8 +129,10 @@ func (c *ContainerHandler) Run(ctx context.Context, handler Handler, cfg *config
// wait for result
select {
case err := <-containerChan:
util.WaitForChan(tunnelChan, 2*time.Second)
return errors.Wrap(err, "tunnel to container")
case err := <-tunnelChan:
util.WaitForChan(containerChan, 2*time.Second)
return errors.Wrap(err, "connect to server")
}
}
Expand Down Expand Up @@ -204,9 +207,6 @@ func (c *ContainerHandler) runRunInContainer(ctx context.Context, sshClient *ssh
defer stdoutWriter.Close()
defer cancel()

c.log.Debugf("Run container tunnel")
defer c.log.Debugf("Container tunnel exited")

command := fmt.Sprintf("'%s' agent container-tunnel --workspace-info '%s'", c.client.AgentPath(), workspaceInfo)
if c.log.GetLevel() == logrus.DebugLevel {
command += " --debug"
Expand Down
4 changes: 4 additions & 0 deletions pkg/tunnel/direct.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"context"
"io"
"os"
"time"

devssh "github.com/loft-sh/devpod/pkg/ssh"
"github.com/loft-sh/devpod/pkg/util"
"github.com/pkg/errors"
)

Expand Down Expand Up @@ -52,8 +54,10 @@ func NewTunnel(ctx context.Context, tunnel Tunnel, handler Handler) error {
// wait for result
select {
case err := <-innerTunnelChan:
util.WaitForChan(outerTunnelChan, 2*time.Second)
return errors.Wrap(err, "inner tunnel")
case err := <-outerTunnelChan:
util.WaitForChan(innerTunnelChan, 2*time.Second)
return errors.Wrap(err, "outer tunnel")
}
}
13 changes: 13 additions & 0 deletions pkg/util/channel.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package util

import "time"

// WaitForChan races the channel closing against a timeout
func WaitForChan(channel <-chan error, timeout time.Duration) {
select {
case <-time.After(timeout):
return
case <-channel:
return
}
}

0 comments on commit 4856dbd

Please sign in to comment.