From 2de14eec0e4c093eba73e1d5b7b7b66850ec48d3 Mon Sep 17 00:00:00 2001 From: Aleksandr Nogikh Date: Mon, 1 Apr 2024 17:10:16 +0200 Subject: [PATCH] pkg/fuzzer: use MAB to distinguish between exec fuzz and exec gen Let's try to use a plain delta-epsylon MAB for this purpose. To better track its effect, also calculate moving averages of the "new max signal" / "execution time" ratios for exec fuzz and exec gen. --- pkg/fuzzer/fuzzer.go | 80 +++++++++++++++++++++++++++---------- pkg/fuzzer/fuzzer_test.go | 2 + pkg/fuzzer/stats.go | 2 + pkg/ipc/ipc.go | 8 ++++ pkg/learning/mab.go | 65 ++++++++++++++++++++++++++++++ pkg/learning/mab_test.go | 66 ++++++++++++++++++++++++++++++ pkg/learning/window.go | 67 +++++++++++++++++++++++++++++++ pkg/learning/window_test.go | 35 ++++++++++++++++ pkg/rpctype/rpctype.go | 5 ++- syz-fuzzer/fuzzer.go | 8 +++- syz-fuzzer/proc.go | 12 +++--- syz-manager/http.go | 2 +- syz-manager/rpc.go | 9 ++++- 13 files changed, 329 insertions(+), 32 deletions(-) create mode 100644 pkg/learning/mab.go create mode 100644 pkg/learning/mab_test.go create mode 100644 pkg/learning/window.go create mode 100644 pkg/learning/window_test.go diff --git a/pkg/fuzzer/fuzzer.go b/pkg/fuzzer/fuzzer.go index df6fbd598f47..ff86637a35cb 100644 --- a/pkg/fuzzer/fuzzer.go +++ b/pkg/fuzzer/fuzzer.go @@ -14,6 +14,7 @@ import ( "github.com/google/syzkaller/pkg/corpus" "github.com/google/syzkaller/pkg/ipc" + "github.com/google/syzkaller/pkg/learning" "github.com/google/syzkaller/pkg/rpctype" "github.com/google/syzkaller/pkg/signal" "github.com/google/syzkaller/prog" @@ -34,6 +35,12 @@ type Fuzzer struct { ctMu sync.Mutex // TODO: use RWLock. ctRegenerate chan struct{} + // Use a MAB to determine the right distribution of + // exec fuzz and exec gen. + genFuzzMAB *learning.PlainMAB[string] + genSignalSpeed *learning.RunningRatioAverage[float64] + fuzzSignalSpeed *learning.RunningRatioAverage[float64] + nextExec *priorityQueue[*Request] nextJobID atomic.Int64 @@ -43,6 +50,12 @@ type Fuzzer struct { func NewFuzzer(ctx context.Context, cfg *Config, rnd *rand.Rand, target *prog.Target) *Fuzzer { + genFuzzMAB := &learning.PlainMAB[string]{ + ExplorationRate: 0.02, + MinLearningRate: 0.001, + } + genFuzzMAB.AddArms(statFuzz, statGenerate) + f := &Fuzzer{ Config: cfg, Cover: &Cover{}, @@ -54,7 +67,10 @@ func NewFuzzer(ctx context.Context, cfg *Config, rnd *rand.Rand, // We're okay to lose some of the messages -- if we are already // regenerating the table, we don't want to repeat it right away. - ctRegenerate: make(chan struct{}), + ctRegenerate: make(chan struct{}), + genFuzzMAB: genFuzzMAB, + genSignalSpeed: learning.NewRunningRatioAverage[float64](20000), + fuzzSignalSpeed: learning.NewRunningRatioAverage[float64](20000), nextExec: makePriorityQueue[*Request](), } @@ -91,22 +107,26 @@ type Request struct { flags ProgTypes stat string resultC chan *Result + + genFuzzAction *learning.Action[string] } type Result struct { - Info *ipc.ProgInfo - Stop bool + Info *ipc.ProgInfo + Stop bool + ElapsedSec float64 } func (fuzzer *Fuzzer) Done(req *Request, res *Result) { // Triage individual calls. // We do it before unblocking the waiting threads because // it may result it concurrent modification of req.Prog. + var newSignal int if req.NeedSignal != rpctype.NoSignal && res.Info != nil { for call, info := range res.Info.Calls { - fuzzer.triageProgCall(req.Prog, &info, call, req.flags) + newSignal += fuzzer.triageProgCall(req.Prog, &info, call, req.flags) } - fuzzer.triageProgCall(req.Prog, &res.Info.Extra, -1, req.flags) + newSignal += fuzzer.triageProgCall(req.Prog, &res.Info.Extra, -1, req.flags) } // Unblock threads that wait for the result. if req.resultC != nil { @@ -116,20 +136,38 @@ func (fuzzer *Fuzzer) Done(req *Request, res *Result) { fuzzer.mu.Lock() fuzzer.stats[req.stat]++ fuzzer.mu.Unlock() + // Update the MAB(s). + reward := 0.0 + if res.ElapsedSec > 0 { + // Similarly to the "SyzVegas: Beating Kernel Fuzzing Odds with Reinforcement Learning" + // paper, let's use the ratio of "new max signal" to "execution time". + // Unlike the paper, let's take the raw value of it instead of its ratio to the average one. + reward = float64(newSignal) / res.ElapsedSec + if req.stat == statGenerate { + fuzzer.genSignalSpeed.Save(float64(newSignal), res.ElapsedSec) + } else if req.stat == statFuzz { + fuzzer.fuzzSignalSpeed.Save(float64(newSignal), res.ElapsedSec) + } + } + if req.genFuzzAction != nil { + fuzzer.mu.Lock() + fuzzer.genFuzzMAB.SaveReward(*req.genFuzzAction, reward) + fuzzer.mu.Unlock() + } } func (fuzzer *Fuzzer) triageProgCall(p *prog.Prog, info *ipc.CallInfo, call int, - flags ProgTypes) { + flags ProgTypes) int { prio := signalPrio(p, info, call) newMaxSignal := fuzzer.Cover.addRawMaxSignal(info.Signal, prio) if newMaxSignal.Empty() { - return + return 0 } if flags&progInTriage > 0 { // We are already triaging this exact prog. // All newly found coverage is flaky. fuzzer.Logf(2, "found new flaky signal in call %d in %s", call, p) - return + return newMaxSignal.Len() } fuzzer.Logf(2, "found new signal in call %d in %s", call, p) fuzzer.startJob(&triageJob{ @@ -140,6 +178,7 @@ func (fuzzer *Fuzzer) triageProgCall(p *prog.Prog, info *ipc.CallInfo, call int, flags: flags, jobPriority: triageJobPrio(flags), }) + return newMaxSignal.Len() } func signalPrio(p *prog.Prog, info *ipc.CallInfo, call int) (prio uint8) { @@ -184,21 +223,20 @@ func (fuzzer *Fuzzer) nextInput() *Request { } } - // Either generate a new input or mutate an existing one. - mutateRate := 0.95 - if !fuzzer.Config.Coverage { - // If we don't have real coverage signal, generate programs - // more frequently because fallback signal is weak. - mutateRate = 0.5 - } rnd := fuzzer.rand() - if rnd.Float64() < mutateRate { - req := mutateProgRequest(fuzzer, rnd) - if req != nil { - return req - } + fuzzer.mu.Lock() + action := fuzzer.genFuzzMAB.Action(rnd) + fuzzer.mu.Unlock() + + var req *Request + if action.Arm == statFuzz { + req = mutateProgRequest(fuzzer, rnd) } - return genProgRequest(fuzzer, rnd) + if req == nil { + req = genProgRequest(fuzzer, rnd) + } + req.genFuzzAction = &action + return req } func (fuzzer *Fuzzer) startJob(newJob job) { diff --git a/pkg/fuzzer/fuzzer_test.go b/pkg/fuzzer/fuzzer_test.go index 5c09201097be..4d9e0d541992 100644 --- a/pkg/fuzzer/fuzzer_test.go +++ b/pkg/fuzzer/fuzzer_test.go @@ -85,6 +85,8 @@ func TestFuzz(t *testing.T) { t.Logf("%s", p.Serialize()) } + t.Logf("stats: %+v", fuzzer.Stats().Named) + assert.Equal(t, len(tf.expectedCrashes), len(tf.crashes), "not all expected crashes were found") } diff --git a/pkg/fuzzer/stats.go b/pkg/fuzzer/stats.go index 044febc64712..223336422231 100644 --- a/pkg/fuzzer/stats.go +++ b/pkg/fuzzer/stats.go @@ -42,5 +42,7 @@ func (fuzzer *Fuzzer) Stats() Stats { for k, v := range fuzzer.stats { ret.Named[k] = v } + ret.Named["exec gen, sig/sec*1000"] = uint64(fuzzer.genSignalSpeed.Load() * 1000) + ret.Named["exec fuzz, sig/sec*1000"] = uint64(fuzzer.fuzzSignalSpeed.Load() * 1000) return ret } diff --git a/pkg/ipc/ipc.go b/pkg/ipc/ipc.go index c90e56caf936..0c8dd48336c3 100644 --- a/pkg/ipc/ipc.go +++ b/pkg/ipc/ipc.go @@ -253,6 +253,12 @@ var rateLimit = time.NewTicker(1 * time.Second) // hanged: program hanged and was killed // err0: failed to start the process or bug in executor itself. func (env *Env) Exec(opts *ExecOpts, p *prog.Prog) (output []byte, info *ProgInfo, hanged bool, err0 error) { + output, info, hanged, _, err0 = env.ExecWithElapsed(opts, p) + return +} + +func (env *Env) ExecWithElapsed(opts *ExecOpts, p *prog.Prog) (output []byte, + info *ProgInfo, hanged bool, elapsed time.Duration, err0 error) { // Copy-in serialized program. progSize, err := p.SerializeForExec(env.in) if err != nil { @@ -275,7 +281,9 @@ func (env *Env) Exec(opts *ExecOpts, p *prog.Prog) (output []byte, info *ProgInf return } + start := time.Now() output, hanged, err0 = env.cmd.exec(opts, progData) + elapsed = time.Since(start) if err0 != nil { env.cmd.close() env.cmd = nil diff --git a/pkg/learning/mab.go b/pkg/learning/mab.go new file mode 100644 index 000000000000..1a54d28fb517 --- /dev/null +++ b/pkg/learning/mab.go @@ -0,0 +1,65 @@ +// Copyright 2024 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package learning + +import ( + "math/rand" +) + +type Action[T comparable] struct { + Arm T + index int +} + +type countedValue struct { + value float64 + count int64 +} + +func (cv *countedValue) update(value, minStep float64) { + // Using larger steps at the beginning allows us to + // converge faster to the actual value. + // The minStep limit ensures that we can still track + // non-stationary problems. + cv.count++ + step := 1.0 / float64(cv.count) + if step < minStep { + step = minStep + } + cv.value += (value - cv.value) * step +} + +// PlainMAB is a very simple epsylon-greedy MAB implementation. +// It's not thread-safe. +type PlainMAB[T comparable] struct { + MinLearningRate float64 + ExplorationRate float64 + arms []T + weights []countedValue +} + +func (p *PlainMAB[T]) AddArms(arms ...T) { + for _, arm := range arms { + p.arms = append(p.arms, arm) + p.weights = append(p.weights, countedValue{0, 0}) + } +} + +func (p *PlainMAB[T]) Action(r *rand.Rand) Action[T] { + var pos int + if r.Float64() < p.ExplorationRate { + pos = r.Intn(len(p.arms)) + } else { + for i := 1; i < len(p.arms); i++ { + if p.weights[i].value > p.weights[pos].value { + pos = i + } + } + } + return Action[T]{Arm: p.arms[pos], index: pos} +} + +func (p *PlainMAB[T]) SaveReward(action Action[T], reward float64) { + p.weights[action.index].update(reward, p.MinLearningRate) +} diff --git a/pkg/learning/mab_test.go b/pkg/learning/mab_test.go new file mode 100644 index 000000000000..2f93e6c2fa5e --- /dev/null +++ b/pkg/learning/mab_test.go @@ -0,0 +1,66 @@ +// Copyright 2024 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package learning + +import ( + "math/rand" + "testing" + + "github.com/google/syzkaller/pkg/testutil" + "github.com/stretchr/testify/assert" +) + +func TestMABSmallDiff(t *testing.T) { + r := rand.New(testutil.RandSource(t)) + bandit := &PlainMAB[int]{ + MinLearningRate: 0.0001, + ExplorationRate: 0.1, + } + arms := []float64{0.65, 0.7} + for i := range arms { + bandit.AddArms(i) + } + const steps = 40000 + counts := runMAB(r, bandit, arms, steps) + t.Logf("counts: %v", counts) + assert.Greater(t, counts[1], steps/4*3) +} + +func TestNonStationaryMAB(t *testing.T) { + r := rand.New(testutil.RandSource(t)) + bandit := &PlainMAB[int]{ + MinLearningRate: 0.02, + ExplorationRate: 0.04, + } + + arms := []float64{0.2, 0.7, 0.5, 0.1} + for i := range arms { + bandit.AddArms(i) + } + + const steps = 25000 + counts := runMAB(r, bandit, arms, steps) + t.Logf("initially: %v", counts) + + // Ensure that we've found the best arm. + assert.Greater(t, counts[1], steps/2) + + // Now change the best arm's avg reward. + arms[3] = 0.9 + counts = runMAB(r, bandit, arms, steps) + t.Logf("after reward change: %v", counts) + assert.Greater(t, counts[3], steps/2) +} + +func runMAB(r *rand.Rand, bandit *PlainMAB[int], arms []float64, steps int) []int { + counts := make([]int, len(arms)) + for i := 0; i < steps; i++ { + action := bandit.Action(r) + // TODO: use normal distribution? + reward := r.Float64() * arms[action.Arm] + counts[action.Arm]++ + bandit.SaveReward(action, reward) + } + return counts +} diff --git a/pkg/learning/window.go b/pkg/learning/window.go new file mode 100644 index 000000000000..2cbfe5d3bfec --- /dev/null +++ b/pkg/learning/window.go @@ -0,0 +1,67 @@ +// Copyright 2024 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package learning + +import "sync" + +type Number interface { + int | int64 | float64 +} + +type RunningAverage[T Number] struct { + window []T + mu sync.RWMutex + pos int + total T +} + +func NewRunningAverage[T Number](size int) *RunningAverage[T] { + return &RunningAverage[T]{ + window: make([]T, size), + } +} + +func (ra *RunningAverage[T]) SaveInt(val int) { + ra.Save(T(val)) +} + +func (ra *RunningAverage[T]) Save(val T) { + ra.mu.Lock() + defer ra.mu.Unlock() + prev := ra.window[ra.pos] + ra.window[ra.pos] = val + ra.total += val - prev + ra.pos = (ra.pos + 1) % len(ra.window) +} + +func (ra *RunningAverage[T]) Load() T { + ra.mu.RLock() + defer ra.mu.RUnlock() + return ra.total +} + +type RunningRatioAverage[T Number] struct { + values *RunningAverage[T] + divideBy *RunningAverage[T] +} + +func NewRunningRatioAverage[T Number](size int) *RunningRatioAverage[T] { + return &RunningRatioAverage[T]{ + values: NewRunningAverage[T](size), + divideBy: NewRunningAverage[T](size), + } +} + +func (rra *RunningRatioAverage[T]) Save(nomDelta, denomDelta T) { + rra.values.Save(nomDelta) + rra.divideBy.Save(denomDelta) +} + +func (rra *RunningRatioAverage[T]) Load() float64 { + denom := rra.divideBy.Load() + if denom == 0 { + return 0 + } + return float64(rra.values.Load()) / float64(denom) +} diff --git a/pkg/learning/window_test.go b/pkg/learning/window_test.go new file mode 100644 index 000000000000..b01dd300c607 --- /dev/null +++ b/pkg/learning/window_test.go @@ -0,0 +1,35 @@ +// Copyright 2024 syzkaller project authors. All rights reserved. +// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. + +package learning + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRunningRatioAverage(t *testing.T) { + ra := NewRunningRatioAverage[float64](3) + for i := 0; i < 4; i++ { + ra.Save(2.0, 1.0) + } + assert.InDelta(t, 2.0, ra.Load(), 0.1) + for i := 0; i < 4; i++ { + ra.Save(3.0, 2.0) + } + assert.InDelta(t, 1.5, ra.Load(), 0.1) +} + +func TestRunningAverage(t *testing.T) { + ra := NewRunningAverage[int](3) + assert.Equal(t, 0, ra.Load()) + ra.Save(1) + assert.Equal(t, 1, ra.Load()) + ra.Save(2) + assert.Equal(t, 3, ra.Load()) + for i := 4; i <= 10; i++ { + ra.SaveInt(i) + } + assert.Equal(t, 8+9+10, ra.Load()) +} diff --git a/pkg/rpctype/rpctype.go b/pkg/rpctype/rpctype.go index b31f9ec967d7..a086386808fc 100644 --- a/pkg/rpctype/rpctype.go +++ b/pkg/rpctype/rpctype.go @@ -36,8 +36,9 @@ type ExecutionRequest struct { // ExecutionResult is sent after ExecutionRequest is completed. type ExecutionResult struct { - ID int64 - Info ipc.ProgInfo + ID int64 + Info ipc.ProgInfo + Elapsed time.Duration // execution time } // ExchangeInfoRequest is periodically sent by syz-fuzzer to syz-manager. diff --git a/syz-fuzzer/fuzzer.go b/syz-fuzzer/fuzzer.go index fb7f57b30fd2..1cfc62161afb 100644 --- a/syz-fuzzer/fuzzer.go +++ b/syz-fuzzer/fuzzer.go @@ -57,7 +57,8 @@ type FuzzerTool struct { // to the communication thread. type executionResult struct { rpctype.ExecutionRequest - info *ipc.ProgInfo + info *ipc.ProgInfo + elapsed time.Duration } // executionRequest offloads prog deseralization to another thread. @@ -391,7 +392,10 @@ func (tool *FuzzerTool) exchangeDataWorker() { } func (tool *FuzzerTool) convertExecutionResult(res executionResult) rpctype.ExecutionResult { - ret := rpctype.ExecutionResult{ID: res.ID} + ret := rpctype.ExecutionResult{ + ID: res.ID, + Elapsed: res.elapsed, + } if res.info != nil { if res.NeedSignal == rpctype.NewSignal { tool.diffMaxSignal(res.info) diff --git a/syz-fuzzer/proc.go b/syz-fuzzer/proc.go index 76e8ec43538d..b4b57e147aba 100644 --- a/syz-fuzzer/proc.go +++ b/syz-fuzzer/proc.go @@ -64,12 +64,13 @@ func (proc *Proc) loop() { (req.NeedCover || req.NeedSignal != rpctype.NoSignal || req.NeedHints) { proc.env.ForceRestart() } - info := proc.executeRaw(&opts, req.prog) + info, elapsed := proc.executeRaw(&opts, req.prog) // Let's perform signal filtering in a separate thread to get the most // exec/sec out of a syz-executor instance. proc.tool.results <- executionResult{ ExecutionRequest: req.ExecutionRequest, info: info, + elapsed: elapsed, } } } @@ -86,11 +87,12 @@ func (proc *Proc) nextRequest() executionRequest { return <-proc.tool.inputs } -func (proc *Proc) executeRaw(opts *ipc.ExecOpts, p *prog.Prog) *ipc.ProgInfo { +func (proc *Proc) executeRaw(opts *ipc.ExecOpts, p *prog.Prog) (*ipc.ProgInfo, time.Duration) { for try := 0; ; try++ { var output []byte var info *ipc.ProgInfo var hanged bool + var elapsed time.Duration // On a heavily loaded VM, syz-executor may take significant time to start. // Let's do it outside of the gate ticket. err := proc.env.RestartIfNeeded(p.Target) @@ -98,7 +100,7 @@ func (proc *Proc) executeRaw(opts *ipc.ExecOpts, p *prog.Prog) *ipc.ProgInfo { // Limit concurrency. ticket := proc.tool.gate.Enter() proc.logProgram(opts, p) - output, info, hanged, err = proc.env.Exec(opts, p) + output, info, hanged, elapsed, err = proc.env.ExecWithElapsed(opts, p) proc.tool.gate.Leave(ticket) } if err != nil { @@ -107,7 +109,7 @@ func (proc *Proc) executeRaw(opts *ipc.ExecOpts, p *prog.Prog) *ipc.ProgInfo { // but so far we don't have a better handling than counting this. // This error is observed a lot on the seeded syz_mount_image calls. proc.tool.bufferTooSmall.Add(1) - return nil + return nil, elapsed } if try > 10 { log.SyzFatalf("executor %v failed %v times: %v\n%s", proc.pid, try, err, output) @@ -118,7 +120,7 @@ func (proc *Proc) executeRaw(opts *ipc.ExecOpts, p *prog.Prog) *ipc.ProgInfo { continue } log.Logf(2, "result hanged=%v: %s", hanged, output) - return info + return info, elapsed } } diff --git a/syz-manager/http.go b/syz-manager/http.go index 22af491fde57..d20c8fefba30 100644 --- a/syz-manager/http.go +++ b/syz-manager/http.go @@ -176,7 +176,7 @@ func (mgr *Manager) collectStats() []UIStat { for k, v := range rawStats { val := "" switch { - case k == "fuzzer jobs" || strings.HasPrefix(k, "rpc exchange"): + case k == "fuzzer jobs" || strings.HasPrefix(k, "rpc exchange") || strings.Contains(k, "/sec"): val = fmt.Sprint(v) default: val = rateStat(v, secs) diff --git a/syz-manager/rpc.go b/syz-manager/rpc.go index 0c94a5793de2..dd6f4ec9fdd5 100644 --- a/syz-manager/rpc.go +++ b/syz-manager/rpc.go @@ -325,7 +325,14 @@ func (runner *Runner) doneRequest(resp rpctype.ExecutionResult, fuzzerObj *fuzze } info.Extra.Cover = runner.instModules.Canonicalize(info.Extra.Cover) info.Extra.Signal = runner.instModules.Canonicalize(info.Extra.Signal) - fuzzerObj.Done(req, &fuzzer.Result{Info: info}) + // The fuzzer may mess with time, so let's be cauctious about the elapsed time values. + res := &fuzzer.Result{ + Info: info, + } + if seconds := resp.Elapsed.Seconds(); seconds > 0 && seconds < 10 { + res.ElapsedSec = seconds + } + fuzzerObj.Done(req, res) } func (runner *Runner) newRequest(req *fuzzer.Request) rpctype.ExecutionRequest {