Skip to content

Commit

Permalink
Merge branch 'master' into ci-build-latest-go
Browse files Browse the repository at this point in the history
  • Loading branch information
timvaillancourt authored Oct 22, 2022
2 parents 6d35fa4 + 9b3fa79 commit b087399
Show file tree
Hide file tree
Showing 11 changed files with 149 additions and 45 deletions.
6 changes: 6 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,27 @@ linters:
disable:
- errcheck
enable:
- bodyclose
- containedctx
- contextcheck
- dogsled
- durationcheck
- errname
- errorlint
- execinquery
- gofmt
- ifshort
- misspell
- nilerr
- nilnil
- noctx
- nolintlint
- nosprintfhostport
- prealloc
- rowserrcheck
- sqlclosecheck
- unconvert
- unparam
- unused
- wastedassign
- whitespace
4 changes: 2 additions & 2 deletions go/base/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ func (this *MigrationContext) ReadConfigFile() error {
if cfg.Section("osc").HasKey("chunk_size") {
this.config.Osc.Chunk_Size, err = cfg.Section("osc").Key("chunk_size").Int64()
if err != nil {
return fmt.Errorf("Unable to read osc chunk size: %s", err.Error())
return fmt.Errorf("Unable to read osc chunk size: %w", err)
}
}

Expand All @@ -873,7 +873,7 @@ func (this *MigrationContext) ReadConfigFile() error {
if cfg.Section("osc").HasKey("max_lag_millis") {
this.config.Osc.Max_Lag_Millis, err = cfg.Section("osc").Key("max_lag_millis").Int64()
if err != nil {
return fmt.Errorf("Unable to read max lag millis: %s", err.Error())
return fmt.Errorf("Unable to read max lag millis: %w", err)
}
}

Expand Down
2 changes: 1 addition & 1 deletion go/logic/applier.go
Original file line number Diff line number Diff line change
Expand Up @@ -1134,7 +1134,7 @@ func (this *Applier) ApplyDMLEventQueries(dmlEvents [](*binlog.BinlogDMLEvent))
}
result, err := tx.Exec(buildResult.query, buildResult.args...)
if err != nil {
err = fmt.Errorf("%s; query=%s; args=%+v", err.Error(), buildResult.query, buildResult.args)
err = fmt.Errorf("%w; query=%s; args=%+v", err, buildResult.query, buildResult.args)
return rollback(err)
}

Expand Down
11 changes: 5 additions & 6 deletions go/logic/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package logic

import (
"fmt"
"io"
"os"
"os/exec"
"path/filepath"
Expand Down Expand Up @@ -34,18 +35,16 @@ const (

type HooksExecutor struct {
migrationContext *base.MigrationContext
writer io.Writer
}

func NewHooksExecutor(migrationContext *base.MigrationContext) *HooksExecutor {
return &HooksExecutor{
migrationContext: migrationContext,
writer: os.Stderr,
}
}

func (this *HooksExecutor) initHooks() error {
return nil
}

func (this *HooksExecutor) applyEnvironmentVariables(extraVariables ...string) []string {
env := os.Environ()
env = append(env, fmt.Sprintf("GH_OST_DATABASE_NAME=%s", this.migrationContext.DatabaseName))
Expand Down Expand Up @@ -76,13 +75,13 @@ func (this *HooksExecutor) applyEnvironmentVariables(extraVariables ...string) [
}

// executeHook executes a command, and sets relevant environment variables
// combined output & error are printed to gh-ost's standard error.
// combined output & error are printed to the configured writer.
func (this *HooksExecutor) executeHook(hook string, extraVariables ...string) error {
cmd := exec.Command(hook)
cmd.Env = this.applyEnvironmentVariables(extraVariables...)

combinedOutput, err := cmd.CombinedOutput()
fmt.Fprintln(os.Stderr, string(combinedOutput))
fmt.Fprintln(this.writer, string(combinedOutput))
return log.Errore(err)
}

Expand Down
113 changes: 113 additions & 0 deletions go/logic/hooks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
Copyright 2022 GitHub Inc.
See https://github.com/github/gh-ost/blob/master/LICENSE
*/

package logic

import (
"bufio"
"bytes"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"testing"
"time"

"github.com/openark/golib/tests"

"github.com/github/gh-ost/go/base"
)

func TestHooksExecutorExecuteHooks(t *testing.T) {
migrationContext := base.NewMigrationContext()
migrationContext.AlterStatement = "ENGINE=InnoDB"
migrationContext.DatabaseName = "test"
migrationContext.Hostname = "test.example.com"
migrationContext.OriginalTableName = "tablename"
migrationContext.RowsDeltaEstimate = 1
migrationContext.RowsEstimate = 122
migrationContext.TotalRowsCopied = 123456
migrationContext.SetETADuration(time.Minute)
migrationContext.SetProgressPct(50)
hooksExecutor := NewHooksExecutor(migrationContext)

writeTmpHookFunc := func(testName, hookName, script string) (path string, err error) {
if path, err = os.MkdirTemp("", testName); err != nil {
return path, err
}
err = os.WriteFile(filepath.Join(path, hookName), []byte(script), 0777)
return path, err
}

t.Run("does-not-exist", func(t *testing.T) {
migrationContext.HooksPath = "/does/not/exist"
tests.S(t).ExpectNil(hooksExecutor.executeHooks("test-hook"))
})

t.Run("failed", func(t *testing.T) {
var err error
if migrationContext.HooksPath, err = writeTmpHookFunc(
"TestHooksExecutorExecuteHooks-failed",
"failed-hook",
"#!/bin/sh\nexit 1",
); err != nil {
panic(err)
}
defer os.RemoveAll(migrationContext.HooksPath)
tests.S(t).ExpectNotNil(hooksExecutor.executeHooks("failed-hook"))
})

t.Run("success", func(t *testing.T) {
var err error
if migrationContext.HooksPath, err = writeTmpHookFunc(
"TestHooksExecutorExecuteHooks-success",
"success-hook",
"#!/bin/sh\nenv",
); err != nil {
panic(err)
}
defer os.RemoveAll(migrationContext.HooksPath)

var buf bytes.Buffer
hooksExecutor.writer = &buf
tests.S(t).ExpectNil(hooksExecutor.executeHooks("success-hook", "TEST="+t.Name()))

scanner := bufio.NewScanner(&buf)
for scanner.Scan() {
split := strings.SplitN(scanner.Text(), "=", 2)
switch split[0] {
case "GH_OST_COPIED_ROWS":
copiedRows, _ := strconv.ParseInt(split[1], 10, 64)
tests.S(t).ExpectEquals(copiedRows, migrationContext.TotalRowsCopied)
case "GH_OST_DATABASE_NAME":
tests.S(t).ExpectEquals(split[1], migrationContext.DatabaseName)
case "GH_OST_DDL":
tests.S(t).ExpectEquals(split[1], migrationContext.AlterStatement)
case "GH_OST_DRY_RUN":
tests.S(t).ExpectEquals(split[1], "false")
case "GH_OST_ESTIMATED_ROWS":
estimatedRows, _ := strconv.ParseInt(split[1], 10, 64)
tests.S(t).ExpectEquals(estimatedRows, int64(123))
case "GH_OST_ETA_SECONDS":
etaSeconds, _ := strconv.ParseInt(split[1], 10, 64)
tests.S(t).ExpectEquals(etaSeconds, int64(60))
case "GH_OST_EXECUTING_HOST":
tests.S(t).ExpectEquals(split[1], migrationContext.Hostname)
case "GH_OST_GHOST_TABLE_NAME":
tests.S(t).ExpectEquals(split[1], fmt.Sprintf("_%s_gho", migrationContext.OriginalTableName))
case "GH_OST_OLD_TABLE_NAME":
tests.S(t).ExpectEquals(split[1], fmt.Sprintf("_%s_del", migrationContext.OriginalTableName))
case "GH_OST_PROGRESS":
progress, _ := strconv.ParseFloat(split[1], 64)
tests.S(t).ExpectEquals(progress, 50.0)
case "GH_OST_TABLE_NAME":
tests.S(t).ExpectEquals(split[1], migrationContext.OriginalTableName)
case "TEST":
tests.S(t).ExpectEquals(split[1], t.Name())
}
}
})
}
7 changes: 3 additions & 4 deletions go/logic/inspect.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package logic
import (
"context"
gosql "database/sql"
"errors"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -554,13 +555,11 @@ func (this *Inspector) CountTableRows(ctx context.Context) error {
query := fmt.Sprintf(`select /* gh-ost */ count(*) as count_rows from %s.%s`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName))
var rowsEstimate int64
if err := conn.QueryRowContext(ctx, query).Scan(&rowsEstimate); err != nil {
switch err {
case context.Canceled, context.DeadlineExceeded:
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
this.migrationContext.Log.Infof("exact row count cancelled (%s), likely because I'm about to cut over. I'm going to kill that query.", ctx.Err())
return mysql.Kill(this.db, connectionID)
default:
return err
}
return err
}

// row count query finished. nil out the cancel func, so the main migration thread
Expand Down
23 changes: 5 additions & 18 deletions go/logic/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ type Migrator struct {
func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator {
migrator := &Migrator{
appVersion: appVersion,
hooksExecutor: NewHooksExecutor(context),
migrationContext: context,
parser: sql.NewAlterTableParser(),
ghostTableMigrated: make(chan bool),
Expand All @@ -113,15 +114,6 @@ func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator {
return migrator
}

// initiateHooksExecutor
func (this *Migrator) initiateHooksExecutor() (err error) {
this.hooksExecutor = NewHooksExecutor(this.migrationContext)
if err := this.hooksExecutor.initHooks(); err != nil {
return err
}
return nil
}

// sleepWhileTrue sleeps indefinitely until the given function returns 'false'
// (or fails with error)
func (this *Migrator) sleepWhileTrue(operation func() (bool, error)) error {
Expand Down Expand Up @@ -342,9 +334,6 @@ func (this *Migrator) Migrate() (err error) {

go this.listenOnPanicAbort()

if err := this.initiateHooksExecutor(); err != nil {
return err
}
if err := this.hooksExecutor.onStartup(); err != nil {
return err
}
Expand Down Expand Up @@ -402,9 +391,9 @@ func (this *Migrator) Migrate() (err error) {
if err := this.applier.ReadMigrationRangeValues(); err != nil {
return err
}
if err := this.initiateThrottler(); err != nil {
return err
}

this.initiateThrottler()

if err := this.hooksExecutor.onBeforeRowCopy(); err != nil {
return err
}
Expand Down Expand Up @@ -1107,7 +1096,7 @@ func (this *Migrator) addDMLEventsListener() error {
}

// initiateThrottler kicks in the throttling collection and the throttling checks.
func (this *Migrator) initiateThrottler() error {
func (this *Migrator) initiateThrottler() {
this.throttler = NewThrottler(this.migrationContext, this.applier, this.inspector, this.appVersion)

go this.throttler.initiateThrottlerCollection(this.firstThrottlingCollected)
Expand All @@ -1117,8 +1106,6 @@ func (this *Migrator) initiateThrottler() error {
<-this.firstThrottlingCollected // other, general metrics
this.migrationContext.Log.Infof("First throttle metrics collected")
go this.throttler.initiateThrottlerChecks()

return nil
}

func (this *Migrator) initiateApplier() error {
Expand Down
2 changes: 2 additions & 0 deletions go/logic/throttler.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,8 @@ func (this *Throttler) collectThrottleHTTPStatus(firstThrottlingCollected chan<-
if err != nil {
return false, err
}
defer resp.Body.Close()

atomic.StoreInt64(&this.migrationContext.ThrottleHTTPStatusCode, int64(resp.StatusCode))
return false, nil
}
Expand Down
10 changes: 4 additions & 6 deletions go/sql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func NewParserFromAlterStatement(alterStatement string) *AlterTableParser {
return parser
}

func (this *AlterTableParser) tokenizeAlterStatement(alterStatement string) (tokens []string, err error) {
func (this *AlterTableParser) tokenizeAlterStatement(alterStatement string) (tokens []string) {
terminatingQuote := rune(0)
f := func(c rune) bool {
switch {
Expand All @@ -86,7 +86,7 @@ func (this *AlterTableParser) tokenizeAlterStatement(alterStatement string) (tok
for i := range tokens {
tokens[i] = strings.TrimSpace(tokens[i])
}
return tokens, nil
return tokens
}

func (this *AlterTableParser) sanitizeQuotesFromAlterStatement(alterStatement string) (strippedStatement string) {
Expand All @@ -95,7 +95,7 @@ func (this *AlterTableParser) sanitizeQuotesFromAlterStatement(alterStatement st
return strippedStatement
}

func (this *AlterTableParser) parseAlterToken(alterToken string) (err error) {
func (this *AlterTableParser) parseAlterToken(alterToken string) {
{
// rename
allStringSubmatch := renameColumnRegexp.FindAllStringSubmatch(alterToken, -1)
Expand Down Expand Up @@ -131,7 +131,6 @@ func (this *AlterTableParser) parseAlterToken(alterToken string) (err error) {
this.isAutoIncrementDefined = true
}
}
return nil
}

func (this *AlterTableParser) ParseAlterStatement(alterStatement string) (err error) {
Expand All @@ -151,8 +150,7 @@ func (this *AlterTableParser) ParseAlterStatement(alterStatement string) (err er
break
}
}
alterTokens, _ := this.tokenizeAlterStatement(this.alterStatementOptions)
for _, alterToken := range alterTokens {
for _, alterToken := range this.tokenizeAlterStatement(this.alterStatementOptions) {
alterToken = this.sanitizeQuotesFromAlterStatement(alterToken)
this.parseAlterToken(alterToken)
this.alterTokens = append(this.alterTokens, alterToken)
Expand Down
Loading

0 comments on commit b087399

Please sign in to comment.