diff --git a/ask_cmd.go b/ask_cmd.go index 8292fe9..4413ecc 100644 --- a/ask_cmd.go +++ b/ask_cmd.go @@ -1,10 +1,12 @@ package main import ( + "bufio" "context" "fmt" "io" "os" + "os/signal" "strings" "github.com/PullRequestInc/go-gpt3" @@ -33,6 +35,7 @@ type askCmd struct { Bash bool `arg:"--bash" help:"output only valid bash"` Model string `arg:"--model,-m" help:"set openai model"` Attach []string `arg:"--attach,-a,separate" help:"attach additional files at the end of the message. pass '-' to pass in stdin"` + Once bool `arg:"--once,-o" help:"whether to just ask the model once"` } func (args *askCmd) buildContent(ctx context.Context) (string, error) { @@ -87,6 +90,39 @@ func (args *askCmd) messages(content string) []gpt3.ChatCompletionRequestMessage } +func (args *askCmd) poll(input *bufio.Reader) (string, bool, error) { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt) + lineCh := make(chan string) + defer close(lineCh) + errCh := make(chan error) + defer close(errCh) + + go func() { + line, err := input.ReadString('\n') + if err != nil { + errCh <- err + } else { + lineCh <- line + } + }() + + select { + case err := <-errCh: + return "", false, err + case <-sigCh: + return "", false, nil + case line := <-lineCh: + line = strings.TrimSpace(line) + if line == "" { + return "", false, nil + } + signal.Stop(sigCh) + close(sigCh) + return line, true, nil + } +} + func (args *askCmd) Execute(ctx context.Context, config *config) error { model := args.Model if model == "" { @@ -99,25 +135,62 @@ func (args *askCmd) Execute(ctx context.Context, config *config) error { if err != nil { return fmt.Errorf("cannot build message: %w", err) } - err = client.ChatCompletionStream(ctx, gpt3.ChatCompletionRequest{ - Messages: args.messages(content), - MaxTokens: args.MaxTokens, - Temperature: &args.Temperature, - Stream: true, - Model: model, - }, func(cr *gpt3.ChatCompletionStreamResponse) error { - message := cr.Choices[0].Delta.Content - if message != "" { - lastMessage = message + + input := bufio.NewReader(os.Stdin) + + messages := args.messages(content) + for { + var response strings.Builder + out := io.MultiWriter(os.Stdout, &response) + + err = client.ChatCompletionStream(ctx, gpt3.ChatCompletionRequest{ + Messages: messages, + MaxTokens: args.MaxTokens, + Temperature: &args.Temperature, + Stream: true, + Model: model, + }, func(cr *gpt3.ChatCompletionStreamResponse) error { + message := cr.Choices[0].Delta.Content + if message != "" { + lastMessage = message + } + _, err := fmt.Fprintf(out, "%s", cr.Choices[0].Delta.Content) + return err + }) + if err != nil { + return err } - fmt.Printf("%s", cr.Choices[0].Delta.Content) - return nil - }) - if err != nil { - return err - } - if len(lastMessage) == 0 || lastMessage[len(lastMessage)-1] != '\n' { - fmt.Printf("\n") + if len(lastMessage) == 0 || lastMessage[len(lastMessage)-1] != '\n' { + _, err := fmt.Fprintf(out, "\n") + if err != nil { + return err + } + } + + if args.Once { + break + } + + _, err := fmt.Printf("%shlp>%s ", colorGreen, colorReset) + if err != nil { + return err + } + + line, cont, err := args.poll(input) + if err != nil { + return err + } + + if !cont { + return nil + } + + messages = append( + messages, + gpt3.ChatCompletionRequestMessage{Role: "assistant", Content: response.String()}, + gpt3.ChatCompletionRequestMessage{Role: "user", Content: line}, + ) + } return nil } diff --git a/color.go b/color.go new file mode 100644 index 0000000..e6ea1b5 --- /dev/null +++ b/color.go @@ -0,0 +1,12 @@ +package main + +const ( + colorReset = "\033[0m" + colorRed = "\033[31m" + colorGreen = "\033[32m" + colorYellow = "\033[33m" + colorBlue = "\033[34m" + colorPurple = "\033[35m" + colorCyan = "\033[36m" + colorWhite = "\033[37m" +) diff --git a/main.go b/main.go index c71a2a4..1dd0e84 100644 --- a/main.go +++ b/main.go @@ -46,14 +46,22 @@ func (args *mainCmd) Execute(ctx context.Context) error { return err } -func main() { +func run() error { var args mainCmd ctx := context.Background() ctx, cancel := context.WithTimeout(ctx, time.Minute*5) defer cancel() arg.MustParse(&args) - if err := args.Execute(ctx); err != nil { - log.Panic(err) + return err + } + return nil +} + +func main() { + if err := run(); err != nil { + log.Printf("error: %v", err) + os.Exit(1) } + os.Exit(0) }