diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..5c7247b --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,7 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [] +} \ No newline at end of file diff --git a/README.md b/README.md index 93b1e1f..fd93e5f 100755 --- a/README.md +++ b/README.md @@ -13,7 +13,9 @@ As well as using this library just for go processes it was also designed to work #### NodeJs -I currently use this library to comunicate between a ElectronJs GUI and a go program. +I currently use this library to comunicate between a ElectronJS GUI and a go program. + +Below is a link to the nodeJS client library: https://github.com/james-barrow/node-ipc-client @@ -27,59 +29,97 @@ Create a server with the default configuation and start listening for the client ```go - sc, err := ipc.StartServer("", nil) + s, err := ipc.StartServer("", nil) if err != nil { log.Println(err) return } ``` - Create a client and connect to the server: ```go - cc, err := ipc.StartClient("", nil) + c, err := ipc.StartClient("", nil) if err != nil { log.Println(err) return } ``` -Read and write data to the connection: + +### Read messages + +Read each message sent: + +```go + + for { + + // message, err := s.Read() server + message, err := c.Read() // client + + if err == nil { + // handle error + } + + // do something with the received messages + } + +``` + +All received messages are formated into the type Message + +```go + +type Message struct { + Err error // details of any error + MsgType int // 0 = reserved , -1 is an internal message (disconnection or error etc), all messages recieved will be > 0 + Data []byte // message data received + Status string // the status of the connection +} + +``` + +### Write a message + + +```go + + //err := s.Write(1, []byte("", config) - + Encryption: false ``` ### Unix Socket Permissions Under most configurations, a socket created by a user will by default not be writable by another user, making it impossible for the client and server to communicate if being run by separate users. - The permission mask can be dropped during socket creation by passing custom configuration to the server start function. **This will make the socket writable by any user.** + The permission mask can be dropped during socket creation by passing a custom configuration to the server start function. **This will make the socket writable for any user.** ```go - - config := &ipc.ServerConfig{UnmaskPermissions: true} - sc, err := ipc.StartServer("", config) - + UnmaskPermissions: true ``` Note: Tested on Linux, not tested on Mac, not implemented on Windows. - ### Testing + ## Testing The package has been tested on Mac, Windows and Linux and has extensive test coverage. -### Licence +## Licence MIT diff --git a/client_all.go b/client_all.go index 6146f65..0a89de7 100755 --- a/client_all.go +++ b/client_all.go @@ -10,10 +10,7 @@ import ( ) // StartClient - start the ipc client. -// // ipcName = is the name of the unix socket or named pipe that the client will try and connect to. -// timeout = number of seconds before the socket/pipe times out trying to connect/re-cconnect - if -1 or 0 it never times out. -// retryTimer = number of seconds before the client tries to connect again. func StartClient(ipcName string, config *ClientConfig) (*Client, error) { err := checkIpcName(ipcName) @@ -32,7 +29,7 @@ func StartClient(ipcName string, config *ClientConfig) (*Client, error) { if config == nil { cc.timeout = 0 - cc.retryTimer = time.Duration(1) + cc.retryTimer = time.Duration(20) cc.encryptionReq = true } else { @@ -68,7 +65,7 @@ func startClient(c *Client) { err := c.dial() if err != nil { - c.received <- &Message{err: err, MsgType: -2} + c.received <- &Message{Err: err, MsgType: -1} return } @@ -137,7 +134,7 @@ func (c *Client) readData(buff []byte) bool { if c.status == Closing { c.status = Closed c.received <- &Message{Status: c.status.String(), MsgType: -1} - c.received <- &Message{err: errors.New("client has closed the connection"), MsgType: -2} + c.received <- &Message{Err: errors.New("client has closed the connection"), MsgType: -2} return false } @@ -157,10 +154,10 @@ func (c *Client) reconnect() { err := c.dial() // connect to the pipe if err != nil { - if err.Error() == "Timed out trying to connect" { + if err.Error() == "timed out trying to connect" { c.status = Timeout c.received <- &Message{Status: c.status.String(), MsgType: -1} - c.received <- &Message{err: errors.New("timed out trying to re-connect"), MsgType: -2} + c.received <- &Message{Err: errors.New("timed out trying to re-connect"), MsgType: -1} } return @@ -172,8 +169,8 @@ func (c *Client) reconnect() { go c.read() } -// Read - blocking function that waits until an non multipart message is received -// returns the message type, data and any error. +// Read - blocking function that receices messages +// if MsgType is a negative number its an internal message func (c *Client) Read() (*Message, error) { m, ok := (<-c.received) @@ -181,16 +178,16 @@ func (c *Client) Read() (*Message, error) { return nil, errors.New("the received channel has been closed") } - if m.err != nil { + if m.Err != nil { close(c.received) close(c.toWrite) - return nil, m.err + return nil, m.Err } return m, nil } -// Write - writes a non multipart message to the ipc connection. +// Write - writes a message to the ipc connection. // msgType - denotes the type of data being sent. 0 is a reserved type for internal messages and errors. func (c *Client) Write(msgType int, message []byte) error { diff --git a/connect_other.go b/connect_other.go index 3d6a166..cf2473f 100755 --- a/connect_other.go +++ b/connect_other.go @@ -39,16 +39,9 @@ func (s *Server) run() error { s.listen = listen - s.status = Listening - s.received <- &Message{Status: s.status.String(), MsgType: -1} - s.connChannel = make(chan bool) - go s.acceptLoop() - err = s.connectionTimer() - if err != nil { - return err - } + s.status = Listening return nil @@ -63,9 +56,10 @@ func (c *Client) dial() error { startTime := time.Now() for { + if c.timeout != 0 { - if time.Now().Sub(startTime).Seconds() > c.timeout { + if time.Since(startTime).Seconds() > c.timeout { c.status = Closed return errors.New("timed out trying to connect") } @@ -79,7 +73,7 @@ func (c *Client) dial() error { } else if strings.Contains(err.Error(), "connect: connection refused") { } else { - c.received <- &Message{err: err, MsgType: -2} + c.received <- &Message{Err: err, MsgType: -1} } } else { diff --git a/connect_windows.go b/connect_windows.go index a10f371..2db313c 100755 --- a/connect_windows.go +++ b/connect_windows.go @@ -48,7 +48,7 @@ func (c *Client) dial() error { for { if c.timeout != 0 { - if time.Now().Sub(startTime).Seconds() > c.timeout { + if time.Since(startTime).Seconds() > c.timeout { c.status = Closed return errors.New("timed out trying to connect") } diff --git a/encryption.go b/encryption.go index 9ef54c4..73af1b3 100755 --- a/encryption.go +++ b/encryption.go @@ -103,7 +103,7 @@ func sendPublic(conn net.Conn, pub *ecdsa.PublicKey) error { func recvPublic(conn net.Conn) (*ecdsa.PublicKey, error) { - buff := make([]byte, 300) + buff := make([]byte, 98) i, err := conn.Read(buff) if err != nil { return nil, errors.New("didn't received public key") @@ -145,7 +145,6 @@ func bytesToPublicKey(recvdPub []byte) *ecdsa.PublicKey { func createCipher(shared [32]byte) (*cipher.AEAD, error) { b, err := aes.NewCipher(shared[:]) - if err != nil { return nil, err } @@ -162,11 +161,9 @@ func encrypt(g cipher.AEAD, data []byte) ([]byte, error) { nonce := make([]byte, g.NonceSize()) - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - return nil, err - } + _, err := io.ReadFull(rand.Reader, nonce) - return g.Seal(nonce, nonce, data, nil), nil + return g.Seal(nonce, nonce, data, nil), err } diff --git a/example/example.go b/example/example.go old mode 100755 new mode 100644 index 3cf3f06..ec0cea6 --- a/example/example.go +++ b/example/example.go @@ -2,195 +2,85 @@ package main import ( "log" - "time" ipc "github.com/james-barrow/golang-ipc" ) func main() { - log.Println("starting") - go server() - client() - -} - -func server() { - - //&ipc.ServerConfig{Encryption: false} - - sc, err := ipc.StartServer("testtest", nil) + c, err := ipc.StartClient("example1", nil) if err != nil { log.Println(err) return } - go func() { - - for { - m, err := sc.Read() - - if err == nil { - if m.MsgType > 0 { - log.Println("Server received: "+string(m.Data)+" - Message type: ", m.MsgType) - } - - } else { - - log.Println("Server error") - log.Println(err) - break - } - } - }() - - go serverSend(sc) - go serverSend1(sc) - serverSend2(sc) - -} - -func serverSend(sc *ipc.Server) { - - for { - - sc.Write(3, []byte("Hello Client 4")) - sc.Write(23, []byte("Hello Client 5")) - sc.Write(65, []byte("Hello Client 6")) - - time.Sleep(time.Second / 30) - - } -} - -func serverSend1(sc *ipc.Server) { - for { - sc.Write(5, []byte("Hello Client 1")) - sc.Write(7, []byte("Hello Client 2")) - sc.Write(9, []byte("Hello Client 3")) + message, err := c.Read() - time.Sleep(time.Second / 30) + if err == nil { - } + if message.MsgType == -1 { -} + log.Println("client status", c.Status()) -func serverSend2(sc *ipc.Server) { + if message.Status == "Reconnecting" { + c.Close() + return + } - for { + } else { - sc.Write(88, []byte("Hello Client 7")) - sc.Write(99, []byte("Hello Client 8")) - sc.Write(22, []byte("Hello Client 9")) + log.Println("Client received: "+string(message.Data)+" - Message type: ", message.MsgType) + c.Write(5, []byte("Message from client - PONG")) - time.Sleep(time.Second / 30) + } + } else { + log.Println(err) + break + } } -} -func client() { +} - //config := &ipc.ClientConfig{Encryption: false} +func server() { - cc, err := ipc.StartClient("testtest", nil) + s, err := ipc.StartServer("example1", nil) if err != nil { - log.Println(err) + log.Println("server error", err) return } - go func() { - - for { - m, err := cc.Read() - - if err != nil { - // An error is only returned if the received channel has been closed, - //so you know the connection has either been intentionally closed or has timmed out waiting to connect/re-connect. - break - } - - //if m.MsgType == -1 { // message type -1 is status change - //log.Println("Status: " + m.Status) - //} - - if m.MsgType == -2 { // message type -2 is an error, these won't automatically cause the received channel to close. - log.Println("Error: " + err.Error()) - } - - if m.MsgType > 0 { // all message types above 0 have been received over the connection - - log.Println(" Message type: ", m.MsgType) - log.Println("Client received: " + string(m.Data)) - } - //} - } - - }() - - go clientSend(cc) - go clientSend(cc) - clientSend2(cc) - -} - -func clientSend(cc *ipc.Client) { + log.Println("server status", s.Status()) for { - _ = cc.Write(14, []byte("hello server 4")) - _ = cc.Write(44, []byte("hello server 5")) - _ = cc.Write(88, []byte("hello server 6")) + message, err := s.Read() - time.Sleep(time.Second / 20) + if err == nil { - } - -} + if message.MsgType == -1 { -func clientSend2(cc *ipc.Client) { - - for { + if message.Status == "Connected" { - _ = cc.Write(444, []byte("hello server 7")) - _ = cc.Write(234, []byte("hello server 8")) - _ = cc.Write(111, []byte("hello server 9")) + log.Println("server status", s.Status()) + s.Write(1, []byte("server - PING")) - time.Sleep(time.Second / 20) - - } -} + } -/* -func clientRecv(c *ipc.Client) { + } else { - for { - m, err := c.Read() + log.Println("Server received: "+string(message.Data)+" - Message type: ", message.MsgType) + s.Close() + return + } - if err != nil { - // An error is only returned if the received channel has been closed, - //so you know the connection has either been intentionally closed or has timmed out waiting to connect/re-connect. + } else { break } - - //if m.MsgType == -1 { // message type -1 is status change - // //log.Println("Status: " + m.Status) - //} - - if m.MsgType == -2 { // message type -2 is an error, these won't automatically cause the received channel to close. - log.Println("Error: " + err.Error()) - } - - if m.MsgType > 0 { // all message types above 0 have been received over the connection - - log.Println(" Message type: ", m.MsgType) - log.Println("Client received: " + string(m.Data)) - } - //} } } -*/ diff --git a/ipc_test.go b/ipc_test.go index 2f104b2..194e775 100755 --- a/ipc_test.go +++ b/ipc_test.go @@ -1,6 +1,9 @@ package ipc import ( + "fmt" + "net" + "os" "testing" "time" ) @@ -43,7 +46,6 @@ func TestStartUp_Configs(t *testing.T) { t.Error(err) } - scon.Timeout = -1 scon.MaxMsgSize = -1 _, err5 := StartServer("test", scon) @@ -75,7 +77,7 @@ func TestStartUp_Configs(t *testing.T) { t.Run("Unmask Server Socket Permissions", func(t *testing.T) { scon.UnmaskPermissions = true - _, err := StartServer("test_perm", scon) + srv, err := StartServer("test_perm", scon) if err != nil { t.Error(err) } @@ -83,24 +85,24 @@ func TestStartUp_Configs(t *testing.T) { // test would not work in windows // can check test_perm.sock in /tmp after running tests to see perms - /* - time.Sleep(time.Second / 4) + time.Sleep(time.Second / 4) - info, err := os.Stat(srv.listen.Addr().String()) - if err != nil { - t.Error(err) - } - got := fmt.Sprintf("%04o", info.Mode().Perm()) - want := "0777" + info, err := os.Stat(srv.listen.Addr().String()) + if err != nil { + t.Error(err) + } + got := fmt.Sprintf("%04o", info.Mode().Perm()) + want := "0777" + + if got != want { + t.Errorf("Got %q, Wanted %q", got, want) + } - if got != want { - t.Errorf("Got %q, Wanted %q", got, want) - } - */ scon.UnmaskPermissions = false }) - } + +/* func TestStartUp_Timeout(t *testing.T) { scon := &ServerConfig{ @@ -113,7 +115,7 @@ func TestStartUp_Timeout(t *testing.T) { _, err1 := sc.Read() if err1 != nil { - if err1.Error() != "Timed out waiting for client to connect" { + if err1.Error() != "timed out waiting for client to connect" { t.Error("should of got server timeout") } break @@ -131,7 +133,7 @@ func TestStartUp_Timeout(t *testing.T) { for { _, err := cc.Read() if err != nil { - if err.Error() != "Timed out trying to connect" { + if err.Error() != "timed out trying to connect" { t.Error("should of got timeout as client was trying to connect") } @@ -141,6 +143,7 @@ func TestStartUp_Timeout(t *testing.T) { } } +*/ func TestWrite(t *testing.T) { @@ -182,14 +185,14 @@ func TestWrite(t *testing.T) { buf := make([]byte, 1) err3 := sc.Write(0, buf) - if err3.Error() != "Message type 0 is reserved" { + if err3.Error() != "message type 0 is reserved" { t.Error("0 is not allowed as a message type") } buf = make([]byte, sc.maxMsgSize+5) err4 := sc.Write(2, buf) - if err4.Error() != "Message exceeds maximum message length" { + if err4.Error() != "message exceeds maximum message length" { t.Error("There should be an error as the data we're attempting to write is bigger than the maxMsgSize") } @@ -225,7 +228,6 @@ func TestWrite(t *testing.T) { } else { t.Error("we should have an error becuse there is no connection") } - } func TestRead(t *testing.T) { @@ -308,7 +310,6 @@ func TestRead(t *testing.T) { close(cIPC.received) // close received channel <-clientFinished - } func TestStatus(t *testing.T) { @@ -351,8 +352,8 @@ func TestStatus(t *testing.T) { s3 := sc.getStatus() - if s3.String() != "Re-connecting" { - t.Error("status string should have returned Re-connecting") + if s3.String() != "Reconnecting" { + t.Error("status string should have returned Reconnecting") } sc.status = Closed @@ -371,7 +372,26 @@ func TestStatus(t *testing.T) { t.Error("status string should have returned Error") } - sc.Status() + sc.status = Closing + + s6 := sc.getStatus() + + if s6.String() != "Closing" { + t.Error("status string should have returned Error") + } + + if s6.String() != "Closing" { + t.Error("status string should have returned Error") + } + + sc.status = 33 + + s7 := sc.getStatus() + + fmt.Println(s7.String()) + if s7.String() != "Status not found" { + t.Error("status string should have returned 'Status not found'") + } cc := &Client{ status: NotConnected, @@ -386,29 +406,7 @@ func TestStatus(t *testing.T) { cc2.getStatus() cc2.Status() - -} - -/* -func checkStatus(sc *Server, t *testing.T) bool { - - for i := 0; i < 25; i++ { - - if sc.getStatus() == 3 { - return true - - } else if i == 25 { - t.Error("Server failed to connect") - break - } - - time.Sleep(time.Second / 5) - } - - return false - } -*/ func TestGetConnected(t *testing.T) { @@ -432,7 +430,6 @@ func TestGetConnected(t *testing.T) { break } } - } func TestServerWrongMessageType(t *testing.T) { @@ -567,7 +564,6 @@ func TestClientWrongMessageType(t *testing.T) { sc.Write(2, []byte("")) <-complete - } func TestServerCorrectMessageType(t *testing.T) { @@ -635,7 +631,6 @@ func TestServerCorrectMessageType(t *testing.T) { sc.Write(5, []byte("")) <-complete - } func TestClientCorrectMessageType(t *testing.T) { @@ -705,7 +700,6 @@ func TestClientCorrectMessageType(t *testing.T) { cc.Write(5, []byte("")) <-complete - } func TestServerSendMessage(t *testing.T) { @@ -783,7 +777,6 @@ func TestServerSendMessage(t *testing.T) { sc.Write(5, []byte("Here is a test message sent from the server to the client... -/and some more test data to pad it out a bit")) <-complete - } func TestClientSendMessage(t *testing.T) { @@ -861,6 +854,23 @@ func TestClientSendMessage(t *testing.T) { cc.Write(5, []byte("Here is a test message sent from the client to the server... -/and some more test data to pad it out a bit")) <-complete +} + +func TestEncryptionFunctions(t *testing.T) { + + res := publicKeyToBytes(nil) + if len(res) != 0 { + t.Error("should have returned 0 bytes") + + } + + buff := make([]byte, 0) + + if bytesToPublicKey(buff) != nil { + t.Error("should have failed as buff is 0 bytes") + } + + } @@ -931,7 +941,6 @@ func TestNoEncrytion(t *testing.T) { <-complete <-complete2 - } func TestServerWrongEncrytion(t *testing.T) { @@ -974,7 +983,6 @@ func TestServerWrongEncrytion(t *testing.T) { break } } - } func TestClientClose(t *testing.T) { @@ -998,10 +1006,8 @@ func TestClientClose(t *testing.T) { for { m, _ := sc.Read() - //if m.Status == "Connected" { - //} - if m.Status == "Re-connecting" { + if m.Status == "Disconnected" { holdIt <- false break } @@ -1010,8 +1016,6 @@ func TestClientClose(t *testing.T) { }() - ready := false - for { mm, err := cc.Read() @@ -1022,18 +1026,13 @@ func TestClientClose(t *testing.T) { } if mm.Status == "Closed" { - ready = true + break } } - if err != nil && ready == true { - break - } - } <-holdIt - } func TestServerClose(t *testing.T) { @@ -1058,10 +1057,7 @@ func TestServerClose(t *testing.T) { m, _ := cc.Read() - //if m.Status == "Connected" { - //} - - if m.Status == "Re-connecting" { + if m.Status == "Reconnecting" { holdIt <- false break } @@ -1069,8 +1065,6 @@ func TestServerClose(t *testing.T) { }() - ready := true - for { mm, err2 := sc.Read() @@ -1085,15 +1079,11 @@ func TestServerClose(t *testing.T) { } } - if err2 != nil && ready == true { - break - } - } <-holdIt - } + func TestClientReconnect(t *testing.T) { sc, err := StartServer("test127", nil) @@ -1135,7 +1125,7 @@ func TestClientReconnect(t *testing.T) { clientConnected <- true } - if m.Status == "Re-connecting" { + if m.Status == "Reconnecting" { reconnectCheck = 2 } @@ -1164,111 +1154,95 @@ func TestClientReconnect(t *testing.T) { break } } - } -func TestServerReconnect(t *testing.T) { +func TestClientReconnectTimeout(t *testing.T) { - sc, err := StartServer("test337", nil) + server, err := StartServer("test7", nil) if err != nil { t.Error(err) } time.Sleep(time.Second / 4) - cc, err2 := StartClient("test337", nil) + config := &ClientConfig{ + Timeout: 2, + RetryTimer: 1, + } + + cc, err2 := StartClient("test7", config) if err2 != nil { t.Error(err) } - connected := make(chan bool, 1) - serverConfirm := make(chan bool, 1) - serverConnected := make(chan bool, 1) go func() { for { - m, _ := cc.Read() + m, _ := server.Read() if m.Status == "Connected" { - connected <- true + server.Close() break } } }() - go func() { + connect := false + reconnect := false - reconnectCheck := 0 + for { - for { + mm, err5 := cc.Read() - m, _ := sc.Read() - if m.Status == "Connected" { - serverConnected <- true + if err5 == nil { + if mm.Status == "Connected" { + connect = true } - if m.Status == "Re-connecting" { - reconnectCheck = 2 + if mm.Status == "Reconnecting" { + reconnect = true } - if m.Status == "Connected" && reconnectCheck == 2 { - serverConfirm <- true - break + if mm.Status == "Timeout" && reconnect == true && connect == true { + return } } - }() - - <-connected - <-serverConnected - - cc.Close() - - cc2, err2 := StartClient("test337", nil) - if err2 != nil { - t.Error(err) - } - for { + if err5 != nil { + if err5.Error() != "timed out trying to re-connect" { + t.Fatal("should have got the timed out error") + } - m, _ := cc2.Read() - if m.Status == "Connected" { - <-serverConfirm break + } } - } -func TestClientReconnectTimeout(t *testing.T) { +func TestServerReconnect(t *testing.T) { - sc, err := StartServer("test7", nil) + sc, err := StartServer("test127", nil) if err != nil { t.Error(err) } - time.Sleep(time.Second / 4) - - config := &ClientConfig{ - Timeout: 2, - RetryTimer: 1, - } - - cc, err2 := StartClient("test7", config) + cc, err2 := StartClient("test127", nil) if err2 != nil { - t.Error(err) + t.Error(err2) } connected := make(chan bool, 1) - clientTimout := make(chan bool, 1) + clientConfirm := make(chan bool, 1) clientConnected := make(chan bool, 1) - clientError := make(chan bool, 1) go func() { for { - m, _ := sc.Read() + m, _ := cc.Read() if m.Status == "Connected" { + <-clientConnected + connected <- true break } @@ -1278,71 +1252,68 @@ func TestClientReconnectTimeout(t *testing.T) { go func() { - reconnect := false + reconnectCheck := 0 for { - mm, err5 := cc.Read() - - if err5 == nil { - if mm.Status == "Connected" { - - clientConnected <- true - } - - if mm.Status == "Re-connecting" { - reconnect = true - } - - if mm.Status == "Timeout" && reconnect == true { - clientTimout <- true + m, err := sc.Read() + if err != nil { + fmt.Println(err) + return + } - } + if m.Status == "Connected" { + clientConnected <- true } - if err5 != nil { - if err5.Error() != "Timed out trying to re-connect" { - t.Fatal("should have got the timed out error") - } + if m.Status == "Disconnected" { + reconnectCheck = 1 + } - clientError <- true + if m.Status == "Connected" && reconnectCheck == 1 { + clientConfirm <- true break - } } }() <-connected - <-clientConnected - sc.Close() + cc.Close() - <-clientTimout - <-clientError -} + c2, err := StartClient("test127", nil) + if err != nil { + t.Error(err) + } -func TestServerReconnectTimeout(t *testing.T) { + for { - config := &ServerConfig{ - Timeout: 1, - Encryption: true, + m, _ := c2.Read() + if m.Status == "Connected" { + break + } } - sc, err := StartServer("test7", config) + <-clientConfirm +} + +func TestServerReconnect2(t *testing.T) { + + sc, err := StartServer("test337", nil) if err != nil { t.Error(err) } time.Sleep(time.Second / 4) - cc, err2 := StartClient("test7", nil) + cc, err2 := StartClient("test337", nil) if err2 != nil { t.Error(err) } - connected := make(chan bool, 1) - serverTimout := make(chan bool, 1) - serverConnected := make(chan bool, 1) - serverError := make(chan bool, 1) + + hasConnected := make(chan bool) + hasDisconnected := make(chan bool) + hasReconnected := make(chan bool) go func() { @@ -1350,127 +1321,57 @@ func TestServerReconnectTimeout(t *testing.T) { m, _ := cc.Read() if m.Status == "Connected" { - connected <- true - break - } - } - }() - - go func() { + <-hasConnected - reconnect := false + cc.Close() - for { + <-hasDisconnected - m, err := sc.Read() - - if err == nil { - if m.Status == "Connected" { - serverConnected <- true - } - - if m.Status == "Re-connecting" { - reconnect = true + c2, err2 := StartClient("test337", nil) + if err2 != nil { + t.Error(err) } - if m.Status == "Timeout" && reconnect == true { - serverTimout <- true - } - } + for { - if err != nil { - if err.Error() != "Timed out waiting for client to connect" { - t.Fatal("the error should be timed out waiting for client") - } else { - serverError <- true - break + m, _ := c2.Read() + if m.Status == "Connected" { + break + } } - } - - } - }() - - <-connected - <-serverConnected - - cc.Close() - - <-serverTimout - <-serverError -} - -// From here -func TestServerReadClose(t *testing.T) { - - config := &ServerConfig{ - Timeout: 1, - Encryption: true, - } - - sc, err := StartServer("test7Q", config) - if err != nil { - t.Error(err) - } - - time.Sleep(time.Second / 4) - cc, err2 := StartClient("test7Q", nil) - if err2 != nil { - t.Error(err) - } - connected := make(chan bool, 1) - serverTimout := make(chan bool, 1) - serverConnected := make(chan bool, 1) - serverErrorTwo := make(chan bool, 1) + <-hasReconnected - go func() { - - for { - - m, _ := cc.Read() - if m.Status == "Connected" { - connected <- true - break + return } } }() - go func() { + connect := false + disconnect := false - for { - - m, err3 := sc.Read() - - if err3 != nil { - if err3.Error() == "the received channel has been closed" { - serverErrorTwo <- true // after the connection times out the received channel is closed, so we're now testing that the close error is returned. - // This is the only error the received function returns. - break - } - } + for { - if err3 == nil { - if m.Status == "Connected" { - serverConnected <- true - } + m, _ := sc.Read() + if m.Status == "Connected" && connect == false { + hasConnected <- true + connect = true - if m.Status == "Timeout" { - serverTimout <- true // checks the connection times out - } - } } - }() - - <-connected - <-serverConnected + if m.Status == "Disconnected" { + hasDisconnected <- true + disconnect = true - cc.Close() + } - <-serverTimout - <-serverErrorTwo + if m.Status == "Connected" && connect == true && disconnect == true { + hasReconnected <- true + return + } + } } func TestClientReadClose(t *testing.T) { @@ -1530,7 +1431,7 @@ func TestClientReadClose(t *testing.T) { clientConnected <- true } - if m.Status == "Re-connecting" { + if m.Status == "Reconnecting" { reconnect = true } @@ -1551,58 +1452,68 @@ func TestClientReadClose(t *testing.T) { <-clientError } -/* - -func TestServerSendWrongVersionNumber(t *testing.T) { +func TestServerReceiveWrongVersionNumber(t *testing.T) { sc, err := StartServer("test5", nil) if err != nil { t.Error(err) } - time.Sleep(time.Second / 4) + go func() { - cc := &Client{ - Name: "", - status: NotConnected, - received: make(chan *Message), - encryptionReq: false, - } + cc := &Client{ + Name: "", + status: NotConnected, + received: make(chan *Message), + encryptionReq: false, + } + + time.Sleep(3 * time.Second) + + base := "/tmp/" + sock := ".sock" + conn, _ := net.Dial("unix", base+"test5"+sock) + + cc.conn = conn + + recv := make([]byte, 2) + _, err2 := cc.conn.Read(recv) + if err2 != nil { + return + } + + if recv[0] != 4 { + cc.handshakeSendReply(1) + return + } - go func() { - cc.Read() }() - go func() { - m := sc.Read() + for { + + m, err := sc.Read() + if err != nil { + fmt.Println(err) + return + } + if m.Err != nil { if m.Err.Error() != "client has a different version number" { t.Error("should have error because server sent the client the wrong version number 1") } } - }() - - base := "/tmp/" - sock := ".sock" - conn, _ := net.Dial("unix", base+"test5"+sock) - - cc.conn = conn - - recv := make([]byte, 2) - _, err2 := cc.conn.Read(recv) - if err2 != nil { - //return errors.New("failed to received handshake message") - } - - if recv[0] != 4 { - cc.handshakeSendReply(1) - //return errors.New("server has sent a different version number") } +} - time.Sleep(3 * time.Second) +func TestClientReceiveWrongVersionNumber(t *testing.T) { + //conn, err := s.listen.Accept() + //if err != nil { + // break + //} } +/* // This test will not pass on Windows unless the net.Listen part for unix sockets is replaced with winio.ListenPipe func TestServerWrongVersionNumber(t *testing.T) { diff --git a/server_all.go b/server_all.go index ad65dc4..4f7a61b 100755 --- a/server_all.go +++ b/server_all.go @@ -10,8 +10,7 @@ import ( // StartServer - starts the ipc server. // -// ipcName = is the name of the unix socket or named pipe that will be created. -// timeout = number of seconds before the socket/pipe times out waiting for a connection/re-cconnection - if -1 or 0 it never times out. +// ipcName - is the name of the unix socket or named pipe that will be created, the client needs to use the same name func StartServer(ipcName string, config *ServerConfig) (*Server, error) { err := checkIpcName(ipcName) @@ -34,12 +33,6 @@ func StartServer(ipcName string, config *ServerConfig) (*Server, error) { } else { - if config.Timeout < 0 { - s.timeout = 0 - } else { - s.timeout = config.Timeout - } - if config.MaxMsgSize < 1024 { s.maxMsgSize = maxMsgSize } else { @@ -59,44 +52,37 @@ func StartServer(ipcName string, config *ServerConfig) (*Server, error) { } } - go startServer(s) + err = s.run() return s, err } -func startServer(s *Server) { - - err := s.run() - if err != nil { - s.received <- &Message{err: err, MsgType: -2} - } -} - func (s *Server) acceptLoop() { + for { conn, err := s.listen.Accept() if err != nil { break } - if s.status == Listening || s.status == ReConnecting { + if s.status == Listening || s.status == Disconnected { s.conn = conn err2 := s.handshake() if err2 != nil { - s.received <- &Message{err: err2, MsgType: -2} + s.received <- &Message{Err: err2, MsgType: -1} s.status = Error s.listen.Close() s.conn.Close() } else { + go s.read() go s.write() s.status = Connected s.received <- &Message{Status: s.status.String(), MsgType: -1} - s.connChannel <- true } } @@ -105,31 +91,6 @@ func (s *Server) acceptLoop() { } -func (s *Server) connectionTimer() error { - - if s.timeout != 0 { - - timeout := make(chan bool) - - go func() { - time.Sleep(s.timeout * time.Second) - timeout <- true - }() - - select { - - case <-s.connChannel: - return nil - case <-timeout: - s.listen.Close() - return errors.New("timed out waiting for client to connect") - } - } - - <-s.connChannel - return nil -} - func (s *Server) read() { bLen := make([]byte, 4) @@ -138,6 +99,8 @@ func (s *Server) read() { res := s.readData(bLen) if !res { + s.conn.Close() + break } @@ -147,13 +110,15 @@ func (s *Server) read() { res = s.readData(msgRecvd) if !res { + s.conn.Close() + break } if s.encryption { msgFinal, err := decrypt(*s.enc.cipher, msgRecvd) if err != nil { - s.received <- &Message{err: err, MsgType: -2} + s.received <- &Message{Err: err, MsgType: -1} continue } @@ -172,6 +137,7 @@ func (s *Server) read() { } } + } func (s *Server) readData(buff []byte) bool { @@ -183,35 +149,24 @@ func (s *Server) readData(buff []byte) bool { s.status = Closed s.received <- &Message{Status: s.status.String(), MsgType: -1} - s.received <- &Message{err: errors.New("server has closed the connection"), MsgType: -2} + s.received <- &Message{Err: errors.New("server has closed the connection"), MsgType: -1} return false } - go s.reConnect() - return false + if err == io.EOF { + + s.status = Disconnected + s.received <- &Message{Status: s.status.String(), MsgType: -1} + return false + } } return true } -func (s *Server) reConnect() { - - s.status = ReConnecting - s.received <- &Message{Status: s.status.String(), MsgType: -1} - - err := s.connectionTimer() - if err != nil { - s.status = Timeout - s.received <- &Message{Status: s.status.String(), MsgType: -1} - - s.received <- &Message{err: err, MsgType: -2} - - } -} - -// Read - blocking function that waits until an non multipart message is received - +// Read - blocking function, reads each message recieved +// if MsgType is a negative number its an internal message func (s *Server) Read() (*Message, error) { m, ok := (<-s.received) @@ -219,16 +174,16 @@ func (s *Server) Read() (*Message, error) { return nil, errors.New("the received channel has been closed") } - if m.err != nil { - close(s.received) - close(s.toWrite) - return nil, m.err + if m.Err != nil { + //close(s.received) + //close(s.toWrite) + return nil, m.Err } return m, nil } -// Write - writes a non multipart message to the ipc connection. +// Write - writes a message to the ipc connection // msgType - denotes the type of data being sent. 0 is a reserved type for internal messages and errors. func (s *Server) Write(msgType int, message []byte) error { @@ -296,12 +251,14 @@ func (s *Server) write() { } } + // getStatus - get the current status of the connection func (s *Server) getStatus() Status { return s.status } + // StatusCode - returns the current connection status func (s *Server) StatusCode() Status { return s.status diff --git a/shared.go b/shared.go index 245661f..5681102 100755 --- a/shared.go +++ b/shared.go @@ -17,13 +17,15 @@ func (status *Status) String() string { case Closing: return "Closing" case ReConnecting: - return "Re-connecting" + return "Reconnecting" case Timeout: return "Timeout" case Closed: return "Closed" case Error: return "Error" + case Disconnected: + return "Disconnected" default: return "Status not found" } @@ -37,5 +39,4 @@ func checkIpcName(ipcName string) error { } return nil - } diff --git a/types.go b/types.go index 2fafa9f..1788b1b 100755 --- a/types.go +++ b/types.go @@ -8,18 +8,17 @@ import ( // Server - holds the details of the server connection & config. type Server struct { - name string - listen net.Listener - conn net.Conn - status Status - received chan (*Message) - connChannel chan bool - toWrite chan (*Message) - timeout time.Duration - encryption bool - maxMsgSize int - enc *encryption - unMask bool + name string + listen net.Listener + conn net.Conn + status Status + received chan (*Message) + toWrite chan (*Message) + timeout time.Duration + encryption bool + maxMsgSize int + enc *encryption + unMask bool } // Client - holds the details of the client connection and config. @@ -37,12 +36,12 @@ type Client struct { enc *encryption } -// Message - contains the received message +// Message - contains the received message type Message struct { - err error // details of any error - MsgType int // type of message sent - 0 is reserved + Err error // details of any error + MsgType int // 0 = reserved , -1 is an internal message (disconnection or error etc), all messages recieved will be > 0 Data []byte // message data received - Status string + Status string // the status of the connection } // Status - Status of the connection @@ -68,11 +67,12 @@ const ( Error Status = iota // Timeout - 8 Timeout Status = iota + // Disconnected - 9 + Disconnected Status = iota ) // ServerConfig - used to pass configuation overrides to ServerStart() type ServerConfig struct { - Timeout time.Duration MaxMsgSize int Encryption bool UnmaskPermissions bool