diff --git a/pkg/mq/msgdispatcher/dispatcher.go b/pkg/mq/msgdispatcher/dispatcher.go index 690301ecf6fde..4a614d653f0f8 100644 --- a/pkg/mq/msgdispatcher/dispatcher.go +++ b/pkg/mq/msgdispatcher/dispatcher.go @@ -91,7 +91,16 @@ func NewDispatcher(ctx context.Context, log := log.With(zap.String("pchannel", pchannel), zap.String("subName", subName), zap.Bool("isMain", isMain)) log.Info("creating dispatcher...") - stream, err := factory.NewTtMsgStream(ctx) + + var stream msgstream.MsgStream + var err error + defer func() { + if err != nil && stream != nil { + stream.Close() + } + }() + + stream, err = factory.NewTtMsgStream(ctx) if err != nil { return nil, err } @@ -106,7 +115,6 @@ func NewDispatcher(ctx context.Context, err = stream.Seek(ctx, []*Pos{position}, false) if err != nil { - stream.Close() log.Error("seek failed", zap.Error(err)) return nil, err } @@ -114,7 +122,7 @@ func NewDispatcher(ctx context.Context, log.Info("seek successfully", zap.Uint64("posTs", position.GetTimestamp()), zap.Time("posTime", posTime), zap.Duration("tsLag", time.Since(posTime))) } else { - err := stream.AsConsumer(ctx, []string{pchannel}, subName, subPos) + err = stream.AsConsumer(ctx, []string{pchannel}, subName, subPos) if err != nil { log.Error("asConsumer failed", zap.Error(err)) return nil, err diff --git a/pkg/mq/msgdispatcher/dispatcher_test.go b/pkg/mq/msgdispatcher/dispatcher_test.go index 0177064119e01..be76a52958689 100644 --- a/pkg/mq/msgdispatcher/dispatcher_test.go +++ b/pkg/mq/msgdispatcher/dispatcher_test.go @@ -56,6 +56,7 @@ func TestDispatcher(t *testing.T) { t.Run("test AsConsumer fail", func(t *testing.T) { ms := msgstream.NewMockMsgStream(t) ms.EXPECT().AsConsumer(mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("mock error")) + ms.EXPECT().Close().Return() factory := &msgstream.MockMqFactory{ NewMsgStreamFunc: func(ctx context.Context) (msgstream.MsgStream, error) { return ms, nil