Skip to content

Commit

Permalink
feat(client): send retry count header (#278)
Browse files Browse the repository at this point in the history
  • Loading branch information
stainless-app[bot] committed Sep 25, 2024
1 parent 565bdb2 commit be9a0e0
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 5 deletions.
78 changes: 74 additions & 4 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"fmt"
"net/http"
"reflect"
"testing"
"time"

Expand Down Expand Up @@ -45,12 +46,12 @@ func TestUserAgentHeader(t *testing.T) {
}

func TestRetryAfter(t *testing.T) {
attempts := 0
retryCountHeaders := make([]string, 0)
client := moderntreasury.NewClient(
option.WithHTTPClient(&http.Client{
Transport: &closureTransport{
fn: func(req *http.Request) (*http.Response, error) {
attempts++
retryCountHeaders = append(retryCountHeaders, req.Header.Get("X-Stainless-Retry-Count"))
return &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{
Expand All @@ -67,8 +68,77 @@ func TestRetryAfter(t *testing.T) {
if err == nil || res != nil {
t.Error("Expected there to be a cancel error and for the response to be nil")
}
if want := 3; attempts != want {
t.Errorf("Expected %d attempts, got %d", want, attempts)

attempts := len(retryCountHeaders)
if attempts != 3 {
t.Errorf("Expected %d attempts, got %d", 3, attempts)
}

expectedRetryCountHeaders := []string{"0", "1", "2"}
if !reflect.DeepEqual(retryCountHeaders, expectedRetryCountHeaders) {
t.Errorf("Expected %v retry count headers, got %v", expectedRetryCountHeaders, retryCountHeaders)
}
}

func TestDeleteRetryCountHeader(t *testing.T) {
retryCountHeaders := make([]string, 0)
client := moderntreasury.NewClient(
option.WithHTTPClient(&http.Client{
Transport: &closureTransport{
fn: func(req *http.Request) (*http.Response, error) {
retryCountHeaders = append(retryCountHeaders, req.Header.Get("X-Stainless-Retry-Count"))
return &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{
http.CanonicalHeaderKey("Retry-After"): []string{"0.1"},
},
}, nil
},
},
}),
option.WithHeaderDel("X-Stainless-Retry-Count"),
)
res, err := client.Counterparties.New(context.Background(), moderntreasury.CounterpartyNewParams{
Name: moderntreasury.F("my first counterparty"),
})
if err == nil || res != nil {
t.Error("Expected there to be a cancel error and for the response to be nil")
}

expectedRetryCountHeaders := []string{"", "", ""}
if !reflect.DeepEqual(retryCountHeaders, expectedRetryCountHeaders) {
t.Errorf("Expected %v retry count headers, got %v", expectedRetryCountHeaders, retryCountHeaders)
}
}

func TestOverwriteRetryCountHeader(t *testing.T) {
retryCountHeaders := make([]string, 0)
client := moderntreasury.NewClient(
option.WithHTTPClient(&http.Client{
Transport: &closureTransport{
fn: func(req *http.Request) (*http.Response, error) {
retryCountHeaders = append(retryCountHeaders, req.Header.Get("X-Stainless-Retry-Count"))
return &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{
http.CanonicalHeaderKey("Retry-After"): []string{"0.1"},
},
}, nil
},
},
}),
option.WithHeader("X-Stainless-Retry-Count", "42"),
)
res, err := client.Counterparties.New(context.Background(), moderntreasury.CounterpartyNewParams{
Name: moderntreasury.F("my first counterparty"),
})
if err == nil || res != nil {
t.Error("Expected there to be a cancel error and for the response to be nil")
}

expectedRetryCountHeaders := []string{"42", "42", "42"}
if !reflect.DeepEqual(retryCountHeaders, expectedRetryCountHeaders) {
t.Errorf("Expected %v retry count headers, got %v", expectedRetryCountHeaders, retryCountHeaders)
}
}

Expand Down
11 changes: 10 additions & 1 deletion internal/requestconfig/requestconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ func NewRequestConfig(ctx context.Context, method string, u string, body interfa
req.Header.Set("Idempotency-Key", "stainless-go-"+uuid.New().String())
}
req.Header.Set("Accept", "application/json")
req.Header.Set("X-Stainless-Retry-Count", "0")
for k, v := range getDefaultHeaders() {
req.Header.Add(k, v)
}
Expand Down Expand Up @@ -337,6 +338,9 @@ func (cfg *RequestConfig) Execute() (err error) {
handler = applyMiddleware(cfg.Middlewares[i], handler)
}

// Don't send the current retry count in the headers if the caller modified the header defaults.
shouldSendRetryCount := cfg.Request.Header.Get("X-Stainless-Retry-Count") == "0"

var res *http.Response
for retryCount := 0; retryCount <= cfg.MaxRetries; retryCount += 1 {
ctx := cfg.Request.Context()
Expand All @@ -346,7 +350,12 @@ func (cfg *RequestConfig) Execute() (err error) {
defer cancel()
}

res, err = handler(cfg.Request.Clone(ctx))
req := cfg.Request.Clone(ctx)
if shouldSendRetryCount {
req.Header.Set("X-Stainless-Retry-Count", strconv.Itoa(retryCount))
}

res, err = handler(req)
if ctx != nil && ctx.Err() != nil {
return ctx.Err()
}
Expand Down

0 comments on commit be9a0e0

Please sign in to comment.