Skip to content

Add support for SendBatchToTarget #599

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions internal/testsuite/store_suite.go
Original file line number Diff line number Diff line change
@@ -152,6 +152,48 @@ func (s *StoreTestSuite) TestMessageStoreSaveMessageAndIncrementGetMessage() {
s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2]))
}

func (s *StoreTestSuite) TestMessageStoreSaveBatchAndIncrementGetMessage() {
s.Require().Nil(s.MsgStore.SetNextSenderMsgSeqNum(420))

// Given the following saved messages
expectedMsgsBySeqNum := map[int]string{
1: "In the frozen land of Nador",
2: "they were forced to eat Robin's minstrels",
3: "and there was much rejoicing",
}
var msgs [][]byte
for _, msg := range expectedMsgsBySeqNum {
msgs = append(msgs, []byte(msg))
}
s.Require().Nil(s.MsgStore.SaveBatchAndIncrNextSenderMsgSeqNum(1, msgs))
s.Equal(423, s.MsgStore.NextSenderMsgSeqNum())

// When the messages are retrieved from the MessageStore
actualMsgs, err := s.MsgStore.GetMessages(1, 3)
s.Require().Nil(err)

// Then the messages should be
s.Require().Len(actualMsgs, 3)
s.Equal(expectedMsgsBySeqNum[1], string(actualMsgs[0]))
s.Equal(expectedMsgsBySeqNum[2], string(actualMsgs[1]))
s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2]))

// When the store is refreshed from its backing store
s.Require().Nil(s.MsgStore.Refresh())

// And the messages are retrieved from the MessageStore
actualMsgs, err = s.MsgStore.GetMessages(1, 3)
s.Require().Nil(err)

s.Equal(423, s.MsgStore.NextSenderMsgSeqNum())

// Then the messages should still be
s.Require().Len(actualMsgs, 3)
s.Equal(expectedMsgsBySeqNum[1], string(actualMsgs[0]))
s.Equal(expectedMsgsBySeqNum[2], string(actualMsgs[1]))
s.Equal(expectedMsgsBySeqNum[3], string(actualMsgs[2]))
}

func (s *StoreTestSuite) TestMessageStoreGetMessagesEmptyStore() {
// When messages are retrieved from an empty store
messages, err := s.MsgStore.GetMessages(1, 2)
9 changes: 9 additions & 0 deletions memorystore.go
Original file line number Diff line number Diff line change
@@ -97,6 +97,15 @@ func (store *memoryStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg
return store.IncrNextSenderMsgSeqNum()
}

func (store *memoryStore) SaveBatchAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error {
for offset, m := range msg {
if err := store.SaveMessageAndIncrNextSenderMsgSeqNum(seqNum+offset, m); err != nil {
return err
}
}
return nil
}

func (store *memoryStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) {
var msgs [][]byte
for seqNum := beginSeqNum; seqNum <= endSeqNum; seqNum++ {
19 changes: 19 additions & 0 deletions registry.go
Original file line number Diff line number Diff line change
@@ -64,6 +64,25 @@ func SendToTarget(m Messagable, sessionID SessionID) error {
return session.queueForSend(msg)
}

// SendBatchToTarget is similar to SendToTarget, but it sends application messages in batch to the sessionID.
// The entire batch would fail if:
// - any message in the batch fails ToApp() validation
// - any message in the batch is an admin message
// This is more efficient compare to SendToTarget in the case of sending a burst of application messages,
// especially when using a persistent store like SQLStore, because it allows batching at the storage layer.
func SendBatchToTarget(m []Messagable, sessionID SessionID) error {
session, ok := lookupSession(sessionID)
if !ok {
return errUnknownSession
}
msg := make([]*Message, len(m))
for i, v := range m {
msg[i] = v.ToMessage()
}

return session.queueBatchAppsForSend(msg)
}

// ResetSession resets session's sequence numbers.
func ResetSession(sessionID SessionID) error {
session, ok := lookupSession(sessionID)
53 changes: 53 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
@@ -346,6 +346,59 @@ func (s *session) persist(seqNum int, msgBytes []byte) error {
return s.store.IncrNextSenderMsgSeqNum()
}

// queueBatchAppsForSend will validate, persist, and queue the messages for send.
func (s *session) queueBatchAppsForSend(msg []*Message) error {
s.sendMutex.Lock()
defer s.sendMutex.Unlock()

msgBytes, err := s.prepBatchAppMessagesForSend(msg)
if err != nil {
return err
}

for _, mb := range msgBytes {
s.toSend = append(s.toSend, mb)
select {
case s.messageEvent <- true:
default:
}
}

return nil
}

func (s *session) prepBatchAppMessagesForSend(msg []*Message) (msgBytes [][]byte, err error) {
seqNum := s.store.NextSenderMsgSeqNum()
for i, m := range msg {
s.fillDefaultHeader(m, nil)
m.Header.SetField(tagMsgSeqNum, FIXInt(seqNum+i))
msgType, err := m.Header.GetBytes(tagMsgType)
if err != nil {
return nil, err
}
if isAdminMessageType(msgType) {
return nil, fmt.Errorf("cannot send admin messages in batch")
}
if errToApp := s.application.ToApp(m, s.sessionID); errToApp != nil {
return nil, errToApp
}
msgBytes = append(msgBytes, m.build())
}
err = s.persistBatch(seqNum, msgBytes)
if err != nil {
return nil, err
}
return msgBytes, nil
}

func (s *session) persistBatch(seqNum int, msgBytes [][]byte) error {
if !s.DisableMessagePersist {
return s.store.SaveBatchAndIncrNextSenderMsgSeqNum(seqNum, msgBytes)
}

return s.store.SetNextSenderMsgSeqNum(seqNum + len(msgBytes))
}

func (s *session) sendQueued() {
for _, msgBytes := range s.toSend {
s.sendBytes(msgBytes)
1 change: 1 addition & 0 deletions store.go
Original file line number Diff line number Diff line change
@@ -35,6 +35,7 @@ type MessageStore interface {

SaveMessage(seqNum int, msg []byte) error
SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error
SaveBatchAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error
GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error)

Refresh() error
34 changes: 27 additions & 7 deletions store/file/filestore.go
Original file line number Diff line number Diff line change
@@ -323,20 +323,24 @@ func (store *fileStore) CreationTime() time.Time {
func (store *fileStore) SetCreationTime(_ time.Time) {
}

func (store *fileStore) SaveMessage(seqNum int, msg []byte) error {
func (store *fileStore) saveMessages(seqNum int, messages [][]byte) error {
offset, err := store.bodyFile.Seek(0, io.SeekEnd)
if err != nil {
return fmt.Errorf("unable to seek to end of file: %s: %s", store.bodyFname, err.Error())
}
if _, err := store.headerFile.Seek(0, io.SeekEnd); err != nil {
return fmt.Errorf("unable to seek to end of file: %s: %s", store.headerFname, err.Error())
}
if _, err := fmt.Fprintf(store.headerFile, "%d,%d,%d\n", seqNum, offset, len(msg)); err != nil {
return fmt.Errorf("unable to write to file: %s: %s", store.headerFname, err.Error())
}
msgOffset := offset
for seqOffset, msg := range messages {
if _, err := fmt.Fprintf(store.headerFile, "%d,%d,%d\n", seqNum+seqOffset, msgOffset, len(msg)); err != nil {
return fmt.Errorf("unable to write to file: %s: %s", store.headerFname, err.Error())
}

if _, err := store.bodyFile.Write(msg); err != nil {
return fmt.Errorf("unable to write to file: %s: %s", store.bodyFname, err.Error())
if _, err := store.bodyFile.Write(msg); err != nil {
return fmt.Errorf("unable to write to file: %s: %s", store.bodyFname, err.Error())
}
msgOffset = msgOffset + int64(len(msg))
}
if store.fileSync {
if err := store.bodyFile.Sync(); err != nil {
@@ -347,10 +351,18 @@ func (store *fileStore) SaveMessage(seqNum int, msg []byte) error {
}
}

store.offsets[seqNum] = msgDef{offset: offset, size: len(msg)}
msgOffset = offset
for seqOffset, msg := range messages {
store.offsets[seqNum+seqOffset] = msgDef{offset: msgOffset, size: len(msg)}
msgOffset = msgOffset + int64(len(msg))
}
return nil
}

func (store *fileStore) SaveMessage(seqNum int, msg []byte) error {
return store.saveMessages(seqNum, [][]byte{msg})
}

func (store *fileStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []byte) error {
err := store.SaveMessage(seqNum, msg)
if err != nil {
@@ -359,6 +371,14 @@ func (store *fileStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []
return store.IncrNextSenderMsgSeqNum()
}

func (store *fileStore) SaveBatchAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error {
err := store.saveMessages(seqNum, msg)
if err != nil {
return err
}
return store.SetNextSenderMsgSeqNum(store.cache.NextSenderMsgSeqNum() + len(msg))
}

func (store *fileStore) getMessage(seqNum int) (msg []byte, found bool, err error) {
msgInfo, found := store.offsets[seqNum]
if !found {
50 changes: 50 additions & 0 deletions store/mongo/mongostore.go
Original file line number Diff line number Diff line change
@@ -338,6 +338,56 @@ func (store *mongoStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg [
return store.cache.SetNextSenderMsgSeqNum(next)
}

func (store *mongoStore) SaveBatchAndIncrNextSenderMsgSeqNum(seqNum int, messages [][]byte) error {

if !store.allowTransactions {
for _, msg := range messages {
if err := store.SaveMessageAndIncrNextSenderMsgSeqNum(seqNum, msg); err != nil {
return err
}
}
return nil
}

// If the mongodb supports replicasets, perform this operation as a transaction instead-
var next int
err := store.db.UseSession(context.Background(), func(sessionCtx mongo.SessionContext) error {
if err := sessionCtx.StartTransaction(); err != nil {
return err
}

entries := make([]interface{}, 0, len(messages))
for _, msg := range messages {
msgFilter := generateMessageFilter(&store.sessionID)
msgFilter.Msgseq = seqNum
msgFilter.Message = msg
}
_, err := store.db.Database(store.mongoDatabase).Collection(store.messagesCollection).InsertMany(sessionCtx, entries)
if err != nil {
return err
}

next = store.cache.NextSenderMsgSeqNum() + len(messages)

msgFilter := generateMessageFilter(&store.sessionID)
sessionUpdate := generateMessageFilter(&store.sessionID)
sessionUpdate.IncomingSeqNum = store.cache.NextTargetMsgSeqNum()
sessionUpdate.OutgoingSeqNum = next
sessionUpdate.CreationTime = store.cache.CreationTime()
_, err = store.db.Database(store.mongoDatabase).Collection(store.sessionsCollection).UpdateOne(sessionCtx, msgFilter, bson.M{"$set": sessionUpdate})
if err != nil {
return err
}

return sessionCtx.CommitTransaction(context.Background())
})
if err != nil {
return err
}

return store.cache.SetNextSenderMsgSeqNum(next)
}

func (store *mongoStore) GetMessages(beginSeqNum, endSeqNum int) (msgs [][]byte, err error) {
msgFilter := generateMessageFilter(&store.sessionID)
// Marshal into database form.
51 changes: 51 additions & 0 deletions store/sql/sqlstore.go
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@ import (
"database/sql"
"fmt"
"regexp"
"strings"
"time"

"github.com/pkg/errors"
@@ -352,6 +353,56 @@ func (store *sqlStore) SaveMessageAndIncrNextSenderMsgSeqNum(seqNum int, msg []b
return store.cache.SetNextSenderMsgSeqNum(next)
}

func (store *sqlStore) SaveBatchAndIncrNextSenderMsgSeqNum(seqNum int, msg [][]byte) error {
s := store.sessionID

tx, err := store.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()

const values = "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
placeholders := make([]string, 0, len(msg))
params := make([]interface{}, 0, len(msg)*10)
for offset, m := range msg {
placeholders = append(placeholders, values)
params = append(params, seqNum+offset, string(m),
s.BeginString, s.Qualifier,
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
}
_, err = tx.Exec(sqlString(`INSERT INTO messages (
msgseqnum, message,
beginstring, session_qualifier,
sendercompid, sendersubid, senderlocid,
targetcompid, targetsubid, targetlocid)
VALUES`+strings.Join(placeholders, ","), store.placeholder),
params...)
if err != nil {
return err
}

next := store.cache.NextSenderMsgSeqNum() + len(msg)
_, err = tx.Exec(sqlString(`UPDATE sessions SET outgoing_seqnum = ?
WHERE beginstring=? AND session_qualifier=?
AND sendercompid=? AND sendersubid=? AND senderlocid=?
AND targetcompid=? AND targetsubid=? AND targetlocid=?`, store.placeholder),
next, s.BeginString, s.Qualifier,
s.SenderCompID, s.SenderSubID, s.SenderLocationID,
s.TargetCompID, s.TargetSubID, s.TargetLocationID)
if err != nil {
return err
}

err = tx.Commit()
if err != nil {
return err
}

return store.cache.SetNextSenderMsgSeqNum(next)
}

func (store *sqlStore) GetMessages(beginSeqNum, endSeqNum int) ([][]byte, error) {
s := store.sessionID
var msgs [][]byte