Skip to content

Commit

Permalink
refactor(statement-logger): statement logger data-racy
Browse files Browse the repository at this point in the history
Statement Logger had a flow, it assumed it has the owner ship of the
query, but statements are reused this means, it produced queries that
are not even possible to execute. (This does not fix the issue of more
columns expected then returned). Since logger works as single thread
to write to the file (Queue), it has to be reworked to format the query
and the values before they are passed to the channel (works like a
queue), this way a query is owned by the channel inside a bytes.Buffer.

PrettyCQL has been rewritten to use `bytes.Buffer` instead of
`strings.Builder`, as the intent of the PrettyCQL is to be stored inside
a file, and files work with `[]byte` rather then strings.

Signed-off-by: Dusan Malusev <[email protected]>
  • Loading branch information
CodeLieutenant committed Nov 1, 2024
1 parent 8cebb9c commit 2390abf
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 133 deletions.
1 change: 1 addition & 0 deletions pkg/querycache/querycache.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ func genDeleteStmtCache(s *typedef.Schema, t *typedef.Table) *typedef.StmtCache
builder = builder.Where(qb.GtOrEq(ck.Name)).Where(qb.LtOrEq(ck.Name))
allTypes = append(allTypes, ck.Type, ck.Type)
}

return &typedef.StmtCache{
Query: builder,
Types: allTypes,
Expand Down
202 changes: 106 additions & 96 deletions pkg/stmtlogger/filelogger.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,149 +15,159 @@
package stmtlogger

import (
"bufio"
"bytes"
"context"
"io"
"log"
"os"
"strconv"
"sync"
"sync/atomic"
"time"

"github.com/pkg/errors"
"go.uber.org/multierr"

"github.com/scylladb/gemini/pkg/typedef"
)

const (
defaultChanSize = 1000
defaultChanSize = 1024
defaultBufferSize = 2048
errorsOnFileLimit = 5
)

type StmtToFile interface {
LogStmt(*typedef.Stmt)
LogStmtWithTimeStamp(stmt *typedef.Stmt, ts time.Time)
Close() error
}
type (
StmtToFile interface {
LogStmt(stmt *typedef.Stmt, ts ...time.Time)
Close() error
}

type logger struct {
fd io.Writer
activeChannel atomic.Pointer[loggerChan]
channel loggerChan
isFileNonOperational bool
}
logger struct {
writer *bufio.Writer
fd io.Writer
channel chan *bytes.Buffer
cancel context.CancelFunc
pool sync.Pool
wg sync.WaitGroup
active atomic.Bool
}
)

type loggerChan chan logRec
func NewFileLogger(filename string) (StmtToFile, error) {
if filename == "" {
return &nopFileLogger{}, nil
}

type logRec struct {
stmt *typedef.Stmt
ts time.Time
fd, err := os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
if err != nil {
return nil, err
}

return NewLogger(fd)
}

func (fl *logger) LogStmt(stmt *typedef.Stmt) {
ch := fl.activeChannel.Load()
if ch != nil {
*ch <- logRec{
stmt: stmt,
}
func NewLogger(w io.Writer) (StmtToFile, error) {
ctx, cancel := context.WithCancel(context.Background())

out := &logger{
writer: bufio.NewWriterSize(w, 8192),
fd: w,
channel: make(chan *bytes.Buffer, defaultChanSize),
cancel: cancel,
pool: sync.Pool{
New: func() any {
return bytes.NewBuffer(make([]byte, 0, defaultBufferSize))
},
},
}
out.active.Store(true)

go out.committer(ctx)
return out, nil
}

func (fl *logger) LogStmtWithTimeStamp(stmt *typedef.Stmt, ts time.Time) {
ch := fl.activeChannel.Load()
if ch != nil {
*ch <- logRec{
stmt: stmt,
ts: ts,
}
func (fl *logger) LogStmt(stmt *typedef.Stmt, ts ...time.Time) {
buffer := fl.pool.Get().(*bytes.Buffer)
if err := stmt.PrettyCQLBuffered(buffer); err != nil {
log.Printf("failed to pretty print query: %s", err)
return
}
}

func (fl *logger) Close() error {
if closer, ok := fl.fd.(io.Closer); ok {
return closer.Close()
opType := stmt.QueryType.OpType()

if len(ts) > 0 && !ts[0].IsZero() && (opType == typedef.OpInsert || opType == typedef.OpUpdate || opType == typedef.OpDelete) {
buffer.WriteString(" USING TIMESTAMP ")
buffer.WriteString(strconv.FormatInt(ts[0].UnixMicro(), 10))
}

return nil
buffer.WriteString(";\n")

if fl.active.Load() {
fl.channel <- buffer
}
}

func (fl *logger) committer() {
var err2 error
func (fl *logger) Close() error {
fl.cancel()
fl.active.Swap(false)
close(fl.channel)

defer func() {
fl.activeChannel.Swap(nil)
close(fl.channel)
}()
// Wait for commiter to drain the channel
fl.wg.Wait()

errsAtRow := 0
err := multierr.Append(nil, fl.writer.Flush())

for rec := range fl.channel {
if fl.isFileNonOperational {
continue
}
if closer, ok := fl.fd.(io.Closer); ok {
err = multierr.Append(err, closer.Close())
}

query, err := rec.stmt.PrettyCQL()
if err != nil {
log.Printf("failed to pretty print query: %s", err)
continue
}
return err
}

_, err1 := fl.fd.Write([]byte(query))
opType := rec.stmt.QueryType.OpType()
if rec.ts.IsZero() || !(opType == typedef.OpInsert || opType == typedef.OpUpdate || opType == typedef.OpDelete) {
_, err2 = fl.fd.Write([]byte(";\n"))
func (fl *logger) committer(ctx context.Context) {
fl.wg.Add(1)
defer fl.wg.Done()
errsAtRow := 0

drain := func(rec *bytes.Buffer) {
defer func() {
rec.Reset()
fl.pool.Put(rec)
}()

if _, err := rec.WriteTo(fl.writer); err != nil {
if errors.Is(err, os.ErrClosed) || errsAtRow > errorsOnFileLimit {
return
}
errsAtRow++
log.Printf("failed to write to writer %+v", err)
} else {
_, err2 = fl.fd.Write([]byte(" USING TIMESTAMP " + strconv.FormatInt(rec.ts.UnixNano()/1000, 10) + ";\n"))
}
if err2 == nil && err1 == nil {
errsAtRow = 0
continue
}
}

if errors.Is(err2, os.ErrClosed) || errors.Is(err1, os.ErrClosed) {
fl.isFileNonOperational = true
for {
select {
case <-ctx.Done():
for rec := range fl.channel {
drain(rec)
}
return
}
case rec, ok := <-fl.channel:
if !ok {
return
}

errsAtRow++
if errsAtRow > errorsOnFileLimit {
fl.isFileNonOperational = true
drain(rec)
}

if err2 != nil {
err1 = err2
}

log.Printf("failed to write to writer %v", err1)
return
}
}

func NewFileLogger(filename string) (StmtToFile, error) {
if filename == "" {
return &nopFileLogger{}, nil
}
fd, err := os.OpenFile(filename, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0o644)
if err != nil {
return nil, err
}

return NewLogger(fd)
}

func NewLogger(w io.Writer) (StmtToFile, error) {
out := &logger{
fd: w,
channel: make(loggerChan, defaultChanSize),
}
out.activeChannel.Store(&out.channel)

go out.committer()
return out, nil
}

type nopFileLogger struct{}

func (n *nopFileLogger) LogStmtWithTimeStamp(_ *typedef.Stmt, _ time.Time) {}
func (n *nopFileLogger) LogStmt(_ *typedef.Stmt, _ ...time.Time) {}

func (n *nopFileLogger) Close() error { return nil }

func (n *nopFileLogger) LogStmt(_ *typedef.Stmt) {}
13 changes: 9 additions & 4 deletions pkg/store/cqlstore.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,14 @@ func (cs *cqlStore) mutate(ctx context.Context, stmt *typedef.Stmt) (err error)
func (cs *cqlStore) doMutate(ctx context.Context, stmt *typedef.Stmt, ts time.Time) error {
queryBody, _ := stmt.Query.ToCql()
query := cs.session.Query(queryBody, stmt.Values...).WithContext(ctx)
defer query.Release()

if cs.useServerSideTimestamps {
query = query.DefaultTimestamp(false)
cs.stmtLogger.LogStmt(stmt)
} else {
query = query.WithTimestamp(ts.UnixNano() / 1000)
cs.stmtLogger.LogStmtWithTimeStamp(stmt, ts)
cs.stmtLogger.LogStmt(stmt, ts)
}

if err := query.Exec(); err != nil {
Expand All @@ -94,14 +96,17 @@ func (cs *cqlStore) doMutate(ctx context.Context, stmt *typedef.Stmt, ts time.Ti
}

func (cs *cqlStore) load(ctx context.Context, stmt *typedef.Stmt) (result []map[string]any, err error) {
query, _ := stmt.Query.ToCql()
cql, _ := stmt.Query.ToCql()
cs.stmtLogger.LogStmt(stmt)
iter := cs.session.Query(query, stmt.Values...).WithContext(ctx).Iter()
query := cs.session.Query(cql, stmt.Values...).WithContext(ctx)
defer query.Release()

iter := query.Iter()
cs.ops.WithLabelValues(cs.system, opType(stmt)).Inc()
return loadSet(iter), iter.Close()
}

func (cs cqlStore) close() error {
func (cs *cqlStore) close() error {
cs.session.Close()
return nil
}
Expand Down
7 changes: 3 additions & 4 deletions pkg/typedef/bag.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
package typedef

import (
"bytes"
"github.com/pkg/errors"
"math"
"reflect"
"strings"

"github.com/pkg/errors"

"github.com/gocql/gocql"
"golang.org/x/exp/rand"
Expand Down Expand Up @@ -62,7 +61,7 @@ func (ct *BagType) CQLHolder() string {

type Tuple []any

func (ct *BagType) CQLPretty(builder *strings.Builder, value any) error {
func (ct *BagType) CQLPretty(builder *bytes.Buffer, value any) error {
if reflect.TypeOf(value).Kind() != reflect.Slice {
return errors.Errorf("expected slice, got [%T]%v", value, value)
}
Expand Down
5 changes: 2 additions & 3 deletions pkg/typedef/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
package typedef

import (
"strings"

"bytes"
"github.com/gocql/gocql"
"golang.org/x/exp/rand"
)
Expand All @@ -25,7 +24,7 @@ type Type interface {
Name() string
CQLDef() string
CQLHolder() string
CQLPretty(*strings.Builder, any) error
CQLPretty(*bytes.Buffer, any) error
GenValue(*rand.Rand, *PartitionRangeConfig) []any
GenJSONValue(*rand.Rand, *PartitionRangeConfig) any
LenValue() int
Expand Down
4 changes: 2 additions & 2 deletions pkg/typedef/simple_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
package typedef

import (
"bytes"
"encoding/hex"
"fmt"
"math"
"math/big"
"net"
"strconv"
"strings"
"time"

"github.com/gocql/gocql"
Expand Down Expand Up @@ -69,7 +69,7 @@ func (st SimpleType) LenValue() int {
return 1
}

func (st SimpleType) CQLPretty(builder *strings.Builder, value any) error {
func (st SimpleType) CQLPretty(builder *bytes.Buffer, value any) error {
switch st {
case TYPE_INET:
builder.WriteRune('\'')
Expand Down
3 changes: 2 additions & 1 deletion pkg/typedef/tuple.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package typedef

import (
"bytes"
"strings"

"github.com/pkg/errors"
Expand Down Expand Up @@ -56,7 +57,7 @@ func (t *TupleType) CQLHolder() string {
return "(" + strings.TrimRight(strings.Repeat("?,", len(t.ValueTypes)), ",") + ")"
}

func (t *TupleType) CQLPretty(builder *strings.Builder, value any) error {
func (t *TupleType) CQLPretty(builder *bytes.Buffer, value any) error {
values, ok := value.([]any)
if !ok {
values, ok = value.(Values)
Expand Down
Loading

0 comments on commit 2390abf

Please sign in to comment.