diff --git a/client.go b/client.go index d965040e8..0a19bb4ba 100644 --- a/client.go +++ b/client.go @@ -42,6 +42,11 @@ type Client struct { // // If nil, DefaultTransport is used. Transport RoundTripper + + // Limit for message size, applied after compression stage. + // + // If zero, no limit is applied. + MaxMessageBytes int64 } // A ConsumerGroup and Topic as these are both strings we define a type for diff --git a/produce.go b/produce.go index 72d1ed09b..52bee09d7 100644 --- a/produce.go +++ b/produce.go @@ -159,12 +159,15 @@ func (c *Client) Produce(ctx context.Context, req *ProduceRequest) (*ProduceResp }, }}, }}, + MaxMessageBytes: c.MaxMessageBytes, }) switch { case err == nil: case errors.Is(err, protocol.ErrNoRecord): return new(ProduceResponse), nil + case protocol.IsMaxMessageBytesExceeded(err): + return nil, MessageTooLargeError{} default: return nil, fmt.Errorf("kafka.(*Client).Produce: %w", err) } diff --git a/protocol/error.go b/protocol/error.go index 52c5d0833..b1c754bb3 100644 --- a/protocol/error.go +++ b/protocol/error.go @@ -1,6 +1,7 @@ package protocol import ( + "errors" "fmt" ) @@ -89,3 +90,28 @@ func (e *TopicPartitionError) Error() string { func (e *TopicPartitionError) Unwrap() error { return e.Err } + +type MaxMessageBytesExceededError interface { + error + IsMaxMessageBytesExceeded() +} + +func IsMaxMessageBytesExceeded(err error) bool { + var target MaxMessageBytesExceededError + return errors.As(err, &target) +} + +type baseMaxMessageBytesExceededError struct{ error } + +func (b baseMaxMessageBytesExceededError) IsFatal() {} + +func (b baseMaxMessageBytesExceededError) Unwrap() error { + return b.error +} + +func NewBaseMaxMessageBytesExceededError(err error) error { + if err == nil { + return nil + } + return &baseMaxMessageBytesExceededError{error: err} +} diff --git a/protocol/produce/produce.go b/protocol/produce/produce.go index 6d337c3cf..21560a42e 100644 --- a/protocol/produce/produce.go +++ b/protocol/produce/produce.go @@ -15,10 +15,15 @@ type Request struct { Acks int16 `kafka:"min=v0,max=v8"` Timeout int32 `kafka:"min=v0,max=v8"` Topics []RequestTopic `kafka:"min=v0,max=v8"` + + // Use this to store max.message.bytes + MaxMessageBytes int64 `kafka:"-"` } func (r *Request) ApiKey() protocol.ApiKey { return protocol.Produce } +func (r *Request) MaxMessageBytesSize() int64 { return r.MaxMessageBytes } + func (r *Request) Broker(cluster protocol.Cluster) (protocol.Broker, error) { broker := protocol.Broker{ID: -1} diff --git a/protocol/protocol.go b/protocol/protocol.go index 3d0a7b8dd..6af8e6535 100644 --- a/protocol/protocol.go +++ b/protocol/protocol.go @@ -527,6 +527,13 @@ type Merger interface { Merge(messages []Message, results []interface{}) (Message, error) } +// MaxMessageBytesKeeper is an extension of the Message interface, which aimed +// to store max.message.bytes parameter +type MaxMessageBytesKeeper interface { + // Returns locally stored max.message.bytes value + MaxMessageBytesSize() int64 +} + // Result converts r to a Message or an error, or panics if r could not be // converted to these types. func Result(r interface{}) (Message, error) { diff --git a/protocol/request.go b/protocol/request.go index 135b938bb..98361b700 100644 --- a/protocol/request.go +++ b/protocol/request.go @@ -125,7 +125,15 @@ func WriteRequest(w io.Writer, apiVersion int16, correlationID int32, clientID s err := e.err if err == nil { - size := packUint32(uint32(b.Size()) - 4) + messageSize := uint32(b.Size()) - 4 + + if p, ok := msg.(MaxMessageBytesKeeper); ok && p.MaxMessageBytesSize() != 0 { + if messageSize > uint32(p.MaxMessageBytesSize()) { + return NewBaseMaxMessageBytesExceededError(fmt.Errorf("message size: %d exceeded max.message.bytes: %d", messageSize, p.MaxMessageBytesSize())) + } + } + + size := packUint32(messageSize) b.WriteAt(size[:], 0) _, err = b.WriteTo(w) } diff --git a/writer.go b/writer.go index 3817bf538..51d4a6ce4 100644 --- a/writer.go +++ b/writer.go @@ -4,8 +4,10 @@ import ( "bytes" "context" "errors" + "fmt" "io" "net" + "strconv" "sync" "sync/atomic" "time" @@ -124,6 +126,27 @@ type Writer struct { // The default is to use a kafka default value of 1048576. BatchBytes int64 + // Setting this flag to true causes the WriteMessages starts to derive 'BatchBytes' + // from topic 'max.message.bytes' setting. If writer is used to write to multiple + // topics each topic 'max.message.bytes' will be handled appropriately. + // This option simplifies maintaining of architecture - creates the one source of + // truth - topic settings on broker side + // + // The default is false + AutoDeriveBatchBytes bool + + // Setting this flag to true causes the WriteMessages starts to apply 'BatchBytes' + // as limiting factor after compression stage. + // When this flag is false - it's possible to get case, when Value can exceed + // 'max.message.bytes' setting, but after compression it's less. + // And WriteMessages returns an error, when indeed there are no error. + // + // Nevertheless, 'BatchBytes' also has second function - to form batches, and + // this option doesn't affect this function. + // + // The default is false + ApplyBatchBytesAfterCompression bool + // Time limit on how often incomplete message batches will be flushed to // kafka. // @@ -220,6 +243,10 @@ type Writer struct { // non-nil when a transport was created by NewWriter, remove in 1.0. transport *Transport + + // map for storing each topic max.message.bytes + // Used only when AutoDeriveBatchBytes is true + maxMessageBytesPerTopic sync.Map } // WriterConfig is a configuration type used to create new instances of Writer. @@ -621,17 +648,24 @@ func (w *Writer) WriteMessages(ctx context.Context, msgs ...Message) error { } balancer := w.balancer() - batchBytes := w.batchBytes() - - for i := range msgs { - n := int64(msgs[i].totalSize()) - if n > batchBytes { - // This error is left for backward compatibility with historical - // behavior, but it can yield O(N^2) behaviors. The expectations - // are that the program will check if WriteMessages returned a - // MessageTooLargeError, discard the message that was exceeding - // the maximum size, and try again. - return messageTooLarge(msgs, i) + if w.AutoDeriveBatchBytes { + err := w.deriveBatchBytes(msgs) + if err != nil { + return err + } + } + + if !w.ApplyBatchBytesAfterCompression { + for i := range msgs { + n := int64(msgs[i].totalSize()) + if n > w.batchBytes(msgs[i].Topic) { + // This error is left for backward compatibility with historical + // behavior, but it can yield O(N^2) behaviors. The expectations + // are that the program will check if WriteMessages returned a + // MessageTooLargeError, discard the message that was exceeding + // the maximum size, and try again. + return messageTooLarge(msgs, i) + } } } @@ -730,7 +764,12 @@ func (w *Writer) produce(key topicPartition, batch *writeBatch) (*ProduceRespons ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - return w.client(timeout).Produce(ctx, &ProduceRequest{ + client := w.client(timeout) + if w.ApplyBatchBytesAfterCompression { + client.MaxMessageBytes = w.batchBytes(key.topic) + } + + return client.Produce(ctx, &ProduceRequest{ Partition: int(key.partition), Topic: key.topic, RequiredAcks: w.RequiredAcks, @@ -815,7 +854,55 @@ func (w *Writer) batchSize() int { return 100 } -func (w *Writer) batchBytes() int64 { +func (w *Writer) deriveBatchBytes(msgs []Message) error { + for _, msg := range msgs { + topic, err := w.chooseTopic(msg) + if err != nil { + return err + } + + if _, ok := w.maxMessageBytesPerTopic.Load(topic); ok { + continue + } + + describeResp, err := w.client(w.readTimeout()).DescribeConfigs(context.Background(), &DescribeConfigsRequest{ + Resources: []DescribeConfigRequestResource{{ + ResourceType: ResourceTypeTopic, + ResourceName: msg.Topic, + ConfigNames: []string{ + "max.message.bytes", + }, + }}, + }) + if err != nil { + return err + } + if len(describeResp.Resources) != 1 { + return errors.New("describeResp contains 0 'Resources' entries") + } + if len(describeResp.Resources[0].ConfigEntries) != 1 { + return errors.New("describeResp.Resources[0] contains 0 'ConfigEntries' entries") + } + maxMessageBytesStr := describeResp.Resources[0].ConfigEntries[0].ConfigValue + maxMessageBytes, err := strconv.Atoi(maxMessageBytesStr) + if err != nil { + return err + } + w.maxMessageBytesPerTopic.Store(topic, int64(maxMessageBytes)) + } + return nil +} + +func (w *Writer) batchBytes(topic string) int64 { + if w.AutoDeriveBatchBytes { + if result, ok := w.maxMessageBytesPerTopic.Load(topic); ok { + return result.(int64) + } + // batchBytes expects it's called after 'deriveBatchBytes(msgs)' + // It means, there are no unknown topics + panic(fmt.Sprintf("unknown topic: %s", topic)) + } + if w.BatchBytes > 0 { return w.BatchBytes } @@ -1028,7 +1115,7 @@ func (ptw *partitionWriter) writeMessages(msgs []Message, indexes []int32) map[* defer ptw.mutex.Unlock() batchSize := ptw.w.batchSize() - batchBytes := ptw.w.batchBytes() + batchBytes := ptw.w.batchBytes(ptw.meta.topic) var batches map[*writeBatch][]int32 if !ptw.w.Async {