diff --git a/gptscript.go b/gptscript.go index 28a46a2..8ca9747 100644 --- a/gptscript.go +++ b/gptscript.go @@ -170,6 +170,11 @@ func (g *GPTScript) Run(ctx context.Context, toolPath string, opts Options) (*Ru }).NextChat(ctx, opts.Input) } +func (g *GPTScript) AbortRun(ctx context.Context, run *Run) error { + _, err := g.runBasicCommand(ctx, "abort/"+run.id, (map[string]any)(nil)) + return err +} + type ParseOptions struct { DisableCache bool } diff --git a/gptscript_test.go b/gptscript_test.go index 8f299c8..476492d 100644 --- a/gptscript_test.go +++ b/gptscript_test.go @@ -11,6 +11,7 @@ import ( "strconv" "strings" "testing" + "time" "github.com/getkin/kin-openapi/openapi3" "github.com/stretchr/testify/require" @@ -134,7 +135,7 @@ func TestListModelsWithDefaultProvider(t *testing.T) { } } -func TestAbortRun(t *testing.T) { +func TestCancelRun(t *testing.T) { tool := ToolDef{Instructions: "What is the capital of the united states?"} run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool) @@ -146,7 +147,7 @@ func TestAbortRun(t *testing.T) { <-run.Events() if err := run.Close(); err != nil { - t.Errorf("Error aborting run: %v", err) + t.Errorf("Error canceling run: %v", err) } if run.State() != Error { @@ -158,6 +159,77 @@ func TestAbortRun(t *testing.T) { } } +func TestAbortChatCompletionRun(t *testing.T) { + tool := ToolDef{Instructions: "What is the capital of the united states?"} + + run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool) + if err != nil { + t.Errorf("Error executing tool: %v", err) + } + + // Abort the run after the first event from the LLM + for e := range run.Events() { + if e.Call != nil && e.Call.Type == EventTypeCallProgress && len(e.Call.Output) > 0 && e.Call.Output[0].Content != "Waiting for model response..." { + break + } + } + + if err := g.AbortRun(context.Background(), run); err != nil { + t.Errorf("Error aborting run: %v", err) + } + + // Wait for run to stop + for range run.Events() { + continue + } + + if run.State() != Finished { + t.Errorf("Unexpected run state: %s", run.State()) + } + + if out, err := run.Text(); err != nil { + t.Errorf("Error reading output: %v", err) + } else if strings.TrimSpace(out) != "ABORTED BY USER" && !strings.HasSuffix(out, "\nABORTED BY USER") { + t.Errorf("Unexpected output: %s", out) + } +} + +func TestAbortCommandRun(t *testing.T) { + tool := ToolDef{Instructions: "#!/usr/bin/env bash\necho Hello, world!\nsleep 5\necho Hello, again!\nsleep 5"} + + run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool) + if err != nil { + t.Errorf("Error executing tool: %v", err) + } + + // Abort the run after the first event. + for e := range run.Events() { + if e.Call != nil && e.Call.Type == EventTypeChat { + time.Sleep(2 * time.Second) + break + } + } + + if err := g.AbortRun(context.Background(), run); err != nil { + t.Errorf("Error aborting run: %v", err) + } + + // Wait for run to stop + for range run.Events() { + continue + } + + if run.State() != Finished { + t.Errorf("Unexpected run state: %s", run.State()) + } + + if out, err := run.Text(); err != nil { + t.Errorf("Error reading output: %v", err) + } else if !strings.Contains(out, "Hello, world!") || strings.Contains(out, "Hello, again!") || !strings.HasSuffix(out, "\nABORTED BY USER") { + t.Errorf("Unexpected output: %s", out) + } +} + func TestSimpleEvaluate(t *testing.T) { tool := ToolDef{Instructions: "What is the capital of the united states?"} @@ -844,6 +916,69 @@ func TestToolChat(t *testing.T) { } } +func TestAbortChat(t *testing.T) { + tool := ToolDef{ + Chat: true, + Instructions: "You are a chat bot. Don't finish the conversation until I say 'bye'.", + Tools: []string{"sys.chat.finish"}, + } + + run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool) + if err != nil { + t.Fatalf("Error executing tool: %v", err) + } + inputs := []string{ + "Tell me a joke.", + "What was my first message?", + } + + // Just wait for the chat to start up. + for range run.Events() { + continue + } + + for i, input := range inputs { + run, err = run.NextChat(context.Background(), input) + if err != nil { + t.Fatalf("Error sending next input %q: %v", input, err) + } + + // Abort the run after the first event from the LLM + for e := range run.Events() { + if e.Call != nil && e.Call.Type == EventTypeCallProgress && len(e.Call.Output) > 0 && e.Call.Output[0].Content != "Waiting for model response..." { + break + } + } + + if i == 0 { + if err := g.AbortRun(context.Background(), run); err != nil { + t.Fatalf("Error aborting run: %v", err) + } + } + + // Wait for the run to complete + for range run.Events() { + continue + } + + out, err := run.Text() + if err != nil { + t.Errorf("Error reading output: %s", run.ErrorOutput()) + t.Fatalf("Error reading output: %v", err) + } + + if i == 0 { + if strings.TrimSpace(out) != "ABORTED BY USER" && !strings.HasSuffix(out, "\nABORTED BY USER") { + t.Fatalf("Unexpected output: %s", out) + } + } else { + if !strings.Contains(out, "Tell me a joke") { + t.Errorf("Unexpected output: %s", out) + } + } + } +} + func TestFileChat(t *testing.T) { wd, err := os.Getwd() if err != nil { diff --git a/run.go b/run.go index 52ad6e4..558b388 100644 --- a/run.go +++ b/run.go @@ -37,6 +37,7 @@ type Run struct { basicCommand bool program *Program + id string callsLock sync.RWMutex calls CallFrames rawOutput map[string]any @@ -400,6 +401,7 @@ func (r *Run) request(ctx context.Context, payload any) (err error) { if event.Run.Type == EventTypeRunStart { r.callsLock.Lock() r.program = &event.Run.Program + r.id = event.Run.ID r.callsLock.Unlock() } else if event.Run.Type == EventTypeRunFinish && event.Run.Error != "" { r.state = Error