From a0550b1d79bbc57f543c2ee6d82d71f410b9dcc5 Mon Sep 17 00:00:00 2001 From: eternal-flame-AD Date: Sun, 4 Aug 2024 16:41:02 -0500 Subject: [PATCH] feat: Automatically split STDIN on null characters on push Signed-off-by: eternal-flame-AD --- command/initialize.go | 6 +-- command/push.go | 86 ++++++++++++++++++++---------------------- command/read.go | 61 ++++++++++++++++++++++++++++++ command/read_test.go | 55 +++++++++++++++++++++++++++ command/watch.go | 8 ++-- utils/readfromstdin.go | 25 ++++++------ 6 files changed, 176 insertions(+), 65 deletions(-) create mode 100644 command/read.go create mode 100644 command/read_test.go diff --git a/command/initialize.go b/command/initialize.go index eca19f8..0c3bff2 100644 --- a/command/initialize.go +++ b/command/initialize.go @@ -71,7 +71,7 @@ func inputConfigLocation() string { for { fmt.Println("Where to put the config file?") for i, location := range locations { - fmt.Println(fmt.Sprintf("%d. %s", i+1, location)) + fmt.Printf("%d. %s\n", i+1, location) } value := inputString("Enter a number: ") hr() @@ -215,9 +215,9 @@ func inputDefaultPriority() int { erred("Priority needs to be a number between 0 and 10.") continue } else { + hr() return defaultPriority } - hr() } } @@ -251,7 +251,7 @@ func inputServerURL() *url.URL { }) if err == nil { info := version.(models.VersionInfo) - fmt.Println(fmt.Sprintf("Gotify v%s@%s", info.Version, info.BuildDate)) + fmt.Printf("Gotify v%s@%s\n", info.Version, info.BuildDate) return parsedURL } hr() diff --git a/command/push.go b/command/push.go index 22a9296..ddcdad4 100644 --- a/command/push.go +++ b/command/push.go @@ -4,7 +4,6 @@ import ( "fmt" "net/url" "os" - "strings" "github.com/gotify/cli/v2/config" "github.com/gotify/cli/v2/utils" @@ -31,6 +30,7 @@ func Push() cli.Command { cli.StringFlag{Name: "contentType", Usage: "The content type of the message. See https://gotify.net/docs/msgextras#client-display"}, cli.StringFlag{Name: "clickUrl", Usage: "An URL to open upon clicking the notification. See https://gotify.net/docs/msgextras#client-notification"}, cli.BoolFlag{Name: "disable-unescape-backslash", Usage: "Disable evaluating \\n and \\t (if set, \\n and \\t will be seen as a string)"}, + cli.BoolFlag{Name: "no-split", Usage: "Do not split the message on null character when reading from stdin"}, }, Action: doPush, } @@ -39,9 +39,12 @@ func Push() cli.Command { func doPush(ctx *cli.Context) { conf, confErr := config.ReadConfig(config.GetLocations()) - msgText := readMessage(ctx) - if !ctx.Bool("disable-unescape-backslash") { - msgText = utils.Evaluate(msgText) + msgText := make(chan string) + null := '\x00' + if ctx.Bool("no-split") { + go readMessage(ctx.Args(), os.Stdin, msgText, nil) + } else { + go readMessage(ctx.Args(), os.Stdin, msgText, &null) } priority := ctx.Int("priority") @@ -72,36 +75,47 @@ func doPush(ctx *cli.Context) { priority = conf.DefaultPriority } - msg := models.MessageExternal{ - Message: msgText, - Title: title, - Priority: priority, + parsedURL, err := url.Parse(stringURL) + if err != nil { + utils.Exit1With("invalid url", stringURL) + return } - msg.Extras = map[string]interface{}{ - } + var sent bool + for msgText := range msgText { + if !ctx.Bool("disable-unescape-backslash") { + msgText = utils.Evaluate(msgText) + } - if contentType != "" { - msg.Extras["client::display"] = map[string]interface{}{ - "contentType": contentType, + msg := models.MessageExternal{ + Message: msgText, + Title: title, + Priority: priority, } - } - if clickUrl != "" { - msg.Extras["client::notification"] = map[string]interface{}{ - "click": map[string]string{ - "url": clickUrl, - }, + msg.Extras = map[string]interface{}{} + + if contentType != "" { + msg.Extras["client::display"] = map[string]interface{}{ + "contentType": contentType, + } } - } - parsedURL, err := url.Parse(stringURL) - if err != nil { - utils.Exit1With("invalid url", stringURL) - return - } + if clickUrl != "" { + msg.Extras["client::notification"] = map[string]interface{}{ + "click": map[string]string{ + "url": clickUrl, + }, + } + } + + pushMessage(parsedURL, token, msg, quiet) - pushMessage(parsedURL, token, msg, quiet) + sent = true + } + if !sent { + utils.Exit1With("no message sent! a message must be set, either as argument or via stdin") + } } func pushMessage(parsedURL *url.URL, token string, msg models.MessageExternal, quiet bool) { @@ -119,23 +133,3 @@ func pushMessage(parsedURL *url.URL, token string, msg models.MessageExternal, q utils.Exit1With(err) } } - -func readMessage(ctx *cli.Context) string { - msgArgs := strings.Join(ctx.Args(), " ") - - msgStdin := utils.ReadFrom(os.Stdin) - - if msgArgs == "" && msgStdin == "" { - utils.Exit1With("a message must be set, either as argument or via stdin") - } - - if msgArgs != "" && msgStdin != "" { - utils.Exit1With("a message is set via stdin and arguments, use only one of them") - } - - if msgArgs == "" { - return msgStdin - } else { - return msgArgs - } -} diff --git a/command/read.go b/command/read.go new file mode 100644 index 0000000..c5b8ca4 --- /dev/null +++ b/command/read.go @@ -0,0 +1,61 @@ +package command + +import ( + "io" + "strings" + + "github.com/gotify/cli/v2/utils" +) + +func readMessage(args []string, r io.Reader, output chan<- string, split *rune) { + msgArgs := strings.Join(args, " ") + + if msgArgs != "" { + if utils.ProbeStdin(r) { + utils.Exit1With("message is set via arguments and stdin, use only one of them") + } + + output <- msgArgs + close(output) + return + } + + var buf strings.Builder + for { + var tmp [256]byte + n, err := r.Read(tmp[:]) + if err != nil { + if err.Error() == "EOF" { + break + } + utils.Exit1With(err) + } + tmpStr := string(tmp[:n]) + if split != nil { + // split the message on the null character + parts := strings.Split(tmpStr, string(*split)) + if len(parts) == 1 { + buf.WriteString(parts[0]) + continue + } + + previous := buf.String() + // fuse previous with parts[0], send parts[1] .. parts[n-2] and set parts[n-1] as new previous + firstMsg := previous + parts[0] + output <- firstMsg + for _, part := range parts[1 : len(parts)-1] { + output <- part + } + buf.Reset() + buf.WriteString(parts[len(parts)-1]) + } else { + buf.WriteString(tmpStr) + } + } + + if buf.Len() > 0 { + output <- buf.String() + } + + close(output) +} diff --git a/command/read_test.go b/command/read_test.go new file mode 100644 index 0000000..7023488 --- /dev/null +++ b/command/read_test.go @@ -0,0 +1,55 @@ +package command + +import ( + "strings" + "testing" +) + +// Polyfill for slices.Equal for Go 1.20 +func slicesEqual[T comparable](a, b []T) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +func readChanAll[T any](c chan T) []T { + var res []T + for s := range c { + res = append(res, s) + } + return res +} + +func TestReadMessage(t *testing.T) { + var split rune = '\x00' + + // Test case 1: message set via arguments + output := make(chan string) + go readMessage([]string{"Hello", "World"}, nil, output, nil) + + if res := readChanAll(output); !(slicesEqual(res, []string{"Hello World"})) { + t.Errorf("Expected %v, but got %v", []string{"Hello World"}, res) + } + + // Test case 2: message set via arguments should not split on 'split' character + output = make(chan string) + go readMessage([]string{"Hello\x00World"}, nil, output, &split) + + if res := readChanAll(output); !(slicesEqual(res, []string{"Hello\x00World"})) { + t.Errorf("Expected %v, but got %v", []string{"Hello\x00World"}, res) + } + + // Test case 3: message set via stdin + output = make(chan string) + go readMessage([]string{}, strings.NewReader("Hello\x00World"), output, &split) + + if res := readChanAll(output); !(slicesEqual(res, []string{"Hello", "World"})) { + t.Errorf("Expected %v, but got %v", []string{"Hello", "World"}, res) + } +} diff --git a/command/watch.go b/command/watch.go index f0ee519..e098233 100644 --- a/command/watch.go +++ b/command/watch.go @@ -120,18 +120,18 @@ func doWatch(ctx *cli.Context) { case "long": fmt.Fprintf(msgData, "command output for \"%s\" changed:\n\n", cmdStringNotation) fmt.Fprintln(msgData, "== BEGIN OLD OUTPUT ==") - fmt.Fprint(msgData, lastOutput) + fmt.Fprintln(msgData, lastOutput) fmt.Fprintln(msgData, "== END OLD OUTPUT ==") fmt.Fprintln(msgData, "== BEGIN NEW OUTPUT ==") - fmt.Fprint(msgData, output) + fmt.Fprintln(msgData, output) fmt.Fprintln(msgData, "== END NEW OUTPUT ==") case "default": fmt.Fprintf(msgData, "command output for \"%s\" changed:\n\n", cmdStringNotation) fmt.Fprintln(msgData, "== BEGIN NEW OUTPUT ==") - fmt.Fprint(msgData, output) + fmt.Fprintln(msgData, output) fmt.Fprintln(msgData, "== END NEW OUTPUT ==") case "short": - fmt.Fprintf(msgData, output) + fmt.Fprintln(msgData, output) } msgString := msgData.String() diff --git a/utils/readfromstdin.go b/utils/readfromstdin.go index 53c9b9a..e27008e 100644 --- a/utils/readfromstdin.go +++ b/utils/readfromstdin.go @@ -1,22 +1,23 @@ package utils import ( + "io" "os" - "io/ioutil" ) -func ReadFrom(file *os.File) string { - fi, err := os.Stdin.Stat() - if err != nil { - return "" +func ProbeStdin(file io.Reader) bool { + if file == nil { + return false } - if fi.Mode()&os.ModeNamedPipe == 0 && !fi.Mode().IsRegular() { - return "" + if file, ok := file.(*os.File); ok { + fi, err := file.Stat() + if err != nil { + return false + } + if fi.Mode()&os.ModeNamedPipe == 0 && !fi.Mode().IsRegular() { + return false + } } - bytes, err := ioutil.ReadAll(file) - if err != nil { - return "" - } - return string(bytes) + return true }