diff options
Diffstat (limited to 'src/context')
-rw-r--r-- | src/context/afterfunc_test.go | 141 | ||||
-rw-r--r-- | src/context/benchmark_test.go | 190 | ||||
-rw-r--r-- | src/context/context.go | 785 | ||||
-rw-r--r-- | src/context/context_test.go | 297 | ||||
-rw-r--r-- | src/context/example_test.go | 263 | ||||
-rw-r--r-- | src/context/net_test.go | 21 | ||||
-rw-r--r-- | src/context/x_test.go | 956 |
7 files changed, 2653 insertions, 0 deletions
diff --git a/src/context/afterfunc_test.go b/src/context/afterfunc_test.go new file mode 100644 index 0000000..71f639a --- /dev/null +++ b/src/context/afterfunc_test.go @@ -0,0 +1,141 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package context_test + +import ( + "context" + "sync" + "testing" + "time" +) + +// afterFuncContext is a context that's not one of the types +// defined in context.go, that supports registering AfterFuncs. +type afterFuncContext struct { + mu sync.Mutex + afterFuncs map[*struct{}]func() + done chan struct{} + err error +} + +func newAfterFuncContext() context.Context { + return &afterFuncContext{} +} + +func (c *afterFuncContext) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +func (c *afterFuncContext) Done() <-chan struct{} { + c.mu.Lock() + defer c.mu.Unlock() + if c.done == nil { + c.done = make(chan struct{}) + } + return c.done +} + +func (c *afterFuncContext) Err() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.err +} + +func (c *afterFuncContext) Value(key any) any { + return nil +} + +func (c *afterFuncContext) AfterFunc(f func()) func() bool { + c.mu.Lock() + defer c.mu.Unlock() + k := &struct{}{} + if c.afterFuncs == nil { + c.afterFuncs = make(map[*struct{}]func()) + } + c.afterFuncs[k] = f + return func() bool { + c.mu.Lock() + defer c.mu.Unlock() + _, ok := c.afterFuncs[k] + delete(c.afterFuncs, k) + return ok + } +} + +func (c *afterFuncContext) cancel(err error) { + c.mu.Lock() + defer c.mu.Unlock() + if c.err != nil { + return + } + c.err = err + for _, f := range c.afterFuncs { + go f() + } + c.afterFuncs = nil +} + +func TestCustomContextAfterFuncCancel(t *testing.T) { + ctx0 := &afterFuncContext{} + ctx1, cancel := context.WithCancel(ctx0) + defer cancel() + ctx0.cancel(context.Canceled) + <-ctx1.Done() +} + +func TestCustomContextAfterFuncTimeout(t *testing.T) { + ctx0 := &afterFuncContext{} + ctx1, cancel := context.WithTimeout(ctx0, veryLongDuration) + defer cancel() + ctx0.cancel(context.Canceled) + <-ctx1.Done() +} + +func TestCustomContextAfterFuncAfterFunc(t *testing.T) { + ctx0 := &afterFuncContext{} + donec := make(chan struct{}) + stop := context.AfterFunc(ctx0, func() { + close(donec) + }) + defer stop() + ctx0.cancel(context.Canceled) + <-donec +} + +func TestCustomContextAfterFuncUnregisterCancel(t *testing.T) { + ctx0 := &afterFuncContext{} + _, cancel := context.WithCancel(ctx0) + if got, want := len(ctx0.afterFuncs), 1; got != want { + t.Errorf("after WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want) + } + cancel() + if got, want := len(ctx0.afterFuncs), 0; got != want { + t.Errorf("after canceling WithCancel(ctx0): ctx0 has %v afterFuncs, want %v", got, want) + } +} + +func TestCustomContextAfterFuncUnregisterTimeout(t *testing.T) { + ctx0 := &afterFuncContext{} + _, cancel := context.WithTimeout(ctx0, veryLongDuration) + if got, want := len(ctx0.afterFuncs), 1; got != want { + t.Errorf("after WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want) + } + cancel() + if got, want := len(ctx0.afterFuncs), 0; got != want { + t.Errorf("after canceling WithTimeout(ctx0, d): ctx0 has %v afterFuncs, want %v", got, want) + } +} + +func TestCustomContextAfterFuncUnregisterAfterFunc(t *testing.T) { + ctx0 := &afterFuncContext{} + stop := context.AfterFunc(ctx0, func() {}) + if got, want := len(ctx0.afterFuncs), 1; got != want { + t.Errorf("after AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want) + } + stop() + if got, want := len(ctx0.afterFuncs), 0; got != want { + t.Errorf("after stopping AfterFunc(ctx0, f): ctx0 has %v afterFuncs, want %v", got, want) + } +} diff --git a/src/context/benchmark_test.go b/src/context/benchmark_test.go new file mode 100644 index 0000000..144f473 --- /dev/null +++ b/src/context/benchmark_test.go @@ -0,0 +1,190 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package context_test + +import ( + . "context" + "fmt" + "runtime" + "sync" + "testing" + "time" +) + +func BenchmarkCommonParentCancel(b *testing.B) { + root := WithValue(Background(), "key", "value") + shared, sharedcancel := WithCancel(root) + defer sharedcancel() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + x := 0 + for pb.Next() { + ctx, cancel := WithCancel(shared) + if ctx.Value("key").(string) != "value" { + b.Fatal("should not be reached") + } + for i := 0; i < 100; i++ { + x /= x + 1 + } + cancel() + for i := 0; i < 100; i++ { + x /= x + 1 + } + } + }) +} + +func BenchmarkWithTimeout(b *testing.B) { + for concurrency := 40; concurrency <= 4e5; concurrency *= 100 { + name := fmt.Sprintf("concurrency=%d", concurrency) + b.Run(name, func(b *testing.B) { + benchmarkWithTimeout(b, concurrency) + }) + } +} + +func benchmarkWithTimeout(b *testing.B, concurrentContexts int) { + gomaxprocs := runtime.GOMAXPROCS(0) + perPContexts := concurrentContexts / gomaxprocs + root := Background() + + // Generate concurrent contexts. + var wg sync.WaitGroup + ccf := make([][]CancelFunc, gomaxprocs) + for i := range ccf { + wg.Add(1) + go func(i int) { + defer wg.Done() + cf := make([]CancelFunc, perPContexts) + for j := range cf { + _, cf[j] = WithTimeout(root, time.Hour) + } + ccf[i] = cf + }(i) + } + wg.Wait() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + wcf := make([]CancelFunc, 10) + for pb.Next() { + for i := range wcf { + _, wcf[i] = WithTimeout(root, time.Hour) + } + for _, f := range wcf { + f() + } + } + }) + b.StopTimer() + + for _, cf := range ccf { + for _, f := range cf { + f() + } + } +} + +func BenchmarkCancelTree(b *testing.B) { + depths := []int{1, 10, 100, 1000} + for _, d := range depths { + b.Run(fmt.Sprintf("depth=%d", d), func(b *testing.B) { + b.Run("Root=Background", func(b *testing.B) { + for i := 0; i < b.N; i++ { + buildContextTree(Background(), d) + } + }) + b.Run("Root=OpenCanceler", func(b *testing.B) { + for i := 0; i < b.N; i++ { + ctx, cancel := WithCancel(Background()) + buildContextTree(ctx, d) + cancel() + } + }) + b.Run("Root=ClosedCanceler", func(b *testing.B) { + for i := 0; i < b.N; i++ { + ctx, cancel := WithCancel(Background()) + cancel() + buildContextTree(ctx, d) + } + }) + }) + } +} + +func buildContextTree(root Context, depth int) { + for d := 0; d < depth; d++ { + root, _ = WithCancel(root) + } +} + +func BenchmarkCheckCanceled(b *testing.B) { + ctx, cancel := WithCancel(Background()) + cancel() + b.Run("Err", func(b *testing.B) { + for i := 0; i < b.N; i++ { + ctx.Err() + } + }) + b.Run("Done", func(b *testing.B) { + for i := 0; i < b.N; i++ { + select { + case <-ctx.Done(): + default: + } + } + }) +} + +func BenchmarkContextCancelDone(b *testing.B) { + ctx, cancel := WithCancel(Background()) + defer cancel() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + select { + case <-ctx.Done(): + default: + } + } + }) +} + +func BenchmarkDeepValueNewGoRoutine(b *testing.B) { + for _, depth := range []int{10, 20, 30, 50, 100} { + ctx := Background() + for i := 0; i < depth; i++ { + ctx = WithValue(ctx, i, i) + } + + b.Run(fmt.Sprintf("depth=%d", depth), func(b *testing.B) { + for i := 0; i < b.N; i++ { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + ctx.Value(-1) + }() + wg.Wait() + } + }) + } +} + +func BenchmarkDeepValueSameGoRoutine(b *testing.B) { + for _, depth := range []int{10, 20, 30, 50, 100} { + ctx := Background() + for i := 0; i < depth; i++ { + ctx = WithValue(ctx, i, i) + } + + b.Run(fmt.Sprintf("depth=%d", depth), func(b *testing.B) { + for i := 0; i < b.N; i++ { + ctx.Value(-1) + } + }) + } +} diff --git a/src/context/context.go b/src/context/context.go new file mode 100644 index 0000000..ee66b43 --- /dev/null +++ b/src/context/context.go @@ -0,0 +1,785 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package context defines the Context type, which carries deadlines, +// cancellation signals, and other request-scoped values across API boundaries +// and between processes. +// +// Incoming requests to a server should create a [Context], and outgoing +// calls to servers should accept a Context. The chain of function +// calls between them must propagate the Context, optionally replacing +// it with a derived Context created using [WithCancel], [WithDeadline], +// [WithTimeout], or [WithValue]. When a Context is canceled, all +// Contexts derived from it are also canceled. +// +// The [WithCancel], [WithDeadline], and [WithTimeout] functions take a +// Context (the parent) and return a derived Context (the child) and a +// [CancelFunc]. Calling the CancelFunc cancels the child and its +// children, removes the parent's reference to the child, and stops +// any associated timers. Failing to call the CancelFunc leaks the +// child and its children until the parent is canceled or the timer +// fires. The go vet tool checks that CancelFuncs are used on all +// control-flow paths. +// +// The [WithCancelCause] function returns a [CancelCauseFunc], which +// takes an error and records it as the cancellation cause. Calling +// [Cause] on the canceled context or any of its children retrieves +// the cause. If no cause is specified, Cause(ctx) returns the same +// value as ctx.Err(). +// +// Programs that use Contexts should follow these rules to keep interfaces +// consistent across packages and enable static analysis tools to check context +// propagation: +// +// Do not store Contexts inside a struct type; instead, pass a Context +// explicitly to each function that needs it. The Context should be the first +// parameter, typically named ctx: +// +// func DoSomething(ctx context.Context, arg Arg) error { +// // ... use ctx ... +// } +// +// Do not pass a nil [Context], even if a function permits it. Pass [context.TODO] +// if you are unsure about which Context to use. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +// +// The same Context may be passed to functions running in different goroutines; +// Contexts are safe for simultaneous use by multiple goroutines. +// +// See https://blog.golang.org/context for example code for a server that uses +// Contexts. +package context + +import ( + "errors" + "internal/reflectlite" + "sync" + "sync/atomic" + "time" +) + +// A Context carries a deadline, a cancellation signal, and other values across +// API boundaries. +// +// Context's methods may be called by multiple goroutines simultaneously. +type Context interface { + // Deadline returns the time when work done on behalf of this context + // should be canceled. Deadline returns ok==false when no deadline is + // set. Successive calls to Deadline return the same results. + Deadline() (deadline time.Time, ok bool) + + // Done returns a channel that's closed when work done on behalf of this + // context should be canceled. Done may return nil if this context can + // never be canceled. Successive calls to Done return the same value. + // The close of the Done channel may happen asynchronously, + // after the cancel function returns. + // + // WithCancel arranges for Done to be closed when cancel is called; + // WithDeadline arranges for Done to be closed when the deadline + // expires; WithTimeout arranges for Done to be closed when the timeout + // elapses. + // + // Done is provided for use in select statements: + // + // // Stream generates values with DoSomething and sends them to out + // // until DoSomething returns an error or ctx.Done is closed. + // func Stream(ctx context.Context, out chan<- Value) error { + // for { + // v, err := DoSomething(ctx) + // if err != nil { + // return err + // } + // select { + // case <-ctx.Done(): + // return ctx.Err() + // case out <- v: + // } + // } + // } + // + // See https://blog.golang.org/pipelines for more examples of how to use + // a Done channel for cancellation. + Done() <-chan struct{} + + // If Done is not yet closed, Err returns nil. + // If Done is closed, Err returns a non-nil error explaining why: + // Canceled if the context was canceled + // or DeadlineExceeded if the context's deadline passed. + // After Err returns a non-nil error, successive calls to Err return the same error. + Err() error + + // Value returns the value associated with this context for key, or nil + // if no value is associated with key. Successive calls to Value with + // the same key returns the same result. + // + // Use context values only for request-scoped data that transits + // processes and API boundaries, not for passing optional parameters to + // functions. + // + // A key identifies a specific value in a Context. Functions that wish + // to store values in Context typically allocate a key in a global + // variable then use that key as the argument to context.WithValue and + // Context.Value. A key can be any type that supports equality; + // packages should define keys as an unexported type to avoid + // collisions. + // + // Packages that define a Context key should provide type-safe accessors + // for the values stored using that key: + // + // // Package user defines a User type that's stored in Contexts. + // package user + // + // import "context" + // + // // User is the type of value stored in the Contexts. + // type User struct {...} + // + // // key is an unexported type for keys defined in this package. + // // This prevents collisions with keys defined in other packages. + // type key int + // + // // userKey is the key for user.User values in Contexts. It is + // // unexported; clients use user.NewContext and user.FromContext + // // instead of using this key directly. + // var userKey key + // + // // NewContext returns a new Context that carries value u. + // func NewContext(ctx context.Context, u *User) context.Context { + // return context.WithValue(ctx, userKey, u) + // } + // + // // FromContext returns the User value stored in ctx, if any. + // func FromContext(ctx context.Context) (*User, bool) { + // u, ok := ctx.Value(userKey).(*User) + // return u, ok + // } + Value(key any) any +} + +// Canceled is the error returned by [Context.Err] when the context is canceled. +var Canceled = errors.New("context canceled") + +// DeadlineExceeded is the error returned by [Context.Err] when the context's +// deadline passes. +var DeadlineExceeded error = deadlineExceededError{} + +type deadlineExceededError struct{} + +func (deadlineExceededError) Error() string { return "context deadline exceeded" } +func (deadlineExceededError) Timeout() bool { return true } +func (deadlineExceededError) Temporary() bool { return true } + +// An emptyCtx is never canceled, has no values, and has no deadline. +// It is the common base of backgroundCtx and todoCtx. +type emptyCtx struct{} + +func (emptyCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +func (emptyCtx) Done() <-chan struct{} { + return nil +} + +func (emptyCtx) Err() error { + return nil +} + +func (emptyCtx) Value(key any) any { + return nil +} + +type backgroundCtx struct{ emptyCtx } + +func (backgroundCtx) String() string { + return "context.Background" +} + +type todoCtx struct{ emptyCtx } + +func (todoCtx) String() string { + return "context.TODO" +} + +// Background returns a non-nil, empty [Context]. It is never canceled, has no +// values, and has no deadline. It is typically used by the main function, +// initialization, and tests, and as the top-level Context for incoming +// requests. +func Background() Context { + return backgroundCtx{} +} + +// TODO returns a non-nil, empty [Context]. Code should use context.TODO when +// it's unclear which Context to use or it is not yet available (because the +// surrounding function has not yet been extended to accept a Context +// parameter). +func TODO() Context { + return todoCtx{} +} + +// A CancelFunc tells an operation to abandon its work. +// A CancelFunc does not wait for the work to stop. +// A CancelFunc may be called by multiple goroutines simultaneously. +// After the first call, subsequent calls to a CancelFunc do nothing. +type CancelFunc func() + +// WithCancel returns a copy of parent with a new Done channel. The returned +// context's Done channel is closed when the returned cancel function is called +// or when the parent context's Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { + c := withCancel(parent) + return c, func() { c.cancel(true, Canceled, nil) } +} + +// A CancelCauseFunc behaves like a [CancelFunc] but additionally sets the cancellation cause. +// This cause can be retrieved by calling [Cause] on the canceled Context or on +// any of its derived Contexts. +// +// If the context has already been canceled, CancelCauseFunc does not set the cause. +// For example, if childContext is derived from parentContext: +// - if parentContext is canceled with cause1 before childContext is canceled with cause2, +// then Cause(parentContext) == Cause(childContext) == cause1 +// - if childContext is canceled with cause2 before parentContext is canceled with cause1, +// then Cause(parentContext) == cause1 and Cause(childContext) == cause2 +type CancelCauseFunc func(cause error) + +// WithCancelCause behaves like [WithCancel] but returns a [CancelCauseFunc] instead of a [CancelFunc]. +// Calling cancel with a non-nil error (the "cause") records that error in ctx; +// it can then be retrieved using Cause(ctx). +// Calling cancel with nil sets the cause to Canceled. +// +// Example use: +// +// ctx, cancel := context.WithCancelCause(parent) +// cancel(myError) +// ctx.Err() // returns context.Canceled +// context.Cause(ctx) // returns myError +func WithCancelCause(parent Context) (ctx Context, cancel CancelCauseFunc) { + c := withCancel(parent) + return c, func(cause error) { c.cancel(true, Canceled, cause) } +} + +func withCancel(parent Context) *cancelCtx { + if parent == nil { + panic("cannot create context from nil parent") + } + c := &cancelCtx{} + c.propagateCancel(parent, c) + return c +} + +// Cause returns a non-nil error explaining why c was canceled. +// The first cancellation of c or one of its parents sets the cause. +// If that cancellation happened via a call to CancelCauseFunc(err), +// then [Cause] returns err. +// Otherwise Cause(c) returns the same value as c.Err(). +// Cause returns nil if c has not been canceled yet. +func Cause(c Context) error { + if cc, ok := c.Value(&cancelCtxKey).(*cancelCtx); ok { + cc.mu.Lock() + defer cc.mu.Unlock() + return cc.cause + } + return nil +} + +// AfterFunc arranges to call f in its own goroutine after ctx is done +// (cancelled or timed out). +// If ctx is already done, AfterFunc calls f immediately in its own goroutine. +// +// Multiple calls to AfterFunc on a context operate independently; +// one does not replace another. +// +// Calling the returned stop function stops the association of ctx with f. +// It returns true if the call stopped f from being run. +// If stop returns false, +// either the context is done and f has been started in its own goroutine; +// or f was already stopped. +// The stop function does not wait for f to complete before returning. +// If the caller needs to know whether f is completed, +// it must coordinate with f explicitly. +// +// If ctx has a "AfterFunc(func()) func() bool" method, +// AfterFunc will use it to schedule the call. +func AfterFunc(ctx Context, f func()) (stop func() bool) { + a := &afterFuncCtx{ + f: f, + } + a.cancelCtx.propagateCancel(ctx, a) + return func() bool { + stopped := false + a.once.Do(func() { + stopped = true + }) + if stopped { + a.cancel(true, Canceled, nil) + } + return stopped + } +} + +type afterFuncer interface { + AfterFunc(func()) func() bool +} + +type afterFuncCtx struct { + cancelCtx + once sync.Once // either starts running f or stops f from running + f func() +} + +func (a *afterFuncCtx) cancel(removeFromParent bool, err, cause error) { + a.cancelCtx.cancel(false, err, cause) + if removeFromParent { + removeChild(a.Context, a) + } + a.once.Do(func() { + go a.f() + }) +} + +// A stopCtx is used as the parent context of a cancelCtx when +// an AfterFunc has been registered with the parent. +// It holds the stop function used to unregister the AfterFunc. +type stopCtx struct { + Context + stop func() bool +} + +// goroutines counts the number of goroutines ever created; for testing. +var goroutines atomic.Int32 + +// &cancelCtxKey is the key that a cancelCtx returns itself for. +var cancelCtxKey int + +// parentCancelCtx returns the underlying *cancelCtx for parent. +// It does this by looking up parent.Value(&cancelCtxKey) to find +// the innermost enclosing *cancelCtx and then checking whether +// parent.Done() matches that *cancelCtx. (If not, the *cancelCtx +// has been wrapped in a custom implementation providing a +// different done channel, in which case we should not bypass it.) +func parentCancelCtx(parent Context) (*cancelCtx, bool) { + done := parent.Done() + if done == closedchan || done == nil { + return nil, false + } + p, ok := parent.Value(&cancelCtxKey).(*cancelCtx) + if !ok { + return nil, false + } + pdone, _ := p.done.Load().(chan struct{}) + if pdone != done { + return nil, false + } + return p, true +} + +// removeChild removes a context from its parent. +func removeChild(parent Context, child canceler) { + if s, ok := parent.(stopCtx); ok { + s.stop() + return + } + p, ok := parentCancelCtx(parent) + if !ok { + return + } + p.mu.Lock() + if p.children != nil { + delete(p.children, child) + } + p.mu.Unlock() +} + +// A canceler is a context type that can be canceled directly. The +// implementations are *cancelCtx and *timerCtx. +type canceler interface { + cancel(removeFromParent bool, err, cause error) + Done() <-chan struct{} +} + +// closedchan is a reusable closed channel. +var closedchan = make(chan struct{}) + +func init() { + close(closedchan) +} + +// A cancelCtx can be canceled. When canceled, it also cancels any children +// that implement canceler. +type cancelCtx struct { + Context + + mu sync.Mutex // protects following fields + done atomic.Value // of chan struct{}, created lazily, closed by first cancel call + children map[canceler]struct{} // set to nil by the first cancel call + err error // set to non-nil by the first cancel call + cause error // set to non-nil by the first cancel call +} + +func (c *cancelCtx) Value(key any) any { + if key == &cancelCtxKey { + return c + } + return value(c.Context, key) +} + +func (c *cancelCtx) Done() <-chan struct{} { + d := c.done.Load() + if d != nil { + return d.(chan struct{}) + } + c.mu.Lock() + defer c.mu.Unlock() + d = c.done.Load() + if d == nil { + d = make(chan struct{}) + c.done.Store(d) + } + return d.(chan struct{}) +} + +func (c *cancelCtx) Err() error { + c.mu.Lock() + err := c.err + c.mu.Unlock() + return err +} + +// propagateCancel arranges for child to be canceled when parent is. +// It sets the parent context of cancelCtx. +func (c *cancelCtx) propagateCancel(parent Context, child canceler) { + c.Context = parent + + done := parent.Done() + if done == nil { + return // parent is never canceled + } + + select { + case <-done: + // parent is already canceled + child.cancel(false, parent.Err(), Cause(parent)) + return + default: + } + + if p, ok := parentCancelCtx(parent); ok { + // parent is a *cancelCtx, or derives from one. + p.mu.Lock() + if p.err != nil { + // parent has already been canceled + child.cancel(false, p.err, p.cause) + } else { + if p.children == nil { + p.children = make(map[canceler]struct{}) + } + p.children[child] = struct{}{} + } + p.mu.Unlock() + return + } + + if a, ok := parent.(afterFuncer); ok { + // parent implements an AfterFunc method. + c.mu.Lock() + stop := a.AfterFunc(func() { + child.cancel(false, parent.Err(), Cause(parent)) + }) + c.Context = stopCtx{ + Context: parent, + stop: stop, + } + c.mu.Unlock() + return + } + + goroutines.Add(1) + go func() { + select { + case <-parent.Done(): + child.cancel(false, parent.Err(), Cause(parent)) + case <-child.Done(): + } + }() +} + +type stringer interface { + String() string +} + +func contextName(c Context) string { + if s, ok := c.(stringer); ok { + return s.String() + } + return reflectlite.TypeOf(c).String() +} + +func (c *cancelCtx) String() string { + return contextName(c.Context) + ".WithCancel" +} + +// cancel closes c.done, cancels each of c's children, and, if +// removeFromParent is true, removes c from its parent's children. +// cancel sets c.cause to cause if this is the first time c is canceled. +func (c *cancelCtx) cancel(removeFromParent bool, err, cause error) { + if err == nil { + panic("context: internal error: missing cancel error") + } + if cause == nil { + cause = err + } + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return // already canceled + } + c.err = err + c.cause = cause + d, _ := c.done.Load().(chan struct{}) + if d == nil { + c.done.Store(closedchan) + } else { + close(d) + } + for child := range c.children { + // NOTE: acquiring the child's lock while holding parent's lock. + child.cancel(false, err, cause) + } + c.children = nil + c.mu.Unlock() + + if removeFromParent { + removeChild(c.Context, c) + } +} + +// WithoutCancel returns a copy of parent that is not canceled when parent is canceled. +// The returned context returns no Deadline or Err, and its Done channel is nil. +// Calling [Cause] on the returned context returns nil. +func WithoutCancel(parent Context) Context { + if parent == nil { + panic("cannot create context from nil parent") + } + return withoutCancelCtx{parent} +} + +type withoutCancelCtx struct { + c Context +} + +func (withoutCancelCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +func (withoutCancelCtx) Done() <-chan struct{} { + return nil +} + +func (withoutCancelCtx) Err() error { + return nil +} + +func (c withoutCancelCtx) Value(key any) any { + return value(c, key) +} + +func (c withoutCancelCtx) String() string { + return contextName(c.c) + ".WithoutCancel" +} + +// WithDeadline returns a copy of the parent context with the deadline adjusted +// to be no later than d. If the parent's deadline is already earlier than d, +// WithDeadline(parent, d) is semantically equivalent to parent. The returned +// [Context.Done] channel is closed when the deadline expires, when the returned +// cancel function is called, or when the parent context's Done channel is +// closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this [Context] complete. +func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) { + return WithDeadlineCause(parent, d, nil) +} + +// WithDeadlineCause behaves like [WithDeadline] but also sets the cause of the +// returned Context when the deadline is exceeded. The returned [CancelFunc] does +// not set the cause. +func WithDeadlineCause(parent Context, d time.Time, cause error) (Context, CancelFunc) { + if parent == nil { + panic("cannot create context from nil parent") + } + if cur, ok := parent.Deadline(); ok && cur.Before(d) { + // The current deadline is already sooner than the new one. + return WithCancel(parent) + } + c := &timerCtx{ + deadline: d, + } + c.cancelCtx.propagateCancel(parent, c) + dur := time.Until(d) + if dur <= 0 { + c.cancel(true, DeadlineExceeded, cause) // deadline has already passed + return c, func() { c.cancel(false, Canceled, nil) } + } + c.mu.Lock() + defer c.mu.Unlock() + if c.err == nil { + c.timer = time.AfterFunc(dur, func() { + c.cancel(true, DeadlineExceeded, cause) + }) + } + return c, func() { c.cancel(true, Canceled, nil) } +} + +// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to +// implement Done and Err. It implements cancel by stopping its timer then +// delegating to cancelCtx.cancel. +type timerCtx struct { + cancelCtx + timer *time.Timer // Under cancelCtx.mu. + + deadline time.Time +} + +func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { + return c.deadline, true +} + +func (c *timerCtx) String() string { + return contextName(c.cancelCtx.Context) + ".WithDeadline(" + + c.deadline.String() + " [" + + time.Until(c.deadline).String() + "])" +} + +func (c *timerCtx) cancel(removeFromParent bool, err, cause error) { + c.cancelCtx.cancel(false, err, cause) + if removeFromParent { + // Remove this timerCtx from its parent cancelCtx's children. + removeChild(c.cancelCtx.Context, c) + } + c.mu.Lock() + if c.timer != nil { + c.timer.Stop() + c.timer = nil + } + c.mu.Unlock() +} + +// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this [Context] complete: +// +// func slowOperationWithTimeout(ctx context.Context) (Result, error) { +// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) +// defer cancel() // releases resources if slowOperation completes before timeout elapses +// return slowOperation(ctx) +// } +func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { + return WithDeadline(parent, time.Now().Add(timeout)) +} + +// WithTimeoutCause behaves like [WithTimeout] but also sets the cause of the +// returned Context when the timeout expires. The returned [CancelFunc] does +// not set the cause. +func WithTimeoutCause(parent Context, timeout time.Duration, cause error) (Context, CancelFunc) { + return WithDeadlineCause(parent, time.Now().Add(timeout), cause) +} + +// WithValue returns a copy of parent in which the value associated with key is +// val. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +// +// The provided key must be comparable and should not be of type +// string or any other built-in type to avoid collisions between +// packages using context. Users of WithValue should define their own +// types for keys. To avoid allocating when assigning to an +// interface{}, context keys often have concrete type +// struct{}. Alternatively, exported context key variables' static +// type should be a pointer or interface. +func WithValue(parent Context, key, val any) Context { + if parent == nil { + panic("cannot create context from nil parent") + } + if key == nil { + panic("nil key") + } + if !reflectlite.TypeOf(key).Comparable() { + panic("key is not comparable") + } + return &valueCtx{parent, key, val} +} + +// A valueCtx carries a key-value pair. It implements Value for that key and +// delegates all other calls to the embedded Context. +type valueCtx struct { + Context + key, val any +} + +// stringify tries a bit to stringify v, without using fmt, since we don't +// want context depending on the unicode tables. This is only used by +// *valueCtx.String(). +func stringify(v any) string { + switch s := v.(type) { + case stringer: + return s.String() + case string: + return s + } + return "<not Stringer>" +} + +func (c *valueCtx) String() string { + return contextName(c.Context) + ".WithValue(type " + + reflectlite.TypeOf(c.key).String() + + ", val " + stringify(c.val) + ")" +} + +func (c *valueCtx) Value(key any) any { + if c.key == key { + return c.val + } + return value(c.Context, key) +} + +func value(c Context, key any) any { + for { + switch ctx := c.(type) { + case *valueCtx: + if key == ctx.key { + return ctx.val + } + c = ctx.Context + case *cancelCtx: + if key == &cancelCtxKey { + return c + } + c = ctx.Context + case withoutCancelCtx: + if key == &cancelCtxKey { + // This implements Cause(ctx) == nil + // when ctx is created using WithoutCancel. + return nil + } + c = ctx.c + case *timerCtx: + if key == &cancelCtxKey { + return &ctx.cancelCtx + } + c = ctx.Context + case backgroundCtx, todoCtx: + return nil + default: + return c.Value(key) + } + } +} diff --git a/src/context/context_test.go b/src/context/context_test.go new file mode 100644 index 0000000..57066c9 --- /dev/null +++ b/src/context/context_test.go @@ -0,0 +1,297 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package context + +// Tests in package context cannot depend directly on package testing due to an import cycle. +// If your test does requires access to unexported members of the context package, +// add your test below as `func XTestFoo(t testingT)` and add a `TestFoo` to x_test.go +// that calls it. Otherwise, write a regular test in a test.go file in package context_test. + +import ( + "time" +) + +type testingT interface { + Deadline() (time.Time, bool) + Error(args ...any) + Errorf(format string, args ...any) + Fail() + FailNow() + Failed() bool + Fatal(args ...any) + Fatalf(format string, args ...any) + Helper() + Log(args ...any) + Logf(format string, args ...any) + Name() string + Parallel() + Skip(args ...any) + SkipNow() + Skipf(format string, args ...any) + Skipped() bool +} + +const veryLongDuration = 1000 * time.Hour // an arbitrary upper bound on the test's running time + +func contains(m map[canceler]struct{}, key canceler) bool { + _, ret := m[key] + return ret +} + +func XTestParentFinishesChild(t testingT) { + // Context tree: + // parent -> cancelChild + // parent -> valueChild -> timerChild + // parent -> afterChild + parent, cancel := WithCancel(Background()) + cancelChild, stop := WithCancel(parent) + defer stop() + valueChild := WithValue(parent, "key", "value") + timerChild, stop := WithTimeout(valueChild, veryLongDuration) + defer stop() + afterStop := AfterFunc(parent, func() {}) + defer afterStop() + + select { + case x := <-parent.Done(): + t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) + case x := <-cancelChild.Done(): + t.Errorf("<-cancelChild.Done() == %v want nothing (it should block)", x) + case x := <-timerChild.Done(): + t.Errorf("<-timerChild.Done() == %v want nothing (it should block)", x) + case x := <-valueChild.Done(): + t.Errorf("<-valueChild.Done() == %v want nothing (it should block)", x) + default: + } + + // The parent's children should contain the three cancelable children. + pc := parent.(*cancelCtx) + cc := cancelChild.(*cancelCtx) + tc := timerChild.(*timerCtx) + pc.mu.Lock() + var ac *afterFuncCtx + for c := range pc.children { + if a, ok := c.(*afterFuncCtx); ok { + ac = a + break + } + } + if len(pc.children) != 3 || !contains(pc.children, cc) || !contains(pc.children, tc) || ac == nil { + t.Errorf("bad linkage: pc.children = %v, want %v, %v, and an afterFunc", + pc.children, cc, tc) + } + pc.mu.Unlock() + + if p, ok := parentCancelCtx(cc.Context); !ok || p != pc { + t.Errorf("bad linkage: parentCancelCtx(cancelChild.Context) = %v, %v want %v, true", p, ok, pc) + } + if p, ok := parentCancelCtx(tc.Context); !ok || p != pc { + t.Errorf("bad linkage: parentCancelCtx(timerChild.Context) = %v, %v want %v, true", p, ok, pc) + } + if p, ok := parentCancelCtx(ac.Context); !ok || p != pc { + t.Errorf("bad linkage: parentCancelCtx(afterChild.Context) = %v, %v want %v, true", p, ok, pc) + } + + cancel() + + pc.mu.Lock() + if len(pc.children) != 0 { + t.Errorf("pc.cancel didn't clear pc.children = %v", pc.children) + } + pc.mu.Unlock() + + // parent and children should all be finished. + check := func(ctx Context, name string) { + select { + case <-ctx.Done(): + default: + t.Errorf("<-%s.Done() blocked, but shouldn't have", name) + } + if e := ctx.Err(); e != Canceled { + t.Errorf("%s.Err() == %v want %v", name, e, Canceled) + } + } + check(parent, "parent") + check(cancelChild, "cancelChild") + check(valueChild, "valueChild") + check(timerChild, "timerChild") + + // WithCancel should return a canceled context on a canceled parent. + precanceledChild := WithValue(parent, "key", "value") + select { + case <-precanceledChild.Done(): + default: + t.Errorf("<-precanceledChild.Done() blocked, but shouldn't have") + } + if e := precanceledChild.Err(); e != Canceled { + t.Errorf("precanceledChild.Err() == %v want %v", e, Canceled) + } +} + +func XTestChildFinishesFirst(t testingT) { + cancelable, stop := WithCancel(Background()) + defer stop() + for _, parent := range []Context{Background(), cancelable} { + child, cancel := WithCancel(parent) + + select { + case x := <-parent.Done(): + t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) + case x := <-child.Done(): + t.Errorf("<-child.Done() == %v want nothing (it should block)", x) + default: + } + + cc := child.(*cancelCtx) + pc, pcok := parent.(*cancelCtx) // pcok == false when parent == Background() + if p, ok := parentCancelCtx(cc.Context); ok != pcok || (ok && pc != p) { + t.Errorf("bad linkage: parentCancelCtx(cc.Context) = %v, %v want %v, %v", p, ok, pc, pcok) + } + + if pcok { + pc.mu.Lock() + if len(pc.children) != 1 || !contains(pc.children, cc) { + t.Errorf("bad linkage: pc.children = %v, cc = %v", pc.children, cc) + } + pc.mu.Unlock() + } + + cancel() + + if pcok { + pc.mu.Lock() + if len(pc.children) != 0 { + t.Errorf("child's cancel didn't remove self from pc.children = %v", pc.children) + } + pc.mu.Unlock() + } + + // child should be finished. + select { + case <-child.Done(): + default: + t.Errorf("<-child.Done() blocked, but shouldn't have") + } + if e := child.Err(); e != Canceled { + t.Errorf("child.Err() == %v want %v", e, Canceled) + } + + // parent should not be finished. + select { + case x := <-parent.Done(): + t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) + default: + } + if e := parent.Err(); e != nil { + t.Errorf("parent.Err() == %v want nil", e) + } + } +} + +func XTestCancelRemoves(t testingT) { + checkChildren := func(when string, ctx Context, want int) { + if got := len(ctx.(*cancelCtx).children); got != want { + t.Errorf("%s: context has %d children, want %d", when, got, want) + } + } + + ctx, _ := WithCancel(Background()) + checkChildren("after creation", ctx, 0) + _, cancel := WithCancel(ctx) + checkChildren("with WithCancel child ", ctx, 1) + cancel() + checkChildren("after canceling WithCancel child", ctx, 0) + + ctx, _ = WithCancel(Background()) + checkChildren("after creation", ctx, 0) + _, cancel = WithTimeout(ctx, 60*time.Minute) + checkChildren("with WithTimeout child ", ctx, 1) + cancel() + checkChildren("after canceling WithTimeout child", ctx, 0) + + ctx, _ = WithCancel(Background()) + checkChildren("after creation", ctx, 0) + stop := AfterFunc(ctx, func() {}) + checkChildren("with AfterFunc child ", ctx, 1) + stop() + checkChildren("after stopping AfterFunc child ", ctx, 0) +} + +type myCtx struct { + Context +} + +type myDoneCtx struct { + Context +} + +func (d *myDoneCtx) Done() <-chan struct{} { + c := make(chan struct{}) + return c +} +func XTestCustomContextGoroutines(t testingT) { + g := goroutines.Load() + checkNoGoroutine := func() { + t.Helper() + now := goroutines.Load() + if now != g { + t.Fatalf("%d goroutines created", now-g) + } + } + checkCreatedGoroutine := func() { + t.Helper() + now := goroutines.Load() + if now != g+1 { + t.Fatalf("%d goroutines created, want 1", now-g) + } + g = now + } + + _, cancel0 := WithCancel(&myDoneCtx{Background()}) + cancel0() + checkCreatedGoroutine() + + _, cancel0 = WithTimeout(&myDoneCtx{Background()}, veryLongDuration) + cancel0() + checkCreatedGoroutine() + + checkNoGoroutine() + defer checkNoGoroutine() + + ctx1, cancel1 := WithCancel(Background()) + defer cancel1() + checkNoGoroutine() + + ctx2 := &myCtx{ctx1} + ctx3, cancel3 := WithCancel(ctx2) + defer cancel3() + checkNoGoroutine() + + _, cancel3b := WithCancel(&myDoneCtx{ctx2}) + defer cancel3b() + checkCreatedGoroutine() // ctx1 is not providing Done, must not be used + + ctx4, cancel4 := WithTimeout(ctx3, veryLongDuration) + defer cancel4() + checkNoGoroutine() + + ctx5, cancel5 := WithCancel(ctx4) + defer cancel5() + checkNoGoroutine() + + cancel5() + checkNoGoroutine() + + _, cancel6 := WithTimeout(ctx5, veryLongDuration) + defer cancel6() + checkNoGoroutine() + + // Check applied to canceled context. + cancel6() + cancel1() + _, cancel7 := WithCancel(ctx5) + defer cancel7() + checkNoGoroutine() +} diff --git a/src/context/example_test.go b/src/context/example_test.go new file mode 100644 index 0000000..03333b5 --- /dev/null +++ b/src/context/example_test.go @@ -0,0 +1,263 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package context_test + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" +) + +var neverReady = make(chan struct{}) // never closed + +// This example demonstrates the use of a cancelable context to prevent a +// goroutine leak. By the end of the example function, the goroutine started +// by gen will return without leaking. +func ExampleWithCancel() { + // gen generates integers in a separate goroutine and + // sends them to the returned channel. + // The callers of gen need to cancel the context once + // they are done consuming generated integers not to leak + // the internal goroutine started by gen. + gen := func(ctx context.Context) <-chan int { + dst := make(chan int) + n := 1 + go func() { + for { + select { + case <-ctx.Done(): + return // returning not to leak the goroutine + case dst <- n: + n++ + } + } + }() + return dst + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // cancel when we are finished consuming integers + + for n := range gen(ctx) { + fmt.Println(n) + if n == 5 { + break + } + } + // Output: + // 1 + // 2 + // 3 + // 4 + // 5 +} + +// This example passes a context with an arbitrary deadline to tell a blocking +// function that it should abandon its work as soon as it gets to it. +func ExampleWithDeadline() { + d := time.Now().Add(shortDuration) + ctx, cancel := context.WithDeadline(context.Background(), d) + + // Even though ctx will be expired, it is good practice to call its + // cancellation function in any case. Failure to do so may keep the + // context and its parent alive longer than necessary. + defer cancel() + + select { + case <-neverReady: + fmt.Println("ready") + case <-ctx.Done(): + fmt.Println(ctx.Err()) + } + + // Output: + // context deadline exceeded +} + +// This example passes a context with a timeout to tell a blocking function that +// it should abandon its work after the timeout elapses. +func ExampleWithTimeout() { + // Pass a context with a timeout to tell a blocking function that it + // should abandon its work after the timeout elapses. + ctx, cancel := context.WithTimeout(context.Background(), shortDuration) + defer cancel() + + select { + case <-neverReady: + fmt.Println("ready") + case <-ctx.Done(): + fmt.Println(ctx.Err()) // prints "context deadline exceeded" + } + + // Output: + // context deadline exceeded +} + +// This example demonstrates how a value can be passed to the context +// and also how to retrieve it if it exists. +func ExampleWithValue() { + type favContextKey string + + f := func(ctx context.Context, k favContextKey) { + if v := ctx.Value(k); v != nil { + fmt.Println("found value:", v) + return + } + fmt.Println("key not found:", k) + } + + k := favContextKey("language") + ctx := context.WithValue(context.Background(), k, "Go") + + f(ctx, k) + f(ctx, favContextKey("color")) + + // Output: + // found value: Go + // key not found: color +} + +// This example uses AfterFunc to define a function which waits on a sync.Cond, +// stopping the wait when a context is canceled. +func ExampleAfterFunc_cond() { + waitOnCond := func(ctx context.Context, cond *sync.Cond, conditionMet func() bool) error { + stopf := context.AfterFunc(ctx, func() { + // We need to acquire cond.L here to be sure that the Broadcast + // below won't occur before the call to Wait, which would result + // in a missed signal (and deadlock). + cond.L.Lock() + defer cond.L.Unlock() + + // If multiple goroutines are waiting on cond simultaneously, + // we need to make sure we wake up exactly this one. + // That means that we need to Broadcast to all of the goroutines, + // which will wake them all up. + // + // If there are N concurrent calls to waitOnCond, each of the goroutines + // will spuriously wake up O(N) other goroutines that aren't ready yet, + // so this will cause the overall CPU cost to be O(N²). + cond.Broadcast() + }) + defer stopf() + + // Since the wakeups are using Broadcast instead of Signal, this call to + // Wait may unblock due to some other goroutine's context becoming done, + // so to be sure that ctx is actually done we need to check it in a loop. + for !conditionMet() { + cond.Wait() + if ctx.Err() != nil { + return ctx.Err() + } + } + + return nil + } + + cond := sync.NewCond(new(sync.Mutex)) + + var wg sync.WaitGroup + for i := 0; i < 4; i++ { + wg.Add(1) + go func() { + defer wg.Done() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + cond.L.Lock() + defer cond.L.Unlock() + + err := waitOnCond(ctx, cond, func() bool { return false }) + fmt.Println(err) + }() + } + wg.Wait() + + // Output: + // context deadline exceeded + // context deadline exceeded + // context deadline exceeded + // context deadline exceeded +} + +// This example uses AfterFunc to define a function which reads from a net.Conn, +// stopping the read when a context is canceled. +func ExampleAfterFunc_connection() { + readFromConn := func(ctx context.Context, conn net.Conn, b []byte) (n int, err error) { + stopc := make(chan struct{}) + stop := context.AfterFunc(ctx, func() { + conn.SetReadDeadline(time.Now()) + close(stopc) + }) + n, err = conn.Read(b) + if !stop() { + // The AfterFunc was started. + // Wait for it to complete, and reset the Conn's deadline. + <-stopc + conn.SetReadDeadline(time.Time{}) + return n, ctx.Err() + } + return n, err + } + + listener, err := net.Listen("tcp", ":0") + if err != nil { + fmt.Println(err) + return + } + defer listener.Close() + + conn, err := net.Dial(listener.Addr().Network(), listener.Addr().String()) + if err != nil { + fmt.Println(err) + return + } + defer conn.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + b := make([]byte, 1024) + _, err = readFromConn(ctx, conn, b) + fmt.Println(err) + + // Output: + // context deadline exceeded +} + +// This example uses AfterFunc to define a function which combines +// the cancellation signals of two Contexts. +func ExampleAfterFunc_merge() { + // mergeCancel returns a context that contains the values of ctx, + // and which is canceled when either ctx or cancelCtx is canceled. + mergeCancel := func(ctx, cancelCtx context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancelCause(ctx) + stop := context.AfterFunc(cancelCtx, func() { + cancel(context.Cause(cancelCtx)) + }) + return ctx, func() { + stop() + cancel(context.Canceled) + } + } + + ctx1, cancel1 := context.WithCancelCause(context.Background()) + defer cancel1(errors.New("ctx1 canceled")) + + ctx2, cancel2 := context.WithCancelCause(context.Background()) + + mergedCtx, mergedCancel := mergeCancel(ctx1, ctx2) + defer mergedCancel() + + cancel2(errors.New("ctx2 canceled")) + <-mergedCtx.Done() + fmt.Println(context.Cause(mergedCtx)) + + // Output: + // ctx2 canceled +} diff --git a/src/context/net_test.go b/src/context/net_test.go new file mode 100644 index 0000000..a007689 --- /dev/null +++ b/src/context/net_test.go @@ -0,0 +1,21 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package context_test + +import ( + "context" + "net" + "testing" +) + +func TestDeadlineExceededIsNetError(t *testing.T) { + err, ok := context.DeadlineExceeded.(net.Error) + if !ok { + t.Fatal("DeadlineExceeded does not implement net.Error") + } + if !err.Timeout() || !err.Temporary() { + t.Fatalf("Timeout() = %v, Temporary() = %v, want true, true", err.Timeout(), err.Temporary()) + } +} diff --git a/src/context/x_test.go b/src/context/x_test.go new file mode 100644 index 0000000..57fe60b --- /dev/null +++ b/src/context/x_test.go @@ -0,0 +1,956 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package context_test + +import ( + . "context" + "errors" + "fmt" + "math/rand" + "runtime" + "strings" + "sync" + "testing" + "time" +) + +// Each XTestFoo in context_test.go must be called from a TestFoo here to run. +func TestParentFinishesChild(t *testing.T) { + XTestParentFinishesChild(t) // uses unexported context types +} +func TestChildFinishesFirst(t *testing.T) { + XTestChildFinishesFirst(t) // uses unexported context types +} +func TestCancelRemoves(t *testing.T) { + XTestCancelRemoves(t) // uses unexported context types +} +func TestCustomContextGoroutines(t *testing.T) { + XTestCustomContextGoroutines(t) // reads the context.goroutines counter +} + +// The following are regular tests in package context_test. + +// otherContext is a Context that's not one of the types defined in context.go. +// This lets us test code paths that differ based on the underlying type of the +// Context. +type otherContext struct { + Context +} + +const ( + shortDuration = 1 * time.Millisecond // a reasonable duration to block in a test + veryLongDuration = 1000 * time.Hour // an arbitrary upper bound on the test's running time +) + +// quiescent returns an arbitrary duration by which the program should have +// completed any remaining work and reached a steady (idle) state. +func quiescent(t *testing.T) time.Duration { + deadline, ok := t.Deadline() + if !ok { + return 5 * time.Second + } + + const arbitraryCleanupMargin = 1 * time.Second + return time.Until(deadline) - arbitraryCleanupMargin +} +func TestBackground(t *testing.T) { + c := Background() + if c == nil { + t.Fatalf("Background returned nil") + } + select { + case x := <-c.Done(): + t.Errorf("<-c.Done() == %v want nothing (it should block)", x) + default: + } + if got, want := fmt.Sprint(c), "context.Background"; got != want { + t.Errorf("Background().String() = %q want %q", got, want) + } +} + +func TestTODO(t *testing.T) { + c := TODO() + if c == nil { + t.Fatalf("TODO returned nil") + } + select { + case x := <-c.Done(): + t.Errorf("<-c.Done() == %v want nothing (it should block)", x) + default: + } + if got, want := fmt.Sprint(c), "context.TODO"; got != want { + t.Errorf("TODO().String() = %q want %q", got, want) + } +} + +func TestWithCancel(t *testing.T) { + c1, cancel := WithCancel(Background()) + + if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want { + t.Errorf("c1.String() = %q want %q", got, want) + } + + o := otherContext{c1} + c2, _ := WithCancel(o) + contexts := []Context{c1, o, c2} + + for i, c := range contexts { + if d := c.Done(); d == nil { + t.Errorf("c[%d].Done() == %v want non-nil", i, d) + } + if e := c.Err(); e != nil { + t.Errorf("c[%d].Err() == %v want nil", i, e) + } + + select { + case x := <-c.Done(): + t.Errorf("<-c.Done() == %v want nothing (it should block)", x) + default: + } + } + + cancel() // Should propagate synchronously. + for i, c := range contexts { + select { + case <-c.Done(): + default: + t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i) + } + if e := c.Err(); e != Canceled { + t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled) + } + } +} + +func testDeadline(c Context, name string, t *testing.T) { + t.Helper() + d := quiescent(t) + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-timer.C: + t.Fatalf("%s: context not timed out after %v", name, d) + case <-c.Done(): + } + if e := c.Err(); e != DeadlineExceeded { + t.Errorf("%s: c.Err() == %v; want %v", name, e, DeadlineExceeded) + } +} + +func TestDeadline(t *testing.T) { + t.Parallel() + + c, _ := WithDeadline(Background(), time.Now().Add(shortDuration)) + if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { + t.Errorf("c.String() = %q want prefix %q", got, prefix) + } + testDeadline(c, "WithDeadline", t) + + c, _ = WithDeadline(Background(), time.Now().Add(shortDuration)) + o := otherContext{c} + testDeadline(o, "WithDeadline+otherContext", t) + + c, _ = WithDeadline(Background(), time.Now().Add(shortDuration)) + o = otherContext{c} + c, _ = WithDeadline(o, time.Now().Add(veryLongDuration)) + testDeadline(c, "WithDeadline+otherContext+WithDeadline", t) + + c, _ = WithDeadline(Background(), time.Now().Add(-shortDuration)) + testDeadline(c, "WithDeadline+inthepast", t) + + c, _ = WithDeadline(Background(), time.Now()) + testDeadline(c, "WithDeadline+now", t) +} + +func TestTimeout(t *testing.T) { + t.Parallel() + + c, _ := WithTimeout(Background(), shortDuration) + if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { + t.Errorf("c.String() = %q want prefix %q", got, prefix) + } + testDeadline(c, "WithTimeout", t) + + c, _ = WithTimeout(Background(), shortDuration) + o := otherContext{c} + testDeadline(o, "WithTimeout+otherContext", t) + + c, _ = WithTimeout(Background(), shortDuration) + o = otherContext{c} + c, _ = WithTimeout(o, veryLongDuration) + testDeadline(c, "WithTimeout+otherContext+WithTimeout", t) +} + +func TestCanceledTimeout(t *testing.T) { + c, _ := WithTimeout(Background(), time.Second) + o := otherContext{c} + c, cancel := WithTimeout(o, veryLongDuration) + cancel() // Should propagate synchronously. + select { + case <-c.Done(): + default: + t.Errorf("<-c.Done() blocked, but shouldn't have") + } + if e := c.Err(); e != Canceled { + t.Errorf("c.Err() == %v want %v", e, Canceled) + } +} + +type key1 int +type key2 int + +var k1 = key1(1) +var k2 = key2(1) // same int as k1, different type +var k3 = key2(3) // same type as k2, different int + +func TestValues(t *testing.T) { + check := func(c Context, nm, v1, v2, v3 string) { + if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 { + t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0) + } + if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 { + t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0) + } + if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 { + t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0) + } + } + + c0 := Background() + check(c0, "c0", "", "", "") + + c1 := WithValue(Background(), k1, "c1k1") + check(c1, "c1", "c1k1", "", "") + + if got, want := fmt.Sprint(c1), `context.Background.WithValue(type context_test.key1, val c1k1)`; got != want { + t.Errorf("c.String() = %q want %q", got, want) + } + + c2 := WithValue(c1, k2, "c2k2") + check(c2, "c2", "c1k1", "c2k2", "") + + c3 := WithValue(c2, k3, "c3k3") + check(c3, "c2", "c1k1", "c2k2", "c3k3") + + c4 := WithValue(c3, k1, nil) + check(c4, "c4", "", "c2k2", "c3k3") + + o0 := otherContext{Background()} + check(o0, "o0", "", "", "") + + o1 := otherContext{WithValue(Background(), k1, "c1k1")} + check(o1, "o1", "c1k1", "", "") + + o2 := WithValue(o1, k2, "o2k2") + check(o2, "o2", "c1k1", "o2k2", "") + + o3 := otherContext{c4} + check(o3, "o3", "", "c2k2", "c3k3") + + o4 := WithValue(o3, k3, nil) + check(o4, "o4", "", "c2k2", "") +} + +func TestAllocs(t *testing.T) { + bg := Background() + for _, test := range []struct { + desc string + f func() + limit float64 + gccgoLimit float64 + }{ + { + desc: "Background()", + f: func() { Background() }, + limit: 0, + gccgoLimit: 0, + }, + { + desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1), + f: func() { + c := WithValue(bg, k1, nil) + c.Value(k1) + }, + limit: 3, + gccgoLimit: 3, + }, + { + desc: "WithTimeout(bg, 1*time.Nanosecond)", + f: func() { + c, _ := WithTimeout(bg, 1*time.Nanosecond) + <-c.Done() + }, + limit: 12, + gccgoLimit: 15, + }, + { + desc: "WithCancel(bg)", + f: func() { + c, cancel := WithCancel(bg) + cancel() + <-c.Done() + }, + limit: 5, + gccgoLimit: 8, + }, + { + desc: "WithTimeout(bg, 5*time.Millisecond)", + f: func() { + c, cancel := WithTimeout(bg, 5*time.Millisecond) + cancel() + <-c.Done() + }, + limit: 8, + gccgoLimit: 25, + }, + } { + limit := test.limit + if runtime.Compiler == "gccgo" { + // gccgo does not yet do escape analysis. + // TODO(iant): Remove this when gccgo does do escape analysis. + limit = test.gccgoLimit + } + numRuns := 100 + if testing.Short() { + numRuns = 10 + } + if n := testing.AllocsPerRun(numRuns, test.f); n > limit { + t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit)) + } + } +} + +func TestSimultaneousCancels(t *testing.T) { + root, cancel := WithCancel(Background()) + m := map[Context]CancelFunc{root: cancel} + q := []Context{root} + // Create a tree of contexts. + for len(q) != 0 && len(m) < 100 { + parent := q[0] + q = q[1:] + for i := 0; i < 4; i++ { + ctx, cancel := WithCancel(parent) + m[ctx] = cancel + q = append(q, ctx) + } + } + // Start all the cancels in a random order. + var wg sync.WaitGroup + wg.Add(len(m)) + for _, cancel := range m { + go func(cancel CancelFunc) { + cancel() + wg.Done() + }(cancel) + } + + d := quiescent(t) + stuck := make(chan struct{}) + timer := time.AfterFunc(d, func() { close(stuck) }) + defer timer.Stop() + + // Wait on all the contexts in a random order. + for ctx := range m { + select { + case <-ctx.Done(): + case <-stuck: + buf := make([]byte, 10<<10) + n := runtime.Stack(buf, true) + t.Fatalf("timed out after %v waiting for <-ctx.Done(); stacks:\n%s", d, buf[:n]) + } + } + // Wait for all the cancel functions to return. + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + case <-stuck: + buf := make([]byte, 10<<10) + n := runtime.Stack(buf, true) + t.Fatalf("timed out after %v waiting for cancel functions; stacks:\n%s", d, buf[:n]) + } +} + +func TestInterlockedCancels(t *testing.T) { + parent, cancelParent := WithCancel(Background()) + child, cancelChild := WithCancel(parent) + go func() { + <-parent.Done() + cancelChild() + }() + cancelParent() + d := quiescent(t) + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-child.Done(): + case <-timer.C: + buf := make([]byte, 10<<10) + n := runtime.Stack(buf, true) + t.Fatalf("timed out after %v waiting for child.Done(); stacks:\n%s", d, buf[:n]) + } +} + +func TestLayersCancel(t *testing.T) { + testLayers(t, time.Now().UnixNano(), false) +} + +func TestLayersTimeout(t *testing.T) { + testLayers(t, time.Now().UnixNano(), true) +} + +func testLayers(t *testing.T, seed int64, testTimeout bool) { + t.Parallel() + + r := rand.New(rand.NewSource(seed)) + errorf := func(format string, a ...any) { + t.Errorf(fmt.Sprintf("seed=%d: %s", seed, format), a...) + } + const ( + minLayers = 30 + ) + type value int + var ( + vals []*value + cancels []CancelFunc + numTimers int + ctx = Background() + ) + for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ { + switch r.Intn(3) { + case 0: + v := new(value) + ctx = WithValue(ctx, v, v) + vals = append(vals, v) + case 1: + var cancel CancelFunc + ctx, cancel = WithCancel(ctx) + cancels = append(cancels, cancel) + case 2: + var cancel CancelFunc + d := veryLongDuration + if testTimeout { + d = shortDuration + } + ctx, cancel = WithTimeout(ctx, d) + cancels = append(cancels, cancel) + numTimers++ + } + } + checkValues := func(when string) { + for _, key := range vals { + if val := ctx.Value(key).(*value); key != val { + errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key) + } + } + } + if !testTimeout { + select { + case <-ctx.Done(): + errorf("ctx should not be canceled yet") + default: + } + } + if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) { + t.Errorf("ctx.String() = %q want prefix %q", s, prefix) + } + t.Log(ctx) + checkValues("before cancel") + if testTimeout { + d := quiescent(t) + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + case <-timer.C: + errorf("ctx should have timed out after %v", d) + } + checkValues("after timeout") + } else { + cancel := cancels[r.Intn(len(cancels))] + cancel() + select { + case <-ctx.Done(): + default: + errorf("ctx should be canceled") + } + checkValues("after cancel") + } +} + +func TestWithCancelCanceledParent(t *testing.T) { + parent, pcancel := WithCancelCause(Background()) + cause := fmt.Errorf("Because!") + pcancel(cause) + + c, _ := WithCancel(parent) + select { + case <-c.Done(): + default: + t.Errorf("child not done immediately upon construction") + } + if got, want := c.Err(), Canceled; got != want { + t.Errorf("child not canceled; got = %v, want = %v", got, want) + } + if got, want := Cause(c), cause; got != want { + t.Errorf("child has wrong cause; got = %v, want = %v", got, want) + } +} + +func TestWithCancelSimultaneouslyCanceledParent(t *testing.T) { + // Cancel the parent goroutine concurrently with creating a child. + for i := 0; i < 100; i++ { + parent, pcancel := WithCancelCause(Background()) + cause := fmt.Errorf("Because!") + go pcancel(cause) + + c, _ := WithCancel(parent) + <-c.Done() + if got, want := c.Err(), Canceled; got != want { + t.Errorf("child not canceled; got = %v, want = %v", got, want) + } + if got, want := Cause(c), cause; got != want { + t.Errorf("child has wrong cause; got = %v, want = %v", got, want) + } + } +} + +func TestWithValueChecksKey(t *testing.T) { + panicVal := recoveredValue(func() { _ = WithValue(Background(), []byte("foo"), "bar") }) + if panicVal == nil { + t.Error("expected panic") + } + panicVal = recoveredValue(func() { _ = WithValue(Background(), nil, "bar") }) + if got, want := fmt.Sprint(panicVal), "nil key"; got != want { + t.Errorf("panic = %q; want %q", got, want) + } +} + +func TestInvalidDerivedFail(t *testing.T) { + panicVal := recoveredValue(func() { _, _ = WithCancel(nil) }) + if panicVal == nil { + t.Error("expected panic") + } + panicVal = recoveredValue(func() { _, _ = WithDeadline(nil, time.Now().Add(shortDuration)) }) + if panicVal == nil { + t.Error("expected panic") + } + panicVal = recoveredValue(func() { _ = WithValue(nil, "foo", "bar") }) + if panicVal == nil { + t.Error("expected panic") + } +} + +func recoveredValue(fn func()) (v any) { + defer func() { v = recover() }() + fn() + return +} + +func TestDeadlineExceededSupportsTimeout(t *testing.T) { + i, ok := DeadlineExceeded.(interface { + Timeout() bool + }) + if !ok { + t.Fatal("DeadlineExceeded does not support Timeout interface") + } + if !i.Timeout() { + t.Fatal("wrong value for timeout") + } +} +func TestCause(t *testing.T) { + var ( + forever = 1e6 * time.Second + parentCause = fmt.Errorf("parentCause") + childCause = fmt.Errorf("childCause") + tooSlow = fmt.Errorf("tooSlow") + finishedEarly = fmt.Errorf("finishedEarly") + ) + for _, test := range []struct { + name string + ctx func() Context + err error + cause error + }{ + { + name: "Background", + ctx: Background, + err: nil, + cause: nil, + }, + { + name: "TODO", + ctx: TODO, + err: nil, + cause: nil, + }, + { + name: "WithCancel", + ctx: func() Context { + ctx, cancel := WithCancel(Background()) + cancel() + return ctx + }, + err: Canceled, + cause: Canceled, + }, + { + name: "WithCancelCause", + ctx: func() Context { + ctx, cancel := WithCancelCause(Background()) + cancel(parentCause) + return ctx + }, + err: Canceled, + cause: parentCause, + }, + { + name: "WithCancelCause nil", + ctx: func() Context { + ctx, cancel := WithCancelCause(Background()) + cancel(nil) + return ctx + }, + err: Canceled, + cause: Canceled, + }, + { + name: "WithCancelCause: parent cause before child", + ctx: func() Context { + ctx, cancelParent := WithCancelCause(Background()) + ctx, cancelChild := WithCancelCause(ctx) + cancelParent(parentCause) + cancelChild(childCause) + return ctx + }, + err: Canceled, + cause: parentCause, + }, + { + name: "WithCancelCause: parent cause after child", + ctx: func() Context { + ctx, cancelParent := WithCancelCause(Background()) + ctx, cancelChild := WithCancelCause(ctx) + cancelChild(childCause) + cancelParent(parentCause) + return ctx + }, + err: Canceled, + cause: childCause, + }, + { + name: "WithCancelCause: parent cause before nil", + ctx: func() Context { + ctx, cancelParent := WithCancelCause(Background()) + ctx, cancelChild := WithCancel(ctx) + cancelParent(parentCause) + cancelChild() + return ctx + }, + err: Canceled, + cause: parentCause, + }, + { + name: "WithCancelCause: parent cause after nil", + ctx: func() Context { + ctx, cancelParent := WithCancelCause(Background()) + ctx, cancelChild := WithCancel(ctx) + cancelChild() + cancelParent(parentCause) + return ctx + }, + err: Canceled, + cause: Canceled, + }, + { + name: "WithCancelCause: child cause after nil", + ctx: func() Context { + ctx, cancelParent := WithCancel(Background()) + ctx, cancelChild := WithCancelCause(ctx) + cancelParent() + cancelChild(childCause) + return ctx + }, + err: Canceled, + cause: Canceled, + }, + { + name: "WithCancelCause: child cause before nil", + ctx: func() Context { + ctx, cancelParent := WithCancel(Background()) + ctx, cancelChild := WithCancelCause(ctx) + cancelChild(childCause) + cancelParent() + return ctx + }, + err: Canceled, + cause: childCause, + }, + { + name: "WithTimeout", + ctx: func() Context { + ctx, cancel := WithTimeout(Background(), 0) + cancel() + return ctx + }, + err: DeadlineExceeded, + cause: DeadlineExceeded, + }, + { + name: "WithTimeout canceled", + ctx: func() Context { + ctx, cancel := WithTimeout(Background(), forever) + cancel() + return ctx + }, + err: Canceled, + cause: Canceled, + }, + { + name: "WithTimeoutCause", + ctx: func() Context { + ctx, cancel := WithTimeoutCause(Background(), 0, tooSlow) + cancel() + return ctx + }, + err: DeadlineExceeded, + cause: tooSlow, + }, + { + name: "WithTimeoutCause canceled", + ctx: func() Context { + ctx, cancel := WithTimeoutCause(Background(), forever, tooSlow) + cancel() + return ctx + }, + err: Canceled, + cause: Canceled, + }, + { + name: "WithTimeoutCause stacked", + ctx: func() Context { + ctx, cancel := WithCancelCause(Background()) + ctx, _ = WithTimeoutCause(ctx, 0, tooSlow) + cancel(finishedEarly) + return ctx + }, + err: DeadlineExceeded, + cause: tooSlow, + }, + { + name: "WithTimeoutCause stacked canceled", + ctx: func() Context { + ctx, cancel := WithCancelCause(Background()) + ctx, _ = WithTimeoutCause(ctx, forever, tooSlow) + cancel(finishedEarly) + return ctx + }, + err: Canceled, + cause: finishedEarly, + }, + { + name: "WithoutCancel", + ctx: func() Context { + return WithoutCancel(Background()) + }, + err: nil, + cause: nil, + }, + { + name: "WithoutCancel canceled", + ctx: func() Context { + ctx, cancel := WithCancelCause(Background()) + ctx = WithoutCancel(ctx) + cancel(finishedEarly) + return ctx + }, + err: nil, + cause: nil, + }, + { + name: "WithoutCancel timeout", + ctx: func() Context { + ctx, cancel := WithTimeoutCause(Background(), 0, tooSlow) + ctx = WithoutCancel(ctx) + cancel() + return ctx + }, + err: nil, + cause: nil, + }, + } { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + ctx := test.ctx() + if got, want := ctx.Err(), test.err; want != got { + t.Errorf("ctx.Err() = %v want %v", got, want) + } + if got, want := Cause(ctx), test.cause; want != got { + t.Errorf("Cause(ctx) = %v want %v", got, want) + } + }) + } +} + +func TestCauseRace(t *testing.T) { + cause := errors.New("TestCauseRace") + ctx, cancel := WithCancelCause(Background()) + go func() { + cancel(cause) + }() + for { + // Poll Cause, rather than waiting for Done, to test that + // access to the underlying cause is synchronized properly. + if err := Cause(ctx); err != nil { + if err != cause { + t.Errorf("Cause returned %v, want %v", err, cause) + } + break + } + runtime.Gosched() + } +} + +func TestWithoutCancel(t *testing.T) { + key, value := "key", "value" + ctx := WithValue(Background(), key, value) + ctx = WithoutCancel(ctx) + if d, ok := ctx.Deadline(); !d.IsZero() || ok != false { + t.Errorf("ctx.Deadline() = %v, %v want zero, false", d, ok) + } + if done := ctx.Done(); done != nil { + t.Errorf("ctx.Deadline() = %v want nil", done) + } + if err := ctx.Err(); err != nil { + t.Errorf("ctx.Err() = %v want nil", err) + } + if v := ctx.Value(key); v != value { + t.Errorf("ctx.Value(%q) = %q want %q", key, v, value) + } +} + +type customDoneContext struct { + Context + donec chan struct{} +} + +func (c *customDoneContext) Done() <-chan struct{} { + return c.donec +} + +func TestCustomContextPropagation(t *testing.T) { + cause := errors.New("TestCustomContextPropagation") + donec := make(chan struct{}) + ctx1, cancel1 := WithCancelCause(Background()) + ctx2 := &customDoneContext{ + Context: ctx1, + donec: donec, + } + ctx3, cancel3 := WithCancel(ctx2) + defer cancel3() + + cancel1(cause) + close(donec) + + <-ctx3.Done() + if got, want := ctx3.Err(), Canceled; got != want { + t.Errorf("child not canceled; got = %v, want = %v", got, want) + } + if got, want := Cause(ctx3), cause; got != want { + t.Errorf("child has wrong cause; got = %v, want = %v", got, want) + } +} + +func TestAfterFuncCalledAfterCancel(t *testing.T) { + ctx, cancel := WithCancel(Background()) + donec := make(chan struct{}) + stop := AfterFunc(ctx, func() { + close(donec) + }) + select { + case <-donec: + t.Fatalf("AfterFunc called before context is done") + case <-time.After(shortDuration): + } + cancel() + select { + case <-donec: + case <-time.After(veryLongDuration): + t.Fatalf("AfterFunc not called after context is canceled") + } + if stop() { + t.Fatalf("stop() = true, want false") + } +} + +func TestAfterFuncCalledAfterTimeout(t *testing.T) { + ctx, cancel := WithTimeout(Background(), shortDuration) + defer cancel() + donec := make(chan struct{}) + AfterFunc(ctx, func() { + close(donec) + }) + select { + case <-donec: + case <-time.After(veryLongDuration): + t.Fatalf("AfterFunc not called after context is canceled") + } +} + +func TestAfterFuncCalledImmediately(t *testing.T) { + ctx, cancel := WithCancel(Background()) + cancel() + donec := make(chan struct{}) + AfterFunc(ctx, func() { + close(donec) + }) + select { + case <-donec: + case <-time.After(veryLongDuration): + t.Fatalf("AfterFunc not called for already-canceled context") + } +} + +func TestAfterFuncNotCalledAfterStop(t *testing.T) { + ctx, cancel := WithCancel(Background()) + donec := make(chan struct{}) + stop := AfterFunc(ctx, func() { + close(donec) + }) + if !stop() { + t.Fatalf("stop() = false, want true") + } + cancel() + select { + case <-donec: + t.Fatalf("AfterFunc called for already-canceled context") + case <-time.After(shortDuration): + } + if stop() { + t.Fatalf("stop() = true, want false") + } +} + +// This test verifies that cancelling a context does not block waiting for AfterFuncs to finish. +func TestAfterFuncCalledAsynchronously(t *testing.T) { + ctx, cancel := WithCancel(Background()) + donec := make(chan struct{}) + stop := AfterFunc(ctx, func() { + // The channel send blocks until donec is read from. + donec <- struct{}{} + }) + defer stop() + cancel() + // After cancel returns, read from donec and unblock the AfterFunc. + select { + case <-donec: + case <-time.After(veryLongDuration): + t.Fatalf("AfterFunc not called after context is canceled") + } +} |