diff --git a/consumergroup.go b/consumergroup.go index b9d0a7e2e..6b5a0c7be 100644 --- a/consumergroup.go +++ b/consumergroup.go @@ -1,8 +1,6 @@ package kafka import ( - "bufio" - "bytes" "context" "errors" "fmt" @@ -13,6 +11,8 @@ import ( "strings" "sync" "time" + + "github.com/segmentio/kafka-go/protocol/consumer" ) // ErrGroupClosed is returned by ConsumerGroup.Next when the group has already @@ -168,7 +168,6 @@ type ConsumerGroupConfig struct { // Validate method validates ConsumerGroupConfig properties and sets relevant // defaults. func (config *ConsumerGroupConfig) Validate() error { - if len(config.Brokers) == 0 { return errors.New("cannot create a consumer group with an empty list of broker addresses") } @@ -925,12 +924,12 @@ func (cg *ConsumerGroup) coordinator() (coordinator, error) { // the leader. Otherwise, GroupMemberAssignments will be nil. // // Possible kafka error codes returned: -// * GroupLoadInProgress: -// * GroupCoordinatorNotAvailable: -// * NotCoordinatorForGroup: -// * InconsistentGroupProtocol: -// * InvalidSessionTimeout: -// * GroupAuthorizationFailed: +// - GroupLoadInProgress: +// - GroupCoordinatorNotAvailable: +// - NotCoordinatorForGroup: +// - InconsistentGroupProtocol: +// - InvalidSessionTimeout: +// - GroupAuthorizationFailed: func (cg *ConsumerGroup) joinGroup(conn coordinator, memberID string) (string, int32, GroupMemberAssignments, error) { request, err := cg.makeJoinGroupRequestV1(memberID) if err != nil { @@ -951,7 +950,6 @@ func (cg *ConsumerGroup) joinGroup(conn coordinator, memberID string) (string, i cg.withLogger(func(l Logger) { l.Printf("joined group %s as member %s in generation %d", cg.config.ID, memberID, generationID) }) - var assignments GroupMemberAssignments if iAmLeader := response.MemberID == response.LeaderID; iAmLeader { v, err := cg.assignTopicPartitions(conn, response) @@ -990,15 +988,19 @@ func (cg *ConsumerGroup) makeJoinGroupRequestV1(memberID string) (joinGroupReque for _, balancer := range cg.config.GroupBalancers { userData, err := balancer.UserData() if err != nil { - return joinGroupRequestV1{}, fmt.Errorf("unable to construct protocol metadata for member, %v: %w", balancer.ProtocolName(), err) + return joinGroupRequestV1{}, fmt.Errorf("unable to construct protocol metadata user data for member, %v: %w", balancer.ProtocolName(), err) + } + pm, err := (&consumer.Subscription{ + Version: 1, + Topics: cg.config.Topics, + UserData: userData, + }).Bytes() + if err != nil { + return joinGroupRequestV1{}, fmt.Errorf("unable to construct protocol metadata subscription for member, %v: %w", balancer.ProtocolName(), err) } request.GroupProtocols = append(request.GroupProtocols, joinGroupRequestGroupProtocolV1{ - ProtocolName: balancer.ProtocolName(), - ProtocolMetadata: groupMetadata{ - Version: 1, - Topics: cg.config.Topics, - UserData: userData, - }.bytes(), + ProtocolName: balancer.ProtocolName(), + ProtocolMetadata: pm, }) } @@ -1053,9 +1055,9 @@ func (cg *ConsumerGroup) assignTopicPartitions(conn coordinator, group joinGroup func (cg *ConsumerGroup) makeMemberProtocolMetadata(in []joinGroupResponseMemberV1) ([]GroupMember, error) { members := make([]GroupMember, 0, len(in)) for _, item := range in { - metadata := groupMetadata{} - reader := bufio.NewReader(bytes.NewReader(item.MemberMetadata)) - if remain, err := (&metadata).readFrom(reader, len(item.MemberMetadata)); err != nil || remain != 0 { + var metadata consumer.Subscription + err := metadata.FromBytes(item.MemberMetadata) + if err != nil { return nil, fmt.Errorf("unable to read metadata for member, %v: %w", item.MemberID, err) } @@ -1073,13 +1075,16 @@ func (cg *ConsumerGroup) makeMemberProtocolMetadata(in []joinGroupResponseMember // Readers subscriptions topic => partitions // // Possible kafka error codes returned: -// * GroupCoordinatorNotAvailable: -// * NotCoordinatorForGroup: -// * IllegalGeneration: -// * RebalanceInProgress: -// * GroupAuthorizationFailed: +// - GroupCoordinatorNotAvailable: +// - NotCoordinatorForGroup: +// - IllegalGeneration: +// - RebalanceInProgress: +// - GroupAuthorizationFailed: func (cg *ConsumerGroup) syncGroup(conn coordinator, memberID string, generationID int32, memberAssignments GroupMemberAssignments) (map[string][]int32, error) { - request := cg.makeSyncGroupRequestV0(memberID, generationID, memberAssignments) + request, err := cg.makeSyncGroupRequestV0(memberID, generationID, memberAssignments) + if err != nil { + return nil, err + } response, err := conn.syncGroup(request) if err == nil && response.ErrorCode != 0 { err = Error(response.ErrorCode) @@ -1088,13 +1093,13 @@ func (cg *ConsumerGroup) syncGroup(conn coordinator, memberID string, generation return nil, err } - assignments := groupAssignment{} - reader := bufio.NewReader(bytes.NewReader(response.MemberAssignments)) - if _, err := (&assignments).readFrom(reader, len(response.MemberAssignments)); err != nil { + var assignment consumer.Assignment + err = assignment.FromBytes(response.MemberAssignments) + if err != nil { return nil, err } - if len(assignments.Topics) == 0 { + if len(assignment.AssignedPartitions) == 0 { cg.withLogger(func(l Logger) { l.Printf("received empty assignments for group, %v as member %s for generation %d", cg.config.ID, memberID, generationID) }) @@ -1104,10 +1109,15 @@ func (cg *ConsumerGroup) syncGroup(conn coordinator, memberID string, generation l.Printf("sync group finished for group, %v", cg.config.ID) }) - return assignments.Topics, nil + assignments := make(map[string][]int32, len(assignment.AssignedPartitions)) + for _, ap := range assignment.AssignedPartitions { + assignments[ap.Topic] = ap.Partitions + } + + return assignments, nil } -func (cg *ConsumerGroup) makeSyncGroupRequestV0(memberID string, generationID int32, memberAssignments GroupMemberAssignments) syncGroupRequestV0 { +func (cg *ConsumerGroup) makeSyncGroupRequestV0(memberID string, generationID int32, memberAssignments GroupMemberAssignments) (syncGroupRequestV0, error) { request := syncGroupRequestV0{ GroupID: cg.config.ID, GenerationID: generationID, @@ -1118,20 +1128,27 @@ func (cg *ConsumerGroup) makeSyncGroupRequestV0(memberID string, generationID in request.GroupAssignments = make([]syncGroupRequestGroupAssignmentV0, 0, 1) for memberID, topics := range memberAssignments { - topics32 := make(map[string][]int32) + assignedPartitions := make([]consumer.TopicPartition, 0, len(topics)) for topic, partitions := range topics { - partitions32 := make([]int32, len(partitions)) + topic := consumer.TopicPartition{ + Topic: topic, + Partitions: make([]int32, len(partitions)), + } for i := range partitions { - partitions32[i] = int32(partitions[i]) + topic.Partitions[i] = int32(partitions[i]) } - topics32[topic] = partitions32 + assignedPartitions = append(assignedPartitions, topic) + } + assignments, err := (&consumer.Assignment{ + Version: 1, + AssignedPartitions: assignedPartitions, + }).Bytes() + if err != nil { + return request, err } request.GroupAssignments = append(request.GroupAssignments, syncGroupRequestGroupAssignmentV0{ - MemberID: memberID, - MemberAssignments: groupAssignment{ - Version: 1, - Topics: topics32, - }.bytes(), + MemberID: memberID, + MemberAssignments: assignments, }) } @@ -1140,7 +1157,7 @@ func (cg *ConsumerGroup) makeSyncGroupRequestV0(memberID string, generationID in }) } - return request + return request, nil } func (cg *ConsumerGroup) fetchOffsets(conn coordinator, subs map[string][]int32) (map[string]map[int]int64, error) { diff --git a/consumergroup_test.go b/consumergroup_test.go index 0d3e290a9..4da9d4524 100644 --- a/consumergroup_test.go +++ b/consumergroup_test.go @@ -8,6 +8,9 @@ import ( "sync" "testing" "time" + + "github.com/segmentio/kafka-go/protocol" + "github.com/segmentio/kafka-go/protocol/consumer" ) var _ coordinator = mockCoordinator{} @@ -146,11 +149,15 @@ func TestReaderAssignTopicPartitions(t *testing.T) { } for memberID, topics := range topicsByMemberID { + mm, err := protocol.Marshal(1, consumer.Subscription{ + Topics: topics, + }) + if err != nil { + t.Errorf("error marshaling consumer subscription: %v", err) + } resp.Members = append(resp.Members, joinGroupResponseMemberV1{ - MemberID: memberID, - MemberMetadata: groupMetadata{ - Topics: topics, - }.bytes(), + MemberID: memberID, + MemberMetadata: mm, }) } @@ -553,7 +560,6 @@ func TestConsumerGroupErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.scenario, func(t *testing.T) { - tt.prepare(&mc) group, err := NewConsumerGroup(ConsumerGroupConfig{ diff --git a/describegroups.go b/describegroups.go index 4faf7a01b..f0ac160a0 100644 --- a/describegroups.go +++ b/describegroups.go @@ -1,12 +1,10 @@ package kafka import ( - "bufio" - "bytes" "context" - "fmt" "net" + "github.com/segmentio/kafka-go/protocol/consumer" "github.com/segmentio/kafka-go/protocol/describegroups" ) @@ -168,54 +166,26 @@ func decodeMemberMetadata(rawMetadata []byte) (DescribeGroupsResponseMemberMetad return mm, nil } - buf := bytes.NewBuffer(rawMetadata) - bufReader := bufio.NewReader(buf) - remain := len(rawMetadata) - - var err error - var version16 int16 - - if remain, err = readInt16(bufReader, remain, &version16); err != nil { - return mm, err - } - mm.Version = int(version16) - - if remain, err = readStringArray(bufReader, remain, &mm.Topics); err != nil { - return mm, err - } - if remain, err = readBytes(bufReader, remain, &mm.UserData); err != nil { + var sub consumer.Subscription + err := sub.FromBytes(rawMetadata) + if err != nil { return mm, err } - if mm.Version == 1 && remain > 0 { - fn := func(r *bufio.Reader, size int) (fnRemain int, fnErr error) { - op := DescribeGroupsResponseMemberMetadataOwnedPartition{} - if fnRemain, fnErr = readString(r, size, &op.Topic); fnErr != nil { - return - } - - ps := []int32{} - if fnRemain, fnErr = readInt32Array(r, fnRemain, &ps); fnErr != nil { - return - } - - for _, p := range ps { - op.Partitions = append(op.Partitions, int(p)) - } - - mm.OwnedPartitions = append(mm.OwnedPartitions, op) - return + mm.Version = int(sub.Version) + mm.Topics = sub.Topics + mm.UserData = sub.UserData + mm.OwnedPartitions = make([]DescribeGroupsResponseMemberMetadataOwnedPartition, len(sub.OwnedPartitions)) + for i, op := range sub.OwnedPartitions { + mm.OwnedPartitions[i] = DescribeGroupsResponseMemberMetadataOwnedPartition{ + Topic: op.Topic, + Partitions: make([]int, len(op.Partitions)), } - - if remain, err = readArrayWith(bufReader, remain, fn); err != nil { - return mm, err + for j, part := range op.Partitions { + mm.OwnedPartitions[i].Partitions[j] = int(part) } } - if remain != 0 { - return mm, fmt.Errorf("Got non-zero number of bytes remaining: %d", remain) - } - return mm, nil } @@ -231,68 +201,24 @@ func decodeMemberAssignments(rawAssignments []byte) (DescribeGroupsResponseAssig return ma, nil } - buf := bytes.NewBuffer(rawAssignments) - bufReader := bufio.NewReader(buf) - remain := len(rawAssignments) - - var err error - var version16 int16 - - if remain, err = readInt16(bufReader, remain, &version16); err != nil { + var assignment consumer.Assignment + err := assignment.FromBytes(rawAssignments) + if err != nil { return ma, err } - ma.Version = int(version16) - - fn := func(r *bufio.Reader, size int) (fnRemain int, fnErr error) { - item := GroupMemberTopic{} - - if fnRemain, fnErr = readString(r, size, &item.Topic); fnErr != nil { - return - } - - partitions := []int32{} - if fnRemain, fnErr = readInt32Array(r, fnRemain, &partitions); fnErr != nil { - return + ma.Version = int(assignment.Version) + ma.UserData = assignment.UserData + ma.Topics = make([]GroupMemberTopic, len(assignment.AssignedPartitions)) + for i, topic := range assignment.AssignedPartitions { + ma.Topics[i] = GroupMemberTopic{ + Topic: topic.Topic, + Partitions: make([]int, len(topic.Partitions)), } - for _, partition := range partitions { - item.Partitions = append(item.Partitions, int(partition)) + for j, part := range topic.Partitions { + ma.Topics[i].Partitions[j] = int(part) } - - ma.Topics = append(ma.Topics, item) - return - } - if remain, err = readArrayWith(bufReader, remain, fn); err != nil { - return ma, err - } - - if remain, err = readBytes(bufReader, remain, &ma.UserData); err != nil { - return ma, err - } - - if remain != 0 { - return ma, fmt.Errorf("Got non-zero number of bytes remaining: %d", remain) } return ma, nil } - -// readInt32Array reads an array of int32s. It's adapted from the implementation of -// readStringArray. -func readInt32Array(r *bufio.Reader, sz int, v *[]int32) (remain int, err error) { - var content []int32 - fn := func(r *bufio.Reader, size int) (fnRemain int, fnErr error) { - var value int32 - if fnRemain, fnErr = readInt32(r, size, &value); fnErr != nil { - return - } - content = append(content, value) - return - } - if remain, err = readArrayWith(r, sz, fn); err != nil { - return - } - - *v = content - return -} diff --git a/joingroup.go b/joingroup.go index 30823a69a..aaf5d2a68 100644 --- a/joingroup.go +++ b/joingroup.go @@ -2,7 +2,6 @@ package kafka import ( "bufio" - "bytes" "context" "fmt" "net" @@ -189,43 +188,6 @@ func (c *Client) JoinGroup(ctx context.Context, req *JoinGroupRequest) (*JoinGro return res, nil } -type groupMetadata struct { - Version int16 - Topics []string - UserData []byte -} - -func (t groupMetadata) size() int32 { - return sizeofInt16(t.Version) + - sizeofStringArray(t.Topics) + - sizeofBytes(t.UserData) -} - -func (t groupMetadata) writeTo(wb *writeBuffer) { - wb.writeInt16(t.Version) - wb.writeStringArray(t.Topics) - wb.writeBytes(t.UserData) -} - -func (t groupMetadata) bytes() []byte { - buf := bytes.NewBuffer(nil) - t.writeTo(&writeBuffer{w: buf}) - return buf.Bytes() -} - -func (t *groupMetadata) readFrom(r *bufio.Reader, size int) (remain int, err error) { - if remain, err = readInt16(r, size, &t.Version); err != nil { - return - } - if remain, err = readStringArray(r, remain, &t.Topics); err != nil { - return - } - if remain, err = readBytes(r, remain, &t.UserData); err != nil { - return - } - return -} - type joinGroupRequestGroupProtocolV1 struct { ProtocolName string ProtocolMetadata []byte diff --git a/joingroup_test.go b/joingroup_test.go index 926f5b4a6..9bca21ceb 100644 --- a/joingroup_test.go +++ b/joingroup_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/segmentio/kafka-go/protocol/consumer" ktesting "github.com/segmentio/kafka-go/testing" ) @@ -130,12 +131,20 @@ func TestSaramaCompatibility(t *testing.T) { // // See consumer_group_members_test.go // + groupMemberMetadataV1MissingOwnedPartitions = []byte{ + 0, 1, // Version + 0, 0, 0, 2, // Topic array length + 0, 3, 'o', 'n', 'e', // Topic one + 0, 3, 't', 'w', 'o', // Topic two + 0, 0, 0, 3, 0x01, 0x02, 0x03, // Userdata + } groupMemberMetadata = []byte{ 0, 1, // Version 0, 0, 0, 2, // Topic array length 0, 3, 'o', 'n', 'e', // Topic one 0, 3, 't', 'w', 'o', // Topic two 0, 0, 0, 3, 0x01, 0x02, 0x03, // Userdata + 0, 0, 0, 0, // OwnedPartitions KIP-429 } groupMemberAssignment = []byte{ 0, 1, // Version @@ -147,15 +156,12 @@ func TestSaramaCompatibility(t *testing.T) { } ) - t.Run("verify metadata", func(t *testing.T) { - var item groupMetadata - remain, err := (&item).readFrom(bufio.NewReader(bytes.NewReader(groupMemberMetadata)), len(groupMemberMetadata)) + t.Run("verify metadata v1 missing OwnedPartitions", func(t *testing.T) { + var item consumer.Subscription + err := item.FromBytes(groupMemberMetadataV1MissingOwnedPartitions) if err != nil { t.Fatalf("bad err: %v", err) } - if remain != 0 { - t.Fatalf("expected 0; got %v", remain) - } if v := item.Version; v != 1 { t.Errorf("expected Version 1; got %v", v) @@ -168,53 +174,49 @@ func TestSaramaCompatibility(t *testing.T) { } }) - t.Run("verify assignments", func(t *testing.T) { - var item groupAssignment - remain, err := (&item).readFrom(bufio.NewReader(bytes.NewReader(groupMemberAssignment)), len(groupMemberAssignment)) + t.Run("verify metadata", func(t *testing.T) { + var item consumer.Subscription + err := item.FromBytes(groupMemberMetadata) if err != nil { t.Fatalf("bad err: %v", err) } - if remain != 0 { - t.Fatalf("expected 0; got %v", remain) - } if v := item.Version; v != 1 { t.Errorf("expected Version 1; got %v", v) } - if v := item.Topics; !reflect.DeepEqual(map[string][]int32{"one": {0, 2, 4}}, v) { - t.Errorf(`expected map[string][]int32{"one": {0, 2, 4}}; got %v`, v) + if v := item.Topics; !reflect.DeepEqual([]string{"one", "two"}, v) { + t.Errorf(`expected {"one", "two"}; got %v`, v) } if v := item.UserData; !reflect.DeepEqual([]byte{0x01, 0x02, 0x03}, v) { t.Errorf("expected []byte{0x01, 0x02, 0x03}; got %v", v) } + if v := item.OwnedPartitions; len(v) != 0 { + t.Errorf("expected no owned partitions; got %v", item.OwnedPartitions) + } }) -} -func TestMemberMetadata(t *testing.T) { - item := groupMetadata{ - Version: 1, - Topics: []string{"a", "b"}, - UserData: []byte(`blah`), - } - - b := bytes.NewBuffer(nil) - w := &writeBuffer{w: b} - item.writeTo(w) + t.Run("verify assignments", func(t *testing.T) { + var item consumer.Assignment + err := item.FromBytes(groupMemberAssignment) + if err != nil { + t.Fatalf("bad err: %v", err) + } - var found groupMetadata - remain, err := (&found).readFrom(bufio.NewReader(b), b.Len()) - if err != nil { - t.Error(err) - t.FailNow() - } - if remain != 0 { - t.Errorf("expected 0 remain, got %v", remain) - t.FailNow() - } - if !reflect.DeepEqual(item, found) { - t.Error("expected item and found to be the same") - t.FailNow() - } + if v := item.Version; v != 1 { + t.Errorf("expected Version 1; got %v", v) + } + if v := item.AssignedPartitions; !reflect.DeepEqual([]consumer.TopicPartition{ + { + Topic: "one", + Partitions: []int32{0, 2, 4}, + }, + }, v) { + t.Errorf(`expected []{{Topic: "one", Partitions: []{0,2,4}}}; got %v`, v) + } + if v := item.UserData; !reflect.DeepEqual([]byte{0x01, 0x02, 0x03}, v) { + t.Errorf("expected []byte{0x01, 0x02, 0x03}; got %v", v) + } + }) } func TestJoinGroupResponseV1(t *testing.T) { diff --git a/protocol/consumer/consumer.go b/protocol/consumer/consumer.go index ab643105d..8919cca45 100644 --- a/protocol/consumer/consumer.go +++ b/protocol/consumer/consumer.go @@ -1,21 +1,85 @@ package consumer -const MaxVersionSupported = 1 +import ( + "encoding/binary" + "errors" + "io" + + "github.com/segmentio/kafka-go/protocol" +) + +const MaxVersionSupported = 3 type Subscription struct { - Version int16 `kafka:"min=v0,max=v1"` - Topics []string `kafka:"min=v0,max=v1"` - UserData []byte `kafka:"min=v0,max=v1,nullable"` - OwnedPartitions []TopicPartition `kafka:"min=v1,max=v1"` + Version int16 `kafka:"min=v0,max=v3"` + Topics []string `kafka:"min=v0,max=v3"` + UserData []byte `kafka:"min=v0,max=v3,nullable"` + OwnedPartitions []TopicPartition `kafka:"min=v1,max=v3"` + GenerationID int32 `kafka:"min=v2,max=v3"` + RackID string `kafka:"min=v3,max=v3,nullable"` +} + +type subscriptionBackwardsCompat struct { + Version int16 `kafka:"min=v0,max=v1"` + Topics []string `kafka:"min=v0,max=v1"` + UserData []byte `kafka:"min=v0,max=v1,nullable"` +} + +func (s *subscriptionBackwardsCompat) FromBytes(b []byte) error { + // This type is only intended to maintain backwards compatibility with + // this library and support other clients in the wild sending + // version 1 supscription data without OwnedPartitionsy + return protocol.Unmarshal(b, 1, s) +} + +func (s *Subscription) FromBytes(b []byte) error { + if len(b) < 2 { + return io.ErrUnexpectedEOF + } + version := readInt16(b[0:2]) + err := protocol.Unmarshal(b, version, s) + if err != nil && version >= 1 && errors.Is(err, io.ErrUnexpectedEOF) { + var sub subscriptionBackwardsCompat + if err = sub.FromBytes(b); err != nil { + return err + } + s.Version = sub.Version + s.Topics = sub.Topics + s.UserData = sub.UserData + return nil + + } + + return err +} + +func (s *Subscription) Bytes() ([]byte, error) { + return protocol.Marshal(s.Version, *s) } type Assignment struct { - Version int16 `kafka:"min=v0,max=v1"` - AssignedPartitions []TopicPartition `kafka:"min=v0,max=v1"` - UserData []byte `kafka:"min=v0,max=v1,nullable"` + Version int16 `kafka:"min=v0,max=v3"` + AssignedPartitions []TopicPartition `kafka:"min=v0,max=v3"` + UserData []byte `kafka:"min=v0,max=v3,nullable"` +} + +func (a *Assignment) FromBytes(b []byte) error { + if len(b) < 2 { + return io.ErrUnexpectedEOF + } + version := readInt16(b[0:2]) + return protocol.Unmarshal(b, version, a) +} + +func (a *Assignment) Bytes() ([]byte, error) { + return protocol.Marshal(a.Version, *a) } type TopicPartition struct { - Topic string `kafka:"min=v0,max=v1"` - Partitions []int32 `kafka:"min=v0,max=v1"` + Topic string `kafka:"min=v0,max=v3"` + Partitions []int32 `kafka:"min=v0,max=v3"` +} + +func readInt16(b []byte) int16 { + return int16(binary.BigEndian.Uint16(b)) } diff --git a/protocol/consumer/consumer_test.go b/protocol/consumer/consumer_test.go index 760336e7d..bb7f0589d 100644 --- a/protocol/consumer/consumer_test.go +++ b/protocol/consumer/consumer_test.go @@ -1,6 +1,7 @@ package consumer_test import ( + "bytes" "reflect" "testing" @@ -18,13 +19,22 @@ func TestSubscription(t *testing.T) { Partitions: []int32{1, 2, 3}, }, }, + GenerationID: 10, + RackID: "rack", } - for _, version := range []int16{1, 0} { + for _, version := range []int16{3, 2, 1, 0} { + sub := subscription if version == 0 { - subscription.OwnedPartitions = nil + sub.OwnedPartitions = nil } - data, err := protocol.Marshal(version, subscription) + if version < 2 { + sub.GenerationID = 0 + } + if version < 3 { + sub.RackID = "" + } + data, err := protocol.Marshal(version, sub) if err != nil { t.Fatal(err) } @@ -33,8 +43,36 @@ func TestSubscription(t *testing.T) { if err != nil { t.Fatal(err) } - if !reflect.DeepEqual(subscription, gotSubscription) { + if !reflect.DeepEqual(sub, gotSubscription) { t.Fatalf("unexpected result after marshal/unmarshal \nexpected\n %#v\ngot\n %#v", subscription, gotSubscription) } } } + +func TestInvalidVersion1(t *testing.T) { + groupMemberMetadataV1MissingOwnedPartitions := []byte{ + 0, 1, // Version + 0, 0, 0, 2, // Topic array length + 0, 3, 'o', 'n', 'e', // Topic one + 0, 3, 't', 'w', 'o', // Topic two + 0, 0, 0, 3, 0x01, 0x02, 0x03, // Userdata + } + + var s consumer.Subscription + err := s.FromBytes(groupMemberMetadataV1MissingOwnedPartitions) + if err != nil { + t.Fatal(err) + } + + if s.Version != 1 { + t.Fatalf("expected version to be 1 got: %d", s.Version) + } + + if !reflect.DeepEqual(s.Topics, []string{"one", "two"}) { + t.Fatalf("expected topics to be [one two] got: %v", s.Topics) + } + + if !bytes.Equal(s.UserData, []byte{0x01, 0x02, 0x03}) { + t.Fatalf(`expected user data to be [1 2 3] got: %v`, s.UserData) + } +} diff --git a/syncgroup.go b/syncgroup.go index ff37569e7..1e3e104ba 100644 --- a/syncgroup.go +++ b/syncgroup.go @@ -2,13 +2,11 @@ package kafka import ( "bufio" - "bytes" "context" "fmt" "net" "time" - "github.com/segmentio/kafka-go/protocol" "github.com/segmentio/kafka-go/protocol/consumer" "github.com/segmentio/kafka-go/protocol/syncgroup" ) @@ -93,7 +91,7 @@ func (c *Client) SyncGroup(ctx context.Context, req *SyncGroupRequest) (*SyncGro for _, assignment := range req.Assignments { assign := consumer.Assignment{ - Version: consumer.MaxVersionSupported, + Version: 1, AssignedPartitions: make([]consumer.TopicPartition, 0, len(assignment.Assignment.AssignedPartitions)), UserData: assignment.Assignment.UserData, } @@ -109,7 +107,7 @@ func (c *Client) SyncGroup(ctx context.Context, req *SyncGroupRequest) (*SyncGro assign.AssignedPartitions = append(assign.AssignedPartitions, tp) } - assignBytes, err := protocol.Marshal(consumer.MaxVersionSupported, assign) + assignBytes, err := assign.Bytes() if err != nil { return nil, fmt.Errorf("kafka.(*Client).SyncGroup: %w", err) } @@ -128,7 +126,7 @@ func (c *Client) SyncGroup(ctx context.Context, req *SyncGroupRequest) (*SyncGro r := m.(*syncgroup.Response) var assignment consumer.Assignment - err = protocol.Unmarshal(r.Assignments, consumer.MaxVersionSupported, &assignment) + err = assignment.FromBytes(r.Assignments) if err != nil { return nil, fmt.Errorf("kafka.(*Client).SyncGroup: %w", err) } @@ -154,62 +152,6 @@ func (c *Client) SyncGroup(ctx context.Context, req *SyncGroupRequest) (*SyncGro return res, nil } -type groupAssignment struct { - Version int16 - Topics map[string][]int32 - UserData []byte -} - -func (t groupAssignment) size() int32 { - sz := sizeofInt16(t.Version) + sizeofInt16(int16(len(t.Topics))) - - for topic, partitions := range t.Topics { - sz += sizeofString(topic) + sizeofInt32Array(partitions) - } - - return sz + sizeofBytes(t.UserData) -} - -func (t groupAssignment) writeTo(wb *writeBuffer) { - wb.writeInt16(t.Version) - wb.writeInt32(int32(len(t.Topics))) - - for topic, partitions := range t.Topics { - wb.writeString(topic) - wb.writeInt32Array(partitions) - } - - wb.writeBytes(t.UserData) -} - -func (t *groupAssignment) readFrom(r *bufio.Reader, size int) (remain int, err error) { - // I came across this case when testing for compatibility with bsm/sarama-cluster. It - // appears in some cases, sarama-cluster can send a nil array entry. Admittedly, I - // didn't look too closely at it. - if size == 0 { - t.Topics = map[string][]int32{} - return 0, nil - } - - if remain, err = readInt16(r, size, &t.Version); err != nil { - return - } - if remain, err = readMapStringInt32(r, remain, &t.Topics); err != nil { - return - } - if remain, err = readBytes(r, remain, &t.UserData); err != nil { - return - } - - return -} - -func (t groupAssignment) bytes() []byte { - buf := bytes.NewBuffer(nil) - t.writeTo(&writeBuffer{w: buf}) - return buf.Bytes() -} - type syncGroupRequestGroupAssignmentV0 struct { // MemberID assigned by the group coordinator MemberID string diff --git a/syncgroup_test.go b/syncgroup_test.go index 930696bde..af9e177c0 100644 --- a/syncgroup_test.go +++ b/syncgroup_test.go @@ -156,52 +156,6 @@ func TestClientSyncGroup(t *testing.T) { } } -func TestGroupAssignment(t *testing.T) { - item := groupAssignment{ - Version: 1, - Topics: map[string][]int32{ - "a": {1, 2, 3}, - "b": {4, 5}, - }, - UserData: []byte(`blah`), - } - - b := bytes.NewBuffer(nil) - w := &writeBuffer{w: b} - item.writeTo(w) - - var found groupAssignment - remain, err := (&found).readFrom(bufio.NewReader(b), b.Len()) - if err != nil { - t.Error(err) - t.FailNow() - } - if remain != 0 { - t.Errorf("expected 0 remain, got %v", remain) - t.FailNow() - } - if !reflect.DeepEqual(item, found) { - t.Error("expected item and found to be the same") - t.FailNow() - } -} - -func TestGroupAssignmentReadsFromZeroSize(t *testing.T) { - var item groupAssignment - remain, err := (&item).readFrom(bufio.NewReader(bytes.NewReader(nil)), 0) - if err != nil { - t.Error(err) - t.FailNow() - } - if remain != 0 { - t.Errorf("expected 0 remain, got %v", remain) - t.FailNow() - } - if item.Topics == nil { - t.Error("expected non nil Topics to be assigned") - } -} - func TestSyncGroupResponseV0(t *testing.T) { item := syncGroupResponseV0{ ErrorCode: 2,