Skip to content

Commit

Permalink
Merge pull request #32 from tcriess/change-websocket-improvements
Browse files Browse the repository at this point in the history
Fix potential write to closed channel
  • Loading branch information
GRVYDEV authored Feb 21, 2021
2 parents 2d62791 + 9a7a5ce commit ec2e136
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 31 deletions.
14 changes: 10 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,11 @@ func websocketHandler(w http.ResponseWriter, r *http.Request) {
Event: ws.MessageTypeCandidate,
Data: string(candidateString),
}); err == nil {
c.Send <- msg
hub.RLock()
if _, ok := hub.Clients[c]; ok {
c.Send <- msg
}
hub.RUnlock()
} else {
log.Println(err)
}
Expand Down Expand Up @@ -262,12 +266,14 @@ func websocketHandler(w http.ResponseWriter, r *http.Request) {
Event: ws.MessageTypeOffer,
Data: string(offerString),
}); err == nil {
c.Send <- msg
hub.RLock()
if _, ok := hub.Clients[c]; ok {
c.Send <- msg
}
hub.RUnlock()
} else {
log.Printf("could not marshal ws message: %s", err)
}

go hub.SendInfo(hub.GetInfo()) // non-blocking broadcast, required as the read loop is not started yet.

c.ReadLoop()
}
17 changes: 5 additions & 12 deletions ws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ func (c *Client) ReadLoop() {
return
}
}

// we do not send anything to the other clients!
//message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1))
//c.hub.Broadcast <- message
}
}

Expand Down Expand Up @@ -120,14 +116,11 @@ func (c *Client) WriteLoop() {
if err != nil {
return
}
_, _ = w.Write(message)

// Add queued messages to the current websocket message.
n := len(c.Send)
for i := 0; i < n; i++ {
_, _ = w.Write([]byte{'\n'})
message = <-c.Send
_, _ = w.Write(message)
_, err = w.Write(message)
if err != nil {
log.Printf("could not send message: %s",err)
w.Close()
return
}

if err := w.Close(); err != nil {
Expand Down
48 changes: 33 additions & 15 deletions ws/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ws
import (
"encoding/json"
"log"
"sync"
"time"
)

Expand All @@ -18,49 +19,66 @@ type Info struct {
}

type Hub struct {
// Registered clients.
clients map[*Client]struct{}
// Registered Clients.
Clients map[*Client]struct{}

// Broadcast messages to all clients.
// Broadcast messages to all Clients.
Broadcast chan []byte

// Register a new client to the hub.
Register chan *Client

// Unregister a client from the hub.
Unregister chan *Client

// lock to prevent write to closed channel
sync.RWMutex
}

func NewHub() *Hub {
return &Hub{
clients: make(map[*Client]struct{}),
Clients: make(map[*Client]struct{}),
Broadcast: make(chan []byte),
Register: make(chan *Client),
Unregister: make(chan *Client),
Register: make(chan *Client, 1),
Unregister: make(chan *Client, 1),
}
}

// NoClients returns the number of clients registered
// NoClients returns the number of Clients registered
func (h *Hub) NoClients() int {
return len(h.clients)
h.RLock()
defer h.RUnlock()
return len(h.Clients)
}

// Run is the main hub event loop handling register, unregister and broadcast events.
func (h *Hub) Run() {
for {
select {
case client := <-h.Register:
h.clients[client] = struct{}{}
h.Lock()
h.Clients[client] = struct{}{}
h.Unlock()
go h.SendInfo(h.GetInfo())
case client := <-h.Unregister:
if _, ok := h.clients[client]; ok {
delete(h.clients, client)
h.RLock()
if _, ok := h.Clients[client]; ok {
h.RUnlock()
h.Lock()
delete(h.Clients, client)
h.Unlock()
client.conn.Close()
close(client.Send)
go h.SendInfo(h.GetInfo()) // this way the number of clients does not change between calling the goroutine and executing it
go h.SendInfo(h.GetInfo()) // this way the number of Clients does not change between calling the goroutine and executing it
} else {
h.RUnlock()
}
case message := <-h.Broadcast:
for client := range h.clients {
h.RLock()
for client := range h.Clients {
client.Send <- message
}
h.RUnlock()
}
}
}
Expand All @@ -71,7 +89,7 @@ func (h *Hub) GetInfo() Info {
}
}

// SendInfo broadcasts hub statistics to all clients.
// SendInfo broadcasts hub statistics to all Clients.
func (h *Hub) SendInfo(info Info) {
i, err := json.Marshal(info)
if err != nil {
Expand All @@ -85,4 +103,4 @@ func (h *Hub) SendInfo(info Info) {
} else {
log.Printf("could not marshal ws message: %s", err)
}
}
}

0 comments on commit ec2e136

Please sign in to comment.