diff options
Diffstat (limited to 'src/cmd/compile/internal/rangefunc')
-rw-r--r-- | src/cmd/compile/internal/rangefunc/rangefunc_test.go | 1297 | ||||
-rw-r--r-- | src/cmd/compile/internal/rangefunc/rewrite.go | 1334 |
2 files changed, 2631 insertions, 0 deletions
diff --git a/src/cmd/compile/internal/rangefunc/rangefunc_test.go b/src/cmd/compile/internal/rangefunc/rangefunc_test.go new file mode 100644 index 0000000..16856c6 --- /dev/null +++ b/src/cmd/compile/internal/rangefunc/rangefunc_test.go @@ -0,0 +1,1297 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +//go:build goexperiment.rangefunc + +package rangefunc_test + +import ( + "slices" + "testing" +) + +type Seq2[T1, T2 any] func(yield func(T1, T2) bool) + +// OfSliceIndex returns a Seq over the elements of s. It is equivalent +// to range s. +func OfSliceIndex[T any, S ~[]T](s S) Seq2[int, T] { + return func(yield func(int, T) bool) { + for i, v := range s { + if !yield(i, v) { + return + } + } + return + } +} + +// BadOfSliceIndex is "bad" because it ignores the return value from yield +// and just keeps on iterating. +func BadOfSliceIndex[T any, S ~[]T](s S) Seq2[int, T] { + return func(yield func(int, T) bool) { + for i, v := range s { + yield(i, v) + } + return + } +} + +// VeryBadOfSliceIndex is "very bad" because it ignores the return value from yield +// and just keeps on iterating, and also wraps that call in a defer-recover so it can +// keep on trying after the first panic. +func VeryBadOfSliceIndex[T any, S ~[]T](s S) Seq2[int, T] { + return func(yield func(int, T) bool) { + for i, v := range s { + func() { + defer func() { + recover() + }() + yield(i, v) + }() + } + return + } +} + +// CooperativeBadOfSliceIndex calls the loop body from a goroutine after +// a ping on a channel, and returns recover()on that same channel. +func CooperativeBadOfSliceIndex[T any, S ~[]T](s S, proceed chan any) Seq2[int, T] { + return func(yield func(int, T) bool) { + for i, v := range s { + if !yield(i, v) { + // if the body breaks, call yield just once in a goroutine + go func() { + <-proceed + defer func() { + proceed <- recover() + }() + yield(0, s[0]) + }() + return + } + } + return + } +} + +// TrickyIterator is a type intended to test whether an iterator that +// calls a yield function after loop exit must inevitably escape the +// closure; this might be relevant to future checking/optimization. +type TrickyIterator struct { + yield func(int, int) bool +} + +func (ti *TrickyIterator) iterAll(s []int) Seq2[int, int] { + return func(yield func(int, int) bool) { + ti.yield = yield // Save yield for future abuse + for i, v := range s { + if !yield(i, v) { + return + } + } + return + } +} + +func (ti *TrickyIterator) iterOne(s []int) Seq2[int, int] { + return func(yield func(int, int) bool) { + ti.yield = yield // Save yield for future abuse + if len(s) > 0 { // Not in a loop might escape differently + yield(0, s[0]) + } + return + } +} + +func (ti *TrickyIterator) iterZero(s []int) Seq2[int, int] { + return func(yield func(int, int) bool) { + ti.yield = yield // Save yield for future abuse + // Don't call it at all, maybe it won't escape + return + } +} + +func (ti *TrickyIterator) fail() { + if ti.yield != nil { + ti.yield(1, 1) + } +} + +// Check wraps the function body passed to iterator forall +// in code that ensures that it cannot (successfully) be called +// either after body return false (control flow out of loop) or +// forall itself returns (the iteration is now done). +// +// Note that this can catch errors before the inserted checks. +func Check[U, V any](forall Seq2[U, V]) Seq2[U, V] { + return func(body func(U, V) bool) { + ret := true + forall(func(u U, v V) bool { + if !ret { + panic("Checked iterator access after exit") + } + ret = body(u, v) + return ret + }) + ret = false + } +} + +func TestCheck(t *testing.T) { + i := 0 + defer func() { + if r := recover(); r != nil { + t.Logf("Saw expected panic '%v'", r) + } else { + t.Error("Wanted to see a failure") + } + }() + for _, x := range Check(BadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) { + i += x + if i > 4*9 { + break + } + } +} + +func TestCooperativeBadOfSliceIndex(t *testing.T) { + i := 0 + proceed := make(chan any) + for _, x := range CooperativeBadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, proceed) { + i += x + if i >= 36 { + break + } + } + proceed <- true + if r := <-proceed; r != nil { + t.Logf("Saw expected panic '%v'", r) + } else { + t.Error("Wanted to see a failure") + } + if i != 36 { + t.Errorf("Expected i == 36, saw %d instead", i) + } else { + t.Logf("i = %d", i) + } +} + +func TestCheckCooperativeBadOfSliceIndex(t *testing.T) { + i := 0 + proceed := make(chan any) + for _, x := range Check(CooperativeBadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, proceed)) { + i += x + if i >= 36 { + break + } + } + proceed <- true + if r := <-proceed; r != nil { + t.Logf("Saw expected panic '%v'", r) + } else { + t.Error("Wanted to see a failure") + } + if i != 36 { + t.Errorf("Expected i == 36, saw %d instead", i) + } else { + t.Logf("i = %d", i) + } +} + +func TestTrickyIterAll(t *testing.T) { + trickItAll := TrickyIterator{} + i := 0 + for _, x := range trickItAll.iterAll([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + i += x + if i >= 36 { + break + } + } + + if i != 36 { + t.Errorf("Expected i == 36, saw %d instead", i) + } else { + t.Logf("i = %d", i) + } + + defer func() { + if r := recover(); r != nil { + t.Logf("Saw expected panic '%v'", r) + } else { + t.Error("Wanted to see a failure") + } + }() + + trickItAll.fail() +} + +func TestTrickyIterOne(t *testing.T) { + trickItOne := TrickyIterator{} + i := 0 + for _, x := range trickItOne.iterOne([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + i += x + if i >= 36 { + break + } + } + + // Don't care about value, ought to be 36 anyhow. + t.Logf("i = %d", i) + + defer func() { + if r := recover(); r != nil { + t.Logf("Saw expected panic '%v'", r) + } else { + t.Error("Wanted to see a failure") + } + }() + + trickItOne.fail() +} + +func TestTrickyIterZero(t *testing.T) { + trickItZero := TrickyIterator{} + i := 0 + for _, x := range trickItZero.iterZero([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + i += x + if i >= 36 { + break + } + } + + // Don't care about value, ought to be 0 anyhow. + t.Logf("i = %d", i) + + defer func() { + if r := recover(); r != nil { + t.Logf("Saw expected panic '%v'", r) + } else { + t.Error("Wanted to see a failure") + } + }() + + trickItZero.fail() +} + +func TestCheckTrickyIterZero(t *testing.T) { + trickItZero := TrickyIterator{} + i := 0 + for _, x := range Check(trickItZero.iterZero([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) { + i += x + if i >= 36 { + break + } + } + + // Don't care about value, ought to be 0 anyhow. + t.Logf("i = %d", i) + + defer func() { + if r := recover(); r != nil { + t.Logf("Saw expected panic '%v'", r) + } else { + t.Error("Wanted to see a failure") + } + }() + + trickItZero.fail() +} + +// TestBreak1 should just work, with well-behaved iterators. +// (The misbehaving iterator detector should not trigger.) +func TestBreak1(t *testing.T) { + var result []int + var expect = []int{1, 2, -1, 1, 2, -2, 1, 2, -3} + for _, x := range OfSliceIndex([]int{-1, -2, -3, -4}) { + if x == -4 { + break + } + for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if y == 3 { + break + } + result = append(result, y) + } + result = append(result, x) + } + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// TestBreak2 should just work, with well-behaved iterators. +// (The misbehaving iterator detector should not trigger.) +func TestBreak2(t *testing.T) { + var result []int + var expect = []int{1, 2, -1, 1, 2, -2, 1, 2, -3} +outer: + for _, x := range OfSliceIndex([]int{-1, -2, -3, -4}) { + for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if y == 3 { + break + } + if x == -4 { + break outer + } + + result = append(result, y) + } + result = append(result, x) + } + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// TestContinue should just work, with well-behaved iterators. +// (The misbehaving iterator detector should not trigger.) +func TestContinue(t *testing.T) { + var result []int + var expect = []int{-1, 1, 2, -2, 1, 2, -3, 1, 2, -4} +outer: + for _, x := range OfSliceIndex([]int{-1, -2, -3, -4}) { + result = append(result, x) + for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if y == 3 { + continue outer + } + if x == -4 { + break outer + } + + result = append(result, y) + } + result = append(result, x-10) + } + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// TestBreak3 should just work, with well-behaved iterators. +// (The misbehaving iterator detector should not trigger.) +func TestBreak3(t *testing.T) { + var result []int + var expect = []int{100, 10, 2, 4, 200, 10, 2, 4, 20, 2, 4, 300, 10, 2, 4, 20, 2, 4, 30} +X: + for _, x := range OfSliceIndex([]int{100, 200, 300, 400}) { + Y: + for _, y := range OfSliceIndex([]int{10, 20, 30, 40}) { + if 10*y >= x { + break + } + result = append(result, y) + if y == 30 { + continue X + } + Z: + for _, z := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if z&1 == 1 { + continue Z + } + result = append(result, z) + if z >= 4 { + continue Y + } + } + result = append(result, -y) // should never be executed + } + result = append(result, x) + } + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// TestBreak1BadA should end in a panic when the outer-loop's +// single-level break is ignore by BadOfSliceIndex +func TestBreak1BadA(t *testing.T) { + var result []int + var expect = []int{1, 2, -1, 1, 2, -2, 1, 2, -3} + + defer func() { + if r := recover(); r != nil { + t.Logf("Saw expected panic '%v'", r) + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } + } else { + t.Error("Wanted to see a failure") + } + }() + + for _, x := range BadOfSliceIndex([]int{-1, -2, -3, -4, -5}) { + if x == -4 { + break + } + for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if y == 3 { + break + } + result = append(result, y) + } + result = append(result, x) + } +} + +// TestBreak1BadB should end in a panic, sooner, when the inner-loop's +// (nested) single-level break is ignored by BadOfSliceIndex +func TestBreak1BadB(t *testing.T) { + var result []int + var expect = []int{1, 2} // inner breaks, panics, after before outer appends + + defer func() { + if r := recover(); r != nil { + t.Logf("Saw expected panic '%v'", r) + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } + } else { + t.Error("Wanted to see a failure") + } + }() + + for _, x := range OfSliceIndex([]int{-1, -2, -3, -4, -5}) { + if x == -4 { + break + } + for _, y := range BadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if y == 3 { + break + } + result = append(result, y) + } + result = append(result, x) + } +} + +// TestMultiCont0 tests multilevel continue with no bad iterators +// (it should just work) +func TestMultiCont0(t *testing.T) { + var result []int + var expect = []int{1000, 10, 2, 4, 2000} + +W: + for _, w := range OfSliceIndex([]int{1000, 2000}) { + result = append(result, w) + if w == 2000 { + break + } + for _, x := range OfSliceIndex([]int{100, 200, 300, 400}) { + for _, y := range OfSliceIndex([]int{10, 20, 30, 40}) { + result = append(result, y) + for _, z := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if z&1 == 1 { + continue + } + result = append(result, z) + if z >= 4 { + continue W // modified to be multilevel + } + } + result = append(result, -y) // should never be executed + } + result = append(result, x) + } + } + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// TestMultiCont1 tests multilevel continue with a bad iterator +// in the outermost loop exited by the continue. +func TestMultiCont1(t *testing.T) { + var result []int + var expect = []int{1000, 10, 2, 4} + defer func() { + if r := recover(); r != nil { + t.Logf("Saw expected panic '%v'", r) + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } + } else { + t.Errorf("Wanted to see a failure, result was %v", result) + } + }() + +W: + for _, w := range OfSliceIndex([]int{1000, 2000}) { + result = append(result, w) + if w == 2000 { + break + } + for _, x := range BadOfSliceIndex([]int{100, 200, 300, 400}) { + for _, y := range OfSliceIndex([]int{10, 20, 30, 40}) { + result = append(result, y) + for _, z := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if z&1 == 1 { + continue + } + result = append(result, z) + if z >= 4 { + continue W + } + } + result = append(result, -y) // should never be executed + } + result = append(result, x) + } + } + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// TestMultiCont2 tests multilevel continue with a bad iterator +// in a middle loop exited by the continue. +func TestMultiCont2(t *testing.T) { + var result []int + var expect = []int{1000, 10, 2, 4} + defer func() { + if r := recover(); r != nil { + t.Logf("Saw expected panic '%v'", r) + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } + } else { + t.Errorf("Wanted to see a failure, result was %v", result) + } + }() + +W: + for _, w := range OfSliceIndex([]int{1000, 2000}) { + result = append(result, w) + if w == 2000 { + break + } + for _, x := range OfSliceIndex([]int{100, 200, 300, 400}) { + for _, y := range BadOfSliceIndex([]int{10, 20, 30, 40}) { + result = append(result, y) + for _, z := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if z&1 == 1 { + continue + } + result = append(result, z) + if z >= 4 { + continue W + } + } + result = append(result, -y) // should never be executed + } + result = append(result, x) + } + } + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// TestMultiCont3 tests multilevel continue with a bad iterator +// in the innermost loop exited by the continue. +func TestMultiCont3(t *testing.T) { + var result []int + var expect = []int{1000, 10, 2, 4} + defer func() { + if r := recover(); r != nil { + t.Logf("Saw expected panic '%v'", r) + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } + } else { + t.Errorf("Wanted to see a failure, result was %v", result) + } + }() + +W: + for _, w := range OfSliceIndex([]int{1000, 2000}) { + result = append(result, w) + if w == 2000 { + break + } + for _, x := range OfSliceIndex([]int{100, 200, 300, 400}) { + for _, y := range OfSliceIndex([]int{10, 20, 30, 40}) { + result = append(result, y) + for _, z := range BadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if z&1 == 1 { + continue + } + result = append(result, z) + if z >= 4 { + continue W + } + } + result = append(result, -y) // should never be executed + } + result = append(result, x) + } + } + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// TestMultiBreak0 tests multilevel break with a bad iterator +// in the outermost loop exited by the break (the outermost loop). +func TestMultiBreak0(t *testing.T) { + var result []int + var expect = []int{1000, 10, 2, 4} + defer func() { + if r := recover(); r != nil { + t.Logf("Saw expected panic '%v'", r) + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } + } else { + t.Errorf("Wanted to see a failure, result was %v", result) + } + }() + +W: + for _, w := range BadOfSliceIndex([]int{1000, 2000}) { + result = append(result, w) + if w == 2000 { + break + } + for _, x := range OfSliceIndex([]int{100, 200, 300, 400}) { + for _, y := range OfSliceIndex([]int{10, 20, 30, 40}) { + result = append(result, y) + for _, z := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if z&1 == 1 { + continue + } + result = append(result, z) + if z >= 4 { + break W + } + } + result = append(result, -y) // should never be executed + } + result = append(result, x) + } + } + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// TestMultiBreak1 tests multilevel break with a bad iterator +// in an intermediate loop exited by the break. +func TestMultiBreak1(t *testing.T) { + var result []int + var expect = []int{1000, 10, 2, 4} + defer func() { + if r := recover(); r != nil { + t.Logf("Saw expected panic '%v'", r) + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } + } else { + t.Errorf("Wanted to see a failure, result was %v", result) + } + }() + +W: + for _, w := range OfSliceIndex([]int{1000, 2000}) { + result = append(result, w) + if w == 2000 { + break + } + for _, x := range BadOfSliceIndex([]int{100, 200, 300, 400}) { + for _, y := range OfSliceIndex([]int{10, 20, 30, 40}) { + result = append(result, y) + for _, z := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if z&1 == 1 { + continue + } + result = append(result, z) + if z >= 4 { + break W + } + } + result = append(result, -y) // should never be executed + } + result = append(result, x) + } + } + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// TestMultiBreak2 tests multilevel break with two bad iterators +// in intermediate loops exited by the break. +func TestMultiBreak2(t *testing.T) { + var result []int + var expect = []int{1000, 10, 2, 4} + defer func() { + if r := recover(); r != nil { + t.Logf("Saw expected panic '%v'", r) + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } + } else { + t.Errorf("Wanted to see a failure, result was %v", result) + } + }() + +W: + for _, w := range OfSliceIndex([]int{1000, 2000}) { + result = append(result, w) + if w == 2000 { + break + } + for _, x := range BadOfSliceIndex([]int{100, 200, 300, 400}) { + for _, y := range BadOfSliceIndex([]int{10, 20, 30, 40}) { + result = append(result, y) + for _, z := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if z&1 == 1 { + continue + } + result = append(result, z) + if z >= 4 { + break W + } + } + result = append(result, -y) // should never be executed + } + result = append(result, x) + } + } + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// TestMultiBreak3 tests multilevel break with the bad iterator +// in the innermost loop exited by the break. +func TestMultiBreak3(t *testing.T) { + var result []int + var expect = []int{1000, 10, 2, 4} + defer func() { + if r := recover(); r != nil { + t.Logf("Saw expected panic '%v'", r) + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } + } else { + t.Errorf("Wanted to see a failure, result was %v", result) + } + }() + +W: + for _, w := range OfSliceIndex([]int{1000, 2000}) { + result = append(result, w) + if w == 2000 { + break + } + for _, x := range OfSliceIndex([]int{100, 200, 300, 400}) { + for _, y := range OfSliceIndex([]int{10, 20, 30, 40}) { + result = append(result, y) + for _, z := range BadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if z&1 == 1 { + continue + } + result = append(result, z) + if z >= 4 { + break W + } + } + result = append(result, -y) // should never be executed + } + result = append(result, x) + } + } + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// veryBad tests that a loop nest behaves sensibly in the face of a +// "very bad" iterator. In this case, "sensibly" means that the +// break out of X still occurs after the very bad iterator finally +// quits running (the control flow bread crumbs remain.) +func veryBad(s []int) []int { + var result []int +X: + for _, x := range OfSliceIndex([]int{1, 2, 3}) { + + result = append(result, x) + + for _, y := range VeryBadOfSliceIndex(s) { + result = append(result, y) + break X + } + for _, z := range OfSliceIndex([]int{100, 200, 300}) { + result = append(result, z) + if z == 100 { + break + } + } + } + return result +} + +// checkVeryBad wraps a "very bad" iterator with Check, +// demonstrating that the very bad iterator also hides panics +// thrown by Check. +func checkVeryBad(s []int) []int { + var result []int +X: + for _, x := range OfSliceIndex([]int{1, 2, 3}) { + + result = append(result, x) + + for _, y := range Check(VeryBadOfSliceIndex(s)) { + result = append(result, y) + break X + } + for _, z := range OfSliceIndex([]int{100, 200, 300}) { + result = append(result, z) + if z == 100 { + break + } + } + } + return result +} + +// okay is the not-bad version of veryBad. +// They should behave the same. +func okay(s []int) []int { + var result []int +X: + for _, x := range OfSliceIndex([]int{1, 2, 3}) { + + result = append(result, x) + + for _, y := range OfSliceIndex(s) { + result = append(result, y) + break X + } + for _, z := range OfSliceIndex([]int{100, 200, 300}) { + result = append(result, z) + if z == 100 { + break + } + } + } + return result +} + +// TestVeryBad1 checks the behavior of an extremely poorly behaved iterator. +func TestVeryBad1(t *testing.T) { + result := veryBad([]int{10, 20, 30, 40, 50}) // odd length + expect := []int{1, 10} + + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// TestVeryBad2 checks the behavior of an extremely poorly behaved iterator. +func TestVeryBad2(t *testing.T) { + result := veryBad([]int{10, 20, 30, 40}) // even length + expect := []int{1, 10} + + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// TestCheckVeryBad checks the behavior of an extremely poorly behaved iterator, +// which also suppresses the exceptions from "Check" +func TestCheckVeryBad(t *testing.T) { + result := checkVeryBad([]int{10, 20, 30, 40}) // even length + expect := []int{1, 10} + + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// TestOk is the nice version of the very bad iterator. +func TestOk(t *testing.T) { + result := okay([]int{10, 20, 30, 40, 50}) // odd length + expect := []int{1, 10} + + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } +} + +// testBreak1BadDefer checks that defer behaves properly even in +// the presence of loop bodies panicking out of bad iterators. +// (i.e., the instrumentation did not break defer in these loops) +func testBreak1BadDefer(t *testing.T) (result []int) { + var expect = []int{1, 2, -1, 1, 2, -2, 1, 2, -3, -30, -20, -10} + + defer func() { + if r := recover(); r != nil { + t.Logf("Saw expected panic '%v'", r) + if !slices.Equal(expect, result) { + t.Errorf("(Inner) Expected %v, got %v", expect, result) + } + } else { + t.Error("Wanted to see a failure") + } + }() + + for _, x := range BadOfSliceIndex([]int{-1, -2, -3, -4, -5}) { + if x == -4 { + break + } + defer func() { + result = append(result, x*10) + }() + for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if y == 3 { + break + } + result = append(result, y) + } + result = append(result, x) + } + return +} + +func TestBreak1BadDefer(t *testing.T) { + var result []int + var expect = []int{1, 2, -1, 1, 2, -2, 1, 2, -3, -30, -20, -10} + result = testBreak1BadDefer(t) + if !slices.Equal(expect, result) { + t.Errorf("(Outer) Expected %v, got %v", expect, result) + } +} + +// testReturn1 has no bad iterators. +func testReturn1(t *testing.T) (result []int, err any) { + defer func() { + err = recover() + }() + for _, x := range OfSliceIndex([]int{-1, -2, -3, -4, -5}) { + result = append(result, x) + if x == -4 { + break + } + defer func() { + result = append(result, x*10) + }() + for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if y == 3 { + return + } + result = append(result, y) + } + result = append(result, x) + } + return +} + +// testReturn2 has an outermost bad iterator +func testReturn2(t *testing.T) (result []int, err any) { + defer func() { + err = recover() + }() + for _, x := range BadOfSliceIndex([]int{-1, -2, -3, -4, -5}) { + result = append(result, x) + if x == -4 { + break + } + defer func() { + result = append(result, x*10) + }() + for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if y == 3 { + return + } + result = append(result, y) + } + result = append(result, x) + } + return +} + +// testReturn3 has an innermost bad iterator +func testReturn3(t *testing.T) (result []int, err any) { + defer func() { + err = recover() + }() + for _, x := range OfSliceIndex([]int{-1, -2, -3, -4, -5}) { + result = append(result, x) + if x == -4 { + break + } + defer func() { + result = append(result, x*10) + }() + for _, y := range BadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if y == 3 { + return + } + result = append(result, y) + } + } + return +} + +// TestReturns checks that returns through bad iterators behave properly, +// for inner and outer bad iterators. +func TestReturns(t *testing.T) { + var result []int + var expect = []int{-1, 1, 2, -10} + var err any + + result, err = testReturn1(t) + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } + if err != nil { + t.Errorf("Unexpected error %v", err) + } + + result, err = testReturn2(t) + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } + if err == nil { + t.Errorf("Missing expected error") + } else { + t.Logf("Saw expected panic '%v'", err) + } + + result, err = testReturn3(t) + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } + if err == nil { + t.Errorf("Missing expected error") + } else { + t.Logf("Saw expected panic '%v'", err) + } + +} + +// testGotoA1 tests loop-nest-internal goto, no bad iterators. +func testGotoA1(t *testing.T) (result []int, err any) { + defer func() { + err = recover() + }() + for _, x := range OfSliceIndex([]int{-1, -2, -3, -4, -5}) { + result = append(result, x) + if x == -4 { + break + } + defer func() { + result = append(result, x*10) + }() + for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if y == 3 { + goto A + } + result = append(result, y) + } + result = append(result, x) + A: + } + return +} + +// testGotoA2 tests loop-nest-internal goto, outer bad iterator. +func testGotoA2(t *testing.T) (result []int, err any) { + defer func() { + err = recover() + }() + for _, x := range BadOfSliceIndex([]int{-1, -2, -3, -4, -5}) { + result = append(result, x) + if x == -4 { + break + } + defer func() { + result = append(result, x*10) + }() + for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if y == 3 { + goto A + } + result = append(result, y) + } + result = append(result, x) + A: + } + return +} + +// testGotoA3 tests loop-nest-internal goto, inner bad iterator. +func testGotoA3(t *testing.T) (result []int, err any) { + defer func() { + err = recover() + }() + for _, x := range OfSliceIndex([]int{-1, -2, -3, -4, -5}) { + result = append(result, x) + if x == -4 { + break + } + defer func() { + result = append(result, x*10) + }() + for _, y := range BadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if y == 3 { + goto A + } + result = append(result, y) + } + result = append(result, x) + A: + } + return +} + +func TestGotoA(t *testing.T) { + var result []int + var expect = []int{-1, 1, 2, -2, 1, 2, -3, 1, 2, -4, -30, -20, -10} + var expect3 = []int{-1, 1, 2, -10} // first goto becomes a panic + var err any + + result, err = testGotoA1(t) + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } + if err != nil { + t.Errorf("Unexpected error %v", err) + } + + result, err = testGotoA2(t) + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } + if err == nil { + t.Errorf("Missing expected error") + } else { + t.Logf("Saw expected panic '%v'", err) + } + + result, err = testGotoA3(t) + if !slices.Equal(expect3, result) { + t.Errorf("Expected %v, got %v", expect3, result) + } + if err == nil { + t.Errorf("Missing expected error") + } else { + t.Logf("Saw expected panic '%v'", err) + } +} + +// testGotoB1 tests loop-nest-exiting goto, no bad iterators. +func testGotoB1(t *testing.T) (result []int, err any) { + defer func() { + err = recover() + }() + for _, x := range OfSliceIndex([]int{-1, -2, -3, -4, -5}) { + result = append(result, x) + if x == -4 { + break + } + defer func() { + result = append(result, x*10) + }() + for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if y == 3 { + goto B + } + result = append(result, y) + } + result = append(result, x) + } +B: + result = append(result, 999) + return +} + +// testGotoB2 tests loop-nest-exiting goto, outer bad iterator. +func testGotoB2(t *testing.T) (result []int, err any) { + defer func() { + err = recover() + }() + for _, x := range BadOfSliceIndex([]int{-1, -2, -3, -4, -5}) { + result = append(result, x) + if x == -4 { + break + } + defer func() { + result = append(result, x*10) + }() + for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if y == 3 { + goto B + } + result = append(result, y) + } + result = append(result, x) + } +B: + result = append(result, 999) + return +} + +// testGotoB3 tests loop-nest-exiting goto, inner bad iterator. +func testGotoB3(t *testing.T) (result []int, err any) { + defer func() { + err = recover() + }() + for _, x := range OfSliceIndex([]int{-1, -2, -3, -4, -5}) { + result = append(result, x) + if x == -4 { + break + } + defer func() { + result = append(result, x*10) + }() + for _, y := range BadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + if y == 3 { + goto B + } + result = append(result, y) + } + result = append(result, x) + } +B: + result = append(result, 999) + return +} + +func TestGotoB(t *testing.T) { + var result []int + var expect = []int{-1, 1, 2, 999, -10} + var expectX = []int{-1, 1, 2, -10} + var err any + + result, err = testGotoB1(t) + if !slices.Equal(expect, result) { + t.Errorf("Expected %v, got %v", expect, result) + } + if err != nil { + t.Errorf("Unexpected error %v", err) + } + + result, err = testGotoB2(t) + if !slices.Equal(expectX, result) { + t.Errorf("Expected %v, got %v", expectX, result) + } + if err == nil { + t.Errorf("Missing expected error") + } else { + t.Logf("Saw expected panic '%v'", err) + } + + result, err = testGotoB3(t) + if !slices.Equal(expectX, result) { + t.Errorf("Expected %v, got %v", expectX, result) + } + if err == nil { + t.Errorf("Missing expected error") + } else { + t.Logf("Saw expected panic '%v'", err) + } +} diff --git a/src/cmd/compile/internal/rangefunc/rewrite.go b/src/cmd/compile/internal/rangefunc/rewrite.go new file mode 100644 index 0000000..d439412 --- /dev/null +++ b/src/cmd/compile/internal/rangefunc/rewrite.go @@ -0,0 +1,1334 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Package rangefunc rewrites range-over-func to code that doesn't use range-over-funcs. +Rewriting the construct in the front end, before noder, means the functions generated during +the rewrite are available in a noder-generated representation for inlining by the back end. + +# Theory of Operation + +The basic idea is to rewrite + + for x := range f { + ... + } + +into + + f(func(x T) bool { + ... + }) + +But it's not usually that easy. + +# Range variables + +For a range not using :=, the assigned variables cannot be function parameters +in the generated body function. Instead, we allocate fake parameters and +start the body with an assignment. For example: + + for expr1, expr2 = range f { + ... + } + +becomes + + f(func(#p1 T1, #p2 T2) bool { + expr1, expr2 = #p1, #p2 + ... + }) + +(All the generated variables have a # at the start to signal that they +are internal variables when looking at the generated code in a +debugger. Because variables have all been resolved to the specific +objects they represent, there is no danger of using plain "p1" and +colliding with a Go variable named "p1"; the # is just nice to have, +not for correctness.) + +It can also happen that there are fewer range variables than function +arguments, in which case we end up with something like + + f(func(x T1, _ T2) bool { + ... + }) + +or + + f(func(#p1 T1, #p2 T2, _ T3) bool { + expr1, expr2 = #p1, #p2 + ... + }) + +# Return + +If the body contains a "break", that break turns into "return false", +to tell f to stop. And if the body contains a "continue", that turns +into "return true", to tell f to proceed with the next value. +Those are the easy cases. + +If the body contains a return or a break/continue/goto L, then we need +to rewrite that into code that breaks out of the loop and then +triggers that control flow. In general we rewrite + + for x := range f { + ... + } + +into + + { + var #next int + f(func(x T1) bool { + ... + return true + }) + ... check #next ... + } + +The variable #next is an integer code that says what to do when f +returns. Each difficult statement sets #next and then returns false to +stop f. + +A plain "return" rewrites to {#next = -1; return false}. +The return false breaks the loop. Then when f returns, the "check +#next" section includes + + if #next == -1 { return } + +which causes the return we want. + +Return with arguments is more involved. We need somewhere to store the +arguments while we break out of f, so we add them to the var +declaration, like: + + { + var ( + #next int + #r1 type1 + #r2 type2 + ) + f(func(x T1) bool { + ... + { + // return a, b + #r1, #r2 = a, b + #next = -2 + return false + } + ... + return true + }) + if #next == -2 { return #r1, #r2 } + } + +TODO: What about: + + func f() (x bool) { + for range g(&x) { + return true + } + } + + func g(p *bool) func(func() bool) { + return func(yield func() bool) { + yield() + // Is *p true or false here? + } + } + +With this rewrite the "return true" is not visible after yield returns, +but maybe it should be? + +# Checking + +To permit checking that an iterator is well-behaved -- that is, that +it does not call the loop body again after it has returned false or +after the entire loop has exited (it might retain a copy of the body +function, or pass it to another goroutine) -- each generated loop has +its own #exitK flag that is checked before each iteration, and set both +at any early exit and after the iteration completes. + +For example: + + for x := range f { + ... + if ... { break } + ... + } + +becomes + + { + var #exit1 bool + f(func(x T1) bool { + if #exit1 { runtime.panicrangeexit() } + ... + if ... { #exit1 = true ; return false } + ... + return true + }) + #exit1 = true + } + +# Nested Loops + +So far we've only considered a single loop. If a function contains a +sequence of loops, each can be translated individually. But loops can +be nested. It would work to translate the innermost loop and then +translate the loop around it, and so on, except that there'd be a lot +of rewriting of rewritten code and the overall traversals could end up +taking time quadratic in the depth of the nesting. To avoid all that, +we use a single rewriting pass that handles a top-most range-over-func +loop and all the range-over-func loops it contains at the same time. + +If we need to return from inside a doubly-nested loop, the rewrites +above stay the same, but the check after the inner loop only says + + if #next < 0 { return false } + +to stop the outer loop so it can do the actual return. That is, + + for range f { + for range g { + ... + return a, b + ... + } + } + +becomes + + { + var ( + #next int + #r1 type1 + #r2 type2 + ) + var #exit1 bool + f(func() { + if #exit1 { runtime.panicrangeexit() } + var #exit2 bool + g(func() { + if #exit2 { runtime.panicrangeexit() } + ... + { + // return a, b + #r1, #r2 = a, b + #next = -2 + #exit1, #exit2 = true, true + return false + } + ... + return true + }) + #exit2 = true + if #next < 0 { + return false + } + return true + }) + #exit1 = true + if #next == -2 { + return #r1, #r2 + } + } + +Note that the #next < 0 after the inner loop handles both kinds of +return with a single check. + +# Labeled break/continue of range-over-func loops + +For a labeled break or continue of an outer range-over-func, we +use positive #next values. Any such labeled break or continue +really means "do N breaks" or "do N breaks and 1 continue". +We encode that as perLoopStep*N or perLoopStep*N+1 respectively. + +Loops that might need to propagate a labeled break or continue +add one or both of these to the #next checks: + + if #next >= 2 { + #next -= 2 + return false + } + + if #next == 1 { + #next = 0 + return true + } + +For example + + F: for range f { + for range g { + for range h { + ... + break F + ... + ... + continue F + ... + } + } + ... + } + +becomes + + { + var #next int + var #exit1 bool + f(func() { + if #exit1 { runtime.panicrangeexit() } + var #exit2 bool + g(func() { + if #exit2 { runtime.panicrangeexit() } + var #exit3 bool + h(func() { + if #exit3 { runtime.panicrangeexit() } + ... + { + // break F + #next = 4 + #exit1, #exit2, #exit3 = true, true, true + return false + } + ... + { + // continue F + #next = 3 + #exit2, #exit3 = true, true + return false + } + ... + return true + }) + #exit3 = true + if #next >= 2 { + #next -= 2 + return false + } + return true + }) + #exit2 = true + if #next >= 2 { + #next -= 2 + return false + } + if #next == 1 { + #next = 0 + return true + } + ... + return true + }) + #exit1 = true + } + +Note that the post-h checks only consider a break, +since no generated code tries to continue g. + +# Gotos and other labeled break/continue + +The final control flow translations are goto and break/continue of a +non-range-over-func statement. In both cases, we may need to break out +of one or more range-over-func loops before we can do the actual +control flow statement. Each such break/continue/goto L statement is +assigned a unique negative #next value (below -2, since -1 and -2 are +for the two kinds of return). Then the post-checks for a given loop +test for the specific codes that refer to labels directly targetable +from that block. Otherwise, the generic + + if #next < 0 { return false } + +check handles stopping the next loop to get one step closer to the label. + +For example + + Top: print("start\n") + for range f { + for range g { + ... + for range h { + ... + goto Top + ... + } + } + } + +becomes + + Top: print("start\n") + { + var #next int + var #exit1 bool + f(func() { + if #exit1 { runtime.panicrangeexit() } + var #exit2 bool + g(func() { + if #exit2 { runtime.panicrangeexit() } + ... + var #exit3 bool + h(func() { + if #exit3 { runtime.panicrangeexit() } + ... + { + // goto Top + #next = -3 + #exit1, #exit2, #exit3 = true, true, true + return false + } + ... + return true + }) + #exit3 = true + if #next < 0 { + return false + } + return true + }) + #exit2 = true + if #next < 0 { + return false + } + return true + }) + #exit1 = true + if #next == -3 { + #next = 0 + goto Top + } + } + +Labeled break/continue to non-range-over-funcs are handled the same +way as goto. + +# Defers + +The last wrinkle is handling defer statements. If we have + + for range f { + defer print("A") + } + +we cannot rewrite that into + + f(func() { + defer print("A") + }) + +because the deferred code will run at the end of the iteration, not +the end of the containing function. To fix that, the runtime provides +a special hook that lets us obtain a defer "token" representing the +outer function and then use it in a later defer to attach the deferred +code to that outer function. + +Normally, + + defer print("A") + +compiles to + + runtime.deferproc(func() { print("A") }) + +This changes in a range-over-func. For example: + + for range f { + defer print("A") + } + +compiles to + + var #defers = runtime.deferrangefunc() + f(func() { + runtime.deferprocat(func() { print("A") }, #defers) + }) + +For this rewriting phase, we insert the explicit initialization of +#defers and then attach the #defers variable to the CallStmt +representing the defer. That variable will be propagated to the +backend and will cause the backend to compile the defer using +deferprocat instead of an ordinary deferproc. + +TODO: Could call runtime.deferrangefuncend after f. +*/ +package rangefunc + +import ( + "cmd/compile/internal/base" + "cmd/compile/internal/syntax" + "cmd/compile/internal/types2" + "fmt" + "go/constant" + "os" +) + +// nopos is the zero syntax.Pos. +var nopos syntax.Pos + +// A rewriter implements rewriting the range-over-funcs in a given function. +type rewriter struct { + pkg *types2.Package + info *types2.Info + outer *syntax.FuncType + body *syntax.BlockStmt + + // References to important types and values. + any types2.Object + bool types2.Object + int types2.Object + true types2.Object + false types2.Object + + // Branch numbering, computed as needed. + branchNext map[branch]int // branch -> #next value + labelLoop map[string]*syntax.ForStmt // label -> innermost rangefunc loop it is declared inside (nil for no loop) + + // Stack of nodes being visited. + stack []syntax.Node // all nodes + forStack []*forLoop // range-over-func loops + + rewritten map[*syntax.ForStmt]syntax.Stmt + + // Declared variables in generated code for outermost loop. + declStmt *syntax.DeclStmt + nextVar types2.Object + retVars []types2.Object + defers types2.Object + exitVarCount int // exitvars are referenced from their respective loops +} + +// A branch is a single labeled branch. +type branch struct { + tok syntax.Token + label string +} + +// A forLoop describes a single range-over-func loop being processed. +type forLoop struct { + nfor *syntax.ForStmt // actual syntax + exitFlag *types2.Var // #exit variable for this loop + exitFlagDecl *syntax.VarDecl + + checkRet bool // add check for "return" after loop + checkRetArgs bool // add check for "return args" after loop + checkBreak bool // add check for "break" after loop + checkContinue bool // add check for "continue" after loop + checkBranch []branch // add check for labeled branch after loop +} + +// Rewrite rewrites all the range-over-funcs in the files. +func Rewrite(pkg *types2.Package, info *types2.Info, files []*syntax.File) { + for _, file := range files { + syntax.Inspect(file, func(n syntax.Node) bool { + switch n := n.(type) { + case *syntax.FuncDecl: + rewriteFunc(pkg, info, n.Type, n.Body) + return false + case *syntax.FuncLit: + rewriteFunc(pkg, info, n.Type, n.Body) + return false + } + return true + }) + } +} + +// rewriteFunc rewrites all the range-over-funcs in a single function (a top-level func or a func literal). +// The typ and body are the function's type and body. +func rewriteFunc(pkg *types2.Package, info *types2.Info, typ *syntax.FuncType, body *syntax.BlockStmt) { + if body == nil { + return + } + r := &rewriter{ + pkg: pkg, + info: info, + outer: typ, + body: body, + } + syntax.Inspect(body, r.inspect) + if (base.Flag.W != 0) && r.forStack != nil { + syntax.Fdump(os.Stderr, body) + } +} + +// checkFuncMisuse reports whether to check for misuse of iterator callbacks functions. +func (r *rewriter) checkFuncMisuse() bool { + return base.Debug.RangeFuncCheck != 0 +} + +// inspect is a callback for syntax.Inspect that drives the actual rewriting. +// If it sees a func literal, it kicks off a separate rewrite for that literal. +// Otherwise, it maintains a stack of range-over-func loops and +// converts each in turn. +func (r *rewriter) inspect(n syntax.Node) bool { + switch n := n.(type) { + case *syntax.FuncLit: + rewriteFunc(r.pkg, r.info, n.Type, n.Body) + return false + + default: + // Push n onto stack. + r.stack = append(r.stack, n) + if nfor, ok := forRangeFunc(n); ok { + loop := &forLoop{nfor: nfor} + r.forStack = append(r.forStack, loop) + r.startLoop(loop) + } + + case nil: + // n == nil signals that we are done visiting + // the top-of-stack node's children. Find it. + n = r.stack[len(r.stack)-1] + + // If we are inside a range-over-func, + // take this moment to replace any break/continue/goto/return + // statements directly contained in this node. + // Also replace any converted for statements + // with the rewritten block. + switch n := n.(type) { + case *syntax.BlockStmt: + for i, s := range n.List { + n.List[i] = r.editStmt(s) + } + case *syntax.CaseClause: + for i, s := range n.Body { + n.Body[i] = r.editStmt(s) + } + case *syntax.CommClause: + for i, s := range n.Body { + n.Body[i] = r.editStmt(s) + } + case *syntax.LabeledStmt: + n.Stmt = r.editStmt(n.Stmt) + } + + // Pop n. + if len(r.forStack) > 0 && r.stack[len(r.stack)-1] == r.forStack[len(r.forStack)-1].nfor { + r.endLoop(r.forStack[len(r.forStack)-1]) + r.forStack = r.forStack[:len(r.forStack)-1] + } + r.stack = r.stack[:len(r.stack)-1] + } + return true +} + +// startLoop sets up for converting a range-over-func loop. +func (r *rewriter) startLoop(loop *forLoop) { + // For first loop in function, allocate syntax for any, bool, int, true, and false. + if r.any == nil { + r.any = types2.Universe.Lookup("any") + r.bool = types2.Universe.Lookup("bool") + r.int = types2.Universe.Lookup("int") + r.true = types2.Universe.Lookup("true") + r.false = types2.Universe.Lookup("false") + r.rewritten = make(map[*syntax.ForStmt]syntax.Stmt) + } + if r.checkFuncMisuse() { + // declare the exit flag for this loop's body + loop.exitFlag, loop.exitFlagDecl = r.exitVar(loop.nfor.Pos()) + } +} + +// editStmt returns the replacement for the statement x, +// or x itself if it should be left alone. +// This includes the for loops we are converting, +// as left in x.rewritten by r.endLoop. +func (r *rewriter) editStmt(x syntax.Stmt) syntax.Stmt { + if x, ok := x.(*syntax.ForStmt); ok { + if s := r.rewritten[x]; s != nil { + return s + } + } + + if len(r.forStack) > 0 { + switch x := x.(type) { + case *syntax.BranchStmt: + return r.editBranch(x) + case *syntax.CallStmt: + if x.Tok == syntax.Defer { + return r.editDefer(x) + } + case *syntax.ReturnStmt: + return r.editReturn(x) + } + } + + return x +} + +// editDefer returns the replacement for the defer statement x. +// See the "Defers" section in the package doc comment above for more context. +func (r *rewriter) editDefer(x *syntax.CallStmt) syntax.Stmt { + if r.defers == nil { + // Declare and initialize the #defers token. + init := &syntax.CallExpr{ + Fun: runtimeSym(r.info, "deferrangefunc"), + } + tv := syntax.TypeAndValue{Type: r.any.Type()} + tv.SetIsValue() + init.SetTypeInfo(tv) + r.defers = r.declVar("#defers", r.any.Type(), init) + } + + // Attach the token as an "extra" argument to the defer. + x.DeferAt = r.useVar(r.defers) + setPos(x.DeferAt, x.Pos()) + return x +} + +func (r *rewriter) exitVar(pos syntax.Pos) (*types2.Var, *syntax.VarDecl) { + r.exitVarCount++ + + name := fmt.Sprintf("#exit%d", r.exitVarCount) + typ := r.bool.Type() + obj := types2.NewVar(pos, r.pkg, name, typ) + n := syntax.NewName(pos, name) + setValueType(n, typ) + r.info.Defs[n] = obj + + return obj, &syntax.VarDecl{NameList: []*syntax.Name{n}} +} + +// editReturn returns the replacement for the return statement x. +// See the "Return" section in the package doc comment above for more context. +func (r *rewriter) editReturn(x *syntax.ReturnStmt) syntax.Stmt { + // #next = -1 is return with no arguments; -2 is return with arguments. + var next int + if x.Results == nil { + next = -1 + r.forStack[0].checkRet = true + } else { + next = -2 + r.forStack[0].checkRetArgs = true + } + + // Tell the loops along the way to check for a return. + for _, loop := range r.forStack[1:] { + loop.checkRet = true + } + + // Assign results, set #next, and return false. + bl := &syntax.BlockStmt{} + if x.Results != nil { + if r.retVars == nil { + for i, a := range r.outer.ResultList { + obj := r.declVar(fmt.Sprintf("#r%d", i+1), a.Type.GetTypeInfo().Type, nil) + r.retVars = append(r.retVars, obj) + } + } + bl.List = append(bl.List, &syntax.AssignStmt{Lhs: r.useList(r.retVars), Rhs: x.Results}) + } + bl.List = append(bl.List, &syntax.AssignStmt{Lhs: r.next(), Rhs: r.intConst(next)}) + if r.checkFuncMisuse() { + // mark all enclosing loop bodies as exited + for i := 0; i < len(r.forStack); i++ { + bl.List = append(bl.List, r.setExitedAt(i)) + } + } + bl.List = append(bl.List, &syntax.ReturnStmt{Results: r.useVar(r.false)}) + setPos(bl, x.Pos()) + return bl +} + +// perLoopStep is part of the encoding of loop-spanning control flow +// for function range iterators. Each multiple of two encodes a "return false" +// passing control to an enclosing iterator; a terminal value of 1 encodes +// "return true" (i.e., local continue) from the body function, and a terminal +// value of 0 encodes executing the remainder of the body function. +const perLoopStep = 2 + +// editBranch returns the replacement for the branch statement x, +// or x itself if it should be left alone. +// See the package doc comment above for more context. +func (r *rewriter) editBranch(x *syntax.BranchStmt) syntax.Stmt { + if x.Tok == syntax.Fallthrough { + // Fallthrough is unaffected by the rewrite. + return x + } + + // Find target of break/continue/goto in r.forStack. + // (The target may not be in r.forStack at all.) + targ := x.Target + i := len(r.forStack) - 1 + if x.Label == nil && r.forStack[i].nfor != targ { + // Unlabeled break or continue that's not nfor must be inside nfor. Leave alone. + return x + } + for i >= 0 && r.forStack[i].nfor != targ { + i-- + } + // exitFrom is the index of the loop interior to the target of the control flow, + // if such a loop exists (it does not if i == len(r.forStack) - 1) + exitFrom := i + 1 + + // Compute the value to assign to #next and the specific return to use. + var next int + var ret *syntax.ReturnStmt + if x.Tok == syntax.Goto || i < 0 { + // goto Label + // or break/continue of labeled non-range-over-func loop. + // We may be able to leave it alone, or we may have to break + // out of one or more nested loops and then use #next to signal + // to complete the break/continue/goto. + // Figure out which range-over-func loop contains the label. + r.computeBranchNext() + nfor := r.forStack[len(r.forStack)-1].nfor + label := x.Label.Value + targ := r.labelLoop[label] + if nfor == targ { + // Label is in the innermost range-over-func loop; use it directly. + return x + } + + // Set #next to the code meaning break/continue/goto label. + next = r.branchNext[branch{x.Tok, label}] + + // Break out of nested loops up to targ. + i := len(r.forStack) - 1 + for i >= 0 && r.forStack[i].nfor != targ { + i-- + } + exitFrom = i + 1 + + // Mark loop we exit to get to targ to check for that branch. + // When i==-1 that's the outermost func body + top := r.forStack[i+1] + top.checkBranch = append(top.checkBranch, branch{x.Tok, label}) + + // Mark loops along the way to check for a plain return, so they break. + for j := i + 2; j < len(r.forStack); j++ { + r.forStack[j].checkRet = true + } + + // In the innermost loop, use a plain "return false". + ret = &syntax.ReturnStmt{Results: r.useVar(r.false)} + } else { + // break/continue of labeled range-over-func loop. + depth := len(r.forStack) - 1 - i + + // For continue of innermost loop, use "return true". + // Otherwise we are breaking the innermost loop, so "return false". + + if depth == 0 && x.Tok == syntax.Continue { + ret = &syntax.ReturnStmt{Results: r.useVar(r.true)} + setPos(ret, x.Pos()) + return ret + } + ret = &syntax.ReturnStmt{Results: r.useVar(r.false)} + + // If this is a simple break, mark this loop as exited and return false. + // No adjustments to #next. + if depth == 0 { + var stmts []syntax.Stmt + if r.checkFuncMisuse() { + stmts = []syntax.Stmt{r.setExited(), ret} + } else { + stmts = []syntax.Stmt{ret} + } + bl := &syntax.BlockStmt{ + List: stmts, + } + setPos(bl, x.Pos()) + return bl + } + + // The loop inside the one we are break/continue-ing + // needs to make that happen when we break out of it. + if x.Tok == syntax.Continue { + r.forStack[exitFrom].checkContinue = true + } else { + exitFrom = i + r.forStack[exitFrom].checkBreak = true + } + + // The loops along the way just need to break. + for j := exitFrom + 1; j < len(r.forStack); j++ { + r.forStack[j].checkBreak = true + } + + // Set next to break the appropriate number of times; + // the final time may be a continue, not a break. + next = perLoopStep * depth + if x.Tok == syntax.Continue { + next-- + } + } + + // Assign #next = next and do the return. + as := &syntax.AssignStmt{Lhs: r.next(), Rhs: r.intConst(next)} + bl := &syntax.BlockStmt{ + List: []syntax.Stmt{as}, + } + + if r.checkFuncMisuse() { + // Set #exitK for this loop and those exited by the control flow. + for i := exitFrom; i < len(r.forStack); i++ { + bl.List = append(bl.List, r.setExitedAt(i)) + } + } + + bl.List = append(bl.List, ret) + setPos(bl, x.Pos()) + return bl +} + +// computeBranchNext computes the branchNext numbering +// and determines which labels end up inside which range-over-func loop bodies. +func (r *rewriter) computeBranchNext() { + if r.labelLoop != nil { + return + } + + r.labelLoop = make(map[string]*syntax.ForStmt) + r.branchNext = make(map[branch]int) + + var labels []string + var stack []syntax.Node + var forStack []*syntax.ForStmt + forStack = append(forStack, nil) + syntax.Inspect(r.body, func(n syntax.Node) bool { + if n != nil { + stack = append(stack, n) + if nfor, ok := forRangeFunc(n); ok { + forStack = append(forStack, nfor) + } + if n, ok := n.(*syntax.LabeledStmt); ok { + l := n.Label.Value + labels = append(labels, l) + f := forStack[len(forStack)-1] + r.labelLoop[l] = f + } + } else { + n := stack[len(stack)-1] + stack = stack[:len(stack)-1] + if n == forStack[len(forStack)-1] { + forStack = forStack[:len(forStack)-1] + } + } + return true + }) + + // Assign numbers to all the labels we observed. + used := -2 + for _, l := range labels { + used -= 3 + r.branchNext[branch{syntax.Break, l}] = used + r.branchNext[branch{syntax.Continue, l}] = used + 1 + r.branchNext[branch{syntax.Goto, l}] = used + 2 + } +} + +// endLoop finishes the conversion of a range-over-func loop. +// We have inspected and rewritten the body of the loop and can now +// construct the body function and rewrite the for loop into a call +// bracketed by any declarations and checks it requires. +func (r *rewriter) endLoop(loop *forLoop) { + // Pick apart for range X { ... } + nfor := loop.nfor + start, end := nfor.Pos(), nfor.Body.Rbrace // start, end position of for loop + rclause := nfor.Init.(*syntax.RangeClause) + rfunc := types2.CoreType(rclause.X.GetTypeInfo().Type).(*types2.Signature) // type of X - func(func(...)bool) + if rfunc.Params().Len() != 1 { + base.Fatalf("invalid typecheck of range func") + } + ftyp := types2.CoreType(rfunc.Params().At(0).Type()).(*types2.Signature) // func(...) bool + if ftyp.Results().Len() != 1 { + base.Fatalf("invalid typecheck of range func") + } + + // Build X(bodyFunc) + call := &syntax.ExprStmt{ + X: &syntax.CallExpr{ + Fun: rclause.X, + ArgList: []syntax.Expr{ + r.bodyFunc(nfor.Body.List, syntax.UnpackListExpr(rclause.Lhs), rclause.Def, ftyp, start, end), + }, + }, + } + setPos(call, start) + + // Build checks based on #next after X(bodyFunc) + checks := r.checks(loop, end) + + // Rewrite for vars := range X { ... } to + // + // { + // r.declStmt + // call + // checks + // } + // + // The r.declStmt can be added to by this loop or any inner loop + // during the creation of r.bodyFunc; it is only emitted in the outermost + // converted range loop. + block := &syntax.BlockStmt{Rbrace: end} + setPos(block, start) + if len(r.forStack) == 1 && r.declStmt != nil { + setPos(r.declStmt, start) + block.List = append(block.List, r.declStmt) + } + + // declare the exitFlag here so it has proper scope and zeroing + if r.checkFuncMisuse() { + exitFlagDecl := &syntax.DeclStmt{DeclList: []syntax.Decl{loop.exitFlagDecl}} + block.List = append(block.List, exitFlagDecl) + } + + // iteratorFunc(bodyFunc) + block.List = append(block.List, call) + + if r.checkFuncMisuse() { + // iteratorFunc has exited, mark the exit flag for the body + block.List = append(block.List, r.setExited()) + } + block.List = append(block.List, checks...) + + if len(r.forStack) == 1 { // ending an outermost loop + r.declStmt = nil + r.nextVar = nil + r.retVars = nil + r.defers = nil + } + + r.rewritten[nfor] = block +} + +func (r *rewriter) setExited() *syntax.AssignStmt { + return r.setExitedAt(len(r.forStack) - 1) +} + +func (r *rewriter) setExitedAt(index int) *syntax.AssignStmt { + loop := r.forStack[index] + return &syntax.AssignStmt{ + Lhs: r.useVar(loop.exitFlag), + Rhs: r.useVar(r.true), + } +} + +// bodyFunc converts the loop body (control flow has already been updated) +// to a func literal that can be passed to the range function. +// +// vars is the range variables from the range statement. +// def indicates whether this is a := range statement. +// ftyp is the type of the function we are creating +// start and end are the syntax positions to use for new nodes +// that should be at the start or end of the loop. +func (r *rewriter) bodyFunc(body []syntax.Stmt, lhs []syntax.Expr, def bool, ftyp *types2.Signature, start, end syntax.Pos) *syntax.FuncLit { + // Starting X(bodyFunc); build up bodyFunc first. + var params, results []*types2.Var + results = append(results, types2.NewVar(start, nil, "", r.bool.Type())) + bodyFunc := &syntax.FuncLit{ + // Note: Type is ignored but needs to be non-nil to avoid panic in syntax.Inspect. + Type: &syntax.FuncType{}, + Body: &syntax.BlockStmt{ + List: []syntax.Stmt{}, + Rbrace: end, + }, + } + setPos(bodyFunc, start) + + for i := 0; i < ftyp.Params().Len(); i++ { + typ := ftyp.Params().At(i).Type() + var paramVar *types2.Var + if i < len(lhs) && def { + // Reuse range variable as parameter. + x := lhs[i] + paramVar = r.info.Defs[x.(*syntax.Name)].(*types2.Var) + } else { + // Declare new parameter and assign it to range expression. + paramVar = types2.NewVar(start, r.pkg, fmt.Sprintf("#p%d", 1+i), typ) + if i < len(lhs) { + x := lhs[i] + as := &syntax.AssignStmt{Lhs: x, Rhs: r.useVar(paramVar)} + as.SetPos(x.Pos()) + setPos(as.Rhs, x.Pos()) + bodyFunc.Body.List = append(bodyFunc.Body.List, as) + } + } + params = append(params, paramVar) + } + + tv := syntax.TypeAndValue{ + Type: types2.NewSignatureType(nil, nil, nil, + types2.NewTuple(params...), + types2.NewTuple(results...), + false), + } + tv.SetIsValue() + bodyFunc.SetTypeInfo(tv) + + loop := r.forStack[len(r.forStack)-1] + + if r.checkFuncMisuse() { + bodyFunc.Body.List = append(bodyFunc.Body.List, r.assertNotExited(start, loop)) + } + + // Original loop body (already rewritten by editStmt during inspect). + bodyFunc.Body.List = append(bodyFunc.Body.List, body...) + + // return true to continue at end of loop body + ret := &syntax.ReturnStmt{Results: r.useVar(r.true)} + ret.SetPos(end) + bodyFunc.Body.List = append(bodyFunc.Body.List, ret) + + return bodyFunc +} + +// checks returns the post-call checks that need to be done for the given loop. +func (r *rewriter) checks(loop *forLoop, pos syntax.Pos) []syntax.Stmt { + var list []syntax.Stmt + if len(loop.checkBranch) > 0 { + did := make(map[branch]bool) + for _, br := range loop.checkBranch { + if did[br] { + continue + } + did[br] = true + doBranch := &syntax.BranchStmt{Tok: br.tok, Label: &syntax.Name{Value: br.label}} + list = append(list, r.ifNext(syntax.Eql, r.branchNext[br], doBranch)) + } + } + if len(r.forStack) == 1 { + if loop.checkRetArgs { + list = append(list, r.ifNext(syntax.Eql, -2, retStmt(r.useList(r.retVars)))) + } + if loop.checkRet { + list = append(list, r.ifNext(syntax.Eql, -1, retStmt(nil))) + } + } else { + if loop.checkRetArgs || loop.checkRet { + // Note: next < 0 also handles gotos handled by outer loops. + // We set checkRet in that case to trigger this check. + list = append(list, r.ifNext(syntax.Lss, 0, retStmt(r.useVar(r.false)))) + } + if loop.checkBreak { + list = append(list, r.ifNext(syntax.Geq, perLoopStep, retStmt(r.useVar(r.false)))) + } + if loop.checkContinue { + list = append(list, r.ifNext(syntax.Eql, perLoopStep-1, retStmt(r.useVar(r.true)))) + } + } + + for _, j := range list { + setPos(j, pos) + } + return list +} + +// retStmt returns a return statement returning the given return values. +func retStmt(results syntax.Expr) *syntax.ReturnStmt { + return &syntax.ReturnStmt{Results: results} +} + +// ifNext returns the statement: +// +// if #next op c { adjust; then } +// +// When op is >=, adjust is #next -= c. +// When op is == and c is not -1 or -2, adjust is #next = 0. +// Otherwise adjust is omitted. +func (r *rewriter) ifNext(op syntax.Operator, c int, then syntax.Stmt) syntax.Stmt { + nif := &syntax.IfStmt{ + Cond: &syntax.Operation{Op: op, X: r.next(), Y: r.intConst(c)}, + Then: &syntax.BlockStmt{ + List: []syntax.Stmt{then}, + }, + } + tv := syntax.TypeAndValue{Type: r.bool.Type()} + tv.SetIsValue() + nif.Cond.SetTypeInfo(tv) + + if op == syntax.Geq { + sub := &syntax.AssignStmt{ + Op: syntax.Sub, + Lhs: r.next(), + Rhs: r.intConst(c), + } + nif.Then.List = []syntax.Stmt{sub, then} + } + if op == syntax.Eql && c != -1 && c != -2 { + clr := &syntax.AssignStmt{ + Lhs: r.next(), + Rhs: r.intConst(0), + } + nif.Then.List = []syntax.Stmt{clr, then} + } + + return nif +} + +// setValueType marks x as a value with type typ. +func setValueType(x syntax.Expr, typ syntax.Type) { + tv := syntax.TypeAndValue{Type: typ} + tv.SetIsValue() + x.SetTypeInfo(tv) +} + +// assertNotExited returns the statement: +// +// if #exitK { runtime.panicrangeexit() } +// +// where #exitK is the exit guard for loop. +func (r *rewriter) assertNotExited(start syntax.Pos, loop *forLoop) syntax.Stmt { + callPanicExpr := &syntax.CallExpr{ + Fun: runtimeSym(r.info, "panicrangeexit"), + } + setValueType(callPanicExpr, nil) // no result type + + callPanic := &syntax.ExprStmt{X: callPanicExpr} + + nif := &syntax.IfStmt{ + Cond: r.useVar(loop.exitFlag), + Then: &syntax.BlockStmt{ + List: []syntax.Stmt{callPanic}, + }, + } + setPos(nif, start) + return nif +} + +// next returns a reference to the #next variable. +func (r *rewriter) next() *syntax.Name { + if r.nextVar == nil { + r.nextVar = r.declVar("#next", r.int.Type(), nil) + } + return r.useVar(r.nextVar) +} + +// forRangeFunc checks whether n is a range-over-func. +// If so, it returns n.(*syntax.ForStmt), true. +// Otherwise it returns nil, false. +func forRangeFunc(n syntax.Node) (*syntax.ForStmt, bool) { + nfor, ok := n.(*syntax.ForStmt) + if !ok { + return nil, false + } + nrange, ok := nfor.Init.(*syntax.RangeClause) + if !ok { + return nil, false + } + _, ok = types2.CoreType(nrange.X.GetTypeInfo().Type).(*types2.Signature) + if !ok { + return nil, false + } + return nfor, true +} + +// intConst returns syntax for an integer literal with the given value. +func (r *rewriter) intConst(c int) *syntax.BasicLit { + lit := &syntax.BasicLit{ + Value: fmt.Sprint(c), + Kind: syntax.IntLit, + } + tv := syntax.TypeAndValue{Type: r.int.Type(), Value: constant.MakeInt64(int64(c))} + tv.SetIsValue() + lit.SetTypeInfo(tv) + return lit +} + +// useVar returns syntax for a reference to decl, which should be its declaration. +func (r *rewriter) useVar(obj types2.Object) *syntax.Name { + n := syntax.NewName(nopos, obj.Name()) + tv := syntax.TypeAndValue{Type: obj.Type()} + tv.SetIsValue() + n.SetTypeInfo(tv) + r.info.Uses[n] = obj + return n +} + +// useList is useVar for a list of decls. +func (r *rewriter) useList(vars []types2.Object) syntax.Expr { + var new []syntax.Expr + for _, obj := range vars { + new = append(new, r.useVar(obj)) + } + if len(new) == 1 { + return new[0] + } + return &syntax.ListExpr{ElemList: new} +} + +// declVar declares a variable with a given name type and initializer value. +func (r *rewriter) declVar(name string, typ types2.Type, init syntax.Expr) *types2.Var { + if r.declStmt == nil { + r.declStmt = &syntax.DeclStmt{} + } + stmt := r.declStmt + obj := types2.NewVar(stmt.Pos(), r.pkg, name, typ) + n := syntax.NewName(stmt.Pos(), name) + tv := syntax.TypeAndValue{Type: typ} + tv.SetIsValue() + n.SetTypeInfo(tv) + r.info.Defs[n] = obj + stmt.DeclList = append(stmt.DeclList, &syntax.VarDecl{ + NameList: []*syntax.Name{n}, + // Note: Type is ignored + Values: init, + }) + return obj +} + +// declType declares a type with the given name and type. +// This is more like "type name = typ" than "type name typ". +func declType(pos syntax.Pos, name string, typ types2.Type) *syntax.Name { + n := syntax.NewName(pos, name) + n.SetTypeInfo(syntax.TypeAndValue{Type: typ}) + return n +} + +// runtimePkg is a fake runtime package that contains what we need to refer to in package runtime. +var runtimePkg = func() *types2.Package { + var nopos syntax.Pos + pkg := types2.NewPackage("runtime", "runtime") + anyType := types2.Universe.Lookup("any").Type() + + // func deferrangefunc() unsafe.Pointer + obj := types2.NewFunc(nopos, pkg, "deferrangefunc", types2.NewSignatureType(nil, nil, nil, nil, types2.NewTuple(types2.NewParam(nopos, pkg, "extra", anyType)), false)) + pkg.Scope().Insert(obj) + + // func panicrangeexit() + obj = types2.NewFunc(nopos, pkg, "panicrangeexit", types2.NewSignatureType(nil, nil, nil, nil, nil, false)) + pkg.Scope().Insert(obj) + + return pkg +}() + +// runtimeSym returns a reference to a symbol in the fake runtime package. +func runtimeSym(info *types2.Info, name string) *syntax.Name { + obj := runtimePkg.Scope().Lookup(name) + n := syntax.NewName(nopos, "runtime."+name) + tv := syntax.TypeAndValue{Type: obj.Type()} + tv.SetIsValue() + tv.SetIsRuntimeHelper() + n.SetTypeInfo(tv) + info.Uses[n] = obj + return n +} + +// setPos walks the top structure of x that has no position assigned +// and assigns it all to have position pos. +// When setPos encounters a syntax node with a position assigned, +// setPos does not look inside that node. +// setPos only needs to handle syntax we create in this package; +// all other syntax should have positions assigned already. +func setPos(x syntax.Node, pos syntax.Pos) { + if x == nil { + return + } + syntax.Inspect(x, func(n syntax.Node) bool { + if n == nil || n.Pos() != nopos { + return false + } + n.SetPos(pos) + switch n := n.(type) { + case *syntax.BlockStmt: + if n.Rbrace == nopos { + n.Rbrace = pos + } + } + return true + }) +} |