Skip to content

Commit

Permalink
Updates SSHDialer to be more robust to network fluctuations (#3005)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickzelei authored Dec 4, 2024
1 parent ccf9bf2 commit 302dd92
Show file tree
Hide file tree
Showing 12 changed files with 435 additions and 75 deletions.
8 changes: 5 additions & 3 deletions backend/pkg/sqlconnect/sql-connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"errors"
"fmt"
"log/slog"
"net"
"strconv"
"sync"
"time"

Expand Down Expand Up @@ -175,7 +177,7 @@ func getTunnelConnectorFn(
return nil, nil, fmt.Errorf("unable to construct ssh tunnel config: %w", err)
}
logger.Debug("constructed tunnel config")
dialer := tun.NewLazySSHDialer(cfg.Addr, cfg.ClientConfig)
dialer := tun.NewLazySSHDialer(cfg.Addr, cfg.ClientConfig, tun.DefaultSSHDialerConfig(), logger)
conn, cleanup, err := getConnector(dialer)
if err != nil {
return nil, nil, fmt.Errorf("unable to build db connector: %w", err)
Expand Down Expand Up @@ -275,7 +277,7 @@ func getTunnelConfig(tunnel *mgmtv1alpha1.SSHTunnel) (*tunnelConfig, error) {
User: tunnel.GetUser(),
Auth: authmethods,
HostKeyCallback: hostcallback,
Timeout: 10 * time.Second, // todo: make configurable
Timeout: 15 * time.Second, // todo: make configurable
},
}, nil
}
Expand All @@ -284,7 +286,7 @@ func getSshAddr(tunnel *mgmtv1alpha1.SSHTunnel) string {
host := tunnel.GetHost()
port := tunnel.GetPort()
if port > 0 {
return fmt.Sprintf("%s:%d", host, port)
return net.JoinHostPort(host, strconv.FormatInt(int64(port), 10))
}
return host
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ func New(
return dialer.DialContext(ctx, network, addr)
}

// RegisterConnConfig returns unique connection strings, so even if the dsn is used for multiple calls to New()
// The unregister will not interfere with any other instances of Connector that are using the same input dsn
connStr := stdlib.RegisterConnConfig(cfg)
cleanup := func() {
stdlib.UnregisterConnConfig(connStr)
Expand Down
166 changes: 143 additions & 23 deletions internal/sshtunnel/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package sshtunnel
import (
"context"
"fmt"
"log/slog"
"net"
"sync"
"time"
Expand All @@ -17,30 +18,58 @@ type Dialer interface {

var _ Dialer = (*SSHDialer)(nil)

type SSHDialerConfig struct {
MaxRetries int
InitialBackoff time.Duration
MaxBackoff time.Duration // Max allowed backoff time
BackoffFactor float64 // backoff multiplier

KeepAliveInterval time.Duration
KeepAliveTimeout time.Duration
}

type SSHDialer struct {
addr string
cfg *ssh.ClientConfig
ccfg *ssh.ClientConfig

dialCfg *SSHDialerConfig

client *ssh.Client
clientmu *sync.RWMutex
clientmu *sync.Mutex
logger *slog.Logger
}

func DefaultSSHDialerConfig() *SSHDialerConfig {
return &SSHDialerConfig{
MaxRetries: 3,
InitialBackoff: 100 * time.Millisecond,
MaxBackoff: 1 * time.Second,
BackoffFactor: 2,

KeepAliveInterval: 30 * time.Second,
KeepAliveTimeout: 15 * time.Second,
}
}

func NewLazySSHDialer(addr string, cfg *ssh.ClientConfig) *SSHDialer {
return &SSHDialer{addr: addr, cfg: cfg, clientmu: &sync.RWMutex{}}
func NewLazySSHDialer(addr string, ccfg *ssh.ClientConfig, dialCfg *SSHDialerConfig, logger *slog.Logger) *SSHDialer {
if dialCfg == nil {
dialCfg = DefaultSSHDialerConfig()
}
return &SSHDialer{addr: addr, ccfg: ccfg, clientmu: &sync.Mutex{}, dialCfg: dialCfg, logger: logger}
}

func NewSSHDialer(client *ssh.Client) *SSHDialer {
return &SSHDialer{client: client, clientmu: &sync.RWMutex{}}
func NewSSHDialer(client *ssh.Client, logger *slog.Logger) *SSHDialer {
return &SSHDialer{client: client, clientmu: &sync.Mutex{}, logger: logger}
}

func (s *SSHDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
client, err := s.getClient()
client, err := s.getClient(ctx)
if err != nil {
return nil, err
return nil, fmt.Errorf("unable to get or create ssh client during DialContext: %w", err)
}
conn, err := client.DialContext(ctx, network, addr)
if err != nil {
return nil, err
return nil, fmt.Errorf("unable to dial address: %w", err)
}
return &wrappedSshConn{Conn: conn}, nil
}
Expand All @@ -52,35 +81,113 @@ func (s *SSHDialer) Dial(network, addr string) (net.Conn, error) {
func (s *SSHDialer) Close() error {
s.clientmu.Lock()
defer s.clientmu.Unlock()
if s.client == nil {
return nil
}
client := s.client
s.client = nil
return client.Close()
}

func (s *SSHDialer) getClient() (*ssh.Client, error) {
s.clientmu.RLock()
client := s.client
s.clientmu.RUnlock()
if client != nil {
return client, nil
return client.Close()
}
return nil
}

const (
keepaliveName = "[email protected]"
)

func (s *SSHDialer) getClient(ctx context.Context) (*ssh.Client, error) {
s.clientmu.Lock()
defer s.clientmu.Unlock()

if s.client != nil {
return s.client, nil
wantReply := true
_, _, err := s.client.SendRequest(keepaliveName, wantReply, nil)
if err == nil {
return s.client, nil
}
s.logger.Info(fmt.Sprintf("SSH client was dead, closing and attempting to re-created: %s", err.Error()))
s.client.Close()
s.client = nil
}

var client *ssh.Client
var err error
backoff := s.dialCfg.InitialBackoff

for i := 0; i < s.dialCfg.MaxRetries; i++ {
client, err = ssh.Dial("tcp", s.addr, s.ccfg)
if err == nil {
s.startKeepAlive(client)
break
}
s.logger.Error(fmt.Sprintf("failed to dial SSH Server on attempt %d/%d: %s", i, s.dialCfg.MaxRetries, err.Error()))
if i < s.dialCfg.MaxRetries-1 {
s.logger.Debug(fmt.Sprintf("waiting %.1f seconds until attempting to re-connect to SSH Server", backoff.Seconds()))
err = sleepContext(ctx, backoff)
if err != nil {
break
}
nextBackoff := time.Duration(float64(backoff) * s.dialCfg.BackoffFactor)
if nextBackoff > s.dialCfg.MaxBackoff {
nextBackoff = s.dialCfg.MaxBackoff
}
backoff = nextBackoff
}
}
// todo: implement retries
client, err := ssh.Dial("tcp", s.addr, s.cfg)

if err != nil {
return nil, fmt.Errorf("unable to dial ssh server: %w", err)
return nil, fmt.Errorf("unable to dial ssh server after %d attempts: %w", s.dialCfg.MaxRetries, err)
}
s.client = client
return client, nil
}

func (s *SSHDialer) startKeepAlive(client *ssh.Client) {
go func() {
s.logger.Info("keepalive started for ssh client")
t := time.NewTicker(s.dialCfg.KeepAliveInterval)
defer t.Stop()

for range t.C {
s.clientmu.Lock()
if s.client != client {
s.clientmu.Unlock()
return
}

// Create a timeout context for the keepalive request
ctx, cancel := context.WithTimeout(context.Background(), s.dialCfg.KeepAliveTimeout)
done := make(chan error, 1)

go func() {
wantReply := true
_, _, err := client.SendRequest(keepaliveName, wantReply, nil)
done <- err
}()

// Wait for either timeout or response
select {
case err := <-done:
if err != nil {
s.logger.Error("keepalive failed", "error", err)
s.client = nil
client.Close()
}
case <-ctx.Done():
s.logger.Error("keepalive timed out")
s.client = nil
client.Close()
}

cancel()
s.clientmu.Unlock()

if s.client == nil {
return
}
}
}()
}

type wrappedSshConn struct {
net.Conn
}
Expand All @@ -101,3 +208,16 @@ func (w *wrappedSshConn) SetReadDeadline(deadline time.Time) error {
func (w *wrappedSshConn) SetWriteDeadline(deadline time.Time) error {
return nil
}

func sleepContext(ctx context.Context, d time.Duration) error {
if d <= 0 {
return nil
}

select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(d):
return nil
}
}
Loading

0 comments on commit 302dd92

Please sign in to comment.