Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug in handling postgres COPY command and a few others #610

Merged
merged 4 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions network/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) {
ctx = context.Background()
}

var received int
total := 0
buffer := bytes.NewBuffer(nil)
// Read the data in chunks.
for ctx.Err() == nil {
Expand All @@ -240,19 +240,19 @@ func (c *Client) Receive() (int, []byte, *gerr.GatewayDError) {
if err != nil {
c.logger.Error().Err(err).Msg("Couldn't receive data from the server")
span.RecordError(err)
return received, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err)
return total, buffer.Bytes(), gerr.ErrClientReceiveFailed.Wrap(err)
}
received += read
total += read
buffer.Write(chunk[:read])

if read == 0 || read < c.ReceiveChunkSize {
if read < c.ReceiveChunkSize {
break
}
}

span.AddEvent("Received data from server")

return received, buffer.Bytes(), nil
return total, buffer.Bytes(), nil
}

// Reconnect reconnects to the server.
Expand Down
28 changes: 18 additions & 10 deletions network/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,17 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate
received, response, err := pr.receiveTrafficFromServer(client)
span.AddEvent("Received traffic from server")

// If the response is empty, don't send anything, instead just close the ingress connection.
if received == 0 || err != nil {
// If there is no data to send to the client,
// we don't need to run the hooks and
// we obviously have no data to send to the client.
if received == 0 {
span.AddEvent("No data to send to client")
stack.PopLastRequest()
return nil
}

// If there is an error, close the ingress connection.
if err != nil {
fields := map[string]interface{}{"function": "proxy.passthrough"}
if client.LocalAddr() != "" {
fields["localAddr"] = client.LocalAddr()
Expand All @@ -517,7 +526,7 @@ func (pr *Proxy) PassThroughToClient(conn *ConnWrapper, stack *Stack) *gerr.Gate

// Get the last request from the stack.
lastRequest := stack.PopLastRequest()
request := make([]byte, 0)
request := []byte{}
if lastRequest != nil {
request = lastRequest.Data
}
Expand Down Expand Up @@ -698,7 +707,7 @@ func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayD
defer span.End()

// request contains the data from the client.
received := 0
total := 0
buffer := bytes.NewBuffer(nil)
for {
chunk := make([]byte, pr.ClientConfig.ReceiveChunkSize)
Expand All @@ -713,10 +722,10 @@ func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayD
return chunk[:read], gerr.ErrReadFailed.Wrap(err)
}

received += read
total += read
buffer.Write(chunk[:read])

if received == 0 || received < pr.ClientConfig.ReceiveChunkSize {
if read < pr.ClientConfig.ReceiveChunkSize {
break
}

Expand All @@ -725,19 +734,18 @@ func (pr *Proxy) receiveTrafficFromClient(conn net.Conn) ([]byte, *gerr.GatewayD
}
}

length := len(buffer.Bytes())
pr.Logger.Debug().Fields(
map[string]interface{}{
"length": length,
"length": total,
"local": LocalAddr(conn),
"remote": RemoteAddr(conn),
},
).Msg("Received data from client")

span.AddEvent("Received data from client")

metrics.BytesReceivedFromClient.Observe(float64(length))
metrics.TotalTrafficBytes.Observe(float64(length))
metrics.BytesReceivedFromClient.Observe(float64(total))
metrics.TotalTrafficBytes.Observe(float64(total))

return buffer.Bytes(), nil
}
Expand Down
20 changes: 14 additions & 6 deletions network/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ type Server struct {
LoadbalancerStrategyName string
LoadbalancerRules []config.LoadBalancingRule
LoadbalancerConsistentHash *config.ConsistentHash
connectionToProxyMap map[*ConnWrapper]IProxy
connectionToProxyMap *sync.Map
}

var _ IServer = (*Server)(nil)
Expand Down Expand Up @@ -181,7 +181,7 @@ func (s *Server) OnOpen(conn *ConnWrapper) ([]byte, Action) {
}

// Assign connection to proxy
s.connectionToProxyMap[conn] = proxy
s.connectionToProxyMap.Store(conn, proxy)

// Run the OnOpened hooks.
pluginTimeoutCtx, cancel = context.WithTimeout(context.Background(), s.PluginTimeout)
Expand Down Expand Up @@ -696,7 +696,7 @@ func NewServer(
connections: 0,
running: &atomic.Bool{},
stopServer: make(chan struct{}),
connectionToProxyMap: make(map[*ConnWrapper]IProxy),
connectionToProxyMap: &sync.Map{},
LoadbalancerStrategyName: srv.LoadbalancerStrategyName,
LoadbalancerRules: srv.LoadbalancerRules,
LoadbalancerConsistentHash: srv.LoadbalancerConsistentHash,
Expand Down Expand Up @@ -737,11 +737,19 @@ func (s *Server) CountConnections() int {

// GetProxyForConnection returns the proxy associated with the given connection.
func (s *Server) GetProxyForConnection(conn *ConnWrapper) (IProxy, bool) {
proxy, exists := s.connectionToProxyMap[conn]
return proxy, exists
proxy, exists := s.connectionToProxyMap.Load(conn)
if !exists {
return nil, false
}

if proxy, ok := proxy.(IProxy); ok {
return proxy, true
}

return nil, false
}

// RemoveConnectionFromMap removes the given connection from the connection-to-proxy map.
func (s *Server) RemoveConnectionFromMap(conn *ConnWrapper) {
delete(s.connectionToProxyMap, conn)
s.connectionToProxyMap.Delete(conn)
}