diff --git a/util/websocket_utils.go b/util/websocket_utils.go index 5e87dac..d8d78da 100644 --- a/util/websocket_utils.go +++ b/util/websocket_utils.go @@ -10,9 +10,13 @@ import ( "github.com/gorilla/websocket" ) +const ( + pongWait = 15 * time.Second +) + type WebSocketEventHub struct { Connections map[string]*websocket.Conn - receiveMap map[string]bool + listenTag sync.Map funcMap sync.Map } @@ -30,7 +34,6 @@ type EventReceives struct { func NewWebSocketEvent() *WebSocketEventHub { return &WebSocketEventHub{ Connections: make(map[string]*websocket.Conn), - receiveMap: make(map[string]bool), } } @@ -41,18 +44,28 @@ func NewEventReceives() *EventReceives { } } +func createConnection(uri string, headers http.Header) *websocket.Conn { + conn, _, err := websocket.DefaultDialer.Dial(uri, headers) + if err != nil { + log.Fatal("dial:", err) + return nil + } + // conn.SetReadDeadline(time.Now().Add(pongWait)) + conn.SetPongHandler(func(appData string) error { + conn.SetReadDeadline(time.Now().Add(pongWait)) + return nil + }) + return conn +} + func (eventHub *WebSocketEventHub) CreateManagement(eventCode string, websocketHost string, token string) bool { if eventHub.Connections[eventCode] == nil { var socketUri = fmt.Sprintf("%s/events/v1/management/sub?code=%s", websocketHost, eventCode) // fmt.Println(socketUri) - conn, _, err := websocket.DefaultDialer.Dial( - socketUri, - http.Header{ - "Authorization": []string{token}, - }, - ) - if err != nil { - log.Printf("dail: %s", err.Error()) + conn := createConnection(socketUri, http.Header{ + "Authorization": []string{token}, + }) + if conn == nil { return false } eventHub.Connections[eventCode] = conn @@ -64,12 +77,8 @@ func (eventHub *WebSocketEventHub) CreateAuthentication(eventCode string, websoc if eventHub.Connections[eventCode] == nil { var socketUri = fmt.Sprintf("%s/events/v1/authentication/sub?code=%s&token=%s", websocketHost, eventCode, token) // fmt.Println(socketUri) - conn, _, err := websocket.DefaultDialer.Dial( - socketUri, - http.Header{}, - ) - if err != nil { - log.Printf("dail: %s", err.Error()) + conn := createConnection(socketUri, http.Header{}) + if conn == nil { return false } eventHub.Connections[eventCode] = conn @@ -89,27 +98,50 @@ func (eventHub *WebSocketEventHub) AddReceiver(eventCode string, onSuccess func( eventHub.funcMap.Store(eventCode, receivers) } +func connPong(conn *websocket.Conn) { + ticker := time.NewTicker(pongWait) + defer ticker.Stop() + for range ticker.C { + err := conn.WriteMessage(websocket.PongMessage, nil) + if err != nil { + log.Fatal(err) + } + } +} + func (eventHub *WebSocketEventHub) StartReceive(eventCode string) { - started, ok := eventHub.receiveMap[eventCode] - if ok && started { + started, loaded := eventHub.listenTag.LoadOrStore(eventCode, true) + if loaded && started.(bool) { return } - eventHub.receiveMap[eventCode] = true + log.Println("start connection receive") + + conn := eventHub.Connections[eventCode] + defer conn.Close() + go connPong(conn) + ticker := time.NewTicker(pongWait) + defer ticker.Stop() + begin_time := time.Now() + count := 0 for { - _, message, err := eventHub.Connections[eventCode].ReadMessage() - if funcMap, ok := eventHub.funcMap.Load(eventCode); ok { - funcs := funcMap.(*EventReceives) - if err != nil { - for _, onError := range funcs.Errors { - onError(err) - } - } else { - for _, onSuccess := range funcs.Successes { - onSuccess(message) + select { + case <-ticker.C: + log.Printf("received %v messages, has been running for %v seconds", count, time.Since(begin_time).Seconds()) + default: + _, message, err := conn.ReadMessage() + if funcMap, ok := eventHub.funcMap.Load(eventCode); ok { + funcs := funcMap.(*EventReceives) + if err != nil { + for _, onError := range funcs.Errors { + onError(err) + } + } else { + for _, onSuccess := range funcs.Successes { + onSuccess(message) + } } } + count += 1 } - - time.Sleep(time.Microsecond * 500) } }