diff --git a/client/client.go b/client/client.go index 3357c25e..34a26bb5 100644 --- a/client/client.go +++ b/client/client.go @@ -197,6 +197,9 @@ type Option struct { // alaways use the selected server until it is bad Sticky bool + + // not call server message handler + NilCallServerMessageHandler func(msg *protocol.Message) } // Call represents an active RPC. @@ -663,6 +666,8 @@ func (client *Client) input() { if isServerMessage { if client.ServerMessageChan != nil { client.handleServerRequest(res) + } else if client.option.NilCallServerMessageHandler != nil { + client.option.NilCallServerMessageHandler(res) } continue } diff --git a/client/client_test.go b/client/client_test.go index 25aaca2b..3bd02321 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -2,7 +2,9 @@ package client import ( "context" + "fmt" "math/rand" + "net" "sync" "testing" "time" @@ -40,6 +42,17 @@ func (t *Arith) ThriftMul(ctx context.Context, args *testutils.ThriftArgs_, repl return nil } +type Bidirectional struct { + *server.Server +} + +func (t *Bidirectional) Mul(ctx context.Context, args *Args, reply *Reply) error { + conn := ctx.Value(server.RemoteConnContextKey).(net.Conn) + reply.C = args.A * args.B + t.SendMessage(conn, "test_service_path", "test_service_method", nil, []byte("abcde")) + return nil +} + func TestClient_IT(t *testing.T) { s := server.NewServer() _ = s.RegisterName("Arith", new(Arith), "") @@ -186,3 +199,50 @@ func TestClient_Res_Reset(t *testing.T) { t.Fatalf("data has been set to empty after response has been reset: %v", data) } } + +func TestClient_Bidirectional(t *testing.T) { + s := server.NewServer() + _ = s.RegisterName("Bidirectional", &Bidirectional{Server: s}, "") + go func() { + _ = s.Serve("tcp", "127.0.0.1:0") + }() + defer s.Close() + time.Sleep(500 * time.Millisecond) + + addr := s.Address().String() + + opt := DefaultOption + + var receive string + + opt.NilCallServerMessageHandler = func(msg *protocol.Message) { + fmt.Printf("receive msg from server: %s\n", msg.Payload) + receive = string(msg.Payload) + } + client := &Client{ + option: opt, + } + + err := client.Connect("tcp", addr) + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer client.Close() + + args := &Args{ + A: 10, + B: 20, + } + reply := &Reply{} + err = client.Call(context.Background(), "Bidirectional", "Mul", args, reply) + if err != nil { + t.Fatalf("failed to call: %v", err) + } + if receive != "abcde" { + t.Fatalf("expect abcde but got %s", receive) + } + if reply.C != 200 { + t.Fatalf("expect 200 but got %d", reply.C) + } + +}