summaryrefslogtreecommitdiffstats
path: root/src/cmd/compile/internal/rangefunc
diff options
context:
space:
mode:
Diffstat (limited to 'src/cmd/compile/internal/rangefunc')
-rw-r--r--src/cmd/compile/internal/rangefunc/rangefunc_test.go1297
-rw-r--r--src/cmd/compile/internal/rangefunc/rewrite.go1334
2 files changed, 2631 insertions, 0 deletions
diff --git a/src/cmd/compile/internal/rangefunc/rangefunc_test.go b/src/cmd/compile/internal/rangefunc/rangefunc_test.go
new file mode 100644
index 0000000..16856c6
--- /dev/null
+++ b/src/cmd/compile/internal/rangefunc/rangefunc_test.go
@@ -0,0 +1,1297 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build goexperiment.rangefunc
+
+package rangefunc_test
+
+import (
+ "slices"
+ "testing"
+)
+
+type Seq2[T1, T2 any] func(yield func(T1, T2) bool)
+
+// OfSliceIndex returns a Seq over the elements of s. It is equivalent
+// to range s.
+func OfSliceIndex[T any, S ~[]T](s S) Seq2[int, T] {
+ return func(yield func(int, T) bool) {
+ for i, v := range s {
+ if !yield(i, v) {
+ return
+ }
+ }
+ return
+ }
+}
+
+// BadOfSliceIndex is "bad" because it ignores the return value from yield
+// and just keeps on iterating.
+func BadOfSliceIndex[T any, S ~[]T](s S) Seq2[int, T] {
+ return func(yield func(int, T) bool) {
+ for i, v := range s {
+ yield(i, v)
+ }
+ return
+ }
+}
+
+// VeryBadOfSliceIndex is "very bad" because it ignores the return value from yield
+// and just keeps on iterating, and also wraps that call in a defer-recover so it can
+// keep on trying after the first panic.
+func VeryBadOfSliceIndex[T any, S ~[]T](s S) Seq2[int, T] {
+ return func(yield func(int, T) bool) {
+ for i, v := range s {
+ func() {
+ defer func() {
+ recover()
+ }()
+ yield(i, v)
+ }()
+ }
+ return
+ }
+}
+
+// CooperativeBadOfSliceIndex calls the loop body from a goroutine after
+// a ping on a channel, and returns recover()on that same channel.
+func CooperativeBadOfSliceIndex[T any, S ~[]T](s S, proceed chan any) Seq2[int, T] {
+ return func(yield func(int, T) bool) {
+ for i, v := range s {
+ if !yield(i, v) {
+ // if the body breaks, call yield just once in a goroutine
+ go func() {
+ <-proceed
+ defer func() {
+ proceed <- recover()
+ }()
+ yield(0, s[0])
+ }()
+ return
+ }
+ }
+ return
+ }
+}
+
+// TrickyIterator is a type intended to test whether an iterator that
+// calls a yield function after loop exit must inevitably escape the
+// closure; this might be relevant to future checking/optimization.
+type TrickyIterator struct {
+ yield func(int, int) bool
+}
+
+func (ti *TrickyIterator) iterAll(s []int) Seq2[int, int] {
+ return func(yield func(int, int) bool) {
+ ti.yield = yield // Save yield for future abuse
+ for i, v := range s {
+ if !yield(i, v) {
+ return
+ }
+ }
+ return
+ }
+}
+
+func (ti *TrickyIterator) iterOne(s []int) Seq2[int, int] {
+ return func(yield func(int, int) bool) {
+ ti.yield = yield // Save yield for future abuse
+ if len(s) > 0 { // Not in a loop might escape differently
+ yield(0, s[0])
+ }
+ return
+ }
+}
+
+func (ti *TrickyIterator) iterZero(s []int) Seq2[int, int] {
+ return func(yield func(int, int) bool) {
+ ti.yield = yield // Save yield for future abuse
+ // Don't call it at all, maybe it won't escape
+ return
+ }
+}
+
+func (ti *TrickyIterator) fail() {
+ if ti.yield != nil {
+ ti.yield(1, 1)
+ }
+}
+
+// Check wraps the function body passed to iterator forall
+// in code that ensures that it cannot (successfully) be called
+// either after body return false (control flow out of loop) or
+// forall itself returns (the iteration is now done).
+//
+// Note that this can catch errors before the inserted checks.
+func Check[U, V any](forall Seq2[U, V]) Seq2[U, V] {
+ return func(body func(U, V) bool) {
+ ret := true
+ forall(func(u U, v V) bool {
+ if !ret {
+ panic("Checked iterator access after exit")
+ }
+ ret = body(u, v)
+ return ret
+ })
+ ret = false
+ }
+}
+
+func TestCheck(t *testing.T) {
+ i := 0
+ defer func() {
+ if r := recover(); r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ } else {
+ t.Error("Wanted to see a failure")
+ }
+ }()
+ for _, x := range Check(BadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) {
+ i += x
+ if i > 4*9 {
+ break
+ }
+ }
+}
+
+func TestCooperativeBadOfSliceIndex(t *testing.T) {
+ i := 0
+ proceed := make(chan any)
+ for _, x := range CooperativeBadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, proceed) {
+ i += x
+ if i >= 36 {
+ break
+ }
+ }
+ proceed <- true
+ if r := <-proceed; r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ } else {
+ t.Error("Wanted to see a failure")
+ }
+ if i != 36 {
+ t.Errorf("Expected i == 36, saw %d instead", i)
+ } else {
+ t.Logf("i = %d", i)
+ }
+}
+
+func TestCheckCooperativeBadOfSliceIndex(t *testing.T) {
+ i := 0
+ proceed := make(chan any)
+ for _, x := range Check(CooperativeBadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, proceed)) {
+ i += x
+ if i >= 36 {
+ break
+ }
+ }
+ proceed <- true
+ if r := <-proceed; r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ } else {
+ t.Error("Wanted to see a failure")
+ }
+ if i != 36 {
+ t.Errorf("Expected i == 36, saw %d instead", i)
+ } else {
+ t.Logf("i = %d", i)
+ }
+}
+
+func TestTrickyIterAll(t *testing.T) {
+ trickItAll := TrickyIterator{}
+ i := 0
+ for _, x := range trickItAll.iterAll([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ i += x
+ if i >= 36 {
+ break
+ }
+ }
+
+ if i != 36 {
+ t.Errorf("Expected i == 36, saw %d instead", i)
+ } else {
+ t.Logf("i = %d", i)
+ }
+
+ defer func() {
+ if r := recover(); r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ } else {
+ t.Error("Wanted to see a failure")
+ }
+ }()
+
+ trickItAll.fail()
+}
+
+func TestTrickyIterOne(t *testing.T) {
+ trickItOne := TrickyIterator{}
+ i := 0
+ for _, x := range trickItOne.iterOne([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ i += x
+ if i >= 36 {
+ break
+ }
+ }
+
+ // Don't care about value, ought to be 36 anyhow.
+ t.Logf("i = %d", i)
+
+ defer func() {
+ if r := recover(); r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ } else {
+ t.Error("Wanted to see a failure")
+ }
+ }()
+
+ trickItOne.fail()
+}
+
+func TestTrickyIterZero(t *testing.T) {
+ trickItZero := TrickyIterator{}
+ i := 0
+ for _, x := range trickItZero.iterZero([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ i += x
+ if i >= 36 {
+ break
+ }
+ }
+
+ // Don't care about value, ought to be 0 anyhow.
+ t.Logf("i = %d", i)
+
+ defer func() {
+ if r := recover(); r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ } else {
+ t.Error("Wanted to see a failure")
+ }
+ }()
+
+ trickItZero.fail()
+}
+
+func TestCheckTrickyIterZero(t *testing.T) {
+ trickItZero := TrickyIterator{}
+ i := 0
+ for _, x := range Check(trickItZero.iterZero([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})) {
+ i += x
+ if i >= 36 {
+ break
+ }
+ }
+
+ // Don't care about value, ought to be 0 anyhow.
+ t.Logf("i = %d", i)
+
+ defer func() {
+ if r := recover(); r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ } else {
+ t.Error("Wanted to see a failure")
+ }
+ }()
+
+ trickItZero.fail()
+}
+
+// TestBreak1 should just work, with well-behaved iterators.
+// (The misbehaving iterator detector should not trigger.)
+func TestBreak1(t *testing.T) {
+ var result []int
+ var expect = []int{1, 2, -1, 1, 2, -2, 1, 2, -3}
+ for _, x := range OfSliceIndex([]int{-1, -2, -3, -4}) {
+ if x == -4 {
+ break
+ }
+ for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if y == 3 {
+ break
+ }
+ result = append(result, y)
+ }
+ result = append(result, x)
+ }
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// TestBreak2 should just work, with well-behaved iterators.
+// (The misbehaving iterator detector should not trigger.)
+func TestBreak2(t *testing.T) {
+ var result []int
+ var expect = []int{1, 2, -1, 1, 2, -2, 1, 2, -3}
+outer:
+ for _, x := range OfSliceIndex([]int{-1, -2, -3, -4}) {
+ for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if y == 3 {
+ break
+ }
+ if x == -4 {
+ break outer
+ }
+
+ result = append(result, y)
+ }
+ result = append(result, x)
+ }
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// TestContinue should just work, with well-behaved iterators.
+// (The misbehaving iterator detector should not trigger.)
+func TestContinue(t *testing.T) {
+ var result []int
+ var expect = []int{-1, 1, 2, -2, 1, 2, -3, 1, 2, -4}
+outer:
+ for _, x := range OfSliceIndex([]int{-1, -2, -3, -4}) {
+ result = append(result, x)
+ for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if y == 3 {
+ continue outer
+ }
+ if x == -4 {
+ break outer
+ }
+
+ result = append(result, y)
+ }
+ result = append(result, x-10)
+ }
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// TestBreak3 should just work, with well-behaved iterators.
+// (The misbehaving iterator detector should not trigger.)
+func TestBreak3(t *testing.T) {
+ var result []int
+ var expect = []int{100, 10, 2, 4, 200, 10, 2, 4, 20, 2, 4, 300, 10, 2, 4, 20, 2, 4, 30}
+X:
+ for _, x := range OfSliceIndex([]int{100, 200, 300, 400}) {
+ Y:
+ for _, y := range OfSliceIndex([]int{10, 20, 30, 40}) {
+ if 10*y >= x {
+ break
+ }
+ result = append(result, y)
+ if y == 30 {
+ continue X
+ }
+ Z:
+ for _, z := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if z&1 == 1 {
+ continue Z
+ }
+ result = append(result, z)
+ if z >= 4 {
+ continue Y
+ }
+ }
+ result = append(result, -y) // should never be executed
+ }
+ result = append(result, x)
+ }
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// TestBreak1BadA should end in a panic when the outer-loop's
+// single-level break is ignore by BadOfSliceIndex
+func TestBreak1BadA(t *testing.T) {
+ var result []int
+ var expect = []int{1, 2, -1, 1, 2, -2, 1, 2, -3}
+
+ defer func() {
+ if r := recover(); r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+ } else {
+ t.Error("Wanted to see a failure")
+ }
+ }()
+
+ for _, x := range BadOfSliceIndex([]int{-1, -2, -3, -4, -5}) {
+ if x == -4 {
+ break
+ }
+ for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if y == 3 {
+ break
+ }
+ result = append(result, y)
+ }
+ result = append(result, x)
+ }
+}
+
+// TestBreak1BadB should end in a panic, sooner, when the inner-loop's
+// (nested) single-level break is ignored by BadOfSliceIndex
+func TestBreak1BadB(t *testing.T) {
+ var result []int
+ var expect = []int{1, 2} // inner breaks, panics, after before outer appends
+
+ defer func() {
+ if r := recover(); r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+ } else {
+ t.Error("Wanted to see a failure")
+ }
+ }()
+
+ for _, x := range OfSliceIndex([]int{-1, -2, -3, -4, -5}) {
+ if x == -4 {
+ break
+ }
+ for _, y := range BadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if y == 3 {
+ break
+ }
+ result = append(result, y)
+ }
+ result = append(result, x)
+ }
+}
+
+// TestMultiCont0 tests multilevel continue with no bad iterators
+// (it should just work)
+func TestMultiCont0(t *testing.T) {
+ var result []int
+ var expect = []int{1000, 10, 2, 4, 2000}
+
+W:
+ for _, w := range OfSliceIndex([]int{1000, 2000}) {
+ result = append(result, w)
+ if w == 2000 {
+ break
+ }
+ for _, x := range OfSliceIndex([]int{100, 200, 300, 400}) {
+ for _, y := range OfSliceIndex([]int{10, 20, 30, 40}) {
+ result = append(result, y)
+ for _, z := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if z&1 == 1 {
+ continue
+ }
+ result = append(result, z)
+ if z >= 4 {
+ continue W // modified to be multilevel
+ }
+ }
+ result = append(result, -y) // should never be executed
+ }
+ result = append(result, x)
+ }
+ }
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// TestMultiCont1 tests multilevel continue with a bad iterator
+// in the outermost loop exited by the continue.
+func TestMultiCont1(t *testing.T) {
+ var result []int
+ var expect = []int{1000, 10, 2, 4}
+ defer func() {
+ if r := recover(); r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+ } else {
+ t.Errorf("Wanted to see a failure, result was %v", result)
+ }
+ }()
+
+W:
+ for _, w := range OfSliceIndex([]int{1000, 2000}) {
+ result = append(result, w)
+ if w == 2000 {
+ break
+ }
+ for _, x := range BadOfSliceIndex([]int{100, 200, 300, 400}) {
+ for _, y := range OfSliceIndex([]int{10, 20, 30, 40}) {
+ result = append(result, y)
+ for _, z := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if z&1 == 1 {
+ continue
+ }
+ result = append(result, z)
+ if z >= 4 {
+ continue W
+ }
+ }
+ result = append(result, -y) // should never be executed
+ }
+ result = append(result, x)
+ }
+ }
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// TestMultiCont2 tests multilevel continue with a bad iterator
+// in a middle loop exited by the continue.
+func TestMultiCont2(t *testing.T) {
+ var result []int
+ var expect = []int{1000, 10, 2, 4}
+ defer func() {
+ if r := recover(); r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+ } else {
+ t.Errorf("Wanted to see a failure, result was %v", result)
+ }
+ }()
+
+W:
+ for _, w := range OfSliceIndex([]int{1000, 2000}) {
+ result = append(result, w)
+ if w == 2000 {
+ break
+ }
+ for _, x := range OfSliceIndex([]int{100, 200, 300, 400}) {
+ for _, y := range BadOfSliceIndex([]int{10, 20, 30, 40}) {
+ result = append(result, y)
+ for _, z := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if z&1 == 1 {
+ continue
+ }
+ result = append(result, z)
+ if z >= 4 {
+ continue W
+ }
+ }
+ result = append(result, -y) // should never be executed
+ }
+ result = append(result, x)
+ }
+ }
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// TestMultiCont3 tests multilevel continue with a bad iterator
+// in the innermost loop exited by the continue.
+func TestMultiCont3(t *testing.T) {
+ var result []int
+ var expect = []int{1000, 10, 2, 4}
+ defer func() {
+ if r := recover(); r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+ } else {
+ t.Errorf("Wanted to see a failure, result was %v", result)
+ }
+ }()
+
+W:
+ for _, w := range OfSliceIndex([]int{1000, 2000}) {
+ result = append(result, w)
+ if w == 2000 {
+ break
+ }
+ for _, x := range OfSliceIndex([]int{100, 200, 300, 400}) {
+ for _, y := range OfSliceIndex([]int{10, 20, 30, 40}) {
+ result = append(result, y)
+ for _, z := range BadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if z&1 == 1 {
+ continue
+ }
+ result = append(result, z)
+ if z >= 4 {
+ continue W
+ }
+ }
+ result = append(result, -y) // should never be executed
+ }
+ result = append(result, x)
+ }
+ }
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// TestMultiBreak0 tests multilevel break with a bad iterator
+// in the outermost loop exited by the break (the outermost loop).
+func TestMultiBreak0(t *testing.T) {
+ var result []int
+ var expect = []int{1000, 10, 2, 4}
+ defer func() {
+ if r := recover(); r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+ } else {
+ t.Errorf("Wanted to see a failure, result was %v", result)
+ }
+ }()
+
+W:
+ for _, w := range BadOfSliceIndex([]int{1000, 2000}) {
+ result = append(result, w)
+ if w == 2000 {
+ break
+ }
+ for _, x := range OfSliceIndex([]int{100, 200, 300, 400}) {
+ for _, y := range OfSliceIndex([]int{10, 20, 30, 40}) {
+ result = append(result, y)
+ for _, z := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if z&1 == 1 {
+ continue
+ }
+ result = append(result, z)
+ if z >= 4 {
+ break W
+ }
+ }
+ result = append(result, -y) // should never be executed
+ }
+ result = append(result, x)
+ }
+ }
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// TestMultiBreak1 tests multilevel break with a bad iterator
+// in an intermediate loop exited by the break.
+func TestMultiBreak1(t *testing.T) {
+ var result []int
+ var expect = []int{1000, 10, 2, 4}
+ defer func() {
+ if r := recover(); r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+ } else {
+ t.Errorf("Wanted to see a failure, result was %v", result)
+ }
+ }()
+
+W:
+ for _, w := range OfSliceIndex([]int{1000, 2000}) {
+ result = append(result, w)
+ if w == 2000 {
+ break
+ }
+ for _, x := range BadOfSliceIndex([]int{100, 200, 300, 400}) {
+ for _, y := range OfSliceIndex([]int{10, 20, 30, 40}) {
+ result = append(result, y)
+ for _, z := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if z&1 == 1 {
+ continue
+ }
+ result = append(result, z)
+ if z >= 4 {
+ break W
+ }
+ }
+ result = append(result, -y) // should never be executed
+ }
+ result = append(result, x)
+ }
+ }
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// TestMultiBreak2 tests multilevel break with two bad iterators
+// in intermediate loops exited by the break.
+func TestMultiBreak2(t *testing.T) {
+ var result []int
+ var expect = []int{1000, 10, 2, 4}
+ defer func() {
+ if r := recover(); r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+ } else {
+ t.Errorf("Wanted to see a failure, result was %v", result)
+ }
+ }()
+
+W:
+ for _, w := range OfSliceIndex([]int{1000, 2000}) {
+ result = append(result, w)
+ if w == 2000 {
+ break
+ }
+ for _, x := range BadOfSliceIndex([]int{100, 200, 300, 400}) {
+ for _, y := range BadOfSliceIndex([]int{10, 20, 30, 40}) {
+ result = append(result, y)
+ for _, z := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if z&1 == 1 {
+ continue
+ }
+ result = append(result, z)
+ if z >= 4 {
+ break W
+ }
+ }
+ result = append(result, -y) // should never be executed
+ }
+ result = append(result, x)
+ }
+ }
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// TestMultiBreak3 tests multilevel break with the bad iterator
+// in the innermost loop exited by the break.
+func TestMultiBreak3(t *testing.T) {
+ var result []int
+ var expect = []int{1000, 10, 2, 4}
+ defer func() {
+ if r := recover(); r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+ } else {
+ t.Errorf("Wanted to see a failure, result was %v", result)
+ }
+ }()
+
+W:
+ for _, w := range OfSliceIndex([]int{1000, 2000}) {
+ result = append(result, w)
+ if w == 2000 {
+ break
+ }
+ for _, x := range OfSliceIndex([]int{100, 200, 300, 400}) {
+ for _, y := range OfSliceIndex([]int{10, 20, 30, 40}) {
+ result = append(result, y)
+ for _, z := range BadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if z&1 == 1 {
+ continue
+ }
+ result = append(result, z)
+ if z >= 4 {
+ break W
+ }
+ }
+ result = append(result, -y) // should never be executed
+ }
+ result = append(result, x)
+ }
+ }
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// veryBad tests that a loop nest behaves sensibly in the face of a
+// "very bad" iterator. In this case, "sensibly" means that the
+// break out of X still occurs after the very bad iterator finally
+// quits running (the control flow bread crumbs remain.)
+func veryBad(s []int) []int {
+ var result []int
+X:
+ for _, x := range OfSliceIndex([]int{1, 2, 3}) {
+
+ result = append(result, x)
+
+ for _, y := range VeryBadOfSliceIndex(s) {
+ result = append(result, y)
+ break X
+ }
+ for _, z := range OfSliceIndex([]int{100, 200, 300}) {
+ result = append(result, z)
+ if z == 100 {
+ break
+ }
+ }
+ }
+ return result
+}
+
+// checkVeryBad wraps a "very bad" iterator with Check,
+// demonstrating that the very bad iterator also hides panics
+// thrown by Check.
+func checkVeryBad(s []int) []int {
+ var result []int
+X:
+ for _, x := range OfSliceIndex([]int{1, 2, 3}) {
+
+ result = append(result, x)
+
+ for _, y := range Check(VeryBadOfSliceIndex(s)) {
+ result = append(result, y)
+ break X
+ }
+ for _, z := range OfSliceIndex([]int{100, 200, 300}) {
+ result = append(result, z)
+ if z == 100 {
+ break
+ }
+ }
+ }
+ return result
+}
+
+// okay is the not-bad version of veryBad.
+// They should behave the same.
+func okay(s []int) []int {
+ var result []int
+X:
+ for _, x := range OfSliceIndex([]int{1, 2, 3}) {
+
+ result = append(result, x)
+
+ for _, y := range OfSliceIndex(s) {
+ result = append(result, y)
+ break X
+ }
+ for _, z := range OfSliceIndex([]int{100, 200, 300}) {
+ result = append(result, z)
+ if z == 100 {
+ break
+ }
+ }
+ }
+ return result
+}
+
+// TestVeryBad1 checks the behavior of an extremely poorly behaved iterator.
+func TestVeryBad1(t *testing.T) {
+ result := veryBad([]int{10, 20, 30, 40, 50}) // odd length
+ expect := []int{1, 10}
+
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// TestVeryBad2 checks the behavior of an extremely poorly behaved iterator.
+func TestVeryBad2(t *testing.T) {
+ result := veryBad([]int{10, 20, 30, 40}) // even length
+ expect := []int{1, 10}
+
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// TestCheckVeryBad checks the behavior of an extremely poorly behaved iterator,
+// which also suppresses the exceptions from "Check"
+func TestCheckVeryBad(t *testing.T) {
+ result := checkVeryBad([]int{10, 20, 30, 40}) // even length
+ expect := []int{1, 10}
+
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// TestOk is the nice version of the very bad iterator.
+func TestOk(t *testing.T) {
+ result := okay([]int{10, 20, 30, 40, 50}) // odd length
+ expect := []int{1, 10}
+
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+}
+
+// testBreak1BadDefer checks that defer behaves properly even in
+// the presence of loop bodies panicking out of bad iterators.
+// (i.e., the instrumentation did not break defer in these loops)
+func testBreak1BadDefer(t *testing.T) (result []int) {
+ var expect = []int{1, 2, -1, 1, 2, -2, 1, 2, -3, -30, -20, -10}
+
+ defer func() {
+ if r := recover(); r != nil {
+ t.Logf("Saw expected panic '%v'", r)
+ if !slices.Equal(expect, result) {
+ t.Errorf("(Inner) Expected %v, got %v", expect, result)
+ }
+ } else {
+ t.Error("Wanted to see a failure")
+ }
+ }()
+
+ for _, x := range BadOfSliceIndex([]int{-1, -2, -3, -4, -5}) {
+ if x == -4 {
+ break
+ }
+ defer func() {
+ result = append(result, x*10)
+ }()
+ for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if y == 3 {
+ break
+ }
+ result = append(result, y)
+ }
+ result = append(result, x)
+ }
+ return
+}
+
+func TestBreak1BadDefer(t *testing.T) {
+ var result []int
+ var expect = []int{1, 2, -1, 1, 2, -2, 1, 2, -3, -30, -20, -10}
+ result = testBreak1BadDefer(t)
+ if !slices.Equal(expect, result) {
+ t.Errorf("(Outer) Expected %v, got %v", expect, result)
+ }
+}
+
+// testReturn1 has no bad iterators.
+func testReturn1(t *testing.T) (result []int, err any) {
+ defer func() {
+ err = recover()
+ }()
+ for _, x := range OfSliceIndex([]int{-1, -2, -3, -4, -5}) {
+ result = append(result, x)
+ if x == -4 {
+ break
+ }
+ defer func() {
+ result = append(result, x*10)
+ }()
+ for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if y == 3 {
+ return
+ }
+ result = append(result, y)
+ }
+ result = append(result, x)
+ }
+ return
+}
+
+// testReturn2 has an outermost bad iterator
+func testReturn2(t *testing.T) (result []int, err any) {
+ defer func() {
+ err = recover()
+ }()
+ for _, x := range BadOfSliceIndex([]int{-1, -2, -3, -4, -5}) {
+ result = append(result, x)
+ if x == -4 {
+ break
+ }
+ defer func() {
+ result = append(result, x*10)
+ }()
+ for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if y == 3 {
+ return
+ }
+ result = append(result, y)
+ }
+ result = append(result, x)
+ }
+ return
+}
+
+// testReturn3 has an innermost bad iterator
+func testReturn3(t *testing.T) (result []int, err any) {
+ defer func() {
+ err = recover()
+ }()
+ for _, x := range OfSliceIndex([]int{-1, -2, -3, -4, -5}) {
+ result = append(result, x)
+ if x == -4 {
+ break
+ }
+ defer func() {
+ result = append(result, x*10)
+ }()
+ for _, y := range BadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if y == 3 {
+ return
+ }
+ result = append(result, y)
+ }
+ }
+ return
+}
+
+// TestReturns checks that returns through bad iterators behave properly,
+// for inner and outer bad iterators.
+func TestReturns(t *testing.T) {
+ var result []int
+ var expect = []int{-1, 1, 2, -10}
+ var err any
+
+ result, err = testReturn1(t)
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+ if err != nil {
+ t.Errorf("Unexpected error %v", err)
+ }
+
+ result, err = testReturn2(t)
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+ if err == nil {
+ t.Errorf("Missing expected error")
+ } else {
+ t.Logf("Saw expected panic '%v'", err)
+ }
+
+ result, err = testReturn3(t)
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+ if err == nil {
+ t.Errorf("Missing expected error")
+ } else {
+ t.Logf("Saw expected panic '%v'", err)
+ }
+
+}
+
+// testGotoA1 tests loop-nest-internal goto, no bad iterators.
+func testGotoA1(t *testing.T) (result []int, err any) {
+ defer func() {
+ err = recover()
+ }()
+ for _, x := range OfSliceIndex([]int{-1, -2, -3, -4, -5}) {
+ result = append(result, x)
+ if x == -4 {
+ break
+ }
+ defer func() {
+ result = append(result, x*10)
+ }()
+ for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if y == 3 {
+ goto A
+ }
+ result = append(result, y)
+ }
+ result = append(result, x)
+ A:
+ }
+ return
+}
+
+// testGotoA2 tests loop-nest-internal goto, outer bad iterator.
+func testGotoA2(t *testing.T) (result []int, err any) {
+ defer func() {
+ err = recover()
+ }()
+ for _, x := range BadOfSliceIndex([]int{-1, -2, -3, -4, -5}) {
+ result = append(result, x)
+ if x == -4 {
+ break
+ }
+ defer func() {
+ result = append(result, x*10)
+ }()
+ for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if y == 3 {
+ goto A
+ }
+ result = append(result, y)
+ }
+ result = append(result, x)
+ A:
+ }
+ return
+}
+
+// testGotoA3 tests loop-nest-internal goto, inner bad iterator.
+func testGotoA3(t *testing.T) (result []int, err any) {
+ defer func() {
+ err = recover()
+ }()
+ for _, x := range OfSliceIndex([]int{-1, -2, -3, -4, -5}) {
+ result = append(result, x)
+ if x == -4 {
+ break
+ }
+ defer func() {
+ result = append(result, x*10)
+ }()
+ for _, y := range BadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if y == 3 {
+ goto A
+ }
+ result = append(result, y)
+ }
+ result = append(result, x)
+ A:
+ }
+ return
+}
+
+func TestGotoA(t *testing.T) {
+ var result []int
+ var expect = []int{-1, 1, 2, -2, 1, 2, -3, 1, 2, -4, -30, -20, -10}
+ var expect3 = []int{-1, 1, 2, -10} // first goto becomes a panic
+ var err any
+
+ result, err = testGotoA1(t)
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+ if err != nil {
+ t.Errorf("Unexpected error %v", err)
+ }
+
+ result, err = testGotoA2(t)
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+ if err == nil {
+ t.Errorf("Missing expected error")
+ } else {
+ t.Logf("Saw expected panic '%v'", err)
+ }
+
+ result, err = testGotoA3(t)
+ if !slices.Equal(expect3, result) {
+ t.Errorf("Expected %v, got %v", expect3, result)
+ }
+ if err == nil {
+ t.Errorf("Missing expected error")
+ } else {
+ t.Logf("Saw expected panic '%v'", err)
+ }
+}
+
+// testGotoB1 tests loop-nest-exiting goto, no bad iterators.
+func testGotoB1(t *testing.T) (result []int, err any) {
+ defer func() {
+ err = recover()
+ }()
+ for _, x := range OfSliceIndex([]int{-1, -2, -3, -4, -5}) {
+ result = append(result, x)
+ if x == -4 {
+ break
+ }
+ defer func() {
+ result = append(result, x*10)
+ }()
+ for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if y == 3 {
+ goto B
+ }
+ result = append(result, y)
+ }
+ result = append(result, x)
+ }
+B:
+ result = append(result, 999)
+ return
+}
+
+// testGotoB2 tests loop-nest-exiting goto, outer bad iterator.
+func testGotoB2(t *testing.T) (result []int, err any) {
+ defer func() {
+ err = recover()
+ }()
+ for _, x := range BadOfSliceIndex([]int{-1, -2, -3, -4, -5}) {
+ result = append(result, x)
+ if x == -4 {
+ break
+ }
+ defer func() {
+ result = append(result, x*10)
+ }()
+ for _, y := range OfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if y == 3 {
+ goto B
+ }
+ result = append(result, y)
+ }
+ result = append(result, x)
+ }
+B:
+ result = append(result, 999)
+ return
+}
+
+// testGotoB3 tests loop-nest-exiting goto, inner bad iterator.
+func testGotoB3(t *testing.T) (result []int, err any) {
+ defer func() {
+ err = recover()
+ }()
+ for _, x := range OfSliceIndex([]int{-1, -2, -3, -4, -5}) {
+ result = append(result, x)
+ if x == -4 {
+ break
+ }
+ defer func() {
+ result = append(result, x*10)
+ }()
+ for _, y := range BadOfSliceIndex([]int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) {
+ if y == 3 {
+ goto B
+ }
+ result = append(result, y)
+ }
+ result = append(result, x)
+ }
+B:
+ result = append(result, 999)
+ return
+}
+
+func TestGotoB(t *testing.T) {
+ var result []int
+ var expect = []int{-1, 1, 2, 999, -10}
+ var expectX = []int{-1, 1, 2, -10}
+ var err any
+
+ result, err = testGotoB1(t)
+ if !slices.Equal(expect, result) {
+ t.Errorf("Expected %v, got %v", expect, result)
+ }
+ if err != nil {
+ t.Errorf("Unexpected error %v", err)
+ }
+
+ result, err = testGotoB2(t)
+ if !slices.Equal(expectX, result) {
+ t.Errorf("Expected %v, got %v", expectX, result)
+ }
+ if err == nil {
+ t.Errorf("Missing expected error")
+ } else {
+ t.Logf("Saw expected panic '%v'", err)
+ }
+
+ result, err = testGotoB3(t)
+ if !slices.Equal(expectX, result) {
+ t.Errorf("Expected %v, got %v", expectX, result)
+ }
+ if err == nil {
+ t.Errorf("Missing expected error")
+ } else {
+ t.Logf("Saw expected panic '%v'", err)
+ }
+}
diff --git a/src/cmd/compile/internal/rangefunc/rewrite.go b/src/cmd/compile/internal/rangefunc/rewrite.go
new file mode 100644
index 0000000..d439412
--- /dev/null
+++ b/src/cmd/compile/internal/rangefunc/rewrite.go
@@ -0,0 +1,1334 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+Package rangefunc rewrites range-over-func to code that doesn't use range-over-funcs.
+Rewriting the construct in the front end, before noder, means the functions generated during
+the rewrite are available in a noder-generated representation for inlining by the back end.
+
+# Theory of Operation
+
+The basic idea is to rewrite
+
+ for x := range f {
+ ...
+ }
+
+into
+
+ f(func(x T) bool {
+ ...
+ })
+
+But it's not usually that easy.
+
+# Range variables
+
+For a range not using :=, the assigned variables cannot be function parameters
+in the generated body function. Instead, we allocate fake parameters and
+start the body with an assignment. For example:
+
+ for expr1, expr2 = range f {
+ ...
+ }
+
+becomes
+
+ f(func(#p1 T1, #p2 T2) bool {
+ expr1, expr2 = #p1, #p2
+ ...
+ })
+
+(All the generated variables have a # at the start to signal that they
+are internal variables when looking at the generated code in a
+debugger. Because variables have all been resolved to the specific
+objects they represent, there is no danger of using plain "p1" and
+colliding with a Go variable named "p1"; the # is just nice to have,
+not for correctness.)
+
+It can also happen that there are fewer range variables than function
+arguments, in which case we end up with something like
+
+ f(func(x T1, _ T2) bool {
+ ...
+ })
+
+or
+
+ f(func(#p1 T1, #p2 T2, _ T3) bool {
+ expr1, expr2 = #p1, #p2
+ ...
+ })
+
+# Return
+
+If the body contains a "break", that break turns into "return false",
+to tell f to stop. And if the body contains a "continue", that turns
+into "return true", to tell f to proceed with the next value.
+Those are the easy cases.
+
+If the body contains a return or a break/continue/goto L, then we need
+to rewrite that into code that breaks out of the loop and then
+triggers that control flow. In general we rewrite
+
+ for x := range f {
+ ...
+ }
+
+into
+
+ {
+ var #next int
+ f(func(x T1) bool {
+ ...
+ return true
+ })
+ ... check #next ...
+ }
+
+The variable #next is an integer code that says what to do when f
+returns. Each difficult statement sets #next and then returns false to
+stop f.
+
+A plain "return" rewrites to {#next = -1; return false}.
+The return false breaks the loop. Then when f returns, the "check
+#next" section includes
+
+ if #next == -1 { return }
+
+which causes the return we want.
+
+Return with arguments is more involved. We need somewhere to store the
+arguments while we break out of f, so we add them to the var
+declaration, like:
+
+ {
+ var (
+ #next int
+ #r1 type1
+ #r2 type2
+ )
+ f(func(x T1) bool {
+ ...
+ {
+ // return a, b
+ #r1, #r2 = a, b
+ #next = -2
+ return false
+ }
+ ...
+ return true
+ })
+ if #next == -2 { return #r1, #r2 }
+ }
+
+TODO: What about:
+
+ func f() (x bool) {
+ for range g(&x) {
+ return true
+ }
+ }
+
+ func g(p *bool) func(func() bool) {
+ return func(yield func() bool) {
+ yield()
+ // Is *p true or false here?
+ }
+ }
+
+With this rewrite the "return true" is not visible after yield returns,
+but maybe it should be?
+
+# Checking
+
+To permit checking that an iterator is well-behaved -- that is, that
+it does not call the loop body again after it has returned false or
+after the entire loop has exited (it might retain a copy of the body
+function, or pass it to another goroutine) -- each generated loop has
+its own #exitK flag that is checked before each iteration, and set both
+at any early exit and after the iteration completes.
+
+For example:
+
+ for x := range f {
+ ...
+ if ... { break }
+ ...
+ }
+
+becomes
+
+ {
+ var #exit1 bool
+ f(func(x T1) bool {
+ if #exit1 { runtime.panicrangeexit() }
+ ...
+ if ... { #exit1 = true ; return false }
+ ...
+ return true
+ })
+ #exit1 = true
+ }
+
+# Nested Loops
+
+So far we've only considered a single loop. If a function contains a
+sequence of loops, each can be translated individually. But loops can
+be nested. It would work to translate the innermost loop and then
+translate the loop around it, and so on, except that there'd be a lot
+of rewriting of rewritten code and the overall traversals could end up
+taking time quadratic in the depth of the nesting. To avoid all that,
+we use a single rewriting pass that handles a top-most range-over-func
+loop and all the range-over-func loops it contains at the same time.
+
+If we need to return from inside a doubly-nested loop, the rewrites
+above stay the same, but the check after the inner loop only says
+
+ if #next < 0 { return false }
+
+to stop the outer loop so it can do the actual return. That is,
+
+ for range f {
+ for range g {
+ ...
+ return a, b
+ ...
+ }
+ }
+
+becomes
+
+ {
+ var (
+ #next int
+ #r1 type1
+ #r2 type2
+ )
+ var #exit1 bool
+ f(func() {
+ if #exit1 { runtime.panicrangeexit() }
+ var #exit2 bool
+ g(func() {
+ if #exit2 { runtime.panicrangeexit() }
+ ...
+ {
+ // return a, b
+ #r1, #r2 = a, b
+ #next = -2
+ #exit1, #exit2 = true, true
+ return false
+ }
+ ...
+ return true
+ })
+ #exit2 = true
+ if #next < 0 {
+ return false
+ }
+ return true
+ })
+ #exit1 = true
+ if #next == -2 {
+ return #r1, #r2
+ }
+ }
+
+Note that the #next < 0 after the inner loop handles both kinds of
+return with a single check.
+
+# Labeled break/continue of range-over-func loops
+
+For a labeled break or continue of an outer range-over-func, we
+use positive #next values. Any such labeled break or continue
+really means "do N breaks" or "do N breaks and 1 continue".
+We encode that as perLoopStep*N or perLoopStep*N+1 respectively.
+
+Loops that might need to propagate a labeled break or continue
+add one or both of these to the #next checks:
+
+ if #next >= 2 {
+ #next -= 2
+ return false
+ }
+
+ if #next == 1 {
+ #next = 0
+ return true
+ }
+
+For example
+
+ F: for range f {
+ for range g {
+ for range h {
+ ...
+ break F
+ ...
+ ...
+ continue F
+ ...
+ }
+ }
+ ...
+ }
+
+becomes
+
+ {
+ var #next int
+ var #exit1 bool
+ f(func() {
+ if #exit1 { runtime.panicrangeexit() }
+ var #exit2 bool
+ g(func() {
+ if #exit2 { runtime.panicrangeexit() }
+ var #exit3 bool
+ h(func() {
+ if #exit3 { runtime.panicrangeexit() }
+ ...
+ {
+ // break F
+ #next = 4
+ #exit1, #exit2, #exit3 = true, true, true
+ return false
+ }
+ ...
+ {
+ // continue F
+ #next = 3
+ #exit2, #exit3 = true, true
+ return false
+ }
+ ...
+ return true
+ })
+ #exit3 = true
+ if #next >= 2 {
+ #next -= 2
+ return false
+ }
+ return true
+ })
+ #exit2 = true
+ if #next >= 2 {
+ #next -= 2
+ return false
+ }
+ if #next == 1 {
+ #next = 0
+ return true
+ }
+ ...
+ return true
+ })
+ #exit1 = true
+ }
+
+Note that the post-h checks only consider a break,
+since no generated code tries to continue g.
+
+# Gotos and other labeled break/continue
+
+The final control flow translations are goto and break/continue of a
+non-range-over-func statement. In both cases, we may need to break out
+of one or more range-over-func loops before we can do the actual
+control flow statement. Each such break/continue/goto L statement is
+assigned a unique negative #next value (below -2, since -1 and -2 are
+for the two kinds of return). Then the post-checks for a given loop
+test for the specific codes that refer to labels directly targetable
+from that block. Otherwise, the generic
+
+ if #next < 0 { return false }
+
+check handles stopping the next loop to get one step closer to the label.
+
+For example
+
+ Top: print("start\n")
+ for range f {
+ for range g {
+ ...
+ for range h {
+ ...
+ goto Top
+ ...
+ }
+ }
+ }
+
+becomes
+
+ Top: print("start\n")
+ {
+ var #next int
+ var #exit1 bool
+ f(func() {
+ if #exit1 { runtime.panicrangeexit() }
+ var #exit2 bool
+ g(func() {
+ if #exit2 { runtime.panicrangeexit() }
+ ...
+ var #exit3 bool
+ h(func() {
+ if #exit3 { runtime.panicrangeexit() }
+ ...
+ {
+ // goto Top
+ #next = -3
+ #exit1, #exit2, #exit3 = true, true, true
+ return false
+ }
+ ...
+ return true
+ })
+ #exit3 = true
+ if #next < 0 {
+ return false
+ }
+ return true
+ })
+ #exit2 = true
+ if #next < 0 {
+ return false
+ }
+ return true
+ })
+ #exit1 = true
+ if #next == -3 {
+ #next = 0
+ goto Top
+ }
+ }
+
+Labeled break/continue to non-range-over-funcs are handled the same
+way as goto.
+
+# Defers
+
+The last wrinkle is handling defer statements. If we have
+
+ for range f {
+ defer print("A")
+ }
+
+we cannot rewrite that into
+
+ f(func() {
+ defer print("A")
+ })
+
+because the deferred code will run at the end of the iteration, not
+the end of the containing function. To fix that, the runtime provides
+a special hook that lets us obtain a defer "token" representing the
+outer function and then use it in a later defer to attach the deferred
+code to that outer function.
+
+Normally,
+
+ defer print("A")
+
+compiles to
+
+ runtime.deferproc(func() { print("A") })
+
+This changes in a range-over-func. For example:
+
+ for range f {
+ defer print("A")
+ }
+
+compiles to
+
+ var #defers = runtime.deferrangefunc()
+ f(func() {
+ runtime.deferprocat(func() { print("A") }, #defers)
+ })
+
+For this rewriting phase, we insert the explicit initialization of
+#defers and then attach the #defers variable to the CallStmt
+representing the defer. That variable will be propagated to the
+backend and will cause the backend to compile the defer using
+deferprocat instead of an ordinary deferproc.
+
+TODO: Could call runtime.deferrangefuncend after f.
+*/
+package rangefunc
+
+import (
+ "cmd/compile/internal/base"
+ "cmd/compile/internal/syntax"
+ "cmd/compile/internal/types2"
+ "fmt"
+ "go/constant"
+ "os"
+)
+
+// nopos is the zero syntax.Pos.
+var nopos syntax.Pos
+
+// A rewriter implements rewriting the range-over-funcs in a given function.
+type rewriter struct {
+ pkg *types2.Package
+ info *types2.Info
+ outer *syntax.FuncType
+ body *syntax.BlockStmt
+
+ // References to important types and values.
+ any types2.Object
+ bool types2.Object
+ int types2.Object
+ true types2.Object
+ false types2.Object
+
+ // Branch numbering, computed as needed.
+ branchNext map[branch]int // branch -> #next value
+ labelLoop map[string]*syntax.ForStmt // label -> innermost rangefunc loop it is declared inside (nil for no loop)
+
+ // Stack of nodes being visited.
+ stack []syntax.Node // all nodes
+ forStack []*forLoop // range-over-func loops
+
+ rewritten map[*syntax.ForStmt]syntax.Stmt
+
+ // Declared variables in generated code for outermost loop.
+ declStmt *syntax.DeclStmt
+ nextVar types2.Object
+ retVars []types2.Object
+ defers types2.Object
+ exitVarCount int // exitvars are referenced from their respective loops
+}
+
+// A branch is a single labeled branch.
+type branch struct {
+ tok syntax.Token
+ label string
+}
+
+// A forLoop describes a single range-over-func loop being processed.
+type forLoop struct {
+ nfor *syntax.ForStmt // actual syntax
+ exitFlag *types2.Var // #exit variable for this loop
+ exitFlagDecl *syntax.VarDecl
+
+ checkRet bool // add check for "return" after loop
+ checkRetArgs bool // add check for "return args" after loop
+ checkBreak bool // add check for "break" after loop
+ checkContinue bool // add check for "continue" after loop
+ checkBranch []branch // add check for labeled branch after loop
+}
+
+// Rewrite rewrites all the range-over-funcs in the files.
+func Rewrite(pkg *types2.Package, info *types2.Info, files []*syntax.File) {
+ for _, file := range files {
+ syntax.Inspect(file, func(n syntax.Node) bool {
+ switch n := n.(type) {
+ case *syntax.FuncDecl:
+ rewriteFunc(pkg, info, n.Type, n.Body)
+ return false
+ case *syntax.FuncLit:
+ rewriteFunc(pkg, info, n.Type, n.Body)
+ return false
+ }
+ return true
+ })
+ }
+}
+
+// rewriteFunc rewrites all the range-over-funcs in a single function (a top-level func or a func literal).
+// The typ and body are the function's type and body.
+func rewriteFunc(pkg *types2.Package, info *types2.Info, typ *syntax.FuncType, body *syntax.BlockStmt) {
+ if body == nil {
+ return
+ }
+ r := &rewriter{
+ pkg: pkg,
+ info: info,
+ outer: typ,
+ body: body,
+ }
+ syntax.Inspect(body, r.inspect)
+ if (base.Flag.W != 0) && r.forStack != nil {
+ syntax.Fdump(os.Stderr, body)
+ }
+}
+
+// checkFuncMisuse reports whether to check for misuse of iterator callbacks functions.
+func (r *rewriter) checkFuncMisuse() bool {
+ return base.Debug.RangeFuncCheck != 0
+}
+
+// inspect is a callback for syntax.Inspect that drives the actual rewriting.
+// If it sees a func literal, it kicks off a separate rewrite for that literal.
+// Otherwise, it maintains a stack of range-over-func loops and
+// converts each in turn.
+func (r *rewriter) inspect(n syntax.Node) bool {
+ switch n := n.(type) {
+ case *syntax.FuncLit:
+ rewriteFunc(r.pkg, r.info, n.Type, n.Body)
+ return false
+
+ default:
+ // Push n onto stack.
+ r.stack = append(r.stack, n)
+ if nfor, ok := forRangeFunc(n); ok {
+ loop := &forLoop{nfor: nfor}
+ r.forStack = append(r.forStack, loop)
+ r.startLoop(loop)
+ }
+
+ case nil:
+ // n == nil signals that we are done visiting
+ // the top-of-stack node's children. Find it.
+ n = r.stack[len(r.stack)-1]
+
+ // If we are inside a range-over-func,
+ // take this moment to replace any break/continue/goto/return
+ // statements directly contained in this node.
+ // Also replace any converted for statements
+ // with the rewritten block.
+ switch n := n.(type) {
+ case *syntax.BlockStmt:
+ for i, s := range n.List {
+ n.List[i] = r.editStmt(s)
+ }
+ case *syntax.CaseClause:
+ for i, s := range n.Body {
+ n.Body[i] = r.editStmt(s)
+ }
+ case *syntax.CommClause:
+ for i, s := range n.Body {
+ n.Body[i] = r.editStmt(s)
+ }
+ case *syntax.LabeledStmt:
+ n.Stmt = r.editStmt(n.Stmt)
+ }
+
+ // Pop n.
+ if len(r.forStack) > 0 && r.stack[len(r.stack)-1] == r.forStack[len(r.forStack)-1].nfor {
+ r.endLoop(r.forStack[len(r.forStack)-1])
+ r.forStack = r.forStack[:len(r.forStack)-1]
+ }
+ r.stack = r.stack[:len(r.stack)-1]
+ }
+ return true
+}
+
+// startLoop sets up for converting a range-over-func loop.
+func (r *rewriter) startLoop(loop *forLoop) {
+ // For first loop in function, allocate syntax for any, bool, int, true, and false.
+ if r.any == nil {
+ r.any = types2.Universe.Lookup("any")
+ r.bool = types2.Universe.Lookup("bool")
+ r.int = types2.Universe.Lookup("int")
+ r.true = types2.Universe.Lookup("true")
+ r.false = types2.Universe.Lookup("false")
+ r.rewritten = make(map[*syntax.ForStmt]syntax.Stmt)
+ }
+ if r.checkFuncMisuse() {
+ // declare the exit flag for this loop's body
+ loop.exitFlag, loop.exitFlagDecl = r.exitVar(loop.nfor.Pos())
+ }
+}
+
+// editStmt returns the replacement for the statement x,
+// or x itself if it should be left alone.
+// This includes the for loops we are converting,
+// as left in x.rewritten by r.endLoop.
+func (r *rewriter) editStmt(x syntax.Stmt) syntax.Stmt {
+ if x, ok := x.(*syntax.ForStmt); ok {
+ if s := r.rewritten[x]; s != nil {
+ return s
+ }
+ }
+
+ if len(r.forStack) > 0 {
+ switch x := x.(type) {
+ case *syntax.BranchStmt:
+ return r.editBranch(x)
+ case *syntax.CallStmt:
+ if x.Tok == syntax.Defer {
+ return r.editDefer(x)
+ }
+ case *syntax.ReturnStmt:
+ return r.editReturn(x)
+ }
+ }
+
+ return x
+}
+
+// editDefer returns the replacement for the defer statement x.
+// See the "Defers" section in the package doc comment above for more context.
+func (r *rewriter) editDefer(x *syntax.CallStmt) syntax.Stmt {
+ if r.defers == nil {
+ // Declare and initialize the #defers token.
+ init := &syntax.CallExpr{
+ Fun: runtimeSym(r.info, "deferrangefunc"),
+ }
+ tv := syntax.TypeAndValue{Type: r.any.Type()}
+ tv.SetIsValue()
+ init.SetTypeInfo(tv)
+ r.defers = r.declVar("#defers", r.any.Type(), init)
+ }
+
+ // Attach the token as an "extra" argument to the defer.
+ x.DeferAt = r.useVar(r.defers)
+ setPos(x.DeferAt, x.Pos())
+ return x
+}
+
+func (r *rewriter) exitVar(pos syntax.Pos) (*types2.Var, *syntax.VarDecl) {
+ r.exitVarCount++
+
+ name := fmt.Sprintf("#exit%d", r.exitVarCount)
+ typ := r.bool.Type()
+ obj := types2.NewVar(pos, r.pkg, name, typ)
+ n := syntax.NewName(pos, name)
+ setValueType(n, typ)
+ r.info.Defs[n] = obj
+
+ return obj, &syntax.VarDecl{NameList: []*syntax.Name{n}}
+}
+
+// editReturn returns the replacement for the return statement x.
+// See the "Return" section in the package doc comment above for more context.
+func (r *rewriter) editReturn(x *syntax.ReturnStmt) syntax.Stmt {
+ // #next = -1 is return with no arguments; -2 is return with arguments.
+ var next int
+ if x.Results == nil {
+ next = -1
+ r.forStack[0].checkRet = true
+ } else {
+ next = -2
+ r.forStack[0].checkRetArgs = true
+ }
+
+ // Tell the loops along the way to check for a return.
+ for _, loop := range r.forStack[1:] {
+ loop.checkRet = true
+ }
+
+ // Assign results, set #next, and return false.
+ bl := &syntax.BlockStmt{}
+ if x.Results != nil {
+ if r.retVars == nil {
+ for i, a := range r.outer.ResultList {
+ obj := r.declVar(fmt.Sprintf("#r%d", i+1), a.Type.GetTypeInfo().Type, nil)
+ r.retVars = append(r.retVars, obj)
+ }
+ }
+ bl.List = append(bl.List, &syntax.AssignStmt{Lhs: r.useList(r.retVars), Rhs: x.Results})
+ }
+ bl.List = append(bl.List, &syntax.AssignStmt{Lhs: r.next(), Rhs: r.intConst(next)})
+ if r.checkFuncMisuse() {
+ // mark all enclosing loop bodies as exited
+ for i := 0; i < len(r.forStack); i++ {
+ bl.List = append(bl.List, r.setExitedAt(i))
+ }
+ }
+ bl.List = append(bl.List, &syntax.ReturnStmt{Results: r.useVar(r.false)})
+ setPos(bl, x.Pos())
+ return bl
+}
+
+// perLoopStep is part of the encoding of loop-spanning control flow
+// for function range iterators. Each multiple of two encodes a "return false"
+// passing control to an enclosing iterator; a terminal value of 1 encodes
+// "return true" (i.e., local continue) from the body function, and a terminal
+// value of 0 encodes executing the remainder of the body function.
+const perLoopStep = 2
+
+// editBranch returns the replacement for the branch statement x,
+// or x itself if it should be left alone.
+// See the package doc comment above for more context.
+func (r *rewriter) editBranch(x *syntax.BranchStmt) syntax.Stmt {
+ if x.Tok == syntax.Fallthrough {
+ // Fallthrough is unaffected by the rewrite.
+ return x
+ }
+
+ // Find target of break/continue/goto in r.forStack.
+ // (The target may not be in r.forStack at all.)
+ targ := x.Target
+ i := len(r.forStack) - 1
+ if x.Label == nil && r.forStack[i].nfor != targ {
+ // Unlabeled break or continue that's not nfor must be inside nfor. Leave alone.
+ return x
+ }
+ for i >= 0 && r.forStack[i].nfor != targ {
+ i--
+ }
+ // exitFrom is the index of the loop interior to the target of the control flow,
+ // if such a loop exists (it does not if i == len(r.forStack) - 1)
+ exitFrom := i + 1
+
+ // Compute the value to assign to #next and the specific return to use.
+ var next int
+ var ret *syntax.ReturnStmt
+ if x.Tok == syntax.Goto || i < 0 {
+ // goto Label
+ // or break/continue of labeled non-range-over-func loop.
+ // We may be able to leave it alone, or we may have to break
+ // out of one or more nested loops and then use #next to signal
+ // to complete the break/continue/goto.
+ // Figure out which range-over-func loop contains the label.
+ r.computeBranchNext()
+ nfor := r.forStack[len(r.forStack)-1].nfor
+ label := x.Label.Value
+ targ := r.labelLoop[label]
+ if nfor == targ {
+ // Label is in the innermost range-over-func loop; use it directly.
+ return x
+ }
+
+ // Set #next to the code meaning break/continue/goto label.
+ next = r.branchNext[branch{x.Tok, label}]
+
+ // Break out of nested loops up to targ.
+ i := len(r.forStack) - 1
+ for i >= 0 && r.forStack[i].nfor != targ {
+ i--
+ }
+ exitFrom = i + 1
+
+ // Mark loop we exit to get to targ to check for that branch.
+ // When i==-1 that's the outermost func body
+ top := r.forStack[i+1]
+ top.checkBranch = append(top.checkBranch, branch{x.Tok, label})
+
+ // Mark loops along the way to check for a plain return, so they break.
+ for j := i + 2; j < len(r.forStack); j++ {
+ r.forStack[j].checkRet = true
+ }
+
+ // In the innermost loop, use a plain "return false".
+ ret = &syntax.ReturnStmt{Results: r.useVar(r.false)}
+ } else {
+ // break/continue of labeled range-over-func loop.
+ depth := len(r.forStack) - 1 - i
+
+ // For continue of innermost loop, use "return true".
+ // Otherwise we are breaking the innermost loop, so "return false".
+
+ if depth == 0 && x.Tok == syntax.Continue {
+ ret = &syntax.ReturnStmt{Results: r.useVar(r.true)}
+ setPos(ret, x.Pos())
+ return ret
+ }
+ ret = &syntax.ReturnStmt{Results: r.useVar(r.false)}
+
+ // If this is a simple break, mark this loop as exited and return false.
+ // No adjustments to #next.
+ if depth == 0 {
+ var stmts []syntax.Stmt
+ if r.checkFuncMisuse() {
+ stmts = []syntax.Stmt{r.setExited(), ret}
+ } else {
+ stmts = []syntax.Stmt{ret}
+ }
+ bl := &syntax.BlockStmt{
+ List: stmts,
+ }
+ setPos(bl, x.Pos())
+ return bl
+ }
+
+ // The loop inside the one we are break/continue-ing
+ // needs to make that happen when we break out of it.
+ if x.Tok == syntax.Continue {
+ r.forStack[exitFrom].checkContinue = true
+ } else {
+ exitFrom = i
+ r.forStack[exitFrom].checkBreak = true
+ }
+
+ // The loops along the way just need to break.
+ for j := exitFrom + 1; j < len(r.forStack); j++ {
+ r.forStack[j].checkBreak = true
+ }
+
+ // Set next to break the appropriate number of times;
+ // the final time may be a continue, not a break.
+ next = perLoopStep * depth
+ if x.Tok == syntax.Continue {
+ next--
+ }
+ }
+
+ // Assign #next = next and do the return.
+ as := &syntax.AssignStmt{Lhs: r.next(), Rhs: r.intConst(next)}
+ bl := &syntax.BlockStmt{
+ List: []syntax.Stmt{as},
+ }
+
+ if r.checkFuncMisuse() {
+ // Set #exitK for this loop and those exited by the control flow.
+ for i := exitFrom; i < len(r.forStack); i++ {
+ bl.List = append(bl.List, r.setExitedAt(i))
+ }
+ }
+
+ bl.List = append(bl.List, ret)
+ setPos(bl, x.Pos())
+ return bl
+}
+
+// computeBranchNext computes the branchNext numbering
+// and determines which labels end up inside which range-over-func loop bodies.
+func (r *rewriter) computeBranchNext() {
+ if r.labelLoop != nil {
+ return
+ }
+
+ r.labelLoop = make(map[string]*syntax.ForStmt)
+ r.branchNext = make(map[branch]int)
+
+ var labels []string
+ var stack []syntax.Node
+ var forStack []*syntax.ForStmt
+ forStack = append(forStack, nil)
+ syntax.Inspect(r.body, func(n syntax.Node) bool {
+ if n != nil {
+ stack = append(stack, n)
+ if nfor, ok := forRangeFunc(n); ok {
+ forStack = append(forStack, nfor)
+ }
+ if n, ok := n.(*syntax.LabeledStmt); ok {
+ l := n.Label.Value
+ labels = append(labels, l)
+ f := forStack[len(forStack)-1]
+ r.labelLoop[l] = f
+ }
+ } else {
+ n := stack[len(stack)-1]
+ stack = stack[:len(stack)-1]
+ if n == forStack[len(forStack)-1] {
+ forStack = forStack[:len(forStack)-1]
+ }
+ }
+ return true
+ })
+
+ // Assign numbers to all the labels we observed.
+ used := -2
+ for _, l := range labels {
+ used -= 3
+ r.branchNext[branch{syntax.Break, l}] = used
+ r.branchNext[branch{syntax.Continue, l}] = used + 1
+ r.branchNext[branch{syntax.Goto, l}] = used + 2
+ }
+}
+
+// endLoop finishes the conversion of a range-over-func loop.
+// We have inspected and rewritten the body of the loop and can now
+// construct the body function and rewrite the for loop into a call
+// bracketed by any declarations and checks it requires.
+func (r *rewriter) endLoop(loop *forLoop) {
+ // Pick apart for range X { ... }
+ nfor := loop.nfor
+ start, end := nfor.Pos(), nfor.Body.Rbrace // start, end position of for loop
+ rclause := nfor.Init.(*syntax.RangeClause)
+ rfunc := types2.CoreType(rclause.X.GetTypeInfo().Type).(*types2.Signature) // type of X - func(func(...)bool)
+ if rfunc.Params().Len() != 1 {
+ base.Fatalf("invalid typecheck of range func")
+ }
+ ftyp := types2.CoreType(rfunc.Params().At(0).Type()).(*types2.Signature) // func(...) bool
+ if ftyp.Results().Len() != 1 {
+ base.Fatalf("invalid typecheck of range func")
+ }
+
+ // Build X(bodyFunc)
+ call := &syntax.ExprStmt{
+ X: &syntax.CallExpr{
+ Fun: rclause.X,
+ ArgList: []syntax.Expr{
+ r.bodyFunc(nfor.Body.List, syntax.UnpackListExpr(rclause.Lhs), rclause.Def, ftyp, start, end),
+ },
+ },
+ }
+ setPos(call, start)
+
+ // Build checks based on #next after X(bodyFunc)
+ checks := r.checks(loop, end)
+
+ // Rewrite for vars := range X { ... } to
+ //
+ // {
+ // r.declStmt
+ // call
+ // checks
+ // }
+ //
+ // The r.declStmt can be added to by this loop or any inner loop
+ // during the creation of r.bodyFunc; it is only emitted in the outermost
+ // converted range loop.
+ block := &syntax.BlockStmt{Rbrace: end}
+ setPos(block, start)
+ if len(r.forStack) == 1 && r.declStmt != nil {
+ setPos(r.declStmt, start)
+ block.List = append(block.List, r.declStmt)
+ }
+
+ // declare the exitFlag here so it has proper scope and zeroing
+ if r.checkFuncMisuse() {
+ exitFlagDecl := &syntax.DeclStmt{DeclList: []syntax.Decl{loop.exitFlagDecl}}
+ block.List = append(block.List, exitFlagDecl)
+ }
+
+ // iteratorFunc(bodyFunc)
+ block.List = append(block.List, call)
+
+ if r.checkFuncMisuse() {
+ // iteratorFunc has exited, mark the exit flag for the body
+ block.List = append(block.List, r.setExited())
+ }
+ block.List = append(block.List, checks...)
+
+ if len(r.forStack) == 1 { // ending an outermost loop
+ r.declStmt = nil
+ r.nextVar = nil
+ r.retVars = nil
+ r.defers = nil
+ }
+
+ r.rewritten[nfor] = block
+}
+
+func (r *rewriter) setExited() *syntax.AssignStmt {
+ return r.setExitedAt(len(r.forStack) - 1)
+}
+
+func (r *rewriter) setExitedAt(index int) *syntax.AssignStmt {
+ loop := r.forStack[index]
+ return &syntax.AssignStmt{
+ Lhs: r.useVar(loop.exitFlag),
+ Rhs: r.useVar(r.true),
+ }
+}
+
+// bodyFunc converts the loop body (control flow has already been updated)
+// to a func literal that can be passed to the range function.
+//
+// vars is the range variables from the range statement.
+// def indicates whether this is a := range statement.
+// ftyp is the type of the function we are creating
+// start and end are the syntax positions to use for new nodes
+// that should be at the start or end of the loop.
+func (r *rewriter) bodyFunc(body []syntax.Stmt, lhs []syntax.Expr, def bool, ftyp *types2.Signature, start, end syntax.Pos) *syntax.FuncLit {
+ // Starting X(bodyFunc); build up bodyFunc first.
+ var params, results []*types2.Var
+ results = append(results, types2.NewVar(start, nil, "", r.bool.Type()))
+ bodyFunc := &syntax.FuncLit{
+ // Note: Type is ignored but needs to be non-nil to avoid panic in syntax.Inspect.
+ Type: &syntax.FuncType{},
+ Body: &syntax.BlockStmt{
+ List: []syntax.Stmt{},
+ Rbrace: end,
+ },
+ }
+ setPos(bodyFunc, start)
+
+ for i := 0; i < ftyp.Params().Len(); i++ {
+ typ := ftyp.Params().At(i).Type()
+ var paramVar *types2.Var
+ if i < len(lhs) && def {
+ // Reuse range variable as parameter.
+ x := lhs[i]
+ paramVar = r.info.Defs[x.(*syntax.Name)].(*types2.Var)
+ } else {
+ // Declare new parameter and assign it to range expression.
+ paramVar = types2.NewVar(start, r.pkg, fmt.Sprintf("#p%d", 1+i), typ)
+ if i < len(lhs) {
+ x := lhs[i]
+ as := &syntax.AssignStmt{Lhs: x, Rhs: r.useVar(paramVar)}
+ as.SetPos(x.Pos())
+ setPos(as.Rhs, x.Pos())
+ bodyFunc.Body.List = append(bodyFunc.Body.List, as)
+ }
+ }
+ params = append(params, paramVar)
+ }
+
+ tv := syntax.TypeAndValue{
+ Type: types2.NewSignatureType(nil, nil, nil,
+ types2.NewTuple(params...),
+ types2.NewTuple(results...),
+ false),
+ }
+ tv.SetIsValue()
+ bodyFunc.SetTypeInfo(tv)
+
+ loop := r.forStack[len(r.forStack)-1]
+
+ if r.checkFuncMisuse() {
+ bodyFunc.Body.List = append(bodyFunc.Body.List, r.assertNotExited(start, loop))
+ }
+
+ // Original loop body (already rewritten by editStmt during inspect).
+ bodyFunc.Body.List = append(bodyFunc.Body.List, body...)
+
+ // return true to continue at end of loop body
+ ret := &syntax.ReturnStmt{Results: r.useVar(r.true)}
+ ret.SetPos(end)
+ bodyFunc.Body.List = append(bodyFunc.Body.List, ret)
+
+ return bodyFunc
+}
+
+// checks returns the post-call checks that need to be done for the given loop.
+func (r *rewriter) checks(loop *forLoop, pos syntax.Pos) []syntax.Stmt {
+ var list []syntax.Stmt
+ if len(loop.checkBranch) > 0 {
+ did := make(map[branch]bool)
+ for _, br := range loop.checkBranch {
+ if did[br] {
+ continue
+ }
+ did[br] = true
+ doBranch := &syntax.BranchStmt{Tok: br.tok, Label: &syntax.Name{Value: br.label}}
+ list = append(list, r.ifNext(syntax.Eql, r.branchNext[br], doBranch))
+ }
+ }
+ if len(r.forStack) == 1 {
+ if loop.checkRetArgs {
+ list = append(list, r.ifNext(syntax.Eql, -2, retStmt(r.useList(r.retVars))))
+ }
+ if loop.checkRet {
+ list = append(list, r.ifNext(syntax.Eql, -1, retStmt(nil)))
+ }
+ } else {
+ if loop.checkRetArgs || loop.checkRet {
+ // Note: next < 0 also handles gotos handled by outer loops.
+ // We set checkRet in that case to trigger this check.
+ list = append(list, r.ifNext(syntax.Lss, 0, retStmt(r.useVar(r.false))))
+ }
+ if loop.checkBreak {
+ list = append(list, r.ifNext(syntax.Geq, perLoopStep, retStmt(r.useVar(r.false))))
+ }
+ if loop.checkContinue {
+ list = append(list, r.ifNext(syntax.Eql, perLoopStep-1, retStmt(r.useVar(r.true))))
+ }
+ }
+
+ for _, j := range list {
+ setPos(j, pos)
+ }
+ return list
+}
+
+// retStmt returns a return statement returning the given return values.
+func retStmt(results syntax.Expr) *syntax.ReturnStmt {
+ return &syntax.ReturnStmt{Results: results}
+}
+
+// ifNext returns the statement:
+//
+// if #next op c { adjust; then }
+//
+// When op is >=, adjust is #next -= c.
+// When op is == and c is not -1 or -2, adjust is #next = 0.
+// Otherwise adjust is omitted.
+func (r *rewriter) ifNext(op syntax.Operator, c int, then syntax.Stmt) syntax.Stmt {
+ nif := &syntax.IfStmt{
+ Cond: &syntax.Operation{Op: op, X: r.next(), Y: r.intConst(c)},
+ Then: &syntax.BlockStmt{
+ List: []syntax.Stmt{then},
+ },
+ }
+ tv := syntax.TypeAndValue{Type: r.bool.Type()}
+ tv.SetIsValue()
+ nif.Cond.SetTypeInfo(tv)
+
+ if op == syntax.Geq {
+ sub := &syntax.AssignStmt{
+ Op: syntax.Sub,
+ Lhs: r.next(),
+ Rhs: r.intConst(c),
+ }
+ nif.Then.List = []syntax.Stmt{sub, then}
+ }
+ if op == syntax.Eql && c != -1 && c != -2 {
+ clr := &syntax.AssignStmt{
+ Lhs: r.next(),
+ Rhs: r.intConst(0),
+ }
+ nif.Then.List = []syntax.Stmt{clr, then}
+ }
+
+ return nif
+}
+
+// setValueType marks x as a value with type typ.
+func setValueType(x syntax.Expr, typ syntax.Type) {
+ tv := syntax.TypeAndValue{Type: typ}
+ tv.SetIsValue()
+ x.SetTypeInfo(tv)
+}
+
+// assertNotExited returns the statement:
+//
+// if #exitK { runtime.panicrangeexit() }
+//
+// where #exitK is the exit guard for loop.
+func (r *rewriter) assertNotExited(start syntax.Pos, loop *forLoop) syntax.Stmt {
+ callPanicExpr := &syntax.CallExpr{
+ Fun: runtimeSym(r.info, "panicrangeexit"),
+ }
+ setValueType(callPanicExpr, nil) // no result type
+
+ callPanic := &syntax.ExprStmt{X: callPanicExpr}
+
+ nif := &syntax.IfStmt{
+ Cond: r.useVar(loop.exitFlag),
+ Then: &syntax.BlockStmt{
+ List: []syntax.Stmt{callPanic},
+ },
+ }
+ setPos(nif, start)
+ return nif
+}
+
+// next returns a reference to the #next variable.
+func (r *rewriter) next() *syntax.Name {
+ if r.nextVar == nil {
+ r.nextVar = r.declVar("#next", r.int.Type(), nil)
+ }
+ return r.useVar(r.nextVar)
+}
+
+// forRangeFunc checks whether n is a range-over-func.
+// If so, it returns n.(*syntax.ForStmt), true.
+// Otherwise it returns nil, false.
+func forRangeFunc(n syntax.Node) (*syntax.ForStmt, bool) {
+ nfor, ok := n.(*syntax.ForStmt)
+ if !ok {
+ return nil, false
+ }
+ nrange, ok := nfor.Init.(*syntax.RangeClause)
+ if !ok {
+ return nil, false
+ }
+ _, ok = types2.CoreType(nrange.X.GetTypeInfo().Type).(*types2.Signature)
+ if !ok {
+ return nil, false
+ }
+ return nfor, true
+}
+
+// intConst returns syntax for an integer literal with the given value.
+func (r *rewriter) intConst(c int) *syntax.BasicLit {
+ lit := &syntax.BasicLit{
+ Value: fmt.Sprint(c),
+ Kind: syntax.IntLit,
+ }
+ tv := syntax.TypeAndValue{Type: r.int.Type(), Value: constant.MakeInt64(int64(c))}
+ tv.SetIsValue()
+ lit.SetTypeInfo(tv)
+ return lit
+}
+
+// useVar returns syntax for a reference to decl, which should be its declaration.
+func (r *rewriter) useVar(obj types2.Object) *syntax.Name {
+ n := syntax.NewName(nopos, obj.Name())
+ tv := syntax.TypeAndValue{Type: obj.Type()}
+ tv.SetIsValue()
+ n.SetTypeInfo(tv)
+ r.info.Uses[n] = obj
+ return n
+}
+
+// useList is useVar for a list of decls.
+func (r *rewriter) useList(vars []types2.Object) syntax.Expr {
+ var new []syntax.Expr
+ for _, obj := range vars {
+ new = append(new, r.useVar(obj))
+ }
+ if len(new) == 1 {
+ return new[0]
+ }
+ return &syntax.ListExpr{ElemList: new}
+}
+
+// declVar declares a variable with a given name type and initializer value.
+func (r *rewriter) declVar(name string, typ types2.Type, init syntax.Expr) *types2.Var {
+ if r.declStmt == nil {
+ r.declStmt = &syntax.DeclStmt{}
+ }
+ stmt := r.declStmt
+ obj := types2.NewVar(stmt.Pos(), r.pkg, name, typ)
+ n := syntax.NewName(stmt.Pos(), name)
+ tv := syntax.TypeAndValue{Type: typ}
+ tv.SetIsValue()
+ n.SetTypeInfo(tv)
+ r.info.Defs[n] = obj
+ stmt.DeclList = append(stmt.DeclList, &syntax.VarDecl{
+ NameList: []*syntax.Name{n},
+ // Note: Type is ignored
+ Values: init,
+ })
+ return obj
+}
+
+// declType declares a type with the given name and type.
+// This is more like "type name = typ" than "type name typ".
+func declType(pos syntax.Pos, name string, typ types2.Type) *syntax.Name {
+ n := syntax.NewName(pos, name)
+ n.SetTypeInfo(syntax.TypeAndValue{Type: typ})
+ return n
+}
+
+// runtimePkg is a fake runtime package that contains what we need to refer to in package runtime.
+var runtimePkg = func() *types2.Package {
+ var nopos syntax.Pos
+ pkg := types2.NewPackage("runtime", "runtime")
+ anyType := types2.Universe.Lookup("any").Type()
+
+ // func deferrangefunc() unsafe.Pointer
+ obj := types2.NewFunc(nopos, pkg, "deferrangefunc", types2.NewSignatureType(nil, nil, nil, nil, types2.NewTuple(types2.NewParam(nopos, pkg, "extra", anyType)), false))
+ pkg.Scope().Insert(obj)
+
+ // func panicrangeexit()
+ obj = types2.NewFunc(nopos, pkg, "panicrangeexit", types2.NewSignatureType(nil, nil, nil, nil, nil, false))
+ pkg.Scope().Insert(obj)
+
+ return pkg
+}()
+
+// runtimeSym returns a reference to a symbol in the fake runtime package.
+func runtimeSym(info *types2.Info, name string) *syntax.Name {
+ obj := runtimePkg.Scope().Lookup(name)
+ n := syntax.NewName(nopos, "runtime."+name)
+ tv := syntax.TypeAndValue{Type: obj.Type()}
+ tv.SetIsValue()
+ tv.SetIsRuntimeHelper()
+ n.SetTypeInfo(tv)
+ info.Uses[n] = obj
+ return n
+}
+
+// setPos walks the top structure of x that has no position assigned
+// and assigns it all to have position pos.
+// When setPos encounters a syntax node with a position assigned,
+// setPos does not look inside that node.
+// setPos only needs to handle syntax we create in this package;
+// all other syntax should have positions assigned already.
+func setPos(x syntax.Node, pos syntax.Pos) {
+ if x == nil {
+ return
+ }
+ syntax.Inspect(x, func(n syntax.Node) bool {
+ if n == nil || n.Pos() != nopos {
+ return false
+ }
+ n.SetPos(pos)
+ switch n := n.(type) {
+ case *syntax.BlockStmt:
+ if n.Rbrace == nopos {
+ n.Rbrace = pos
+ }
+ }
+ return true
+ })
+}