From b9dcbe5fc1caf082c0981ffb1b52238ef782f532 Mon Sep 17 00:00:00 2001 From: Jeff Hodges Date: Wed, 30 Aug 2023 01:24:56 -0700 Subject: [PATCH] x/sync/singleflight/v2: add package This adds a new package, x/sync/singleflight/v2, which is a version of x/sync/singleflight that uses generics instead of `string`, and `interface{}` in its API. The biggest changes are to the `Do` methods, and the `Result` type. The `Do` methods now take a `comparable` generic type instead of `string` for its inputs (matching that of `map` and similar). The output type of the `Do` method and the `Val` field in `Result` are now types that user can specify instead `interface{}`. Along the way, some tests received modifications to remove some now-unneeded `fmt.Sprintf` calls or add an `empty` struct for nil return tests. Also, `ExampleGroup` also received some additions to make it clear that non-`string` input types are acceptable. This is following a similar pattern as discussed with the `math/rand/v2` project. There is, however, one difference in affordances between packages in the stdlib and outside of the stdlib that we try to accomadate here. Stdlib packages like `math/rand/v2` can rely on the Go compiler version being the one our package is released with and, for packages outside of the stdlib, the errors can sometimes be cryptic when the package needs a more modern Go compile than a user attempted to use. For instance, `singleflight/v2` has a build tag specifying the need for Go 1.18 or later to be compiled in its necessary code files. When an older compiler is used to build it, the user will get an error starting with "build constraints exclude all Go files in". This makes sense when you read all of the files, but a user may be using it as a transitive dependency and won't know whether the build tags aren't matching because their OS is unsupported or their CPU architecture or what. To ameliorate user confusion, an extra `notgo18.go` file has been added with build tags saying that it's only built if the Go compiler isn't version 1.18 or later. And that file has a compile error in it that will error like so: $ go1.17.2 build . # golang.org/x/sync/singleflight/v2 ./notgo18.go:16:25: undefined: REQUIRES_GO_1_18_OR_LATER That error message is an attempt to be more clear to users about what needs to change in their build. Another alternative to that would be to change x/sync's `go.mod` to require Go 1.18 or greater. However, the rest of the x/sync packages don't require Go 1.18, and that seems like too large of a breaking change for singleflight/v2 alone. --- singleflight/v2/notgo18.go | 16 ++ singleflight/v2/singleflight.go | 208 ++++++++++++++ singleflight/v2/singleflight_test.go | 403 +++++++++++++++++++++++++++ 3 files changed, 627 insertions(+) create mode 100644 singleflight/v2/notgo18.go create mode 100644 singleflight/v2/singleflight.go create mode 100644 singleflight/v2/singleflight_test.go diff --git a/singleflight/v2/notgo18.go b/singleflight/v2/notgo18.go new file mode 100644 index 0000000..3b218ee --- /dev/null +++ b/singleflight/v2/notgo18.go @@ -0,0 +1,16 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build !go1.18 +// +build !go1.18 + +package singleflight // import "golang.org/x/sync/singleflight/v2" + +// singleflight/v2 requires Go 1.18 or later for generics support. To avoid a +// confusing "build constraints exclude all Go files in" compile error on Go +// 1.17 and earlier, we add this file and the below code. The code will fail to +// compile on Go 1.17 or earlier and should help folks understand what they need +// to do. + +const versionRequired = REQUIRES_GO_1_18_OR_LATER diff --git a/singleflight/v2/singleflight.go b/singleflight/v2/singleflight.go new file mode 100644 index 0000000..0443413 --- /dev/null +++ b/singleflight/v2/singleflight.go @@ -0,0 +1,208 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.18 +// +build go1.18 + +// Package singleflight provides a duplicate function call suppression +// mechanism. +package singleflight // import "golang.org/x/sync/singleflight/v2" + +import ( + "bytes" + "errors" + "fmt" + "runtime" + "runtime/debug" + "sync" +) + +// errGoexit indicates the runtime.Goexit was called in +// the user given function. +var errGoexit = errors.New("runtime.Goexit was called") + +// A panicError is an arbitrary value recovered from a panic +// with the stack trace during the execution of given function. +type panicError struct { + value any + stack []byte +} + +// Error implements error interface. +func (p *panicError) Error() string { + return fmt.Sprintf("%v\n\n%s", p.value, p.stack) +} + +func newPanicError(v any) error { + stack := debug.Stack() + + // The first line of the stack trace is of the form "goroutine N [status]:" + // but by the time the panic reaches Do the goroutine may no longer exist + // and its status will have changed. Trim out the misleading line. + if line := bytes.IndexByte(stack[:], '\n'); line >= 0 { + stack = stack[line+1:] + } + return &panicError{value: v, stack: stack} +} + +// call is an in-flight or completed singleflight.Do call +type call[V any] struct { + wg sync.WaitGroup + + // These fields are written once before the WaitGroup is done + // and are only read after the WaitGroup is done. + val V + err error + + // These fields are read and written with the singleflight + // mutex held before the WaitGroup is done, and are read but + // not written after the WaitGroup is done. + dups int + chans []chan<- Result[V] +} + +// Group represents a class of work and forms a namespace in +// which units of work can be executed with duplicate suppression. +type Group[K comparable, V any] struct { + mu sync.Mutex // protects m + m map[K]*call[V] // lazily initialized +} + +// Result holds the results of Do, so they can be passed +// on a channel. +type Result[V any] struct { + Val V + Err error + Shared bool +} + +// Do executes and returns the results of the given function, making +// sure that only one execution is in-flight for a given key at a +// time. If a duplicate comes in, the duplicate caller waits for the +// original to complete and receives the same results. +// The return value shared indicates whether v was given to multiple callers. +func (g *Group[K, V]) Do(key K, fn func() (V, error)) (V, error, bool) { + g.mu.Lock() + if g.m == nil { + g.m = make(map[K]*call[V]) + } + if c, ok := g.m[key]; ok { + c.dups++ + g.mu.Unlock() + c.wg.Wait() + + if e, ok := c.err.(*panicError); ok { + panic(e) + } else if c.err == errGoexit { + runtime.Goexit() + } + return c.val, c.err, true + } + c := new(call[V]) + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + g.doCall(c, key, fn) + return c.val, c.err, c.dups > 0 +} + +// DoChan is like Do but returns a channel that will receive the +// results when they are ready. +// +// The returned channel will not be closed. +func (g *Group[K, V]) DoChan(key K, fn func() (V, error)) <-chan Result[V] { + ch := make(chan Result[V], 1) + g.mu.Lock() + if g.m == nil { + g.m = make(map[K]*call[V]) + } + if c, ok := g.m[key]; ok { + c.dups++ + c.chans = append(c.chans, ch) + g.mu.Unlock() + return ch + } + c := &call[V]{chans: []chan<- Result[V]{ch}} + c.wg.Add(1) + g.m[key] = c + g.mu.Unlock() + + go g.doCall(c, key, fn) + + return ch +} + +// doCall handles the single call for a key. +func (g *Group[K, V]) doCall(c *call[V], key K, fn func() (V, error)) { + normalReturn := false + recovered := false + + // use double-defer to distinguish panic from runtime.Goexit, + // more details see https://golang.org/cl/134395 + defer func() { + // the given function invoked runtime.Goexit + if !normalReturn && !recovered { + c.err = errGoexit + } + + g.mu.Lock() + defer g.mu.Unlock() + c.wg.Done() + if g.m[key] == c { + delete(g.m, key) + } + + if e, ok := c.err.(*panicError); ok { + // In order to prevent the waiting channels from being blocked forever, + // needs to ensure that this panic cannot be recovered. + if len(c.chans) > 0 { + go panic(e) + select {} // Keep this goroutine around so that it will appear in the crash dump. + } else { + panic(e) + } + } else if c.err == errGoexit { + // Already in the process of goexit, no need to call again + } else { + // Normal return + for _, ch := range c.chans { + ch <- Result[V]{c.val, c.err, c.dups > 0} + } + } + }() + + func() { + defer func() { + if !normalReturn { + // Ideally, we would wait to take a stack trace until we've determined + // whether this is a panic or a runtime.Goexit. + // + // Unfortunately, the only way we can distinguish the two is to see + // whether the recover stopped the goroutine from terminating, and by + // the time we know that, the part of the stack trace relevant to the + // panic has been discarded. + if r := recover(); r != nil { + c.err = newPanicError(r) + } + } + }() + + c.val, c.err = fn() + normalReturn = true + }() + + if !normalReturn { + recovered = true + } +} + +// Forget tells the singleflight to forget about a key. Future calls +// to Do for this key will call the function rather than waiting for +// an earlier call to complete. +func (g *Group[K, V]) Forget(key K) { + g.mu.Lock() + delete(g.m, key) + g.mu.Unlock() +} diff --git a/singleflight/v2/singleflight_test.go b/singleflight/v2/singleflight_test.go new file mode 100644 index 0000000..df7bf44 --- /dev/null +++ b/singleflight/v2/singleflight_test.go @@ -0,0 +1,403 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build go1.18 +// +build go1.18 + +package singleflight + +import ( + "bytes" + "errors" + "fmt" + "os" + "os/exec" + "runtime" + "runtime/debug" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestDo(t *testing.T) { + var g Group[string, string] + got, err, _ := g.Do("key", func() (string, error) { + return "bar", nil + }) + want := "bar" + if got != want { + t.Errorf("Do = %#v; want %#v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +func TestDoWithNonStringKey(t *testing.T) { + type fancyKey struct { + x int + y string + } + var g Group[fancyKey, bool] + key := fancyKey{x: 1, y: "foo"} + got, err, _ := g.Do(key, func() (bool, error) { + return true, nil + }) + + want := true + if got != want { + t.Errorf("Do = %v; want %v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } + +} + +func TestDoErr(t *testing.T) { + type empty struct{} + var g Group[string, *empty] + someErr := errors.New("Some error") + v, err, _ := g.Do("key", func() (*empty, error) { + return nil, someErr + }) + if err != someErr { + t.Errorf("Do error = %v; want someErr %v", err, someErr) + } + if v != nil { + t.Errorf("unexpected non-nil value %#v", v) + } +} + +func TestDoDupSuppress(t *testing.T) { + var g Group[string, string] + var wg1, wg2 sync.WaitGroup + c := make(chan string, 1) + var calls int32 + fn := func() (string, error) { + if atomic.AddInt32(&calls, 1) == 1 { + // First invocation. + wg1.Done() + } + v := <-c + c <- v // pump; make available for any future calls + + time.Sleep(10 * time.Millisecond) // let more goroutines enter Do + + return v, nil + } + + const n = 10 + wg1.Add(1) + for i := 0; i < n; i++ { + wg1.Add(1) + wg2.Add(1) + go func() { + defer wg2.Done() + wg1.Done() + v, err, _ := g.Do("key", fn) + if err != nil { + t.Errorf("Do error: %v", err) + return + } + if v != "bar" { + t.Errorf("Do = %T %v; want %q", v, v, "bar") + } + }() + } + wg1.Wait() + // At least one goroutine is in fn now and all of them have at + // least reached the line before the Do. + c <- "bar" + wg2.Wait() + if got := atomic.LoadInt32(&calls); got <= 0 || got >= n { + t.Errorf("number of calls = %d; want over 0 and less than %d", got, n) + } +} + +// Test that singleflight behaves correctly after Forget called. +// See https://github.com/golang/go/issues/31420 +func TestForget(t *testing.T) { + var g Group[string, int] + + var ( + firstStarted = make(chan struct{}) + unblockFirst = make(chan struct{}) + firstFinished = make(chan struct{}) + ) + + go func() { + g.Do("key", func() (i int, e error) { + close(firstStarted) + <-unblockFirst + close(firstFinished) + return + }) + }() + <-firstStarted + g.Forget("key") + + unblockSecond := make(chan struct{}) + secondResult := g.DoChan("key", func() (i int, e error) { + <-unblockSecond + return 2, nil + }) + + close(unblockFirst) + <-firstFinished + + thirdResult := g.DoChan("key", func() (i int, e error) { + return 3, nil + }) + + close(unblockSecond) + <-secondResult + r := <-thirdResult + if r.Val != 2 { + t.Errorf("We should receive result produced by second call, expected: 2, got %d", r.Val) + } +} + +func TestDoChan(t *testing.T) { + var g Group[string, string] + ch := g.DoChan("key", func() (string, error) { + return "bar", nil + }) + + res := <-ch + got := res.Val + err := res.Err + want := "bar" + if got != want { + t.Errorf("Do = %#v; want %#v", got, want) + } + if err != nil { + t.Errorf("Do error = %v", err) + } +} + +// Test singleflight behaves correctly after Do panic. +// See https://github.com/golang/go/issues/41133 +func TestPanicDo(t *testing.T) { + var g Group[string, string] + fn := func() (string, error) { + panic("invalid memory address or nil pointer dereference") + } + + const n = 5 + waited := int32(n) + panicCount := int32(0) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + defer func() { + if err := recover(); err != nil { + t.Logf("Got panic: %v\n%s", err, debug.Stack()) + atomic.AddInt32(&panicCount, 1) + } + + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + + g.Do("key", fn) + }() + } + + select { + case <-done: + if panicCount != n { + t.Errorf("Expect %d panic, but got %d", n, panicCount) + } + case <-time.After(time.Second): + t.Fatalf("Do hangs") + } +} + +func TestGoexitDo(t *testing.T) { + type empty struct{} + var g Group[string, *empty] + fn := func() (*empty, error) { + runtime.Goexit() + return nil, nil + } + + const n = 5 + waited := int32(n) + done := make(chan struct{}) + for i := 0; i < n; i++ { + go func() { + var err error + defer func() { + if err != nil { + t.Errorf("Error should be nil, but got: %v", err) + } + if atomic.AddInt32(&waited, -1) == 0 { + close(done) + } + }() + _, err, _ = g.Do("key", fn) + }() + } + + select { + case <-done: + case <-time.After(time.Second): + t.Fatalf("Do hangs") + } +} + +func executable(t testing.TB) string { + exe, err := os.Executable() + if err != nil { + t.Skipf("skipping: test executable not found") + } + + // Control case: check whether exec.Command works at all. + // (For example, it might fail with a permission error on iOS.) + cmd := exec.Command(exe, "-test.list=^$") + cmd.Env = []string{} + if err := cmd.Run(); err != nil { + t.Skipf("skipping: exec appears not to work on %s: %v", runtime.GOOS, err) + } + + return exe +} + +func TestPanicDoChan(t *testing.T) { + if os.Getenv("TEST_PANIC_DOCHAN") != "" { + defer func() { + recover() + }() + + g := new(Group[string, string]) + ch := g.DoChan("", func() (string, error) { + panic("Panicking in DoChan") + }) + <-ch + t.Fatalf("DoChan unexpectedly returned") + } + + t.Parallel() + + cmd := exec.Command(executable(t), "-test.run="+t.Name(), "-test.v") + cmd.Env = append(os.Environ(), "TEST_PANIC_DOCHAN=1") + out := new(bytes.Buffer) + cmd.Stdout = out + cmd.Stderr = out + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + err := cmd.Wait() + t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out) + if err == nil { + t.Errorf("Test subprocess passed; want a crash due to panic in DoChan") + } + if bytes.Contains(out.Bytes(), []byte("DoChan unexpectedly")) { + t.Errorf("Test subprocess failed with an unexpected failure mode.") + } + if !bytes.Contains(out.Bytes(), []byte("Panicking in DoChan")) { + t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in DoChan") + } +} + +func TestPanicDoSharedByDoChan(t *testing.T) { + if os.Getenv("TEST_PANIC_DOCHAN") != "" { + blocked := make(chan struct{}) + unblock := make(chan struct{}) + + g := new(Group[string, string]) + go func() { + defer func() { + recover() + }() + g.Do("", func() (string, error) { + close(blocked) + <-unblock + panic("Panicking in Do") + }) + }() + + <-blocked + ch := g.DoChan("", func() (string, error) { + panic("DoChan unexpectedly executed callback") + }) + close(unblock) + <-ch + t.Fatalf("DoChan unexpectedly returned") + } + + t.Parallel() + + cmd := exec.Command(executable(t), "-test.run="+t.Name(), "-test.v") + cmd.Env = append(os.Environ(), "TEST_PANIC_DOCHAN=1") + out := new(bytes.Buffer) + cmd.Stdout = out + cmd.Stderr = out + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + err := cmd.Wait() + t.Logf("%s:\n%s", strings.Join(cmd.Args, " "), out) + if err == nil { + t.Errorf("Test subprocess passed; want a crash due to panic in Do shared by DoChan") + } + if bytes.Contains(out.Bytes(), []byte("DoChan unexpectedly")) { + t.Errorf("Test subprocess failed with an unexpected failure mode.") + } + if !bytes.Contains(out.Bytes(), []byte("Panicking in Do")) { + t.Errorf("Test subprocess failed, but the crash isn't caused by panicking in Do") + } +} + +func ExampleGroup() { + g := new(Group[string, string]) + + block := make(chan struct{}) + res1c := g.DoChan("key", func() (string, error) { + <-block + return "func 1", nil + }) + res2c := g.DoChan("key", func() (string, error) { + <-block + return "func 2", nil + }) + close(block) + + res1 := <-res1c + res2 := <-res2c + + // Results are shared by functions executed with duplicate keys. + fmt.Println("Shared:", res2.Shared) + // Only the first function is executed: it is registered and started with "key", + // and doesn't complete before the second funtion is registered with a duplicate key. + fmt.Println("Equal results:", res1.Val == res2.Val) + fmt.Println("Result:", res1.Val) + + // Any comparable type may be used as the key. + type fancyKey struct { + x int + y string + } + + g2 := &Group[fancyKey, int]{} + res, err, _ := g2.Do(fancyKey{1, "neat"}, func() (int, error) { + return 39, nil + }) + + fmt.Println("Comparable key result:", res) + fmt.Println("Comparable key error:", err) + + // Output: + // Shared: true + // Equal results: true + // Result: func 1 + // Comparable key result: 39 + // Comparable key error: +}