Skip to content

Commit

Permalink
Merge pull request #240 from tonkeeper/websocket-unsubscribe
Browse files Browse the repository at this point in the history
Implement unsubscribe methods for websocket
  • Loading branch information
mr-tron authored Nov 1, 2023
2 parents 4cccfed + 418f35f commit 0d44100
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 22 deletions.
149 changes: 148 additions & 1 deletion pkg/pusher/websocket/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ var _ sources.TransactionSource = &mockTxSource{}
var _ sources.MemPoolSource = &mockMemPool{}
var _ sources.TraceSource = &mockTraceSource{}

func TestHandler(t *testing.T) {
func TestHandler_UnsubscribeWhenConnectionIsClosed(t *testing.T) {
var txSubscribed atomic.Bool // to make "go test -race" happy
var txUnsubscribed atomic.Bool // to make "go test -race" happy
source := &mockTxSource{
Expand Down Expand Up @@ -145,3 +145,150 @@ func TestHandler(t *testing.T) {
require.True(t, memPoolUnsubscribed.Load())
require.True(t, traceUnsubscribed.Load())
}

func TestHandler_UnsubscribeMethods(t *testing.T) {
var txSubscribed atomic.Bool // to make "go test -race" happy
var txUnsubscribed atomic.Bool // to make "go test -race" happy
source := &mockTxSource{
OnSubscribeToTransactions: func(ctx context.Context, deliveryFn sources.DeliveryFn, opts sources.SubscribeToTransactionsOptions) sources.CancelFn {
txSubscribed.Store(true)
return func() {
txUnsubscribed.Store(true)
}
},
}
var memPoolSubscribed atomic.Bool // to make "go test -race" happy
var memPoolUnsubscribed atomic.Bool // to make "go test -race" happy
mempool := &mockMemPool{
OnSubscribeToMessages: func(ctx context.Context, deliveryFn sources.DeliveryFn) (sources.CancelFn, error) {
memPoolSubscribed.Store(true)
return func() {
memPoolUnsubscribed.Store(true)
}, nil
},
}
var traceSubscribed atomic.Bool // to make "go test -race" happy
var traceUnsubscribed atomic.Bool // to make "go test -race" happy
traceSource := &mockTraceSource{
OnSubscribeToTraces: func(ctx context.Context, deliveryFn sources.DeliveryFn, opts sources.SubscribeToTraceOptions) sources.CancelFn {
traceSubscribed.Store(true)
return func() {
traceUnsubscribed.Store(true)
}
},
}
logger, _ := zap.NewDevelopment()
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
handler := Handler(logger, source, traceSource, mempool)
err := handler(writer, request, 0, false)
require.Nil(t, err)
}))
defer server.Close()

url := strings.Replace(server.URL, "http", "ws", -1)
conn, _, err := websocket.DefaultDialer.Dial(url, nil)
require.Nil(t, err)

requests := []JsonRPCRequest{
{
ID: 1,
JSONRPC: "2.0",
Method: "subscribe_account",
Params: []string{
"-1:5555555555555555555555555555555555555555555555555555555555555555",
"0:5555555555555555555555555555555555555555555555555555555555555555",
},
},
{
ID: 2,
JSONRPC: "2.0",
Method: "subscribe_mempool",
},
{
ID: 3,
JSONRPC: "2.0",
Method: "subscribe_trace",
Params: []string{
"0:5555555555555555555555555555555555555555555555555555555555555555",
},
},
}
expectedResponses := [][]byte{
[]byte(`{"id":1,"jsonrpc":"2.0","method":"subscribe_account","result":"success! 2 new subscriptions created"}` + "\n"),
[]byte(`{"id":2,"jsonrpc":"2.0","method":"subscribe_mempool","result":"success! you have subscribed to mempool"}` + "\n"),
[]byte(`{"id":3,"jsonrpc":"2.0","method":"subscribe_trace","result":"success! 1 new subscriptions created"}` + "\n"),
}

for i, request := range requests {
expectedResponse := expectedResponses[i]
err = conn.WriteJSON(request)
require.Nil(t, err)

time.Sleep(1 * time.Second)

msgType, msg, err := conn.ReadMessage()
require.Nil(t, err)
require.Equal(t, websocket.TextMessage, msgType)

require.Equal(t, expectedResponse, msg)
}
require.True(t, txSubscribed.Load())
require.False(t, txUnsubscribed.Load())

require.True(t, memPoolSubscribed.Load())
require.False(t, memPoolUnsubscribed.Load())

require.True(t, traceSubscribed.Load())
require.False(t, traceUnsubscribed.Load())

time.Sleep(1 * time.Second)

requests = []JsonRPCRequest{
{
ID: 4,
JSONRPC: "2.0",
Method: "unsubscribe_account",
Params: []string{
"-1:5555555555555555555555555555555555555555555555555555555555555555",
"-1:3333333333333333333333333333333333333333333333333333333333333333",
},
},
{
ID: 5,
JSONRPC: "2.0",
Method: "unsubscribe_mempool",
},
{
ID: 6,
JSONRPC: "2.0",
Method: "unsubscribe_trace",
Params: []string{
"0:5555555555555555555555555555555555555555555555555555555555555555",
"0:3333333333333333333333333333333333333333333333333333333333333333",
},
},
}
expectedResponses = [][]byte{
[]byte(`{"id":4,"jsonrpc":"2.0","method":"unsubscribe_account","result":"success! 1 subscription(s) removed"}` + "\n"),
[]byte(`{"id":5,"jsonrpc":"2.0","method":"unsubscribe_mempool","result":"success! you have unsubscribed from mempool"}` + "\n"),
[]byte(`{"id":6,"jsonrpc":"2.0","method":"unsubscribe_trace","result":"success! 1 subscription(s) removed"}` + "\n"),
}

for i, request := range requests {
expectedResponse := expectedResponses[i]
err = conn.WriteJSON(request)
require.Nil(t, err)

time.Sleep(1 * time.Second)

msgType, msg, err := conn.ReadMessage()
require.Nil(t, err)
require.Equal(t, websocket.TextMessage, msgType)

require.Equal(t, expectedResponse, msg)
}

require.True(t, txUnsubscribed.Load())
require.True(t, memPoolUnsubscribed.Load())
require.True(t, traceUnsubscribed.Load())
}
54 changes: 50 additions & 4 deletions pkg/pusher/websocket/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,23 @@ func (s *session) Run(ctx context.Context) chan JsonRPCRequest {
case request := <-requestCh:
var response string
switch request.Method {
// handle transaction subscriptions
case "subscribe_account":
response = s.subscribeToTransactions(ctx, request.Params)
case "unsubscribe_account":
response = s.unsubscribeFromTransactions(request.Params)

// handle mempool subscriptions
case "subscribe_mempool":
response = s.subscribeToMempool(ctx)
case "unsubscribe_mempool":
response = s.unsubscribeFromMempool()

// handle trace subscriptions
case "subscribe_trace":
response = s.subscribeToTraces(ctx, request.Params)
case "unsubscribe_account":
response = s.unsubscribe(request.Params)
case "unsubscribe_trace":
response = s.unsubscribeFromTraces(request.Params)
}
err = s.writeResponse(response, request)
case <-time.After(s.pingInterval):
Expand Down Expand Up @@ -153,8 +162,20 @@ func (s *session) subscribeToTransactions(ctx context.Context, params []string)
return fmt.Sprintf("success! %v new subscriptions created", counter)
}

func (s *session) unsubscribe(params []string) string {
return "not supported yet"
func (s *session) unsubscribeFromTransactions(params []string) string {
var counter int
for _, a := range params {
account, err := tongo.ParseAddress(a)
if err != nil {
return fmt.Sprintf("failed to process '%v' account: %v", a, err)
}
if cancelFn, ok := s.txSubscriptions[account.ID]; ok {
cancelFn()
delete(s.txSubscriptions, account.ID)
counter += 1
}
}
return fmt.Sprintf("success! %v subscription(s) removed", counter)
}

func (s *session) subscribeToTraces(ctx context.Context, params []string) string {
Expand Down Expand Up @@ -190,6 +211,22 @@ func (s *session) subscribeToTraces(ctx context.Context, params []string) string
return fmt.Sprintf("success! %v new subscriptions created", counter)
}

func (s *session) unsubscribeFromTraces(params []string) string {
var counter int
for _, a := range params {
account, err := tongo.ParseAddress(a)
if err != nil {
return fmt.Sprintf("failed to process '%v' account: %v", a, err)
}
if cancelFn, ok := s.traceSubscriptions[account.ID]; ok {
cancelFn()
delete(s.traceSubscriptions, account.ID)
counter += 1
}
}
return fmt.Sprintf("success! %v subscription(s) removed", counter)
}

func (s *session) subscribeToMempool(ctx context.Context) string {
if s.mempoolSubscription != nil {
return fmt.Sprintf("you are already subscribed to mempool")
Expand All @@ -204,6 +241,15 @@ func (s *session) subscribeToMempool(ctx context.Context) string {
return fmt.Sprintf("success! you have subscribed to mempool")
}

func (s *session) unsubscribeFromMempool() string {
if s.mempoolSubscription == nil {
return fmt.Sprintf("you are not subscribed to mempool")
}
s.mempoolSubscription()
s.mempoolSubscription = nil
return fmt.Sprintf("success! you have unsubscribed from mempool")
}

func jsonRPCResponseMessage(message string, id uint64, jsonrpc, method string) (JsonRPCResponse, error) {
mes, err := json.Marshal(message)
if err != nil {
Expand Down
15 changes: 15 additions & 0 deletions tonapi/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,26 @@ func NewStreamingAPI(opts ...StreamingOption) *StreamingAPI {
// Websocket contains methods to configure a websocket connection to receive particular events from tonapi.io
// happening in the TON blockchain.
type Websocket interface {
// SubscribeToTransactions subscribes to notifications about new transactions for the specified accounts.
SubscribeToTransactions(accounts []string) error
// UnsubscribeFromTransactions unsubscribes from notifications about new transactions for the specified accounts.
UnsubscribeFromTransactions(accounts []string) error

// SubscribeToTraces subscribes to notifications about new traces for the specified accounts.
SubscribeToTraces(accounts []string) error
// UnsubscribeFromTraces unsubscribes from notifications about new traces for the specified accounts.
UnsubscribeFromTraces(accounts []string) error
// SubscribeToMempool subscribes to notifications about new messages in the TON network.

SubscribeToMempool() error
// UnsubscribeFromMempool unsubscribes from notifications about new messages in the TON network.
UnsubscribeFromMempool() error

// SetMempoolHandler defines a callback that will be called when a new mempool event is received.
SetMempoolHandler(handler MempoolHandler)
// SetTransactionHandler defines a callback that will be called when a new transaction event is received.
SetTransactionHandler(handler TransactionHandler)
// SetTraceHandler defines a callback that will be called when a new trace event is received.
SetTraceHandler(handler TraceHandler)
}

Expand Down
41 changes: 24 additions & 17 deletions tonapi/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,35 +39,42 @@ type websocketConnection struct {
}

func (w *websocketConnection) SubscribeToTransactions(accounts []string) error {
request := JsonRPCRequest{
ID: 1,
JSONRPC: "2.0",
Method: "subscribe_account",
Params: accounts,
}
request := JsonRPCRequest{ID: 1, JSONRPC: "2.0", Method: "subscribe_account", Params: accounts}
w.mu.Lock()
defer w.mu.Unlock()
return w.conn.WriteJSON(request)
}

func (w *websocketConnection) UnsubscribeFromTransactions(accounts []string) error {
request := JsonRPCRequest{ID: 1, JSONRPC: "2.0", Method: "unsubscribe_account", Params: accounts}
w.mu.Lock()
defer w.mu.Unlock()
return w.conn.WriteJSON(request)
}

func (w *websocketConnection) SubscribeToTraces(accounts []string) error {
request := JsonRPCRequest{
ID: 1,
JSONRPC: "2.0",
Method: "subscribe_trace",
Params: accounts,
}
request := JsonRPCRequest{ID: 1, JSONRPC: "2.0", Method: "subscribe_trace", Params: accounts}
w.mu.Lock()
defer w.mu.Unlock()
return w.conn.WriteJSON(request)
}

func (w *websocketConnection) UnsubscribeFromTraces(accounts []string) error {
request := JsonRPCRequest{ID: 1, JSONRPC: "2.0", Method: "unsubscribe_trace", Params: accounts}
w.mu.Lock()
defer w.mu.Unlock()
return w.conn.WriteJSON(request)
}

func (w *websocketConnection) SubscribeToMempool() error {
request := JsonRPCRequest{
ID: 1,
JSONRPC: "2.0",
Method: "subscribe_mempool",
}
request := JsonRPCRequest{ID: 1, JSONRPC: "2.0", Method: "subscribe_mempool"}
w.mu.Lock()
defer w.mu.Unlock()
return w.conn.WriteJSON(request)
}

func (w *websocketConnection) UnsubscribeFromMempool() error {
request := JsonRPCRequest{ID: 1, JSONRPC: "2.0", Method: "unsubscribe_mempool"}
w.mu.Lock()
defer w.mu.Unlock()
return w.conn.WriteJSON(request)
Expand Down

0 comments on commit 0d44100

Please sign in to comment.