Skip to content

Commit 0798a8d

Browse files
🥅 BCS-1862 Catch errors that happen during PublishBatch
1 parent 4502ff6 commit 0798a8d

File tree

2 files changed

+48
-16
lines changed

2 files changed

+48
-16
lines changed

publisher/sns/sns.go

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@ package sns
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"strings"
78

89
"github.com/aws/aws-sdk-go/aws"
910
"github.com/aws/aws-sdk-go/aws/request"
1011
"github.com/aws/aws-sdk-go/aws/session"
1112
"github.com/aws/aws-sdk-go/service/sns"
12-
"github.com/google/uuid"
13+
"github.com/creatorstack/htsqs/constants"
14+
"github.com/creatorstack/htsqs/publisher/models"
1315
)
1416

1517
// sender is the interface to sns.SNS. Its sole purpose is to make
@@ -67,18 +69,22 @@ func (p *Publisher) Publish(ctx context.Context, msg interface{}) error {
6769
// kept under 100 messages so that all messages can be published in 10 tries. In case
6870
// of failure when parsing or publishing any of the messages, this function will stop
6971
// further publishing and return an error
70-
func (p *Publisher) PublishBatch(ctx context.Context, msgs []interface{}) error {
72+
func (p *Publisher) PublishBatch(ctx context.Context, msgs []models.Message) (map[string]error, int64, int64, error) {
7173
var (
7274
defaultMessageGroupID = "default"
75+
publishResult = make(map[string]error)
7376
err error
77+
78+
errorCount int64
79+
successCount int64
7480
)
7581

7682
isFifo := strings.Contains(strings.ToLower(p.cfg.TopicArn), "fifo")
7783

7884
var (
7985
numPublishedMessages = 0
8086
start = 0
81-
end = 10 // 10 is the maximum batch size for SNS.PublishBatch
87+
end = constants.MaxBatchSize
8288
)
8389
if end > len(msgs) {
8490
end = len(msgs)
@@ -90,14 +96,13 @@ func (p *Publisher) PublishBatch(ctx context.Context, msgs []interface{}) error
9096
for idx := start; idx < end; idx++ {
9197
msg := msgs[idx]
9298

93-
b, err := json.Marshal(msg)
99+
b, err := json.Marshal(msg.Data)
94100
if err != nil {
95-
return err
101+
return publishResult, successCount, errorCount, err
96102
}
97103

98-
entryId := uuid.New().String()
99104
requestEntry := &sns.PublishBatchRequestEntry{
100-
Id: aws.String(entryId),
105+
Id: aws.String(msg.ID),
101106
Message: aws.String(string(b)),
102107
}
103108

@@ -112,20 +117,38 @@ func (p *Publisher) PublishBatch(ctx context.Context, msgs []interface{}) error
112117
PublishBatchRequestEntries: requestEntries,
113118
TopicArn: &p.cfg.TopicArn,
114119
}
115-
_, err = p.sns.PublishBatchWithContext(ctx, input)
120+
response, err := p.sns.PublishBatchWithContext(ctx, input)
116121
if err != nil {
117-
return err
122+
return publishResult, successCount, errorCount, err
123+
}
124+
125+
for _, errEntry := range response.Failed {
126+
if errEntry != nil && errEntry.Id != nil {
127+
errMsg := "publish error"
128+
if errEntry.Message != nil {
129+
errMsg = *errEntry.Message
130+
}
131+
publishResult[*errEntry.Id] = errors.New(errMsg)
132+
errorCount++
133+
}
134+
}
135+
136+
for _, successEntry := range response.Successful {
137+
if successEntry != nil && successEntry.Id != nil {
138+
publishResult[*successEntry.Id] = nil
139+
successCount++
140+
}
118141
}
119142

120143
numPublishedMessages += len(requestEntries)
121144
start = end
122-
end += 10
145+
end += constants.MaxBatchSize
123146
if end > len(msgs) {
124147
end = len(msgs)
125148
}
126149
}
127150

128-
return err
151+
return publishResult, successCount, errorCount, err
129152
}
130153

131154
func defaultPublisherConfig(cfg *Config) {

publisher/sns/sns_test.go

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"testing"
66

77
"github.com/aws/aws-sdk-go/aws/session"
8+
"github.com/creatorstack/htsqs/publisher/models"
89
"github.com/stretchr/testify/require"
910
)
1011

@@ -27,9 +28,15 @@ func TestPublisher(t *testing.T) {
2728
}
2829

2930
func TestPublisherBatch(t *testing.T) {
30-
inputs := []interface{}{
31-
jsonString(`{"key":"val1"}`),
32-
jsonString(`{"key":"val2"}`),
31+
inputs := []models.Message{
32+
{
33+
ID: "1",
34+
Data: jsonString(`{"key":"val1"}`),
35+
},
36+
{
37+
ID: "2",
38+
Data: jsonString(`{"key":"val2"}`),
39+
},
3340
}
3441

3542
queue := make(chan *string, len(inputs))
@@ -38,12 +45,14 @@ func TestPublisherBatch(t *testing.T) {
3845
pubs := New(Config{})
3946
pubs.sns = &snsPublisherMock{queue: queue}
4047

41-
require.NoError(t, pubs.PublishBatch(context.TODO(), inputs))
48+
_, _, _, err := pubs.PublishBatch(context.TODO(), inputs)
49+
50+
require.NoError(t, err)
4251

4352
idx := 0
4453
for v := range queue {
4554
publishedMessage := *v
46-
require.Equal(t, jsonString(publishedMessage), inputs[idx])
55+
require.Equal(t, jsonString(publishedMessage), inputs[idx].Data)
4756
idx++
4857
if idx >= len(inputs) {
4958
break

0 commit comments

Comments
 (0)