From 7bdc981eac3ff1e83585184563949ac1171e643e Mon Sep 17 00:00:00 2001 From: "fox.cpp" Date: Mon, 29 Jan 2024 23:45:25 +0300 Subject: [PATCH] target/remote: Improve handling of stale connections in pool 1. Apply conn_max_idle_time to each connection individually, not pool bucket. 2. Include local_addr in some log messages to help identify individual connections in the pool. 3. Run conn.Close outside of keysLock and asynchronously. Ensures slow server or dead connection won't cause pool operations to hang. 4. Set 5 second timeout for QUIT call in conn.Close. To detect dead connections faster, there is no reason for any server to take more than 5 seconds to respond to QUIT. See #675. --- internal/smtpconn/pool/pool.go | 27 +++++++++----- internal/smtpconn/smtpconn.go | 59 ++++++++++++++++++++++++------- internal/target/remote/connect.go | 22 ++++++++---- internal/target/remote/remote.go | 12 ++++--- 4 files changed, 88 insertions(+), 32 deletions(-) diff --git a/internal/smtpconn/pool/pool.go b/internal/smtpconn/pool/pool.go index 35ab27ba..4b700ee4 100644 --- a/internal/smtpconn/pool/pool.go +++ b/internal/smtpconn/pool/pool.go @@ -26,6 +26,7 @@ import ( type Conn interface { Usable() bool + LastUseAt() time.Time Close() error } @@ -95,33 +96,38 @@ func (p *P) CleanUp(ctx context.Context) { close(v.c) for conn := range v.c { - conn.Close() + go conn.Close() } delete(p.keys, k) } } func (p *P) Get(ctx context.Context, key string) (Conn, error) { - // TODO: See if it is possible to get rid of this lock. p.keysLock.Lock() - defer p.keysLock.Unlock() bucket, ok := p.keys[key] if !ok { + p.keysLock.Unlock() return p.cfg.New(ctx, key) } if time.Now().Unix()-bucket.lastUse > p.cfg.MaxConnLifetimeSec { // Drop bucket. + delete(p.keys, key) close(bucket.c) + + // Close might take some time, unlock early. + p.keysLock.Unlock() + for conn := range bucket.c { conn.Close() } - delete(p.keys, key) return p.cfg.New(ctx, key) } + p.keysLock.Unlock() + for { var conn Conn select { @@ -134,7 +140,12 @@ func (p *P) Get(ctx context.Context, key string) (Conn, error) { } if !conn.Usable() { - conn.Close() + // Close might take some time, run in parallel. + go conn.Close() + continue + } + if conn.LastUseAt().Add(time.Duration(p.cfg.MaxConnLifetimeSec) * time.Second).Before(time.Now()) { + go conn.Close() continue } @@ -158,12 +169,12 @@ func (p *P) Return(key string, c Conn) { if v.lastUse+p.cfg.StaleKeyLifetimeSec > time.Now().Unix() { continue } - + delete(p.keys, k) close(v.c) + for conn := range v.c { conn.Close() } - delete(p.keys, k) } } @@ -179,7 +190,7 @@ func (p *P) Return(key string, c Conn) { bucket.lastUse = time.Now().Unix() default: // Let it go, let it go... - c.Close() + go c.Close() } } diff --git a/internal/smtpconn/smtpconn.go b/internal/smtpconn/smtpconn.go index 451a8442..4dca9575 100644 --- a/internal/smtpconn/smtpconn.go +++ b/internal/smtpconn/smtpconn.go @@ -79,6 +79,7 @@ type C struct { // "ADDRESS said: ..." AddrInSMTPMsg bool + conn net.Conn serverName string cl *smtp.Client rcpts []string @@ -163,26 +164,28 @@ func (c *C) wrapClientErr(err error, serverName string) error { // Connect actually estabilishes the network connection with the remote host, // executes HELO/EHLO and optionally STARTTLS command. func (c *C) Connect(ctx context.Context, endp config.Endpoint, starttls bool, tlsConfig *tls.Config) (didTLS bool, err error) { - didTLS, cl, err := c.attemptConnect(ctx, false, endp, starttls, tlsConfig) + didTLS, cl, conn, err := c.attemptConnect(ctx, false, endp, starttls, tlsConfig) if err != nil { return false, c.wrapClientErr(err, endp.Host) } c.serverName = endp.Host c.cl = cl + c.conn = conn return didTLS, nil } // ConnectLMTP estabilishes the network connection with the remote host and // sends LHLO command, negotiating LMTP use. func (c *C) ConnectLMTP(ctx context.Context, endp config.Endpoint, starttls bool, tlsConfig *tls.Config) (didTLS bool, err error) { - didTLS, cl, err := c.attemptConnect(ctx, true, endp, starttls, tlsConfig) + didTLS, cl, conn, err := c.attemptConnect(ctx, true, endp, starttls, tlsConfig) if err != nil { return false, c.wrapClientErr(err, endp.Host) } c.serverName = endp.Host c.cl = cl + c.conn = conn return didTLS, nil } @@ -203,14 +206,27 @@ func (err TLSError) Unwrap() error { return err.Err } -func (c *C) attemptConnect(ctx context.Context, lmtp bool, endp config.Endpoint, starttls bool, tlsConfig *tls.Config) (didTLS bool, cl *smtp.Client, err error) { - var conn net.Conn +func (c *C) LocalAddr() net.Addr { + if c.conn == nil { + return nil + } + return c.conn.LocalAddr() +} + +func (c *C) RemoteAddr() net.Addr { + if c.conn == nil { + return nil + } + return c.conn.RemoteAddr() +} + +func (c *C) attemptConnect(ctx context.Context, lmtp bool, endp config.Endpoint, starttls bool, tlsConfig *tls.Config) (didTLS bool, cl *smtp.Client, conn net.Conn, err error) { dialCtx, cancel := context.WithTimeout(ctx, c.ConnectTimeout) conn, err = c.Dialer(dialCtx, endp.Network(), endp.Address()) cancel() if err != nil { - return false, nil, err + return false, nil, nil, err } if endp.IsTLS() { @@ -233,15 +249,15 @@ func (c *C) attemptConnect(ctx context.Context, lmtp bool, endp config.Endpoint, // i18n: hostname is already expected to be in A-labels form. if err := cl.Hello(c.Hostname); err != nil { cl.Close() - return false, nil, err + return false, nil, nil, err } if endp.IsTLS() || !starttls { - return endp.IsTLS(), cl, nil + return endp.IsTLS(), cl, nil, nil } if ok, _ := cl.Extension("STARTTLS"); !ok { - return false, cl, nil + return false, cl, nil, nil } cfg := tlsConfig.Clone() @@ -255,10 +271,10 @@ func (c *C) attemptConnect(ctx context.Context, lmtp bool, endp config.Endpoint, cl.Close() } - return false, nil, TLSError{err} + return false, nil, nil, TLSError{err} } - return true, cl, nil + return true, cl, conn, nil } // Mail sends the MAIL FROM command to the remote server. @@ -307,7 +323,8 @@ func (c *C) Mail(ctx context.Context, from string, opts smtp.MailOptions) error return c.wrapClientErr(err, c.serverName) } - c.Log.DebugMsg("connected", "remote_server", c.serverName) + c.Log.DebugMsg("connected", "remote_server", c.serverName, + "local_addr", c.LocalAddr(), "remote_addr", c.RemoteAddr()) return nil } @@ -482,11 +499,27 @@ func (c *C) Noop() error { return c.cl.Noop() } -// Close sends the QUIT command, if it fail - it directly closes the +// Close sends the QUIT command, if it fails - it directly closes the // connection. func (c *C) Close() error { + c.cl.CommandTimeout = 5 * time.Second + if err := c.cl.Quit(); err != nil { - c.Log.Error("QUIT error", c.wrapClientErr(err, c.serverName)) + var smtpErr *smtp.SMTPError + var netErr *net.OpError + if errors.As(err, &smtpErr) && smtpErr.Code == 421 { + // 421 "Service not available" is typically sent + // when idle timeout happens. + c.Log.DebugMsg("QUIT error", "reason", c.wrapClientErr(err, c.serverName)) + } else if errors.As(err, &netErr) && + (netErr.Timeout() || netErr.Err.Error() == "write: broken pipe" || netErr.Err.Error() == "read: connection reset") { + + // The case for silently closed connections. + c.Log.DebugMsg("QUIT error", "reason", c.wrapClientErr(err, c.serverName)) + } else { + c.Log.Error("QUIT error", c.wrapClientErr(err, c.serverName)) + } + return c.cl.Close() } diff --git a/internal/target/remote/connect.go b/internal/target/remote/connect.go index 0824d4aa..2290d137 100644 --- a/internal/target/remote/connect.go +++ b/internal/target/remote/connect.go @@ -26,6 +26,7 @@ import ( "net" "runtime/trace" "sort" + "time" "github.com/foxcpp/maddy/framework/config" "github.com/foxcpp/maddy/framework/dns" @@ -48,6 +49,7 @@ type mxConn struct { // Amount of times connection was used for an SMTP transaction. transactions int + lastUseAt time.Time // MX/TLS security level established for this connection. mxLevel module.MXLevel @@ -55,12 +57,16 @@ type mxConn struct { } func (c *mxConn) Usable() bool { - if c.C == nil || c.transactions > c.reuseLimit || c.C.Client() == nil { + if c.C == nil || c.transactions > c.reuseLimit || c.C.Client() == nil || c.errored { return false } return c.C.Client().Reset() == nil } +func (c *mxConn) LastUseAt() time.Time { + return c.lastUseAt +} + func (c *mxConn) Close() error { return c.C.Close() } @@ -196,9 +202,9 @@ func (rd *remoteDelivery) attemptMX(ctx context.Context, conn *mxConn, record *n return nil } -func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string) (*smtpconn.C, error) { +func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string) (*mxConn, error) { if c, ok := rd.connections[domain]; ok { - return c.C, nil + return c, nil } pooledConn, err := rd.rt.pool.Get(ctx, domain) @@ -212,7 +218,8 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string // connection with weaker security. if pooledConn != nil && !rd.msgMeta.SMTPOpts.RequireTLS { conn = pooledConn.(*mxConn) - rd.Log.Msg("reusing cached connection", "domain", domain, "transactions_counter", conn.transactions) + rd.Log.Msg("reusing cached connection", "domain", domain, "transactions_counter", conn.transactions, + "local_addr", conn.LocalAddr(), "remote_addr", conn.RemoteAddr()) } else { rd.Log.DebugMsg("opening new connection", "domain", domain, "cache_ignored", pooledConn != nil) conn, err = rd.newConn(ctx, domain) @@ -249,6 +256,7 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string region := trace.StartRegion(ctx, "remote/limits.TakeDest") if err := rd.rt.limits.TakeDest(ctx, domain); err != nil { region.End() + conn.Close() return nil, err } region.End() @@ -269,9 +277,10 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string conn.Close() return nil, err } + conn.lastUseAt = time.Now() rd.connections[domain] = conn - return conn.C, nil + return conn, nil } func (rd *remoteDelivery) newConn(ctx context.Context, domain string) (*mxConn, error) { @@ -279,6 +288,7 @@ func (rd *remoteDelivery) newConn(ctx context.Context, domain string) (*mxConn, reuseLimit: rd.rt.connReuseLimit, C: smtpconn.New(), domain: domain, + lastUseAt: time.Now(), } conn.Dialer = rd.rt.dialer @@ -329,7 +339,7 @@ func (rd *remoteDelivery) newConn(ctx context.Context, domain string) (*mxConn, } region.End() - // Stil not connected? Bail out. + // Still not connected? Bail out. if conn.Client() == nil { return nil, &exterrors.SMTPError{ Code: exterrors.SMTPCode(lastErr, 451, 550), diff --git a/internal/target/remote/remote.go b/internal/target/remote/remote.go index 6fded8e8..3e661c4f 100644 --- a/internal/target/remote/remote.go +++ b/internal/target/remote/remote.go @@ -148,7 +148,7 @@ func (rt *Target) Init(cfg *config.Map) error { MaxConnLifetimeSec: 150, // 2.5 mins, half of recommended idle time from RFC 5321 StaleKeyLifetimeSec: 60 * 5, // should be bigger than MaxConnLifetimeSec } - cfg.Int("conn_max_idle_count", false, false, 10, &poolCfg.MaxConnsPerKey) + cfg.Int("conn_max_idle_count", false, false, 5, &poolCfg.MaxConnsPerKey) cfg.Int64("conn_max_idle_time", false, false, 150, &poolCfg.MaxConnLifetimeSec) if _, err := cfg.Process(); err != nil { @@ -315,6 +315,7 @@ func (rd *remoteDelivery) AddRcpt(ctx context.Context, to string, opts smtp.Rcpt if err := conn.Rcpt(ctx, to, opts); err != nil { return moduleError(err) } + conn.lastUseAt = time.Now() rd.recipients = append(rd.recipients, to) return nil @@ -425,6 +426,7 @@ func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusColl c.SetStatus(rcpt, err) } rd.connections[i].errored = err != nil + conn.lastUseAt = time.Now() }() } @@ -446,12 +448,12 @@ func (rd *remoteDelivery) Close() error { rd.rt.limits.ReleaseDest(conn.domain) conn.transactions++ - if conn.C == nil || conn.transactions > rd.rt.connReuseLimit || conn.C.Client() == nil || conn.errored { - rd.Log.Debugf("disconnected from %s (errored=%v,transactions=%v,disconnected before=%v)", - conn.ServerName(), conn.errored, conn.transactions, conn.C.Client() == nil) + if !conn.Usable() { + rd.Log.Debugf("disconnected %v from %s (errored=%v,transactions=%v,disconnected before=%v)", + conn.LocalAddr(), conn.ServerName(), conn.errored, conn.transactions, conn.C.Client() == nil) conn.Close() } else { - rd.Log.Debugf("returning connection for %s to pool", conn.ServerName()) + rd.Log.Debugf("returning connection %v for %s to pool", conn.LocalAddr(), conn.ServerName()) rd.rt.pool.Return(conn.domain, conn) } }