diff --git a/common/persistence/nosql/nosqlplugin/cassandra/gocql/batch.go b/common/persistence/nosql/nosqlplugin/cassandra/gocql/batch.go index efacee18a1..d4387b86f5 100644 --- a/common/persistence/nosql/nosqlplugin/cassandra/gocql/batch.go +++ b/common/persistence/nosql/nosqlplugin/cassandra/gocql/batch.go @@ -36,13 +36,16 @@ func (b *Batch) Query(stmt string, args ...any) { b.gocqlBatch.Query(stmt, args...) } +// WithContext mutates b in place and returns it. The caller must not retain +// a reference to b before this call and use it concurrently afterward. func (b *Batch) WithContext(ctx context.Context) *Batch { - return newBatch(b.session, b.gocqlBatch.WithContext(ctx)) + b.gocqlBatch = b.gocqlBatch.WithContext(ctx) + return b } func (b *Batch) WithTimestamp(timestamp int64) *Batch { b.gocqlBatch.WithTimestamp(timestamp) - return newBatch(b.session, b.gocqlBatch) + return b } func mustConvertBatchType(batchType BatchType) gocql.BatchType { diff --git a/common/persistence/nosql/nosqlplugin/cassandra/gocql/query.go b/common/persistence/nosql/nosqlplugin/cassandra/gocql/query.go index 79ab9d6ce4..f023e9b53f 100644 --- a/common/persistence/nosql/nosqlplugin/cassandra/gocql/query.go +++ b/common/persistence/nosql/nosqlplugin/cassandra/gocql/query.go @@ -70,41 +70,46 @@ func (q *query) Iter() Iter { func (q *query) PageSize(n int) Query { q.gocqlQuery.PageSize(n) - return newQuery(q.session, q.gocqlQuery) + return q } func (q *query) PageState(state []byte) Query { q.gocqlQuery.PageState(state) - return newQuery(q.session, q.gocqlQuery) + return q } func (q *query) Consistency(c Consistency) Query { q.gocqlQuery.Consistency(mustConvertConsistency(c)) - return newQuery(q.session, q.gocqlQuery) + return q } func (q *query) WithTimestamp(timestamp int64) Query { q.gocqlQuery.WithTimestamp(timestamp) - return newQuery(q.session, q.gocqlQuery) + return q } +// WithContext mutates q in place and returns it. The caller must not retain +// a reference to q before this call and use it concurrently afterward. func (q *query) WithContext(ctx context.Context) Query { q2 := q.gocqlQuery.WithContext(ctx) if q2 == nil { return nil } - return newQuery(q.session, q2) + q.gocqlQuery = q2 + return q } func (q *query) Bind(v ...any) Query { q.gocqlQuery.Bind(v...) - return newQuery(q.session, q.gocqlQuery) + return q } func (q *query) Idempotent(value bool) Query { - return newQuery(q.session, q.gocqlQuery.Idempotent(value)) + q.gocqlQuery.Idempotent(value) + return q } func (q *query) SetSpeculativeExecutionPolicy(policy SpeculativeExecutionPolicy) Query { - return newQuery(q.session, q.gocqlQuery.SetSpeculativeExecutionPolicy(policy)) + q.gocqlQuery.SetSpeculativeExecutionPolicy(policy) + return q }