summaryrefslogtreecommitdiffstats
path: root/src/internal/singleflight
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/internal/singleflight/singleflight.go123
-rw-r--r--src/internal/singleflight/singleflight_test.go186
2 files changed, 309 insertions, 0 deletions
diff --git a/src/internal/singleflight/singleflight.go b/src/internal/singleflight/singleflight.go
new file mode 100644
index 0000000..d0e6d2f
--- /dev/null
+++ b/src/internal/singleflight/singleflight.go
@@ -0,0 +1,123 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package singleflight provides a duplicate function call suppression
+// mechanism.
+package singleflight
+
+import "sync"
+
+// call is an in-flight or completed singleflight.Do call
+type call struct {
+ wg sync.WaitGroup
+
+ // These fields are written once before the WaitGroup is done
+ // and are only read after the WaitGroup is done.
+ val any
+ err error
+
+ // These fields are read and written with the singleflight
+ // mutex held before the WaitGroup is done, and are read but
+ // not written after the WaitGroup is done.
+ dups int
+ chans []chan<- Result
+}
+
+// Group represents a class of work and forms a namespace in
+// which units of work can be executed with duplicate suppression.
+type Group struct {
+ mu sync.Mutex // protects m
+ m map[string]*call // lazily initialized
+}
+
+// Result holds the results of Do, so they can be passed
+// on a channel.
+type Result struct {
+ Val any
+ Err error
+ Shared bool
+}
+
+// Do executes and returns the results of the given function, making
+// sure that only one execution is in-flight for a given key at a
+// time. If a duplicate comes in, the duplicate caller waits for the
+// original to complete and receives the same results.
+// The return value shared indicates whether v was given to multiple callers.
+func (g *Group) Do(key string, fn func() (any, error)) (v any, err error, shared bool) {
+ g.mu.Lock()
+ if g.m == nil {
+ g.m = make(map[string]*call)
+ }
+ if c, ok := g.m[key]; ok {
+ c.dups++
+ g.mu.Unlock()
+ c.wg.Wait()
+ return c.val, c.err, true
+ }
+ c := new(call)
+ c.wg.Add(1)
+ g.m[key] = c
+ g.mu.Unlock()
+
+ g.doCall(c, key, fn)
+ return c.val, c.err, c.dups > 0
+}
+
+// DoChan is like Do but returns a channel that will receive the
+// results when they are ready.
+func (g *Group) DoChan(key string, fn func() (any, error)) <-chan Result {
+ ch := make(chan Result, 1)
+ g.mu.Lock()
+ if g.m == nil {
+ g.m = make(map[string]*call)
+ }
+ if c, ok := g.m[key]; ok {
+ c.dups++
+ c.chans = append(c.chans, ch)
+ g.mu.Unlock()
+ return ch
+ }
+ c := &call{chans: []chan<- Result{ch}}
+ c.wg.Add(1)
+ g.m[key] = c
+ g.mu.Unlock()
+
+ go g.doCall(c, key, fn)
+
+ return ch
+}
+
+// doCall handles the single call for a key.
+func (g *Group) doCall(c *call, key string, fn func() (any, error)) {
+ c.val, c.err = fn()
+
+ g.mu.Lock()
+ c.wg.Done()
+ if g.m[key] == c {
+ delete(g.m, key)
+ }
+ for _, ch := range c.chans {
+ ch <- Result{c.val, c.err, c.dups > 0}
+ }
+ g.mu.Unlock()
+}
+
+// ForgetUnshared tells the singleflight to forget about a key if it is not
+// shared with any other goroutines. Future calls to Do for a forgotten key
+// will call the function rather than waiting for an earlier call to complete.
+// Returns whether the key was forgotten or unknown--that is, whether no
+// other goroutines are waiting for the result.
+func (g *Group) ForgetUnshared(key string) bool {
+ g.mu.Lock()
+ defer g.mu.Unlock()
+ c, ok := g.m[key]
+ if !ok {
+ return true
+ }
+ if c.dups == 0 {
+ delete(g.m, key)
+ return true
+ }
+ return false
+}
diff --git a/src/internal/singleflight/singleflight_test.go b/src/internal/singleflight/singleflight_test.go
new file mode 100644
index 0000000..279e1be
--- /dev/null
+++ b/src/internal/singleflight/singleflight_test.go
@@ -0,0 +1,186 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package singleflight
+
+import (
+ "errors"
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func TestDo(t *testing.T) {
+ var g Group
+ v, err, _ := g.Do("key", func() (any, error) {
+ return "bar", nil
+ })
+ if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
+ t.Errorf("Do = %v; want %v", got, want)
+ }
+ if err != nil {
+ t.Errorf("Do error = %v", err)
+ }
+}
+
+func TestDoErr(t *testing.T) {
+ var g Group
+ someErr := errors.New("some error")
+ v, err, _ := g.Do("key", func() (any, error) {
+ return nil, someErr
+ })
+ if err != someErr {
+ t.Errorf("Do error = %v; want someErr %v", err, someErr)
+ }
+ if v != nil {
+ t.Errorf("unexpected non-nil value %#v", v)
+ }
+}
+
+func TestDoDupSuppress(t *testing.T) {
+ var g Group
+ var wg1, wg2 sync.WaitGroup
+ c := make(chan string, 1)
+ var calls atomic.Int32
+ fn := func() (any, error) {
+ if calls.Add(1) == 1 {
+ // First invocation.
+ wg1.Done()
+ }
+ v := <-c
+ c <- v // pump; make available for any future calls
+
+ time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
+
+ return v, nil
+ }
+
+ const n = 10
+ wg1.Add(1)
+ for i := 0; i < n; i++ {
+ wg1.Add(1)
+ wg2.Add(1)
+ go func() {
+ defer wg2.Done()
+ wg1.Done()
+ v, err, _ := g.Do("key", fn)
+ if err != nil {
+ t.Errorf("Do error: %v", err)
+ return
+ }
+ if s, _ := v.(string); s != "bar" {
+ t.Errorf("Do = %T %v; want %q", v, v, "bar")
+ }
+ }()
+ }
+ wg1.Wait()
+ // At least one goroutine is in fn now and all of them have at
+ // least reached the line before the Do.
+ c <- "bar"
+ wg2.Wait()
+ if got := calls.Load(); got <= 0 || got >= n {
+ t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
+ }
+}
+
+func TestForgetUnshared(t *testing.T) {
+ var g Group
+
+ var firstStarted, firstFinished sync.WaitGroup
+
+ firstStarted.Add(1)
+ firstFinished.Add(1)
+
+ key := "key"
+ firstCh := make(chan struct{})
+ go func() {
+ g.Do(key, func() (i interface{}, e error) {
+ firstStarted.Done()
+ <-firstCh
+ return
+ })
+ firstFinished.Done()
+ }()
+
+ firstStarted.Wait()
+ g.ForgetUnshared(key) // from this point no two function using same key should be executed concurrently
+
+ secondCh := make(chan struct{})
+ go func() {
+ g.Do(key, func() (i interface{}, e error) {
+ // Notify that we started
+ secondCh <- struct{}{}
+ <-secondCh
+ return 2, nil
+ })
+ }()
+
+ <-secondCh
+
+ resultCh := g.DoChan(key, func() (i interface{}, e error) {
+ panic("third must not be started")
+ })
+
+ if g.ForgetUnshared(key) {
+ t.Errorf("Before first goroutine finished, key %q is shared, should return false", key)
+ }
+
+ close(firstCh)
+ firstFinished.Wait()
+
+ if g.ForgetUnshared(key) {
+ t.Errorf("After first goroutine finished, key %q is still shared, should return false", key)
+ }
+
+ secondCh <- struct{}{}
+
+ if result := <-resultCh; result.Val != 2 {
+ t.Errorf("We should receive result produced by second call, expected: 2, got %d", result.Val)
+ }
+}
+
+func TestDoAndForgetUnsharedRace(t *testing.T) {
+ t.Parallel()
+
+ var g Group
+ key := "key"
+ d := time.Millisecond
+ for {
+ var calls, shared atomic.Int64
+ const n = 1000
+ var wg sync.WaitGroup
+ wg.Add(n)
+ for i := 0; i < n; i++ {
+ go func() {
+ g.Do(key, func() (interface{}, error) {
+ time.Sleep(d)
+ return calls.Add(1), nil
+ })
+ if !g.ForgetUnshared(key) {
+ shared.Add(1)
+ }
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+
+ if calls.Load() != 1 {
+ // The goroutines didn't park in g.Do in time,
+ // so the key was re-added and may have been shared after the call.
+ // Try again with more time to park.
+ d *= 2
+ continue
+ }
+
+ // All of the Do calls ended up sharing the first
+ // invocation, so the key should have been unused
+ // (and therefore unshared) when they returned.
+ if shared.Load() > 0 {
+ t.Errorf("after a single shared Do, ForgetUnshared returned false %d times", shared.Load())
+ }
+ break
+ }
+}