diff --git a/pkg/cmd/flagoptions.go b/pkg/cmd/flagoptions.go index d257e38..4b4afa6 100644 --- a/pkg/cmd/flagoptions.go +++ b/pkg/cmd/flagoptions.go @@ -1,7 +1,6 @@ package cmd import ( - "bytes" "encoding/base64" "encoding/json" "fmt" @@ -492,22 +491,18 @@ func flagOptions( case EmptyBody: break case MultipartFormEncoded: - buf := new(bytes.Buffer) - writer := multipart.NewWriter(buf) - // For multipart/form-encoded, we need a map structure bodyMap, ok := requestContents.Body.(map[string]any) if !ok { return nil, fmt.Errorf("Cannot send a non-map value to a form-encoded endpoint: %v\n", requestContents.Body) } encodingFormat := apiform.FormatBrackets - if err := apiform.MarshalWithSettings(bodyMap, writer, encodingFormat); err != nil { - return nil, err - } - if err := writer.Close(); err != nil { - return nil, err - } - options = append(options, option.WithRequestBody(writer.FormDataContentType(), buf)) + contentType, body := newMultipartRequestBody(bodyMap, encodingFormat) + options = append(options, + option.WithRequestBody(contentType, body), + // Streaming request bodies cannot be replayed safely by the SDK retry loop. + option.WithMaxRetries(0), + ) case ApplicationJSON: bodyBytes, err := json.Marshal(requestContents.Body) @@ -538,6 +533,69 @@ func flagOptions( return options, nil } +func newMultipartRequestBody(bodyMap map[string]any, encodingFormat apiform.FormFormat) (string, io.Reader) { + reader, writer := io.Pipe() + multipartWriter := multipart.NewWriter(writer) + contentType := multipartWriter.FormDataContentType() + + go func() { + err := apiform.MarshalWithSettings(bodyMap, multipartWriter, encodingFormat) + if closeErr := multipartWriter.Close(); err == nil { + err = closeErr + } + if closeErr := closeFileUploads(bodyMap); err == nil { + err = closeErr + } + + if err != nil { + _ = writer.CloseWithError(err) + return + } + _ = writer.Close() + }() + + return contentType, reader +} + +func closeFileUploads(value any) error { + return closeFileUploadsValue(reflect.ValueOf(value)) +} + +func closeFileUploadsValue(v reflect.Value) error { + if !v.IsValid() { + return nil + } + + if v.Kind() == reflect.Interface || v.Kind() == reflect.Pointer { + if v.IsNil() { + return nil + } + return closeFileUploadsValue(v.Elem()) + } + + if upload, ok := v.Interface().(fileUpload); ok { + return upload.Close() + } + + switch v.Kind() { + case reflect.Map: + iter := v.MapRange() + for iter.Next() { + if err := closeFileUploadsValue(iter.Value()); err != nil { + return err + } + } + case reflect.Slice, reflect.Array: + for i := 0; i < v.Len(); i++ { + if err := closeFileUploadsValue(v.Index(i)); err != nil { + return err + } + } + } + + return nil +} + // FilePathValue is a string wrapper that marks a value as a file path whose contents should be read // and embedded in the request. Unlike a regular string, embedFilesValue always treats a FilePathValue // as a file path without needing the "@" prefix. diff --git a/pkg/cmd/flagoptions_test.go b/pkg/cmd/flagoptions_test.go index 00734ca..bca3386 100644 --- a/pkg/cmd/flagoptions_test.go +++ b/pkg/cmd/flagoptions_test.go @@ -6,8 +6,11 @@ import ( "os" "path/filepath" "strings" + "sync" "testing" + "time" + "github.com/openai/openai-cli/internal/apiform" "github.com/stretchr/testify/require" ) @@ -235,11 +238,12 @@ func TestEmbedFiles(t *testing.T) { t.Run(tt.name+" io.Reader", func(t *testing.T) { t.Parallel() - _, err := embedFiles(tt.input, EmbedIOReader, nil) + got, err := embedFiles(tt.input, EmbedIOReader, nil) if tt.wantErr { require.Error(t, err) } else { require.NoError(t, err) + require.NoError(t, closeFileUploads(got)) } }) } @@ -382,6 +386,97 @@ func TestEmbedFilesUploadMetadata(t *testing.T) { } } +func TestMultipartRequestBodyStreamsAndClosesUploads(t *testing.T) { + t.Parallel() + + upload := &blockingReadCloser{ + data: []byte("large file payload"), + readStarted: make(chan struct{}), + allowRead: make(chan struct{}), + closed: make(chan struct{}), + } + + contentType, body := newMultipartRequestBody(map[string]any{ + "file": fileUpload{ + Reader: upload, + filename: "large.txt", + contentType: "text/plain", + }, + }, apiform.FormatBrackets) + + require.Contains(t, contentType, "multipart/form-data; boundary=") + + select { + case <-upload.readStarted: + t.Fatal("multipart body read upload data before the request body was consumed") + case <-time.After(50 * time.Millisecond): + } + + type readResult struct { + data []byte + err error + } + resultCh := make(chan readResult, 1) + go func() { + data, err := io.ReadAll(body) + resultCh <- readResult{data: data, err: err} + }() + + select { + case <-upload.readStarted: + case <-time.After(time.Second): + t.Fatal("multipart body did not read upload data after the request body was consumed") + } + close(upload.allowRead) + + var result readResult + select { + case result = <-resultCh: + case <-time.After(time.Second): + t.Fatal("multipart body read did not finish") + } + require.NoError(t, result.err) + require.Contains(t, string(result.data), "large file payload") + + select { + case <-upload.closed: + case <-time.After(time.Second): + t.Fatal("multipart body did not close the uploaded file reader") + } +} + +type blockingReadCloser struct { + data []byte + readStarted chan struct{} + allowRead chan struct{} + closed chan struct{} + startOnce sync.Once + closeOnce sync.Once + offset int +} + +func (r *blockingReadCloser) Read(p []byte) (int, error) { + r.startOnce.Do(func() { + close(r.readStarted) + }) + <-r.allowRead + + if r.offset >= len(r.data) { + return 0, io.EOF + } + + n := copy(p, r.data[r.offset:]) + r.offset += n + return n, nil +} + +func (r *blockingReadCloser) Close() error { + r.closeOnce.Do(func() { + close(r.closed) + }) + return nil +} + func writeTestFile(t *testing.T, dir, filename, content string) { t.Helper()