summaryrefslogtreecommitdiffstats
path: root/src/cmd/compile/internal/walk/switch.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/cmd/compile/internal/walk/switch.go')
-rw-r--r--src/cmd/compile/internal/walk/switch.go597
1 files changed, 597 insertions, 0 deletions
diff --git a/src/cmd/compile/internal/walk/switch.go b/src/cmd/compile/internal/walk/switch.go
new file mode 100644
index 0000000..3705c5b
--- /dev/null
+++ b/src/cmd/compile/internal/walk/switch.go
@@ -0,0 +1,597 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package walk
+
+import (
+ "go/constant"
+ "go/token"
+ "sort"
+
+ "cmd/compile/internal/base"
+ "cmd/compile/internal/ir"
+ "cmd/compile/internal/typecheck"
+ "cmd/compile/internal/types"
+ "cmd/internal/src"
+)
+
+// walkSwitch walks a switch statement.
+func walkSwitch(sw *ir.SwitchStmt) {
+ // Guard against double walk, see #25776.
+ if sw.Walked() {
+ return // Was fatal, but eliminating every possible source of double-walking is hard
+ }
+ sw.SetWalked(true)
+
+ if sw.Tag != nil && sw.Tag.Op() == ir.OTYPESW {
+ walkSwitchType(sw)
+ } else {
+ walkSwitchExpr(sw)
+ }
+}
+
+// walkSwitchExpr generates an AST implementing sw. sw is an
+// expression switch.
+func walkSwitchExpr(sw *ir.SwitchStmt) {
+ lno := ir.SetPos(sw)
+
+ cond := sw.Tag
+ sw.Tag = nil
+
+ // convert switch {...} to switch true {...}
+ if cond == nil {
+ cond = ir.NewBool(true)
+ cond = typecheck.Expr(cond)
+ cond = typecheck.DefaultLit(cond, nil)
+ }
+
+ // Given "switch string(byteslice)",
+ // with all cases being side-effect free,
+ // use a zero-cost alias of the byte slice.
+ // Do this before calling walkExpr on cond,
+ // because walkExpr will lower the string
+ // conversion into a runtime call.
+ // See issue 24937 for more discussion.
+ if cond.Op() == ir.OBYTES2STR && allCaseExprsAreSideEffectFree(sw) {
+ cond := cond.(*ir.ConvExpr)
+ cond.SetOp(ir.OBYTES2STRTMP)
+ }
+
+ cond = walkExpr(cond, sw.PtrInit())
+ if cond.Op() != ir.OLITERAL && cond.Op() != ir.ONIL {
+ cond = copyExpr(cond, cond.Type(), &sw.Compiled)
+ }
+
+ base.Pos = lno
+
+ s := exprSwitch{
+ exprname: cond,
+ }
+
+ var defaultGoto ir.Node
+ var body ir.Nodes
+ for _, ncase := range sw.Cases {
+ label := typecheck.AutoLabel(".s")
+ jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, label)
+
+ // Process case dispatch.
+ if len(ncase.List) == 0 {
+ if defaultGoto != nil {
+ base.Fatalf("duplicate default case not detected during typechecking")
+ }
+ defaultGoto = jmp
+ }
+
+ for _, n1 := range ncase.List {
+ s.Add(ncase.Pos(), n1, jmp)
+ }
+
+ // Process body.
+ body.Append(ir.NewLabelStmt(ncase.Pos(), label))
+ body.Append(ncase.Body...)
+ if fall, pos := endsInFallthrough(ncase.Body); !fall {
+ br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
+ br.SetPos(pos)
+ body.Append(br)
+ }
+ }
+ sw.Cases = nil
+
+ if defaultGoto == nil {
+ br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
+ br.SetPos(br.Pos().WithNotStmt())
+ defaultGoto = br
+ }
+
+ s.Emit(&sw.Compiled)
+ sw.Compiled.Append(defaultGoto)
+ sw.Compiled.Append(body.Take()...)
+ walkStmtList(sw.Compiled)
+}
+
+// An exprSwitch walks an expression switch.
+type exprSwitch struct {
+ exprname ir.Node // value being switched on
+
+ done ir.Nodes
+ clauses []exprClause
+}
+
+type exprClause struct {
+ pos src.XPos
+ lo, hi ir.Node
+ jmp ir.Node
+}
+
+func (s *exprSwitch) Add(pos src.XPos, expr, jmp ir.Node) {
+ c := exprClause{pos: pos, lo: expr, hi: expr, jmp: jmp}
+ if types.IsOrdered[s.exprname.Type().Kind()] && expr.Op() == ir.OLITERAL {
+ s.clauses = append(s.clauses, c)
+ return
+ }
+
+ s.flush()
+ s.clauses = append(s.clauses, c)
+ s.flush()
+}
+
+func (s *exprSwitch) Emit(out *ir.Nodes) {
+ s.flush()
+ out.Append(s.done.Take()...)
+}
+
+func (s *exprSwitch) flush() {
+ cc := s.clauses
+ s.clauses = nil
+ if len(cc) == 0 {
+ return
+ }
+
+ // Caution: If len(cc) == 1, then cc[0] might not an OLITERAL.
+ // The code below is structured to implicitly handle this case
+ // (e.g., sort.Slice doesn't need to invoke the less function
+ // when there's only a single slice element).
+
+ if s.exprname.Type().IsString() && len(cc) >= 2 {
+ // Sort strings by length and then by value. It is
+ // much cheaper to compare lengths than values, and
+ // all we need here is consistency. We respect this
+ // sorting below.
+ sort.Slice(cc, func(i, j int) bool {
+ si := ir.StringVal(cc[i].lo)
+ sj := ir.StringVal(cc[j].lo)
+ if len(si) != len(sj) {
+ return len(si) < len(sj)
+ }
+ return si < sj
+ })
+
+ // runLen returns the string length associated with a
+ // particular run of exprClauses.
+ runLen := func(run []exprClause) int64 { return int64(len(ir.StringVal(run[0].lo))) }
+
+ // Collapse runs of consecutive strings with the same length.
+ var runs [][]exprClause
+ start := 0
+ for i := 1; i < len(cc); i++ {
+ if runLen(cc[start:]) != runLen(cc[i:]) {
+ runs = append(runs, cc[start:i])
+ start = i
+ }
+ }
+ runs = append(runs, cc[start:])
+
+ // Perform two-level binary search.
+ binarySearch(len(runs), &s.done,
+ func(i int) ir.Node {
+ return ir.NewBinaryExpr(base.Pos, ir.OLE, ir.NewUnaryExpr(base.Pos, ir.OLEN, s.exprname), ir.NewInt(runLen(runs[i-1])))
+ },
+ func(i int, nif *ir.IfStmt) {
+ run := runs[i]
+ nif.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, ir.NewUnaryExpr(base.Pos, ir.OLEN, s.exprname), ir.NewInt(runLen(run)))
+ s.search(run, &nif.Body)
+ },
+ )
+ return
+ }
+
+ sort.Slice(cc, func(i, j int) bool {
+ return constant.Compare(cc[i].lo.Val(), token.LSS, cc[j].lo.Val())
+ })
+
+ // Merge consecutive integer cases.
+ if s.exprname.Type().IsInteger() {
+ consecutive := func(last, next constant.Value) bool {
+ delta := constant.BinaryOp(next, token.SUB, last)
+ return constant.Compare(delta, token.EQL, constant.MakeInt64(1))
+ }
+
+ merged := cc[:1]
+ for _, c := range cc[1:] {
+ last := &merged[len(merged)-1]
+ if last.jmp == c.jmp && consecutive(last.hi.Val(), c.lo.Val()) {
+ last.hi = c.lo
+ } else {
+ merged = append(merged, c)
+ }
+ }
+ cc = merged
+ }
+
+ s.search(cc, &s.done)
+}
+
+func (s *exprSwitch) search(cc []exprClause, out *ir.Nodes) {
+ binarySearch(len(cc), out,
+ func(i int) ir.Node {
+ return ir.NewBinaryExpr(base.Pos, ir.OLE, s.exprname, cc[i-1].hi)
+ },
+ func(i int, nif *ir.IfStmt) {
+ c := &cc[i]
+ nif.Cond = c.test(s.exprname)
+ nif.Body = []ir.Node{c.jmp}
+ },
+ )
+}
+
+func (c *exprClause) test(exprname ir.Node) ir.Node {
+ // Integer range.
+ if c.hi != c.lo {
+ low := ir.NewBinaryExpr(c.pos, ir.OGE, exprname, c.lo)
+ high := ir.NewBinaryExpr(c.pos, ir.OLE, exprname, c.hi)
+ return ir.NewLogicalExpr(c.pos, ir.OANDAND, low, high)
+ }
+
+ // Optimize "switch true { ...}" and "switch false { ... }".
+ if ir.IsConst(exprname, constant.Bool) && !c.lo.Type().IsInterface() {
+ if ir.BoolVal(exprname) {
+ return c.lo
+ } else {
+ return ir.NewUnaryExpr(c.pos, ir.ONOT, c.lo)
+ }
+ }
+
+ return ir.NewBinaryExpr(c.pos, ir.OEQ, exprname, c.lo)
+}
+
+func allCaseExprsAreSideEffectFree(sw *ir.SwitchStmt) bool {
+ // In theory, we could be more aggressive, allowing any
+ // side-effect-free expressions in cases, but it's a bit
+ // tricky because some of that information is unavailable due
+ // to the introduction of temporaries during order.
+ // Restricting to constants is simple and probably powerful
+ // enough.
+
+ for _, ncase := range sw.Cases {
+ for _, v := range ncase.List {
+ if v.Op() != ir.OLITERAL {
+ return false
+ }
+ }
+ }
+ return true
+}
+
+// endsInFallthrough reports whether stmts ends with a "fallthrough" statement.
+func endsInFallthrough(stmts []ir.Node) (bool, src.XPos) {
+ // Search backwards for the index of the fallthrough
+ // statement. Do not assume it'll be in the last
+ // position, since in some cases (e.g. when the statement
+ // list contains autotmp_ variables), one or more OVARKILL
+ // nodes will be at the end of the list.
+
+ i := len(stmts) - 1
+ for i >= 0 && stmts[i].Op() == ir.OVARKILL {
+ i--
+ }
+ if i < 0 {
+ return false, src.NoXPos
+ }
+ return stmts[i].Op() == ir.OFALL, stmts[i].Pos()
+}
+
+// walkSwitchType generates an AST that implements sw, where sw is a
+// type switch.
+func walkSwitchType(sw *ir.SwitchStmt) {
+ var s typeSwitch
+ s.facename = sw.Tag.(*ir.TypeSwitchGuard).X
+ sw.Tag = nil
+
+ s.facename = walkExpr(s.facename, sw.PtrInit())
+ s.facename = copyExpr(s.facename, s.facename.Type(), &sw.Compiled)
+ s.okname = typecheck.Temp(types.Types[types.TBOOL])
+
+ // Get interface descriptor word.
+ // For empty interfaces this will be the type.
+ // For non-empty interfaces this will be the itab.
+ itab := ir.NewUnaryExpr(base.Pos, ir.OITAB, s.facename)
+
+ // For empty interfaces, do:
+ // if e._type == nil {
+ // do nil case if it exists, otherwise default
+ // }
+ // h := e._type.hash
+ // Use a similar strategy for non-empty interfaces.
+ ifNil := ir.NewIfStmt(base.Pos, nil, nil, nil)
+ ifNil.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, itab, typecheck.NodNil())
+ base.Pos = base.Pos.WithNotStmt() // disable statement marks after the first check.
+ ifNil.Cond = typecheck.Expr(ifNil.Cond)
+ ifNil.Cond = typecheck.DefaultLit(ifNil.Cond, nil)
+ // ifNil.Nbody assigned at end.
+ sw.Compiled.Append(ifNil)
+
+ // Load hash from type or itab.
+ dotHash := typeHashFieldOf(base.Pos, itab)
+ s.hashname = copyExpr(dotHash, dotHash.Type(), &sw.Compiled)
+
+ br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
+ var defaultGoto, nilGoto ir.Node
+ var body ir.Nodes
+ for _, ncase := range sw.Cases {
+ caseVar := ncase.Var
+
+ // For single-type cases with an interface type,
+ // we initialize the case variable as part of the type assertion.
+ // In other cases, we initialize it in the body.
+ var singleType *types.Type
+ if len(ncase.List) == 1 && ncase.List[0].Op() == ir.OTYPE {
+ singleType = ncase.List[0].Type()
+ }
+ caseVarInitialized := false
+
+ label := typecheck.AutoLabel(".s")
+ jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, label)
+
+ if len(ncase.List) == 0 { // default:
+ if defaultGoto != nil {
+ base.Fatalf("duplicate default case not detected during typechecking")
+ }
+ defaultGoto = jmp
+ }
+
+ for _, n1 := range ncase.List {
+ if ir.IsNil(n1) { // case nil:
+ if nilGoto != nil {
+ base.Fatalf("duplicate nil case not detected during typechecking")
+ }
+ nilGoto = jmp
+ continue
+ }
+
+ if singleType != nil && singleType.IsInterface() {
+ s.Add(ncase.Pos(), n1, caseVar, jmp)
+ caseVarInitialized = true
+ } else {
+ s.Add(ncase.Pos(), n1, nil, jmp)
+ }
+ }
+
+ body.Append(ir.NewLabelStmt(ncase.Pos(), label))
+ if caseVar != nil && !caseVarInitialized {
+ val := s.facename
+ if singleType != nil {
+ // We have a single concrete type. Extract the data.
+ if singleType.IsInterface() {
+ base.Fatalf("singleType interface should have been handled in Add")
+ }
+ val = ifaceData(ncase.Pos(), s.facename, singleType)
+ }
+ if len(ncase.List) == 1 && ncase.List[0].Op() == ir.ODYNAMICTYPE {
+ dt := ncase.List[0].(*ir.DynamicType)
+ x := ir.NewDynamicTypeAssertExpr(ncase.Pos(), ir.ODYNAMICDOTTYPE, val, dt.X)
+ if dt.ITab != nil {
+ // TODO: make ITab a separate field in DynamicTypeAssertExpr?
+ x.T = dt.ITab
+ }
+ x.SetType(caseVar.Type())
+ x.SetTypecheck(1)
+ val = x
+ }
+ l := []ir.Node{
+ ir.NewDecl(ncase.Pos(), ir.ODCL, caseVar),
+ ir.NewAssignStmt(ncase.Pos(), caseVar, val),
+ }
+ typecheck.Stmts(l)
+ body.Append(l...)
+ }
+ body.Append(ncase.Body...)
+ body.Append(br)
+ }
+ sw.Cases = nil
+
+ if defaultGoto == nil {
+ defaultGoto = br
+ }
+ if nilGoto == nil {
+ nilGoto = defaultGoto
+ }
+ ifNil.Body = []ir.Node{nilGoto}
+
+ s.Emit(&sw.Compiled)
+ sw.Compiled.Append(defaultGoto)
+ sw.Compiled.Append(body.Take()...)
+
+ walkStmtList(sw.Compiled)
+}
+
+// typeHashFieldOf returns an expression to select the type hash field
+// from an interface's descriptor word (whether a *runtime._type or
+// *runtime.itab pointer).
+func typeHashFieldOf(pos src.XPos, itab *ir.UnaryExpr) *ir.SelectorExpr {
+ if itab.Op() != ir.OITAB {
+ base.Fatalf("expected OITAB, got %v", itab.Op())
+ }
+ var hashField *types.Field
+ if itab.X.Type().IsEmptyInterface() {
+ // runtime._type's hash field
+ if rtypeHashField == nil {
+ rtypeHashField = runtimeField("hash", int64(2*types.PtrSize), types.Types[types.TUINT32])
+ }
+ hashField = rtypeHashField
+ } else {
+ // runtime.itab's hash field
+ if itabHashField == nil {
+ itabHashField = runtimeField("hash", int64(2*types.PtrSize), types.Types[types.TUINT32])
+ }
+ hashField = itabHashField
+ }
+ return boundedDotPtr(pos, itab, hashField)
+}
+
+var rtypeHashField, itabHashField *types.Field
+
+// A typeSwitch walks a type switch.
+type typeSwitch struct {
+ // Temporary variables (i.e., ONAMEs) used by type switch dispatch logic:
+ facename ir.Node // value being type-switched on
+ hashname ir.Node // type hash of the value being type-switched on
+ okname ir.Node // boolean used for comma-ok type assertions
+
+ done ir.Nodes
+ clauses []typeClause
+}
+
+type typeClause struct {
+ hash uint32
+ body ir.Nodes
+}
+
+func (s *typeSwitch) Add(pos src.XPos, n1 ir.Node, caseVar *ir.Name, jmp ir.Node) {
+ typ := n1.Type()
+ var body ir.Nodes
+ if caseVar != nil {
+ l := []ir.Node{
+ ir.NewDecl(pos, ir.ODCL, caseVar),
+ ir.NewAssignStmt(pos, caseVar, nil),
+ }
+ typecheck.Stmts(l)
+ body.Append(l...)
+ } else {
+ caseVar = ir.BlankNode.(*ir.Name)
+ }
+
+ // cv, ok = iface.(type)
+ as := ir.NewAssignListStmt(pos, ir.OAS2, nil, nil)
+ as.Lhs = []ir.Node{caseVar, s.okname} // cv, ok =
+ switch n1.Op() {
+ case ir.OTYPE:
+ // Static type assertion (non-generic)
+ dot := ir.NewTypeAssertExpr(pos, s.facename, nil)
+ dot.SetType(typ) // iface.(type)
+ as.Rhs = []ir.Node{dot}
+ case ir.ODYNAMICTYPE:
+ // Dynamic type assertion (generic)
+ dt := n1.(*ir.DynamicType)
+ dot := ir.NewDynamicTypeAssertExpr(pos, ir.ODYNAMICDOTTYPE, s.facename, dt.X)
+ if dt.ITab != nil {
+ dot.T = dt.ITab
+ }
+ dot.SetType(typ)
+ dot.SetTypecheck(1)
+ as.Rhs = []ir.Node{dot}
+ default:
+ base.Fatalf("unhandled type case %s", n1.Op())
+ }
+ appendWalkStmt(&body, as)
+
+ // if ok { goto label }
+ nif := ir.NewIfStmt(pos, nil, nil, nil)
+ nif.Cond = s.okname
+ nif.Body = []ir.Node{jmp}
+ body.Append(nif)
+
+ if n1.Op() == ir.OTYPE && !typ.IsInterface() {
+ // Defer static, noninterface cases so they can be binary searched by hash.
+ s.clauses = append(s.clauses, typeClause{
+ hash: types.TypeHash(n1.Type()),
+ body: body,
+ })
+ return
+ }
+
+ s.flush()
+ s.done.Append(body.Take()...)
+}
+
+func (s *typeSwitch) Emit(out *ir.Nodes) {
+ s.flush()
+ out.Append(s.done.Take()...)
+}
+
+func (s *typeSwitch) flush() {
+ cc := s.clauses
+ s.clauses = nil
+ if len(cc) == 0 {
+ return
+ }
+
+ sort.Slice(cc, func(i, j int) bool { return cc[i].hash < cc[j].hash })
+
+ // Combine adjacent cases with the same hash.
+ merged := cc[:1]
+ for _, c := range cc[1:] {
+ last := &merged[len(merged)-1]
+ if last.hash == c.hash {
+ last.body.Append(c.body.Take()...)
+ } else {
+ merged = append(merged, c)
+ }
+ }
+ cc = merged
+
+ binarySearch(len(cc), &s.done,
+ func(i int) ir.Node {
+ return ir.NewBinaryExpr(base.Pos, ir.OLE, s.hashname, ir.NewInt(int64(cc[i-1].hash)))
+ },
+ func(i int, nif *ir.IfStmt) {
+ // TODO(mdempsky): Omit hash equality check if
+ // there's only one type.
+ c := cc[i]
+ nif.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, s.hashname, ir.NewInt(int64(c.hash)))
+ nif.Body.Append(c.body.Take()...)
+ },
+ )
+}
+
+// binarySearch constructs a binary search tree for handling n cases,
+// and appends it to out. It's used for efficiently implementing
+// switch statements.
+//
+// less(i) should return a boolean expression. If it evaluates true,
+// then cases before i will be tested; otherwise, cases i and later.
+//
+// leaf(i, nif) should setup nif (an OIF node) to test case i. In
+// particular, it should set nif.Left and nif.Nbody.
+func binarySearch(n int, out *ir.Nodes, less func(i int) ir.Node, leaf func(i int, nif *ir.IfStmt)) {
+ const binarySearchMin = 4 // minimum number of cases for binary search
+
+ var do func(lo, hi int, out *ir.Nodes)
+ do = func(lo, hi int, out *ir.Nodes) {
+ n := hi - lo
+ if n < binarySearchMin {
+ for i := lo; i < hi; i++ {
+ nif := ir.NewIfStmt(base.Pos, nil, nil, nil)
+ leaf(i, nif)
+ base.Pos = base.Pos.WithNotStmt()
+ nif.Cond = typecheck.Expr(nif.Cond)
+ nif.Cond = typecheck.DefaultLit(nif.Cond, nil)
+ out.Append(nif)
+ out = &nif.Else
+ }
+ return
+ }
+
+ half := lo + n/2
+ nif := ir.NewIfStmt(base.Pos, nil, nil, nil)
+ nif.Cond = less(half)
+ base.Pos = base.Pos.WithNotStmt()
+ nif.Cond = typecheck.Expr(nif.Cond)
+ nif.Cond = typecheck.DefaultLit(nif.Cond, nil)
+ do(lo, half, &nif.Body)
+ do(half, hi, &nif.Else)
+ out.Append(nif)
+ }
+
+ do(0, n, out)
+}