From 5b3b7a35374e483b77cbc085f2724f8a32340f7a Mon Sep 17 00:00:00 2001 From: james-barrow <56474510+james-barrow@users.noreply.github.com> Date: Sun, 14 May 2023 21:13:42 +0100 Subject: [PATCH] bug fixes --- client_all.go | 93 +++++++++++++++++++++++------------------------- connect_other.go | 10 +++--- encryption.go | 4 +-- go.mod | 5 ++- go.sum | 60 +++++++++++++++++++++++++------ handshake.go | 12 +++---- server_all.go | 46 +++++++++++++----------- 7 files changed, 138 insertions(+), 92 deletions(-) diff --git a/client_all.go b/client_all.go index 66313ab..d76c147 100755 --- a/client_all.go +++ b/client_all.go @@ -3,6 +3,7 @@ package ipc import ( "bufio" "errors" + "io" "strings" "time" ) @@ -12,7 +13,6 @@ import ( // ipcName = is the name of the unix socket or named pipe that the client will try and connect to. // timeout = number of seconds before the socket/pipe times out trying to connect/re-cconnect - if -1 or 0 it never times out. // retryTimer = number of seconds before the client tries to connect again. -// func StartClient(ipcName string, config *ClientConfig) (*Client, error) { err := checkIpcName(ipcName) @@ -48,7 +48,7 @@ func StartClient(ipcName string, config *ClientConfig) (*Client, error) { cc.retryTimer = time.Duration(config.RetryTimer) } - if config.Encryption == false { + if !config.Encryption { cc.encryptionReq = false } else { cc.encryptionReq = true // defualt is to always enforce encryption @@ -86,7 +86,7 @@ func (cc *Client) read() { for { res := cc.readData(bLen) - if res == false { + if !res { break } @@ -95,11 +95,11 @@ func (cc *Client) read() { msgRecvd := make([]byte, mLen) res = cc.readData(msgRecvd) - if res == false { + if !res { break } - if cc.encryption == true { + if cc.encryption { msgFinal, err := decrypt(*cc.enc.cipher, msgRecvd) if err != nil { break @@ -124,7 +124,8 @@ func (cc *Client) read() { func (cc *Client) readData(buff []byte) bool { - _, err := cc.conn.Read(buff) + _, err := io.ReadFull(cc.conn, buff) + //_, err := cc.conn.Read(buff) if err != nil { if strings.Contains(err.Error(), "EOF") { // the connection has been closed by the client. cc.conn.Close() @@ -151,90 +152,85 @@ func (cc *Client) readData(buff []byte) bool { } -func (cc *Client) reconnect() { +func (c *Client) reconnect() { - cc.status = ReConnecting - cc.recieved <- &Message{Status: cc.status.String(), MsgType: -1} + c.status = ReConnecting + c.recieved <- &Message{Status: c.status.String(), MsgType: -1} - err := cc.dial() // connect to the pipe + err := c.dial() // connect to the pipe if err != nil { if err.Error() == "Timed out trying to connect" { - cc.status = Timeout - cc.recieved <- &Message{Status: cc.status.String(), MsgType: -1} - cc.recieved <- &Message{err: errors.New("Timed out trying to re-connect"), MsgType: -2} + c.status = Timeout + c.recieved <- &Message{Status: c.status.String(), MsgType: -1} + c.recieved <- &Message{err: errors.New("timed out trying to re-connect"), MsgType: -2} } return } - cc.status = Connected - cc.recieved <- &Message{Status: cc.status.String(), MsgType: -1} - - go cc.read() + c.status = Connected + c.recieved <- &Message{Status: c.status.String(), MsgType: -1} + go c.read() } // Read - blocking function that waits until an non multipart message is recieved // returns the message type, data and any error. -// -func (cc *Client) Read() (*Message, error) { +func (c *Client) Read() (*Message, error) { - m, ok := (<-cc.recieved) - if ok == false { + m, ok := (<-c.recieved) + if !ok { return nil, errors.New("the recieve channel has been closed") } if m.err != nil { - close(cc.recieved) - close(cc.toWrite) + close(c.recieved) + close(c.toWrite) return nil, m.err } return m, nil - } // Write - writes a non multipart message to the ipc connection. // msgType - denotes the type of data being sent. 0 is a reserved type for internal messages and errors. -// -func (cc *Client) Write(msgType int, message []byte) error { +func (c *Client) Write(msgType int, message []byte) error { if msgType == 0 { return errors.New("Message type 0 is reserved") } - if cc.status != Connected { - return errors.New(cc.status.String()) + if c.status != Connected { + return errors.New(c.status.String()) } mlen := len(message) - if mlen > cc.maxMsgSize { + if mlen > c.maxMsgSize { return errors.New("Message exceeds maximum message length") } - cc.toWrite <- &Message{MsgType: msgType, Data: message} + c.toWrite <- &Message{MsgType: msgType, Data: message} return nil - } -func (cc *Client) write() { +func (c *Client) write() { for { - m, ok := <-cc.toWrite + m, ok := <-c.toWrite - if ok == false { + if !ok{ break } toSend := intToBytes(m.MsgType) - writer := bufio.NewWriter(cc.conn) + writer := bufio.NewWriter(c.conn) - if cc.encryption == true { + if c.encryption { toSend = append(toSend, m.Data...) - toSendEnc, err := encrypt(*cc.enc.cipher, toSend) + toSendEnc, err := encrypt(*c.enc.cipher, toSend) if err != nil { //return err } @@ -257,27 +253,28 @@ func (cc *Client) write() { } // getStatus - get the current status of the connection -func (cc *Client) getStatus() Status { - - return cc.status +func (c *Client) getStatus() Status { + return c.status } // StatusCode - returns the current connection status -func (cc *Client) StatusCode() Status { - return cc.status +func (c *Client) StatusCode() Status { + return c.status } // Status - returns the current connection status as a string -func (cc *Client) Status() string { - - return cc.status.String() +func (c *Client) Status() string { + return c.status.String() } // Close - closes the connection -func (cc *Client) Close() { +func (c *Client) Close() { + + c.status = Closing - cc.status = Closing - cc.conn.Close() + if c.conn != nil { + c.conn.Close() + } } diff --git a/connect_other.go b/connect_other.go index 949ecc9..eff12a9 100755 --- a/connect_other.go +++ b/connect_other.go @@ -22,13 +22,13 @@ func (sc *Server) run() error { } var oldUmask int - if sc.unMask == true { + if sc.unMask { oldUmask = syscall.Umask(0) } listen, err := net.Listen("unix", base+sc.name+sock) - if sc.unMask == true { + if sc.unMask { syscall.Umask(oldUmask) } @@ -65,16 +65,16 @@ func (cc *Client) dial() error { if cc.timeout != 0 { if time.Now().Sub(startTime).Seconds() > cc.timeout { cc.status = Closed - return errors.New("Timed out trying to connect") + return errors.New("timed out trying to connect") } } conn, err := net.Dial("unix", base+cc.Name+sock) if err != nil { - if strings.Contains(err.Error(), "connect: no such file or directory") == true { + if strings.Contains(err.Error(), "connect: no such file or directory") { - } else if strings.Contains(err.Error(), "connect: connection refused") == true { + } else if strings.Contains(err.Error(), "connect: connection refused") { } else { cc.recieved <- &Message{err: err, MsgType: -2} diff --git a/encryption.go b/encryption.go index 004a04a..705b50b 100755 --- a/encryption.go +++ b/encryption.go @@ -78,7 +78,7 @@ func generateKeys() (*ecdsa.PrivateKey, *ecdsa.PublicKey, error) { puba := &priva.PublicKey - if priva.IsOnCurve(puba.X, puba.Y) == false { + if !priva.IsOnCurve(puba.X, puba.Y) { return nil, nil, errors.New("keys created arn't on curve") } @@ -115,7 +115,7 @@ func recvPublic(conn net.Conn) (*ecdsa.PublicKey, error) { recvdPub := bytesToPublicKey(buff[:i]) - if recvdPub.IsOnCurve(recvdPub.X, recvdPub.Y) == false { + if !recvdPub.IsOnCurve(recvdPub.X, recvdPub.Y) { return nil, errors.New("didn't recieve valid public key") } diff --git a/go.mod b/go.mod index c9465da..50fda1b 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module github.com/james-barrow/golang-ipc go 1.15 -require github.com/Microsoft/go-winio v0.4.16 +require ( + github.com/Microsoft/go-winio v0.6.1 + golang.org/x/tools v0.9.1 // indirect +) diff --git a/go.sum b/go.sum index 284f8a4..cb8152b 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,52 @@ -github.com/Microsoft/go-winio v0.4.16 h1:FtSW/jqD+l4ba5iPBj9CODVtgfYAD8w2wS923g/cFDk= -github.com/Microsoft/go-winio v0.4.16/go.mod h1:XB6nPKklQyQ7GC9LdcBEcBl8PF76WugXOPRXwdLnMv0= +github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= +github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3 h1:7TYNF4UdlohbFwpNH04CoPMp1cHUZgO1Ebq5r2hIjfo= -golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= +golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= +golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= +golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/handshake.go b/handshake.go index 15112d4..f4f9ec6 100755 --- a/handshake.go +++ b/handshake.go @@ -16,7 +16,7 @@ func (sc *Server) handshake() error { return err } - if sc.encryption == true { + if sc.encryption { err = sc.startEncryption() if err != nil { return err @@ -38,7 +38,7 @@ func (sc *Server) one() error { buff[0] = byte(version) - if sc.encryption == true { + if sc.encryption { buff[1] = byte(1) } else { buff[1] = byte(0) @@ -100,7 +100,7 @@ func (sc *Server) msgLength() error { buff := make([]byte, 4) binary.BigEndian.PutUint32(buff, uint32(sc.maxMsgSize)) - if sc.encryption == true { + if sc.encryption { maxMsg, err := encrypt(*sc.enc.cipher, buff) if err != nil { return err @@ -139,7 +139,7 @@ func (cc *Client) handshake() error { return err } - if cc.encryption == true { + if cc.encryption { err := cc.startEncryption() if err != nil { return err @@ -168,7 +168,7 @@ func (cc *Client) one() error { return errors.New("server has sent a different version number") } - if recv[1] != 1 && cc.encryptionReq == true { + if recv[1] != 1 && cc.encryptionReq { cc.handshakeSendReply(2) return errors.New("server tried to connect without encryption") } @@ -225,7 +225,7 @@ func (cc *Client) msgLength() error { return errors.New("failed to recieve max message length 2") } var buff2 []byte - if cc.encryption == true { + if cc.encryption { buff2, err = decrypt(*cc.enc.cipher, buff) if err != nil { return errors.New("failed to recieve max message length 3") diff --git a/server_all.go b/server_all.go index c0b13a2..3690690 100755 --- a/server_all.go +++ b/server_all.go @@ -3,6 +3,7 @@ package ipc import ( "bufio" "errors" + "io" "time" ) @@ -10,7 +11,6 @@ import ( // // ipcName = is the name of the unix socket or named pipe that will be created. // timeout = number of seconds before the socket/pipe times out waiting for a connection/re-cconnection - if -1 or 0 it never times out. -// func StartServer(ipcName string, config *ServerConfig) (*Server, error) { err := checkIpcName(ipcName) @@ -45,13 +45,13 @@ func StartServer(ipcName string, config *ServerConfig) (*Server, error) { sc.maxMsgSize = config.MaxMsgSize } - if config.Encryption == false { + if !config.Encryption { sc.encryption = false } else { sc.encryption = true } - if config.UnmaskPermissions == true { + if config.UnmaskPermissions { sc.unMask = true } else { sc.unMask = false @@ -121,15 +121,18 @@ func (sc *Server) connectionTimer() error { return nil case <-timeout: sc.listen.Close() - return errors.New("Timed out waiting for client to connect") + return errors.New("timed out waiting for client to connect") } } - select { + //select { - case <-sc.connChannel: - return nil - } + //case <-sc.connChannel: + // return nil + //} + + <-sc.connChannel + return nil } @@ -140,7 +143,7 @@ func (sc *Server) read() { for { res := sc.readData(bLen) - if res == false { + if !res { break } @@ -149,11 +152,11 @@ func (sc *Server) read() { msgRecvd := make([]byte, mLen) res = sc.readData(msgRecvd) - if res == false { + if !res { break } - if sc.encryption == true { + if sc.encryption { msgFinal, err := decrypt(*sc.enc.cipher, msgRecvd) if err != nil { sc.recieved <- &Message{err: err, MsgType: -2} @@ -179,7 +182,8 @@ func (sc *Server) read() { func (sc *Server) readData(buff []byte) bool { - _, err := sc.conn.Read(buff) + _, err := io.ReadFull(sc.conn, buff) + //_, err := sc.conn.Read(buff) if err != nil { if sc.status == Closing { @@ -220,7 +224,7 @@ func (sc *Server) reConnect() { func (sc *Server) Read() (*Message, error) { m, ok := (<-sc.recieved) - if ok == false { + if !ok { return nil, errors.New("the recieve channel has been closed") } @@ -236,7 +240,6 @@ func (sc *Server) Read() (*Message, error) { // Write - writes a non multipart message to the ipc connection. // msgType - denotes the type of data being sent. 0 is a reserved type for internal messages and errors. -// func (sc *Server) Write(msgType int, message []byte) error { if msgType == 0 { @@ -267,7 +270,7 @@ func (sc *Server) write() { m, ok := <-sc.toWrite - if ok == false { + if !ok { break } @@ -275,7 +278,7 @@ func (sc *Server) write() { writer := bufio.NewWriter(sc.conn) - if sc.encryption == true { + if sc.encryption { toSend = append(toSend, m.Data...) toSendEnc, err := encrypt(*sc.enc.cipher, toSend) if err != nil { @@ -306,7 +309,6 @@ func (sc *Server) write() { func (sc *Server) getStatus() Status { return sc.status - } // StatusCode - returns the current connection status @@ -318,14 +320,18 @@ func (sc *Server) StatusCode() Status { func (sc *Server) Status() string { return sc.status.String() - } // Close - closes the connection func (sc *Server) Close() { sc.status = Closing - sc.listen.Close() - sc.conn.Close() + if sc.listen != nil { + sc.listen.Close() + } + + if sc.conn != nil { + sc.conn.Close() + } }