From 6b425f7325e088a1cb00f688c422fb808f08e5a2 Mon Sep 17 00:00:00 2001 From: Santiago De la Cruz <51337247+xhit@users.noreply.github.com> Date: Thu, 6 Jul 2023 12:03:19 -0400 Subject: [PATCH] fix data race sending mail (#82) * fix data race sending mail when timeout exceed * change localhost to 127.0.0.1 in test * minimal fixes in test * reduce loop logic --- email.go | 37 +++++++++++------ email_test.go | 112 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 136 insertions(+), 13 deletions(-) create mode 100644 email_test.go diff --git a/email.go b/email.go index 0054819..e38dbfe 100644 --- a/email.go +++ b/email.go @@ -10,6 +10,7 @@ import ( "net/textproto" "strconv" "strings" + "sync" "time" "github.com/toorop/go-dkim" @@ -55,6 +56,7 @@ type SMTPServer struct { // SMTPClient represents a SMTP Client for send email type SMTPClient struct { + mu sync.Mutex Client *smtpClient KeepAlive bool SendTimeout time.Duration @@ -865,21 +867,29 @@ func (server *SMTPServer) Connect() (*SMTPClient, error) { // Reset send RSET command to smtp client func (smtpClient *SMTPClient) Reset() error { + smtpClient.mu.Lock() + defer smtpClient.mu.Unlock() return smtpClient.Client.reset() } // Noop send NOOP command to smtp client func (smtpClient *SMTPClient) Noop() error { + smtpClient.mu.Lock() + defer smtpClient.mu.Unlock() return smtpClient.Client.noop() } // Quit send QUIT command to smtp client func (smtpClient *SMTPClient) Quit() error { + smtpClient.mu.Lock() + defer smtpClient.mu.Unlock() return smtpClient.Client.quit() } // Close closes the connection func (smtpClient *SMTPClient) Close() error { + smtpClient.mu.Lock() + defer smtpClient.mu.Unlock() return smtpClient.Client.close() } @@ -909,14 +919,14 @@ func send(from string, to []string, msg string, client *SMTPClient) error { if client.SendTimeout != 0 { smtpSendChannel = make(chan error, 1) - go func(from string, to []string, msg string, c *smtpClient) { - smtpSendChannel <- sendMailProcess(from, to, msg, c) - }(from, to, msg, client.Client) + go func(from string, to []string, msg string, client *SMTPClient) { + smtpSendChannel <- sendMailProcess(from, to, msg, client) + }(from, to, msg, client) } if client.SendTimeout == 0 { // no SendTimeout, just fire the sendMailProcess - return sendMailProcess(from, to, msg, client.Client) + return sendMailProcess(from, to, msg, client) } // get the send result or timeout result, which ever happens first @@ -928,35 +938,36 @@ func send(from string, to []string, msg string, client *SMTPClient) error { checkKeepAlive(client) return errors.New("Mail Error: SMTP Send timed out") } - } } return errors.New("Mail Error: No SMTP Client Provided") } -func sendMailProcess(from string, to []string, msg string, c *smtpClient) error { +func sendMailProcess(from string, to []string, msg string, c *SMTPClient) error { + c.mu.Lock() + defer c.mu.Unlock() cmdArgs := make(map[string]string) - if _, ok := c.ext["SIZE"]; ok { + if _, ok := c.Client.ext["SIZE"]; ok { cmdArgs["SIZE"] = strconv.Itoa(len(msg)) } // Set the sender - if err := c.mail(from, cmdArgs); err != nil { + if err := c.Client.mail(from, cmdArgs); err != nil { return err } // Set the recipients for _, address := range to { - if err := c.rcpt(address); err != nil { + if err := c.Client.rcpt(address); err != nil { return err } } // Send the data command - w, err := c.data() + w, err := c.Client.data() if err != nil { return err } @@ -978,9 +989,9 @@ func sendMailProcess(from string, to []string, msg string, c *smtpClient) error // check if keepAlive for close or reset func checkKeepAlive(client *SMTPClient) { if client.KeepAlive { - client.Client.reset() + client.Reset() } else { - client.Client.quit() - client.Client.close() + client.Quit() + client.Close() } } diff --git a/email_test.go b/email_test.go new file mode 100644 index 0000000..a23b9bc --- /dev/null +++ b/email_test.go @@ -0,0 +1,112 @@ +package mail + +import ( + "fmt" + "log" + "net" + "testing" + "time" +) + +func TestSendRace(t *testing.T) { + port := 56666 + port2 := 56667 + timeout := 1 * time.Second + + responses := []string{ + `220 test connected`, + `250 after helo`, + `250 after mail from`, + `250 after rcpt to`, + `354 after data`, + } + + startService(port, responses, 5*time.Second) + startService(port2, responses, 0) + + server := NewSMTPClient() + server.ConnectTimeout = timeout + server.SendTimeout = timeout + server.KeepAlive = false + server.Host = `127.0.0.1` + server.Port = port + + smtpClient, err := server.Connect() + if err != nil { + log.Fatalf("couldn't connect: %s", err.Error()) + } + defer smtpClient.Close() + + // create another server in other port to test timeouts + server.Port = port2 + smtpClient2, err := server.Connect() + if err != nil { + log.Fatalf("couldn't connect: %s", err.Error()) + } + defer smtpClient2.Close() + + msg := NewMSG(). + SetFrom(`foo@bar`). + AddTo(`rcpt@bar`). + SetSubject("subject"). + SetBody(TextPlain, "body") + + // the smtpClient2 has not timeout + err = msg.Send(smtpClient2) + if err != nil { + log.Fatalf("couldn't send: %s", err.Error()) + } + + // the smtpClient send to listener with the last response is after SendTimeout, so when this error is returned the test succeed. + err = msg.Send(smtpClient) + if err != nil && err.Error() != "Mail Error: SMTP Send timed out" { + log.Fatalf("couldn't send: %s", err.Error()) + } +} + +func startService(port int, responses []string, timeout time.Duration) { + log.Printf("starting service at %d...\n", port) + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + if err != nil { + log.Fatalf("couldn't listen to port %d: %s", port, err) + } + + go func() { + for { + conn, err := listener.Accept() + if err != nil { + log.Fatalf("couldn't listen accept the request in port %d", port) + } + go respond(conn, responses, timeout) + } + }() +} + +func respond(conn net.Conn, responses []string, timeout time.Duration) { + buf := make([]byte, 1024) + for _, resp := range responses { + write(conn, resp) + n, err := conn.Read(buf) + if err != nil { + log.Println("couldn't read data") + return + } + readStr := string(buf[:n]) + log.Printf("READ:%s", string(readStr)) + } + + // if timeout, sleep for that time, otherwise sent a 250 OK + if timeout > 0 { + time.Sleep(timeout) + } else { + write(conn, "250 OK") + } + + conn.Close() + fmt.Print("\n\n") +} + +func write(conn net.Conn, command string) { + log.Printf("WRITE:%s", command) + conn.Write([]byte(command + "\n")) +}