Skip to content

Commit 9ec6b54

Browse files
authored
Fix bug in handling postgres COPY command and a few others (#610)
* Fix bug in COPY command by differentiating between total reads vs. chunk reads * Fix bug in concurrent reads and writes * Remove explicit condition for checking total being zero * Return early if there is no data from the server (to send to the client)
1 parent 0ad84b0 commit 9ec6b54

File tree

3 files changed

+37
-21
lines changed

3 files changed

+37
-21
lines changed

network/client.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) {
231231
ctx = context.Background()
232232
}
233233

234-
var received int
234+
total := 0
235235
buffer := bytes.NewBuffer(nil)
236236
// Read the data in chunks.
237237
for ctx.Err() == nil {
@@ -240,19 +240,19 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) {
240240
if err != nil {
241241
c.logger.Error().Err(err).Msg("Couldn't receive data from the server")
242242
span.RecordError(err)
243-
return received, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err)
243+
return total, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err)
244244
}
245-
received += read
245+
total += read
246246
buffer.Write(chunk[:read])
247247

248-
if read == 0 || read < c.ReceiveChunkSize {
248+
if read < c.ReceiveChunkSize {
249249
break
250250
}
251251
}
252252

253253
span.AddEvent("Received data from server")
254254

255-
return received, buffer.Bytes(), nil
255+
return total, buffer.Bytes(), nil
256256
}
257257

258258
// Reconnect reconnects to the server.

network/proxy.go

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -494,8 +494,17 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate
494494
received, response, err := pr.receiveTrafficFromServer(client)
495495
span.AddEvent("Received traffic from server")
496496

497-
// If the response is empty, don't send anything, instead just close the ingress connection.
498-
if received == 0 || err != nil {
497+
// If there is no data to send to the client,
498+
// we don't need to run the hooks and
499+
// we obviously have no data to send to the client.
500+
if received == 0 {
501+
span.AddEvent("No data to send to client")
502+
stack.PopLastRequest()
503+
return nil
504+
}
505+
506+
// If there is an error, close the ingress connection.
507+
if err != nil {
499508
fields := map[string]interface{}{"function": "proxy.passthrough"}
500509
if client.LocalAddr() != "" {
501510
fields["localAddr"] = client.LocalAddr()
@@ -517,7 +526,7 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate
517526

518527
// Get the last request from the stack.
519528
lastRequest := stack.PopLastRequest()
520-
request := make([]byte, 0)
529+
request := []byte{}
521530
if lastRequest != nil {
522531
request = lastRequest.Data
523532
}
@@ -698,7 +707,7 @@ func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayD
698707
defer span.End()
699708

700709
// request contains the data from the client.
701-
received := 0
710+
total := 0
702711
buffer := bytes.NewBuffer(nil)
703712
for {
704713
chunk := make([]byte, pr.ClientConfig.ReceiveChunkSize)
@@ -713,10 +722,10 @@ func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayD
713722
return chunk[:read], gerr.ErrReadFailed.Wrap(err)
714723
}
715724

716-
received += read
725+
total += read
717726
buffer.Write(chunk[:read])
718727

719-
if received == 0 || received < pr.ClientConfig.ReceiveChunkSize {
728+
if read < pr.ClientConfig.ReceiveChunkSize {
720729
break
721730
}
722731

@@ -725,19 +734,18 @@ func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayD
725734
}
726735
}
727736

728-
length := len(buffer.Bytes())
729737
pr.Logger.Debug().Fields(
730738
map[string]interface{}{
731-
"length": length,
739+
"length": total,
732740
"local": LocalAddr(conn),
733741
"remote": RemoteAddr(conn),
734742
},
735743
).Msg("Received data from client")
736744

737745
span.AddEvent("Received data from client")
738746

739-
metrics.BytesReceivedFromClient.Observe(float64(length))
740-
metrics.TotalTrafficBytes.Observe(float64(length))
747+
metrics.BytesReceivedFromClient.Observe(float64(total))
748+
metrics.TotalTrafficBytes.Observe(float64(total))
741749

742750
return buffer.Bytes(), nil
743751
}

network/server.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ type Server struct {
7979
LoadbalancerStrategyName string
8080
LoadbalancerRules []config.LoadBalancingRule
8181
LoadbalancerConsistentHash *config.ConsistentHash
82-
connectionToProxyMap map[*ConnWrapper]IProxy
82+
connectionToProxyMap *sync.Map
8383
}
8484

8585
var _ IServer = (*Server)(nil)
@@ -181,7 +181,7 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) {
181181
}
182182

183183
// Assign connection to proxy
184-
s.connectionToProxyMap[conn] = proxy
184+
s.connectionToProxyMap.Store(conn, proxy)
185185

186186
// Run the OnOpened hooks.
187187
pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), s.PluginTimeout)
@@ -696,7 +696,7 @@ func NewServer(
696696
connections: 0,
697697
running: &atomic.Bool{},
698698
stopServer: make(chan struct{}),
699-
connectionToProxyMap: make(map[*ConnWrapper]IProxy),
699+
connectionToProxyMap: &sync.Map{},
700700
LoadbalancerStrategyName: srv.LoadbalancerStrategyName,
701701
LoadbalancerRules: srv.LoadbalancerRules,
702702
LoadbalancerConsistentHash: srv.LoadbalancerConsistentHash,
@@ -737,11 +737,19 @@ func (s *Server) CountConnections() int {
737737

738738
// GetProxyForConnection returns the proxy associated with the given connection.
739739
func (s *Server) GetProxyForConnection(conn *ConnWrapper) (IProxy, bool) {
740-
proxy, exists := s.connectionToProxyMap[conn]
741-
return proxy, exists
740+
proxy, exists := s.connectionToProxyMap.Load(conn)
741+
if !exists {
742+
return nil, false
743+
}
744+
745+
if proxy, ok := proxy.(IProxy); ok {
746+
return proxy, true
747+
}
748+
749+
return nil, false
742750
}
743751

744752
// RemoveConnectionFromMap removes the given connection from the connection-to-proxy map.
745753
func (s *Server) RemoveConnectionFromMap(conn *ConnWrapper) {
746-
delete(s.connectionToProxyMap, conn)
754+
s.connectionToProxyMap.Delete(conn)
747755
}

0 commit comments

Comments
 (0)