From 418f35fbe2d9d921cdafc7f5f0f84025731bf5c9 Mon Sep 17 00:00:00 2001 From: "aleksej.paschenko" Date: Wed, 1 Nov 2023 18:27:55 +0300 Subject: [PATCH] Implement unsubscribe methods for websocket --- pkg/pusher/websocket/handler_test.go | 149 ++++++++++++++++++++++++++- pkg/pusher/websocket/session.go | 54 +++++++++- tonapi/streaming.go | 15 +++ tonapi/websocket.go | 41 +++++--- 4 files changed, 237 insertions(+), 22 deletions(-) diff --git a/pkg/pusher/websocket/handler_test.go b/pkg/pusher/websocket/handler_test.go index ec84557f..8f8240ac 100644 --- a/pkg/pusher/websocket/handler_test.go +++ b/pkg/pusher/websocket/handler_test.go @@ -44,7 +44,7 @@ var _ sources.TransactionSource = &mockTxSource{} var _ sources.MemPoolSource = &mockMemPool{} var _ sources.TraceSource = &mockTraceSource{} -func TestHandler(t *testing.T) { +func TestHandler_UnsubscribeWhenConnectionIsClosed(t *testing.T) { var txSubscribed atomic.Bool // to make "go test -race" happy var txUnsubscribed atomic.Bool // to make "go test -race" happy source := &mockTxSource{ @@ -145,3 +145,150 @@ func TestHandler(t *testing.T) { require.True(t, memPoolUnsubscribed.Load()) require.True(t, traceUnsubscribed.Load()) } + +func TestHandler_UnsubscribeMethods(t *testing.T) { + var txSubscribed atomic.Bool // to make "go test -race" happy + var txUnsubscribed atomic.Bool // to make "go test -race" happy + source := &mockTxSource{ + OnSubscribeToTransactions: func(ctx context.Context, deliveryFn sources.DeliveryFn, opts sources.SubscribeToTransactionsOptions) sources.CancelFn { + txSubscribed.Store(true) + return func() { + txUnsubscribed.Store(true) + } + }, + } + var memPoolSubscribed atomic.Bool // to make "go test -race" happy + var memPoolUnsubscribed atomic.Bool // to make "go test -race" happy + mempool := &mockMemPool{ + OnSubscribeToMessages: func(ctx context.Context, deliveryFn sources.DeliveryFn) (sources.CancelFn, error) { + memPoolSubscribed.Store(true) + return func() { + memPoolUnsubscribed.Store(true) + }, nil + }, + } + var traceSubscribed atomic.Bool // to make "go test -race" happy + var traceUnsubscribed atomic.Bool // to make "go test -race" happy + traceSource := &mockTraceSource{ + OnSubscribeToTraces: func(ctx context.Context, deliveryFn sources.DeliveryFn, opts sources.SubscribeToTraceOptions) sources.CancelFn { + traceSubscribed.Store(true) + return func() { + traceUnsubscribed.Store(true) + } + }, + } + logger, _ := zap.NewDevelopment() + server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + handler := Handler(logger, source, traceSource, mempool) + err := handler(writer, request, 0, false) + require.Nil(t, err) + })) + defer server.Close() + + url := strings.Replace(server.URL, "http", "ws", -1) + conn, _, err := websocket.DefaultDialer.Dial(url, nil) + require.Nil(t, err) + + requests := []JsonRPCRequest{ + { + ID: 1, + JSONRPC: "2.0", + Method: "subscribe_account", + Params: []string{ + "-1:5555555555555555555555555555555555555555555555555555555555555555", + "0:5555555555555555555555555555555555555555555555555555555555555555", + }, + }, + { + ID: 2, + JSONRPC: "2.0", + Method: "subscribe_mempool", + }, + { + ID: 3, + JSONRPC: "2.0", + Method: "subscribe_trace", + Params: []string{ + "0:5555555555555555555555555555555555555555555555555555555555555555", + }, + }, + } + expectedResponses := [][]byte{ + []byte(`{"id":1,"jsonrpc":"2.0","method":"subscribe_account","result":"success! 2 new subscriptions created"}` + "\n"), + []byte(`{"id":2,"jsonrpc":"2.0","method":"subscribe_mempool","result":"success! you have subscribed to mempool"}` + "\n"), + []byte(`{"id":3,"jsonrpc":"2.0","method":"subscribe_trace","result":"success! 1 new subscriptions created"}` + "\n"), + } + + for i, request := range requests { + expectedResponse := expectedResponses[i] + err = conn.WriteJSON(request) + require.Nil(t, err) + + time.Sleep(1 * time.Second) + + msgType, msg, err := conn.ReadMessage() + require.Nil(t, err) + require.Equal(t, websocket.TextMessage, msgType) + + require.Equal(t, expectedResponse, msg) + } + require.True(t, txSubscribed.Load()) + require.False(t, txUnsubscribed.Load()) + + require.True(t, memPoolSubscribed.Load()) + require.False(t, memPoolUnsubscribed.Load()) + + require.True(t, traceSubscribed.Load()) + require.False(t, traceUnsubscribed.Load()) + + time.Sleep(1 * time.Second) + + requests = []JsonRPCRequest{ + { + ID: 4, + JSONRPC: "2.0", + Method: "unsubscribe_account", + Params: []string{ + "-1:5555555555555555555555555555555555555555555555555555555555555555", + "-1:3333333333333333333333333333333333333333333333333333333333333333", + }, + }, + { + ID: 5, + JSONRPC: "2.0", + Method: "unsubscribe_mempool", + }, + { + ID: 6, + JSONRPC: "2.0", + Method: "unsubscribe_trace", + Params: []string{ + "0:5555555555555555555555555555555555555555555555555555555555555555", + "0:3333333333333333333333333333333333333333333333333333333333333333", + }, + }, + } + expectedResponses = [][]byte{ + []byte(`{"id":4,"jsonrpc":"2.0","method":"unsubscribe_account","result":"success! 1 subscription(s) removed"}` + "\n"), + []byte(`{"id":5,"jsonrpc":"2.0","method":"unsubscribe_mempool","result":"success! you have unsubscribed from mempool"}` + "\n"), + []byte(`{"id":6,"jsonrpc":"2.0","method":"unsubscribe_trace","result":"success! 1 subscription(s) removed"}` + "\n"), + } + + for i, request := range requests { + expectedResponse := expectedResponses[i] + err = conn.WriteJSON(request) + require.Nil(t, err) + + time.Sleep(1 * time.Second) + + msgType, msg, err := conn.ReadMessage() + require.Nil(t, err) + require.Equal(t, websocket.TextMessage, msgType) + + require.Equal(t, expectedResponse, msg) + } + + require.True(t, txUnsubscribed.Load()) + require.True(t, memPoolUnsubscribed.Load()) + require.True(t, traceUnsubscribed.Load()) +} diff --git a/pkg/pusher/websocket/session.go b/pkg/pusher/websocket/session.go index b25a2c43..9c607ad3 100644 --- a/pkg/pusher/websocket/session.go +++ b/pkg/pusher/websocket/session.go @@ -87,14 +87,23 @@ func (s *session) Run(ctx context.Context) chan JsonRPCRequest { case request := <-requestCh: var response string switch request.Method { + // handle transaction subscriptions case "subscribe_account": response = s.subscribeToTransactions(ctx, request.Params) + case "unsubscribe_account": + response = s.unsubscribeFromTransactions(request.Params) + + // handle mempool subscriptions case "subscribe_mempool": response = s.subscribeToMempool(ctx) + case "unsubscribe_mempool": + response = s.unsubscribeFromMempool() + + // handle trace subscriptions case "subscribe_trace": response = s.subscribeToTraces(ctx, request.Params) - case "unsubscribe_account": - response = s.unsubscribe(request.Params) + case "unsubscribe_trace": + response = s.unsubscribeFromTraces(request.Params) } err = s.writeResponse(response, request) case <-time.After(s.pingInterval): @@ -153,8 +162,20 @@ func (s *session) subscribeToTransactions(ctx context.Context, params []string) return fmt.Sprintf("success! %v new subscriptions created", counter) } -func (s *session) unsubscribe(params []string) string { - return "not supported yet" +func (s *session) unsubscribeFromTransactions(params []string) string { + var counter int + for _, a := range params { + account, err := tongo.ParseAddress(a) + if err != nil { + return fmt.Sprintf("failed to process '%v' account: %v", a, err) + } + if cancelFn, ok := s.txSubscriptions[account.ID]; ok { + cancelFn() + delete(s.txSubscriptions, account.ID) + counter += 1 + } + } + return fmt.Sprintf("success! %v subscription(s) removed", counter) } func (s *session) subscribeToTraces(ctx context.Context, params []string) string { @@ -190,6 +211,22 @@ func (s *session) subscribeToTraces(ctx context.Context, params []string) string return fmt.Sprintf("success! %v new subscriptions created", counter) } +func (s *session) unsubscribeFromTraces(params []string) string { + var counter int + for _, a := range params { + account, err := tongo.ParseAddress(a) + if err != nil { + return fmt.Sprintf("failed to process '%v' account: %v", a, err) + } + if cancelFn, ok := s.traceSubscriptions[account.ID]; ok { + cancelFn() + delete(s.traceSubscriptions, account.ID) + counter += 1 + } + } + return fmt.Sprintf("success! %v subscription(s) removed", counter) +} + func (s *session) subscribeToMempool(ctx context.Context) string { if s.mempoolSubscription != nil { return fmt.Sprintf("you are already subscribed to mempool") @@ -204,6 +241,15 @@ func (s *session) subscribeToMempool(ctx context.Context) string { return fmt.Sprintf("success! you have subscribed to mempool") } +func (s *session) unsubscribeFromMempool() string { + if s.mempoolSubscription == nil { + return fmt.Sprintf("you are not subscribed to mempool") + } + s.mempoolSubscription() + s.mempoolSubscription = nil + return fmt.Sprintf("success! you have unsubscribed from mempool") +} + func jsonRPCResponseMessage(message string, id uint64, jsonrpc, method string) (JsonRPCResponse, error) { mes, err := json.Marshal(message) if err != nil { diff --git a/tonapi/streaming.go b/tonapi/streaming.go index 5bffa276..7078014b 100644 --- a/tonapi/streaming.go +++ b/tonapi/streaming.go @@ -100,11 +100,26 @@ func NewStreamingAPI(opts ...StreamingOption) *StreamingAPI { // Websocket contains methods to configure a websocket connection to receive particular events from tonapi.io // happening in the TON blockchain. type Websocket interface { + // SubscribeToTransactions subscribes to notifications about new transactions for the specified accounts. SubscribeToTransactions(accounts []string) error + // UnsubscribeFromTransactions unsubscribes from notifications about new transactions for the specified accounts. + UnsubscribeFromTransactions(accounts []string) error + + // SubscribeToTraces subscribes to notifications about new traces for the specified accounts. SubscribeToTraces(accounts []string) error + // UnsubscribeFromTraces unsubscribes from notifications about new traces for the specified accounts. + UnsubscribeFromTraces(accounts []string) error + // SubscribeToMempool subscribes to notifications about new messages in the TON network. + SubscribeToMempool() error + // UnsubscribeFromMempool unsubscribes from notifications about new messages in the TON network. + UnsubscribeFromMempool() error + + // SetMempoolHandler defines a callback that will be called when a new mempool event is received. SetMempoolHandler(handler MempoolHandler) + // SetTransactionHandler defines a callback that will be called when a new transaction event is received. SetTransactionHandler(handler TransactionHandler) + // SetTraceHandler defines a callback that will be called when a new trace event is received. SetTraceHandler(handler TraceHandler) } diff --git a/tonapi/websocket.go b/tonapi/websocket.go index a15c904c..c53e2fe5 100644 --- a/tonapi/websocket.go +++ b/tonapi/websocket.go @@ -39,35 +39,42 @@ type websocketConnection struct { } func (w *websocketConnection) SubscribeToTransactions(accounts []string) error { - request := JsonRPCRequest{ - ID: 1, - JSONRPC: "2.0", - Method: "subscribe_account", - Params: accounts, - } + request := JsonRPCRequest{ID: 1, JSONRPC: "2.0", Method: "subscribe_account", Params: accounts} + w.mu.Lock() + defer w.mu.Unlock() + return w.conn.WriteJSON(request) +} + +func (w *websocketConnection) UnsubscribeFromTransactions(accounts []string) error { + request := JsonRPCRequest{ID: 1, JSONRPC: "2.0", Method: "unsubscribe_account", Params: accounts} w.mu.Lock() defer w.mu.Unlock() return w.conn.WriteJSON(request) } func (w *websocketConnection) SubscribeToTraces(accounts []string) error { - request := JsonRPCRequest{ - ID: 1, - JSONRPC: "2.0", - Method: "subscribe_trace", - Params: accounts, - } + request := JsonRPCRequest{ID: 1, JSONRPC: "2.0", Method: "subscribe_trace", Params: accounts} + w.mu.Lock() + defer w.mu.Unlock() + return w.conn.WriteJSON(request) +} + +func (w *websocketConnection) UnsubscribeFromTraces(accounts []string) error { + request := JsonRPCRequest{ID: 1, JSONRPC: "2.0", Method: "unsubscribe_trace", Params: accounts} w.mu.Lock() defer w.mu.Unlock() return w.conn.WriteJSON(request) } func (w *websocketConnection) SubscribeToMempool() error { - request := JsonRPCRequest{ - ID: 1, - JSONRPC: "2.0", - Method: "subscribe_mempool", - } + request := JsonRPCRequest{ID: 1, JSONRPC: "2.0", Method: "subscribe_mempool"} + w.mu.Lock() + defer w.mu.Unlock() + return w.conn.WriteJSON(request) +} + +func (w *websocketConnection) UnsubscribeFromMempool() error { + request := JsonRPCRequest{ID: 1, JSONRPC: "2.0", Method: "unsubscribe_mempool"} w.mu.Lock() defer w.mu.Unlock() return w.conn.WriteJSON(request)