diff --git a/cmd/invoke.go b/cmd/invoke.go index 2ac29b0..e48cc83 100644 --- a/cmd/invoke.go +++ b/cmd/invoke.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "os" "os/signal" "strings" @@ -32,6 +33,23 @@ var invocationHistoryCmd = &cobra.Command{ RunE: runInvocationHistory, } +var batchCmd = &cobra.Command{ + Use: "batch ", + Short: "Invoke an action multiple times with different payloads (batch)", + Long: `Invoke an action multiple times with different payloads from a file. + +The payloads file should contain one JSON object per line (newline-delimited JSON). + +Example payloads file: + {"url": "https://example.com/page1"} + {"url": "https://example.com/page2"} + {"url": "https://example.com/page3"} + +Or pipe payloads via stdin: + cat payloads.jsonl | kernel invoke batch my-app analyze -`, + RunE: runBatch, +} + func init() { invokeCmd.Flags().StringP("version", "v", "latest", "Specify a version of the app to invoke (optional, defaults to 'latest')") invokeCmd.Flags().StringP("payload", "p", "", "JSON payload for the invocation (optional)") @@ -40,7 +58,12 @@ func init() { invocationHistoryCmd.Flags().Int("limit", 100, "Max invocations to return (default 100)") invocationHistoryCmd.Flags().StringP("app", "a", "", "Filter by app name") invocationHistoryCmd.Flags().String("version", "", "Filter by invocation version") + + batchCmd.Flags().StringP("version", "v", "latest", "Specify a version of the app to invoke (optional, defaults to 'latest')") + batchCmd.Flags().IntP("max-concurrency", "c", 0, "Maximum number of concurrent invocations (defaults to org limit)") + invokeCmd.AddCommand(invocationHistoryCmd) + invokeCmd.AddCommand(batchCmd) } func runInvoke(cmd *cobra.Command, args []string) error { @@ -291,3 +314,136 @@ func runInvocationHistory(cmd *cobra.Command, args []string) error { } return nil } + +func runBatch(cmd *cobra.Command, args []string) error { + if len(args) != 3 { + return fmt.Errorf("requires exactly 3 arguments: ") + } + + startTime := time.Now() + client := getKernelClient(cmd) + appName := args[0] + actionName := args[1] + payloadsFile := args[2] + version, _ := cmd.Flags().GetString("version") + maxConcurrency, _ := cmd.Flags().GetInt("max-concurrency") + + if version == "" { + return fmt.Errorf("version cannot be an empty string") + } + + // Read payloads from file or stdin + var payloads []string + var file *os.File + var err error + + if payloadsFile == "-" { + file = os.Stdin + pterm.Info.Println("Reading payloads from stdin (one JSON object per line)...") + } else { + file, err = os.Open(payloadsFile) + if err != nil { + return fmt.Errorf("failed to open payloads file: %w", err) + } + defer file.Close() + pterm.Info.Printf("Reading payloads from %s...\n", payloadsFile) + } + + // Read newline-delimited JSON + scanner := json.NewDecoder(file) + lineNum := 0 + for { + lineNum++ + var payload interface{} + if err := scanner.Decode(&payload); err != nil { + if err == io.EOF { + break + } + return fmt.Errorf("invalid JSON on line %d: %w", lineNum, err) + } + + // Marshal back to string for API + payloadBytes, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload on line %d: %w", lineNum, err) + } + payloads = append(payloads, string(payloadBytes)) + } + + if len(payloads) == 0 { + return fmt.Errorf("no payloads provided") + } + + pterm.Info.Printf("Creating batch job with %d invocations...\n", len(payloads)) + + // Create batch job + params := kernel.InvocationNewBatchParams{ + AppName: appName, + ActionName: actionName, + Version: kernel.Opt(version), + Payloads: payloads, + } + if maxConcurrency > 0 { + params.MaxConcurrency = kernel.Opt(int64(maxConcurrency)) + } + + batchJob, err := client.Invocations.NewBatch(cmd.Context(), params, option.WithMaxRetries(0)) + if err != nil { + return handleSdkError(err) + } + + pterm.Success.Printf("Batch job created: %s\n", batchJob.BatchJobID) + pterm.Info.Printf("Total invocations: %d\n", batchJob.TotalCount) + + // Stream batch progress + pterm.Info.Println("Streaming progress...") + spinner, _ := pterm.DefaultSpinner.Start("Waiting for batch to complete...") + + stream := client.BatchJobs.StreamProgressStreaming(cmd.Context(), batchJob.BatchJobID) + succeeded := 0 + failed := 0 + var finalStatus string + + for stream.Next() { + ev := stream.Current() + + switch ev.Event { + case "batch_state": + stateEv := ev.AsBatchState() + finalStatus = stateEv.BatchJob.Status + succeeded = int(stateEv.BatchJob.SucceededCount) + failed = int(stateEv.BatchJob.FailedCount) + + switch finalStatus { + case "succeeded", "failed", "partially_failed": + spinner.Stop() + duration := time.Since(startTime) + + switch finalStatus { + case "succeeded": + pterm.Success.Printfln("✔ All invocations completed successfully in %s", duration.Round(time.Millisecond)) + case "failed": + pterm.Error.Printfln("✘ Batch failed - %d/%d invocations failed", failed, batchJob.TotalCount) + case "partially_failed": + pterm.Warning.Printfln("⚠ Batch partially failed - %d succeeded, %d failed (total: %d)", succeeded, failed, batchJob.TotalCount) + } + return nil + } + + case "batch_progress": + progressEv := ev.AsBatchProgress() + succeeded = int(progressEv.SucceededCount) + failed = int(progressEv.FailedCount) + total := batchJob.TotalCount + completed := succeeded + failed + spinner.UpdateText(fmt.Sprintf("Progress: %d/%d completed (%d succeeded, %d failed)", completed, total, succeeded, failed)) + } + } + + if serr := stream.Err(); serr != nil { + spinner.Stop() + return fmt.Errorf("stream error: %w", serr) + } + + return nil +} diff --git a/go.mod b/go.mod index c57bcea..eaf2738 100644 --- a/go.mod +++ b/go.mod @@ -58,3 +58,5 @@ require ( golang.org/x/text v0.24.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/onkernel/kernel-go-sdk => github.com/stainless-sdks/kernel-go v0.0.0-20251027151320-c81c686e75b7 diff --git a/go.sum b/go.sum index a2a65bb..a6cc81c 100644 --- a/go.sum +++ b/go.sum @@ -118,6 +118,8 @@ github.com/spf13/cobra v1.9.1 h1:CXSaggrXdbHK9CF+8ywj8Amf7PBRmPCOJugH954Nnlo= github.com/spf13/cobra v1.9.1/go.mod h1:nDyEzZ8ogv936Cinf6g1RU9MRY64Ir93oCnqb9wxYW0= github.com/spf13/pflag v1.0.6 h1:jFzHGLGAlb3ruxLB8MhbI6A8+AQX/2eW4qeyNZXNp2o= github.com/spf13/pflag v1.0.6/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stainless-sdks/kernel-go v0.0.0-20251027151320-c81c686e75b7 h1:ReyUyaBMdWOxJ6PDPIztEsIF32i58RGIk4T6T4+M3ZQ= +github.com/stainless-sdks/kernel-go v0.0.0-20251027151320-c81c686e75b7/go.mod h1:MjUR92i8UPqjrmneyVykae6GuB3GGSmnQtnjf1v74Dc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=