From fa16bdcde928eeae663671d1e8795f33db305475 Mon Sep 17 00:00:00 2001 From: qnnn <1543393961@qq.com> Date: Sun, 21 Jan 2024 03:47:26 +0800 Subject: [PATCH] Fix concurrency unsafety error of rand.Intn. --- runner/calldata.go | 28 +++++++++++++++++++++------- runner/calldata_test.go | 11 +++++++++++ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/runner/calldata.go b/runner/calldata.go index 07b236ce..10e1181c 100644 --- a/runner/calldata.go +++ b/runner/calldata.go @@ -7,6 +7,7 @@ import ( htmlTemplate "html/template" "math/rand" "strings" + "sync" "text/template" "text/template/parse" "time" @@ -19,8 +20,11 @@ import ( const charset = "abcdefghijklmnopqrstuvwxyz" + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" -var seededRand *rand.Rand = rand.New( - rand.NewSource(time.Now().UnixNano())) +var seededRandPool = sync.Pool{ + New: func() interface{} { + return rand.New(rand.NewSource(time.Now().UnixNano())) + }, +} var sprigFuncMap htmlTemplate.FuncMap = sprig.FuncMap() @@ -178,7 +182,10 @@ func (td *CallData) ExecuteData(data string) ([]byte, error) { if len(data) > 0 { input := []byte(data) tpl, err := td.execute(data) - if err == nil && tpl != nil { + if err != nil { + return nil, err + } + if tpl != nil { input = tpl.Bytes() } @@ -227,15 +234,21 @@ const minLen = 2 func stringWithCharset(length int, charset string) string { b := make([]byte, length) + rng := seededRandPool.Get().(*rand.Rand) + defer seededRandPool.Put(rng) for i := range b { - b[i] = charset[seededRand.Intn(len(charset))] + b[i] = charset[rng.Intn(len(charset))] } return string(b) } func randomString(length int) string { if length <= 0 { - length = seededRand.Intn(maxLen-minLen+1) + minLen + func() { + rng := seededRandPool.Get().(*rand.Rand) + defer seededRandPool.Put(rng) + length = rng.Intn(maxLen-minLen+1) + minLen + }() } return stringWithCharset(length, charset) @@ -249,6 +262,7 @@ func randomInt(min, max int) int { if max <= 0 { max = 1 } - - return seededRand.Intn(max-min) + min + rng := seededRandPool.Get().(*rand.Rand) + defer seededRandPool.Put(rng) + return rng.Intn(max-min) + min } diff --git a/runner/calldata_test.go b/runner/calldata_test.go index b5404d8b..cd0effda 100644 --- a/runner/calldata_test.go +++ b/runner/calldata_test.go @@ -332,3 +332,14 @@ func TestCallTemplateData_ExecuteFuncs(t *testing.T) { assert.Equal(t, `{"trace_id":"ABCABCABC"}`, string(r)) }) } + +func BenchmarkCallData_randomString(b *testing.B) { + b.N = 100000000 + b.SetParallelism(1024) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = randomString(10) + } + }) + b.Logf("pass") +}