Skip to content

Commit

Permalink
Merge pull request #866 from dickens7/feat-bidirectional
Browse files Browse the repository at this point in the history
feat: add nil call server message handler
  • Loading branch information
smallnest authored Jul 25, 2024
2 parents ff384af + 3c32092 commit 4310c44
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
5 changes: 5 additions & 0 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ type Option struct {

// alaways use the selected server until it is bad
Sticky bool

// not call server message handler
NilCallServerMessageHandler func(msg *protocol.Message)
}

// Call represents an active RPC.
Expand Down Expand Up @@ -663,6 +666,8 @@ func (client *Client) input() {
if isServerMessage {
if client.ServerMessageChan != nil {
client.handleServerRequest(res)
} else if client.option.NilCallServerMessageHandler != nil {
client.option.NilCallServerMessageHandler(res)
}
continue
}
Expand Down
60 changes: 60 additions & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package client

import (
"context"
"fmt"
"math/rand"
"net"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -40,6 +42,17 @@ func (t *Arith) ThriftMul(ctx context.Context, args *testutils.ThriftArgs_, repl
return nil
}

type Bidirectional struct {
*server.Server
}

func (t *Bidirectional) Mul(ctx context.Context, args *Args, reply *Reply) error {
conn := ctx.Value(server.RemoteConnContextKey).(net.Conn)
reply.C = args.A * args.B
t.SendMessage(conn, "test_service_path", "test_service_method", nil, []byte("abcde"))
return nil
}

func TestClient_IT(t *testing.T) {
s := server.NewServer()
_ = s.RegisterName("Arith", new(Arith), "")
Expand Down Expand Up @@ -186,3 +199,50 @@ func TestClient_Res_Reset(t *testing.T) {
t.Fatalf("data has been set to empty after response has been reset: %v", data)
}
}

func TestClient_Bidirectional(t *testing.T) {
s := server.NewServer()
_ = s.RegisterName("Bidirectional", &Bidirectional{Server: s}, "")
go func() {
_ = s.Serve("tcp", "127.0.0.1:0")
}()
defer s.Close()
time.Sleep(500 * time.Millisecond)

addr := s.Address().String()

opt := DefaultOption

var receive string

opt.NilCallServerMessageHandler = func(msg *protocol.Message) {
fmt.Printf("receive msg from server: %s\n", msg.Payload)
receive = string(msg.Payload)
}
client := &Client{
option: opt,
}

err := client.Connect("tcp", addr)
if err != nil {
t.Fatalf("failed to connect: %v", err)
}
defer client.Close()

args := &Args{
A: 10,
B: 20,
}
reply := &Reply{}
err = client.Call(context.Background(), "Bidirectional", "Mul", args, reply)
if err != nil {
t.Fatalf("failed to call: %v", err)
}
if receive != "abcde" {
t.Fatalf("expect abcde but got %s", receive)
}
if reply.C != 200 {
t.Fatalf("expect 200 but got %d", reply.C)
}

}

0 comments on commit 4310c44

Please sign in to comment.