diff --git a/batch.go b/batch.go index 9b943621e..86cc8fed3 100644 --- a/batch.go +++ b/batch.go @@ -72,6 +72,13 @@ func (b *Batch) Queue(query string, arguments ...any) *QueuedQuery { return qq } +// BatchN returns a new Batch with an initial capacity of n. +// This is useful to avoid allocations if you know beforehand how many queries +// at least you want to enqueue. +func BatchN(n int) Batch { + return Batch{queuedQueries: make([]*QueuedQuery, 0, n)} +} + // Len returns number of queries that have been queued so far. func (b *Batch) Len() int { return len(b.queuedQueries) diff --git a/batch_test.go b/batch_test.go index 9ff2417ff..5883c329f 100644 --- a/batch_test.go +++ b/batch_test.go @@ -31,7 +31,7 @@ func TestConnSendBatch(t *testing.T) { );` mustExec(t, conn, sql) - batch := &pgx.Batch{} + batch := pgx.BatchN(7) batch.Queue("insert into ledger(description, amount) values($1, $2)", "q1", 1) batch.Queue("insert into ledger(description, amount) values($1, $2)", "q2", 2) batch.Queue("insert into ledger(description, amount) values($1, $2)", "q3", 3) @@ -40,7 +40,7 @@ func TestConnSendBatch(t *testing.T) { batch.Queue("select * from ledger where false") batch.Queue("select sum(amount) from ledger") - br := conn.SendBatch(ctx, batch) + br := conn.SendBatch(ctx, &batch) ct, err := br.Exec() if err != nil {