Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 69 additions & 11 deletions pkg/cmd/flagoptions.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package cmd

import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -492,22 +491,18 @@ func flagOptions(
case EmptyBody:
break
case MultipartFormEncoded:
buf := new(bytes.Buffer)
writer := multipart.NewWriter(buf)

// For multipart/form-encoded, we need a map structure
bodyMap, ok := requestContents.Body.(map[string]any)
if !ok {
return nil, fmt.Errorf("Cannot send a non-map value to a form-encoded endpoint: %v\n", requestContents.Body)
}
encodingFormat := apiform.FormatBrackets
if err := apiform.MarshalWithSettings(bodyMap, writer, encodingFormat); err != nil {
return nil, err
}
if err := writer.Close(); err != nil {
return nil, err
}
options = append(options, option.WithRequestBody(writer.FormDataContentType(), buf))
contentType, body := newMultipartRequestBody(bodyMap, encodingFormat)
options = append(options,
option.WithRequestBody(contentType, body),
// Streaming request bodies cannot be replayed safely by the SDK retry loop.
option.WithMaxRetries(0),
)

case ApplicationJSON:
bodyBytes, err := json.Marshal(requestContents.Body)
Expand Down Expand Up @@ -538,6 +533,69 @@ func flagOptions(
return options, nil
}

func newMultipartRequestBody(bodyMap map[string]any, encodingFormat apiform.FormFormat) (string, io.Reader) {
reader, writer := io.Pipe()
multipartWriter := multipart.NewWriter(writer)
contentType := multipartWriter.FormDataContentType()

go func() {
err := apiform.MarshalWithSettings(bodyMap, multipartWriter, encodingFormat)
if closeErr := multipartWriter.Close(); err == nil {
err = closeErr
}
if closeErr := closeFileUploads(bodyMap); err == nil {
err = closeErr
}

if err != nil {
_ = writer.CloseWithError(err)
return
}
_ = writer.Close()
}()

return contentType, reader
}

func closeFileUploads(value any) error {
return closeFileUploadsValue(reflect.ValueOf(value))
}

func closeFileUploadsValue(v reflect.Value) error {
if !v.IsValid() {
return nil
}

if v.Kind() == reflect.Interface || v.Kind() == reflect.Pointer {
if v.IsNil() {
return nil
}
return closeFileUploadsValue(v.Elem())
}

if upload, ok := v.Interface().(fileUpload); ok {
return upload.Close()
}

switch v.Kind() {
case reflect.Map:
iter := v.MapRange()
for iter.Next() {
if err := closeFileUploadsValue(iter.Value()); err != nil {
return err
}
}
case reflect.Slice, reflect.Array:
for i := 0; i < v.Len(); i++ {
if err := closeFileUploadsValue(v.Index(i)); err != nil {
return err
}
}
}

return nil
}

// FilePathValue is a string wrapper that marks a value as a file path whose contents should be read
// and embedded in the request. Unlike a regular string, embedFilesValue always treats a FilePathValue
// as a file path without needing the "@" prefix.
Expand Down
97 changes: 96 additions & 1 deletion pkg/cmd/flagoptions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ import (
"os"
"path/filepath"
"strings"
"sync"
"testing"
"time"

"github.com/openai/openai-cli/internal/apiform"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -235,11 +238,12 @@ func TestEmbedFiles(t *testing.T) {
t.Run(tt.name+" io.Reader", func(t *testing.T) {
t.Parallel()

_, err := embedFiles(tt.input, EmbedIOReader, nil)
got, err := embedFiles(tt.input, EmbedIOReader, nil)
if tt.wantErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.NoError(t, closeFileUploads(got))
}
})
}
Expand Down Expand Up @@ -382,6 +386,97 @@ func TestEmbedFilesUploadMetadata(t *testing.T) {
}
}

func TestMultipartRequestBodyStreamsAndClosesUploads(t *testing.T) {
t.Parallel()

upload := &blockingReadCloser{
data: []byte("large file payload"),
readStarted: make(chan struct{}),
allowRead: make(chan struct{}),
closed: make(chan struct{}),
}

contentType, body := newMultipartRequestBody(map[string]any{
"file": fileUpload{
Reader: upload,
filename: "large.txt",
contentType: "text/plain",
},
}, apiform.FormatBrackets)

require.Contains(t, contentType, "multipart/form-data; boundary=")

select {
case <-upload.readStarted:
t.Fatal("multipart body read upload data before the request body was consumed")
case <-time.After(50 * time.Millisecond):
}

type readResult struct {
data []byte
err error
}
resultCh := make(chan readResult, 1)
go func() {
data, err := io.ReadAll(body)
resultCh <- readResult{data: data, err: err}
}()

select {
case <-upload.readStarted:
case <-time.After(time.Second):
t.Fatal("multipart body did not read upload data after the request body was consumed")
}
close(upload.allowRead)

var result readResult
select {
case result = <-resultCh:
case <-time.After(time.Second):
t.Fatal("multipart body read did not finish")
}
require.NoError(t, result.err)
require.Contains(t, string(result.data), "large file payload")

select {
case <-upload.closed:
case <-time.After(time.Second):
t.Fatal("multipart body did not close the uploaded file reader")
}
}

type blockingReadCloser struct {
data []byte
readStarted chan struct{}
allowRead chan struct{}
closed chan struct{}
startOnce sync.Once
closeOnce sync.Once
offset int
}

func (r *blockingReadCloser) Read(p []byte) (int, error) {
r.startOnce.Do(func() {
close(r.readStarted)
})
<-r.allowRead

if r.offset >= len(r.data) {
return 0, io.EOF
}

n := copy(p, r.data[r.offset:])
r.offset += n
return n, nil
}

func (r *blockingReadCloser) Close() error {
r.closeOnce.Do(func() {
close(r.closed)
})
return nil
}

func writeTestFile(t *testing.T, dir, filename, content string) {
t.Helper()

Expand Down