diff --git a/internal/smtpconn/pool/pool.go b/internal/smtpconn/pool/pool.go index 35ab27ba..4b700ee4 100644 --- a/internal/smtpconn/pool/pool.go +++ b/internal/smtpconn/pool/pool.go @@ -26,6 +26,7 @@ import ( type Conn interface { Usable() bool + LastUseAt() time.Time Close() error } @@ -95,33 +96,38 @@ func (p *P) CleanUp(ctx context.Context) { close(v.c) for conn := range v.c { - conn.Close() + go conn.Close() } delete(p.keys, k) } } func (p *P) Get(ctx context.Context, key string) (Conn, error) { - // TODO: See if it is possible to get rid of this lock. p.keysLock.Lock() - defer p.keysLock.Unlock() bucket, ok := p.keys[key] if !ok { + p.keysLock.Unlock() return p.cfg.New(ctx, key) } if time.Now().Unix()-bucket.lastUse > p.cfg.MaxConnLifetimeSec { // Drop bucket. + delete(p.keys, key) close(bucket.c) + + // Close might take some time, unlock early. + p.keysLock.Unlock() + for conn := range bucket.c { conn.Close() } - delete(p.keys, key) return p.cfg.New(ctx, key) } + p.keysLock.Unlock() + for { var conn Conn select { @@ -134,7 +140,12 @@ func (p *P) Get(ctx context.Context, key string) (Conn, error) { } if !conn.Usable() { - conn.Close() + // Close might take some time, run in parallel. + go conn.Close() + continue + } + if conn.LastUseAt().Add(time.Duration(p.cfg.MaxConnLifetimeSec) * time.Second).Before(time.Now()) { + go conn.Close() continue } @@ -158,12 +169,12 @@ func (p *P) Return(key string, c Conn) { if v.lastUse+p.cfg.StaleKeyLifetimeSec > time.Now().Unix() { continue } - + delete(p.keys, k) close(v.c) + for conn := range v.c { conn.Close() } - delete(p.keys, k) } } @@ -179,7 +190,7 @@ func (p *P) Return(key string, c Conn) { bucket.lastUse = time.Now().Unix() default: // Let it go, let it go... - c.Close() + go c.Close() } } diff --git a/internal/smtpconn/smtpconn.go b/internal/smtpconn/smtpconn.go index 451a8442..4dca9575 100644 --- a/internal/smtpconn/smtpconn.go +++ b/internal/smtpconn/smtpconn.go @@ -79,6 +79,7 @@ type C struct { // "ADDRESS said: ..." AddrInSMTPMsg bool + conn net.Conn serverName string cl *smtp.Client rcpts []string @@ -163,26 +164,28 @@ func (c *C) wrapClientErr(err error, serverName string) error { // Connect actually estabilishes the network connection with the remote host, // executes HELO/EHLO and optionally STARTTLS command. func (c *C) Connect(ctx context.Context, endp config.Endpoint, starttls bool, tlsConfig *tls.Config) (didTLS bool, err error) { - didTLS, cl, err := c.attemptConnect(ctx, false, endp, starttls, tlsConfig) + didTLS, cl, conn, err := c.attemptConnect(ctx, false, endp, starttls, tlsConfig) if err != nil { return false, c.wrapClientErr(err, endp.Host) } c.serverName = endp.Host c.cl = cl + c.conn = conn return didTLS, nil } // ConnectLMTP estabilishes the network connection with the remote host and // sends LHLO command, negotiating LMTP use. func (c *C) ConnectLMTP(ctx context.Context, endp config.Endpoint, starttls bool, tlsConfig *tls.Config) (didTLS bool, err error) { - didTLS, cl, err := c.attemptConnect(ctx, true, endp, starttls, tlsConfig) + didTLS, cl, conn, err := c.attemptConnect(ctx, true, endp, starttls, tlsConfig) if err != nil { return false, c.wrapClientErr(err, endp.Host) } c.serverName = endp.Host c.cl = cl + c.conn = conn return didTLS, nil } @@ -203,14 +206,27 @@ func (err TLSError) Unwrap() error { return err.Err } -func (c *C) attemptConnect(ctx context.Context, lmtp bool, endp config.Endpoint, starttls bool, tlsConfig *tls.Config) (didTLS bool, cl *smtp.Client, err error) { - var conn net.Conn +func (c *C) LocalAddr() net.Addr { + if c.conn == nil { + return nil + } + return c.conn.LocalAddr() +} + +func (c *C) RemoteAddr() net.Addr { + if c.conn == nil { + return nil + } + return c.conn.RemoteAddr() +} + +func (c *C) attemptConnect(ctx context.Context, lmtp bool, endp config.Endpoint, starttls bool, tlsConfig *tls.Config) (didTLS bool, cl *smtp.Client, conn net.Conn, err error) { dialCtx, cancel := context.WithTimeout(ctx, c.ConnectTimeout) conn, err = c.Dialer(dialCtx, endp.Network(), endp.Address()) cancel() if err != nil { - return false, nil, err + return false, nil, nil, err } if endp.IsTLS() { @@ -233,15 +249,15 @@ func (c *C) attemptConnect(ctx context.Context, lmtp bool, endp config.Endpoint, // i18n: hostname is already expected to be in A-labels form. if err := cl.Hello(c.Hostname); err != nil { cl.Close() - return false, nil, err + return false, nil, nil, err } if endp.IsTLS() || !starttls { - return endp.IsTLS(), cl, nil + return endp.IsTLS(), cl, nil, nil } if ok, _ := cl.Extension("STARTTLS"); !ok { - return false, cl, nil + return false, cl, nil, nil } cfg := tlsConfig.Clone() @@ -255,10 +271,10 @@ func (c *C) attemptConnect(ctx context.Context, lmtp bool, endp config.Endpoint, cl.Close() } - return false, nil, TLSError{err} + return false, nil, nil, TLSError{err} } - return true, cl, nil + return true, cl, conn, nil } // Mail sends the MAIL FROM command to the remote server. @@ -307,7 +323,8 @@ func (c *C) Mail(ctx context.Context, from string, opts smtp.MailOptions) error return c.wrapClientErr(err, c.serverName) } - c.Log.DebugMsg("connected", "remote_server", c.serverName) + c.Log.DebugMsg("connected", "remote_server", c.serverName, + "local_addr", c.LocalAddr(), "remote_addr", c.RemoteAddr()) return nil } @@ -482,11 +499,27 @@ func (c *C) Noop() error { return c.cl.Noop() } -// Close sends the QUIT command, if it fail - it directly closes the +// Close sends the QUIT command, if it fails - it directly closes the // connection. func (c *C) Close() error { + c.cl.CommandTimeout = 5 * time.Second + if err := c.cl.Quit(); err != nil { - c.Log.Error("QUIT error", c.wrapClientErr(err, c.serverName)) + var smtpErr *smtp.SMTPError + var netErr *net.OpError + if errors.As(err, &smtpErr) && smtpErr.Code == 421 { + // 421 "Service not available" is typically sent + // when idle timeout happens. + c.Log.DebugMsg("QUIT error", "reason", c.wrapClientErr(err, c.serverName)) + } else if errors.As(err, &netErr) && + (netErr.Timeout() || netErr.Err.Error() == "write: broken pipe" || netErr.Err.Error() == "read: connection reset") { + + // The case for silently closed connections. + c.Log.DebugMsg("QUIT error", "reason", c.wrapClientErr(err, c.serverName)) + } else { + c.Log.Error("QUIT error", c.wrapClientErr(err, c.serverName)) + } + return c.cl.Close() } diff --git a/internal/target/remote/connect.go b/internal/target/remote/connect.go index 0824d4aa..2290d137 100644 --- a/internal/target/remote/connect.go +++ b/internal/target/remote/connect.go @@ -26,6 +26,7 @@ import ( "net" "runtime/trace" "sort" + "time" "github.com/foxcpp/maddy/framework/config" "github.com/foxcpp/maddy/framework/dns" @@ -48,6 +49,7 @@ type mxConn struct { // Amount of times connection was used for an SMTP transaction. transactions int + lastUseAt time.Time // MX/TLS security level established for this connection. mxLevel module.MXLevel @@ -55,12 +57,16 @@ type mxConn struct { } func (c *mxConn) Usable() bool { - if c.C == nil || c.transactions > c.reuseLimit || c.C.Client() == nil { + if c.C == nil || c.transactions > c.reuseLimit || c.C.Client() == nil || c.errored { return false } return c.C.Client().Reset() == nil } +func (c *mxConn) LastUseAt() time.Time { + return c.lastUseAt +} + func (c *mxConn) Close() error { return c.C.Close() } @@ -196,9 +202,9 @@ func (rd *remoteDelivery) attemptMX(ctx context.Context, conn *mxConn, record *n return nil } -func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string) (*smtpconn.C, error) { +func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string) (*mxConn, error) { if c, ok := rd.connections[domain]; ok { - return c.C, nil + return c, nil } pooledConn, err := rd.rt.pool.Get(ctx, domain) @@ -212,7 +218,8 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string // connection with weaker security. if pooledConn != nil && !rd.msgMeta.SMTPOpts.RequireTLS { conn = pooledConn.(*mxConn) - rd.Log.Msg("reusing cached connection", "domain", domain, "transactions_counter", conn.transactions) + rd.Log.Msg("reusing cached connection", "domain", domain, "transactions_counter", conn.transactions, + "local_addr", conn.LocalAddr(), "remote_addr", conn.RemoteAddr()) } else { rd.Log.DebugMsg("opening new connection", "domain", domain, "cache_ignored", pooledConn != nil) conn, err = rd.newConn(ctx, domain) @@ -249,6 +256,7 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string region := trace.StartRegion(ctx, "remote/limits.TakeDest") if err := rd.rt.limits.TakeDest(ctx, domain); err != nil { region.End() + conn.Close() return nil, err } region.End() @@ -269,9 +277,10 @@ func (rd *remoteDelivery) connectionForDomain(ctx context.Context, domain string conn.Close() return nil, err } + conn.lastUseAt = time.Now() rd.connections[domain] = conn - return conn.C, nil + return conn, nil } func (rd *remoteDelivery) newConn(ctx context.Context, domain string) (*mxConn, error) { @@ -279,6 +288,7 @@ func (rd *remoteDelivery) newConn(ctx context.Context, domain string) (*mxConn, reuseLimit: rd.rt.connReuseLimit, C: smtpconn.New(), domain: domain, + lastUseAt: time.Now(), } conn.Dialer = rd.rt.dialer @@ -329,7 +339,7 @@ func (rd *remoteDelivery) newConn(ctx context.Context, domain string) (*mxConn, } region.End() - // Stil not connected? Bail out. + // Still not connected? Bail out. if conn.Client() == nil { return nil, &exterrors.SMTPError{ Code: exterrors.SMTPCode(lastErr, 451, 550), diff --git a/internal/target/remote/remote.go b/internal/target/remote/remote.go index 6fded8e8..3e661c4f 100644 --- a/internal/target/remote/remote.go +++ b/internal/target/remote/remote.go @@ -148,7 +148,7 @@ func (rt *Target) Init(cfg *config.Map) error { MaxConnLifetimeSec: 150, // 2.5 mins, half of recommended idle time from RFC 5321 StaleKeyLifetimeSec: 60 * 5, // should be bigger than MaxConnLifetimeSec } - cfg.Int("conn_max_idle_count", false, false, 10, &poolCfg.MaxConnsPerKey) + cfg.Int("conn_max_idle_count", false, false, 5, &poolCfg.MaxConnsPerKey) cfg.Int64("conn_max_idle_time", false, false, 150, &poolCfg.MaxConnLifetimeSec) if _, err := cfg.Process(); err != nil { @@ -315,6 +315,7 @@ func (rd *remoteDelivery) AddRcpt(ctx context.Context, to string, opts smtp.Rcpt if err := conn.Rcpt(ctx, to, opts); err != nil { return moduleError(err) } + conn.lastUseAt = time.Now() rd.recipients = append(rd.recipients, to) return nil @@ -425,6 +426,7 @@ func (rd *remoteDelivery) BodyNonAtomic(ctx context.Context, c module.StatusColl c.SetStatus(rcpt, err) } rd.connections[i].errored = err != nil + conn.lastUseAt = time.Now() }() } @@ -446,12 +448,12 @@ func (rd *remoteDelivery) Close() error { rd.rt.limits.ReleaseDest(conn.domain) conn.transactions++ - if conn.C == nil || conn.transactions > rd.rt.connReuseLimit || conn.C.Client() == nil || conn.errored { - rd.Log.Debugf("disconnected from %s (errored=%v,transactions=%v,disconnected before=%v)", - conn.ServerName(), conn.errored, conn.transactions, conn.C.Client() == nil) + if !conn.Usable() { + rd.Log.Debugf("disconnected %v from %s (errored=%v,transactions=%v,disconnected before=%v)", + conn.LocalAddr(), conn.ServerName(), conn.errored, conn.transactions, conn.C.Client() == nil) conn.Close() } else { - rd.Log.Debugf("returning connection for %s to pool", conn.ServerName()) + rd.Log.Debugf("returning connection %v for %s to pool", conn.LocalAddr(), conn.ServerName()) rd.rt.pool.Return(conn.domain, conn) } }