Skip to content

Commit

Permalink
feat: add interactive mode on ask sub-command
Browse files Browse the repository at this point in the history
  • Loading branch information
yiblet committed Feb 14, 2024
1 parent 496f001 commit 7e024f7
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 21 deletions.
109 changes: 91 additions & 18 deletions ask_cmd.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package main

import (
"bufio"
"context"
"fmt"
"io"
"os"
"os/signal"
"strings"

"github.com/PullRequestInc/go-gpt3"
Expand Down Expand Up @@ -33,6 +35,7 @@ type askCmd struct {
Bash bool `arg:"--bash" help:"output only valid bash"`
Model string `arg:"--model,-m" help:"set openai model"`
Attach []string `arg:"--attach,-a,separate" help:"attach additional files at the end of the message. pass '-' to pass in stdin"`
Once bool `arg:"--once,-o" help:"whether to just ask the model once"`
}

func (args *askCmd) buildContent(ctx context.Context) (string, error) {
Expand Down Expand Up @@ -87,6 +90,39 @@ func (args *askCmd) messages(content string) []gpt3.ChatCompletionRequestMessage

}

func (args *askCmd) poll(input *bufio.Reader) (string, bool, error) {
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, os.Interrupt)
lineCh := make(chan string)
defer close(lineCh)
errCh := make(chan error)
defer close(errCh)

go func() {
line, err := input.ReadString('\n')
if err != nil {
errCh <- err
} else {
lineCh <- line
}
}()

select {
case err := <-errCh:
return "", false, err
case <-sigCh:
return "", false, nil
case line := <-lineCh:
line = strings.TrimSpace(line)
if line == "" {
return "", false, nil
}
signal.Stop(sigCh)
close(sigCh)
return line, true, nil
}
}

func (args *askCmd) Execute(ctx context.Context, config *config) error {
model := args.Model
if model == "" {
Expand All @@ -99,25 +135,62 @@ func (args *askCmd) Execute(ctx context.Context, config *config) error {
if err != nil {
return fmt.Errorf("cannot build message: %w", err)
}
err = client.ChatCompletionStream(ctx, gpt3.ChatCompletionRequest{
Messages: args.messages(content),
MaxTokens: args.MaxTokens,
Temperature: &args.Temperature,
Stream: true,
Model: model,
}, func(cr *gpt3.ChatCompletionStreamResponse) error {
message := cr.Choices[0].Delta.Content
if message != "" {
lastMessage = message

input := bufio.NewReader(os.Stdin)

messages := args.messages(content)
for {
var response strings.Builder
out := io.MultiWriter(os.Stdout, &response)

err = client.ChatCompletionStream(ctx, gpt3.ChatCompletionRequest{
Messages: messages,
MaxTokens: args.MaxTokens,
Temperature: &args.Temperature,
Stream: true,
Model: model,
}, func(cr *gpt3.ChatCompletionStreamResponse) error {
message := cr.Choices[0].Delta.Content
if message != "" {
lastMessage = message
}
_, err := fmt.Fprintf(out, "%s", cr.Choices[0].Delta.Content)
return err
})
if err != nil {
return err
}
fmt.Printf("%s", cr.Choices[0].Delta.Content)
return nil
})
if err != nil {
return err
}
if len(lastMessage) == 0 || lastMessage[len(lastMessage)-1] != '\n' {
fmt.Printf("\n")
if len(lastMessage) == 0 || lastMessage[len(lastMessage)-1] != '\n' {
_, err := fmt.Fprintf(out, "\n")
if err != nil {
return err
}
}

if args.Once {
break
}

_, err := fmt.Printf("%shlp>%s ", colorGreen, colorReset)
if err != nil {
return err
}

line, cont, err := args.poll(input)
if err != nil {
return err
}

if !cont {
return nil
}

messages = append(
messages,
gpt3.ChatCompletionRequestMessage{Role: "assistant", Content: response.String()},
gpt3.ChatCompletionRequestMessage{Role: "user", Content: line},
)

}
return nil
}
12 changes: 12 additions & 0 deletions color.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package main

const (
colorReset = "\033[0m"
colorRed = "\033[31m"
colorGreen = "\033[32m"
colorYellow = "\033[33m"
colorBlue = "\033[34m"
colorPurple = "\033[35m"
colorCyan = "\033[36m"
colorWhite = "\033[37m"
)
14 changes: 11 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,22 @@ func (args *mainCmd) Execute(ctx context.Context) error {
return err
}

func main() {
func run() error {
var args mainCmd
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, time.Minute*5)
defer cancel()
arg.MustParse(&args)

if err := args.Execute(ctx); err != nil {
log.Panic(err)
return err
}
return nil
}

func main() {
if err := run(); err != nil {
log.Printf("error: %v", err)
os.Exit(1)
}
os.Exit(0)
}

0 comments on commit 7e024f7

Please sign in to comment.