diff --git a/go/logic/applier.go b/go/logic/applier.go index 9d59d9ec4..1be696909 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -14,10 +14,13 @@ import ( "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/binlog" - "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" - "github.com/openark/golib/log" + "context" + "database/sql/driver" + + "github.com/github/gh-ost/go/mysql" + drivermysql "github.com/go-sql-driver/mysql" "github.com/openark/golib/sqlutils" ) @@ -77,7 +80,8 @@ func NewApplier(migrationContext *base.MigrationContext) *Applier { func (this *Applier) InitDBConnections() (err error) { applierUri := this.connectionConfig.GetDBUri(this.migrationContext.DatabaseName) - if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, applierUri); err != nil { + uriWithMulti := fmt.Sprintf("%s&multiStatements=true", applierUri) + if this.db, _, err = mysql.GetDB(this.migrationContext.Uuid, uriWithMulti); err != nil { return err } singletonApplierUri := fmt.Sprintf("%s&timeout=0", applierUri) @@ -1207,44 +1211,80 @@ func (this *Applier) buildDMLEventQuery(dmlEvent *binlog.BinlogDMLEvent) []*dmlB // ApplyDMLEventQueries applies multiple DML queries onto the _ghost_ table func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent)) error { var totalDelta int64 + ctx := context.Background() err := func() error { - tx, err := this.db.Begin() + conn, err := this.db.Conn(ctx) if err != nil { return err } + defer conn.Close() + + sessionQuery := "SET /* gh-ost */ SESSION time_zone = '+00:00'" + sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery()) + if _, err := conn.ExecContext(ctx, sessionQuery); err != nil { + return err + } + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + return err + } rollback := func(err error) error { tx.Rollback() return err } - sessionQuery := "SET /* gh-ost */ SESSION time_zone = '+00:00'" - sessionQuery = fmt.Sprintf("%s, %s", sessionQuery, this.generateSqlModeQuery()) - - if _, err := tx.Exec(sessionQuery); err != nil { - return rollback(err) - } + buildResults := make([]*dmlBuildResult, 0, len(dmlEvents)) + nArgs := 0 for _, dmlEvent := range dmlEvents { for _, buildResult := range this.buildDMLEventQuery(dmlEvent) { if buildResult.err != nil { return rollback(buildResult.err) } - result, err := tx.Exec(buildResult.query, buildResult.args...) - if err != nil { - err = fmt.Errorf("%w; query=%s; args=%+v", err, buildResult.query, buildResult.args) - return rollback(err) - } + nArgs += len(buildResult.args) + buildResults = append(buildResults, buildResult) + } + } - rowsAffected, err := result.RowsAffected() - if err != nil { - log.Warningf("error getting rows affected from DML event query: %s. i'm going to assume that the DML affected a single row, but this may result in inaccurate statistics", err) - rowsAffected = 1 + // We batch together the DML queries into multi-statements to minimize network trips. + // We have to use the raw driver connection to access the rows affected + // for each statement in the multi-statement. + execErr := conn.Raw(func(driverConn any) error { + ex := driverConn.(driver.ExecerContext) + nvc := driverConn.(driver.NamedValueChecker) + + multiArgs := make([]driver.NamedValue, 0, nArgs) + multiQueryBuilder := strings.Builder{} + for _, buildResult := range buildResults { + for _, arg := range buildResult.args { + nv := driver.NamedValue{Value: driver.Value(arg)} + nvc.CheckNamedValue(&nv) + multiArgs = append(multiArgs, nv) } - // each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1). - // multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event - totalDelta += buildResult.rowsDelta * rowsAffected + + multiQueryBuilder.WriteString(buildResult.query) + multiQueryBuilder.WriteString(";\n") } + + res, err := ex.ExecContext(ctx, multiQueryBuilder.String(), multiArgs) + if err != nil { + err = fmt.Errorf("%w; query=%s; args=%+v", err, multiQueryBuilder.String(), multiArgs) + return err + } + + mysqlRes := res.(drivermysql.Result) + + // each DML is either a single insert (delta +1), update (delta +0) or delete (delta -1). + // multiplying by the rows actually affected (either 0 or 1) will give an accurate row delta for this DML event + for i, rowsAffected := range mysqlRes.AllRowsAffected() { + totalDelta += buildResults[i].rowsDelta * rowsAffected + } + return nil + }) + + if execErr != nil { + return rollback(execErr) } if err := tx.Commit(); err != nil { return err