summaryrefslogtreecommitdiffstats
path: root/src/cmd/compile/internal/ssa/sccp.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/cmd/compile/internal/ssa/sccp.go')
-rw-r--r--src/cmd/compile/internal/ssa/sccp.go585
1 files changed, 585 insertions, 0 deletions
diff --git a/src/cmd/compile/internal/ssa/sccp.go b/src/cmd/compile/internal/ssa/sccp.go
new file mode 100644
index 0000000..77a6f50
--- /dev/null
+++ b/src/cmd/compile/internal/ssa/sccp.go
@@ -0,0 +1,585 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssa
+
+import (
+ "fmt"
+)
+
+// ----------------------------------------------------------------------------
+// Sparse Conditional Constant Propagation
+//
+// Described in
+// Mark N. Wegman, F. Kenneth Zadeck: Constant Propagation with Conditional Branches.
+// TOPLAS 1991.
+//
+// This algorithm uses three level lattice for SSA value
+//
+// Top undefined
+// / | \
+// .. 1 2 3 .. constant
+// \ | /
+// Bottom not constant
+//
+// It starts with optimistically assuming that all SSA values are initially Top
+// and then propagates constant facts only along reachable control flow paths.
+// Since some basic blocks are not visited yet, corresponding inputs of phi become
+// Top, we use the meet(phi) to compute its lattice.
+//
+// Top ∩ any = any
+// Bottom ∩ any = Bottom
+// ConstantA ∩ ConstantA = ConstantA
+// ConstantA ∩ ConstantB = Bottom
+//
+// Each lattice value is lowered most twice(Top to Constant, Constant to Bottom)
+// due to lattice depth, resulting in a fast convergence speed of the algorithm.
+// In this way, sccp can discover optimization opportunities that cannot be found
+// by just combining constant folding and constant propagation and dead code
+// elimination separately.
+
+// Three level lattice holds compile time knowledge about SSA value
+const (
+ top int8 = iota // undefined
+ constant // constant
+ bottom // not a constant
+)
+
+type lattice struct {
+ tag int8 // lattice type
+ val *Value // constant value
+}
+
+type worklist struct {
+ f *Func // the target function to be optimized out
+ edges []Edge // propagate constant facts through edges
+ uses []*Value // re-visiting set
+ visited map[Edge]bool // visited edges
+ latticeCells map[*Value]lattice // constant lattices
+ defUse map[*Value][]*Value // def-use chains for some values
+ defBlock map[*Value][]*Block // use blocks of def
+ visitedBlock []bool // visited block
+}
+
+// sccp stands for sparse conditional constant propagation, it propagates constants
+// through CFG conditionally and applies constant folding, constant replacement and
+// dead code elimination all together.
+func sccp(f *Func) {
+ var t worklist
+ t.f = f
+ t.edges = make([]Edge, 0)
+ t.visited = make(map[Edge]bool)
+ t.edges = append(t.edges, Edge{f.Entry, 0})
+ t.defUse = make(map[*Value][]*Value)
+ t.defBlock = make(map[*Value][]*Block)
+ t.latticeCells = make(map[*Value]lattice)
+ t.visitedBlock = f.Cache.allocBoolSlice(f.NumBlocks())
+ defer f.Cache.freeBoolSlice(t.visitedBlock)
+
+ // build it early since we rely heavily on the def-use chain later
+ t.buildDefUses()
+
+ // pick up either an edge or SSA value from worklilst, process it
+ for {
+ if len(t.edges) > 0 {
+ edge := t.edges[0]
+ t.edges = t.edges[1:]
+ if _, exist := t.visited[edge]; !exist {
+ dest := edge.b
+ destVisited := t.visitedBlock[dest.ID]
+
+ // mark edge as visited
+ t.visited[edge] = true
+ t.visitedBlock[dest.ID] = true
+ for _, val := range dest.Values {
+ if val.Op == OpPhi || !destVisited {
+ t.visitValue(val)
+ }
+ }
+ // propagates constants facts through CFG, taking condition test
+ // into account
+ if !destVisited {
+ t.propagate(dest)
+ }
+ }
+ continue
+ }
+ if len(t.uses) > 0 {
+ use := t.uses[0]
+ t.uses = t.uses[1:]
+ t.visitValue(use)
+ continue
+ }
+ break
+ }
+
+ // apply optimizations based on discovered constants
+ constCnt, rewireCnt := t.replaceConst()
+ if f.pass.debug > 0 {
+ if constCnt > 0 || rewireCnt > 0 {
+ fmt.Printf("Phase SCCP for %v : %v constants, %v dce\n", f.Name, constCnt, rewireCnt)
+ }
+ }
+}
+
+func equals(a, b lattice) bool {
+ if a == b {
+ // fast path
+ return true
+ }
+ if a.tag != b.tag {
+ return false
+ }
+ if a.tag == constant {
+ // The same content of const value may be different, we should
+ // compare with auxInt instead
+ v1 := a.val
+ v2 := b.val
+ if v1.Op == v2.Op && v1.AuxInt == v2.AuxInt {
+ return true
+ } else {
+ return false
+ }
+ }
+ return true
+}
+
+// possibleConst checks if Value can be fold to const. For those Values that can
+// never become constants(e.g. StaticCall), we don't make futile efforts.
+func possibleConst(val *Value) bool {
+ if isConst(val) {
+ return true
+ }
+ switch val.Op {
+ case OpCopy:
+ return true
+ case OpPhi:
+ return true
+ case
+ // negate
+ OpNeg8, OpNeg16, OpNeg32, OpNeg64, OpNeg32F, OpNeg64F,
+ OpCom8, OpCom16, OpCom32, OpCom64,
+ // math
+ OpFloor, OpCeil, OpTrunc, OpRoundToEven, OpSqrt,
+ // conversion
+ OpTrunc16to8, OpTrunc32to8, OpTrunc32to16, OpTrunc64to8,
+ OpTrunc64to16, OpTrunc64to32, OpCvt32to32F, OpCvt32to64F,
+ OpCvt64to32F, OpCvt64to64F, OpCvt32Fto32, OpCvt32Fto64,
+ OpCvt64Fto32, OpCvt64Fto64, OpCvt32Fto64F, OpCvt64Fto32F,
+ OpCvtBoolToUint8,
+ OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32,
+ OpZeroExt16to64, OpZeroExt32to64, OpSignExt8to16, OpSignExt8to32,
+ OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64,
+ // bit
+ OpCtz8, OpCtz16, OpCtz32, OpCtz64,
+ // mask
+ OpSlicemask,
+ // safety check
+ OpIsNonNil,
+ // not
+ OpNot:
+ return true
+ case
+ // add
+ OpAdd64, OpAdd32, OpAdd16, OpAdd8,
+ OpAdd32F, OpAdd64F,
+ // sub
+ OpSub64, OpSub32, OpSub16, OpSub8,
+ OpSub32F, OpSub64F,
+ // mul
+ OpMul64, OpMul32, OpMul16, OpMul8,
+ OpMul32F, OpMul64F,
+ // div
+ OpDiv32F, OpDiv64F,
+ OpDiv8, OpDiv16, OpDiv32, OpDiv64,
+ OpDiv8u, OpDiv16u, OpDiv32u, OpDiv64u,
+ OpMod8, OpMod16, OpMod32, OpMod64,
+ OpMod8u, OpMod16u, OpMod32u, OpMod64u,
+ // compare
+ OpEq64, OpEq32, OpEq16, OpEq8,
+ OpEq32F, OpEq64F,
+ OpLess64, OpLess32, OpLess16, OpLess8,
+ OpLess64U, OpLess32U, OpLess16U, OpLess8U,
+ OpLess32F, OpLess64F,
+ OpLeq64, OpLeq32, OpLeq16, OpLeq8,
+ OpLeq64U, OpLeq32U, OpLeq16U, OpLeq8U,
+ OpLeq32F, OpLeq64F,
+ OpEqB, OpNeqB,
+ // shift
+ OpLsh64x64, OpRsh64x64, OpRsh64Ux64, OpLsh32x64,
+ OpRsh32x64, OpRsh32Ux64, OpLsh16x64, OpRsh16x64,
+ OpRsh16Ux64, OpLsh8x64, OpRsh8x64, OpRsh8Ux64,
+ // safety check
+ OpIsInBounds, OpIsSliceInBounds,
+ // bit
+ OpAnd8, OpAnd16, OpAnd32, OpAnd64,
+ OpOr8, OpOr16, OpOr32, OpOr64,
+ OpXor8, OpXor16, OpXor32, OpXor64:
+ return true
+ default:
+ return false
+ }
+}
+
+func (t *worklist) getLatticeCell(val *Value) lattice {
+ if !possibleConst(val) {
+ // they are always worst
+ return lattice{bottom, nil}
+ }
+ lt, exist := t.latticeCells[val]
+ if !exist {
+ return lattice{top, nil} // optimistically for un-visited value
+ }
+ return lt
+}
+
+func isConst(val *Value) bool {
+ switch val.Op {
+ case OpConst64, OpConst32, OpConst16, OpConst8,
+ OpConstBool, OpConst32F, OpConst64F:
+ return true
+ default:
+ return false
+ }
+}
+
+// buildDefUses builds def-use chain for some values early, because once the
+// lattice of a value is changed, we need to update lattices of use. But we don't
+// need all uses of it, only uses that can become constants would be added into
+// re-visit worklist since no matter how many times they are revisited, uses which
+// can't become constants lattice remains unchanged, i.e. Bottom.
+func (t *worklist) buildDefUses() {
+ for _, block := range t.f.Blocks {
+ for _, val := range block.Values {
+ for _, arg := range val.Args {
+ // find its uses, only uses that can become constants take into account
+ if possibleConst(arg) && possibleConst(val) {
+ if _, exist := t.defUse[arg]; !exist {
+ t.defUse[arg] = make([]*Value, 0, arg.Uses)
+ }
+ t.defUse[arg] = append(t.defUse[arg], val)
+ }
+ }
+ }
+ for _, ctl := range block.ControlValues() {
+ // for control values that can become constants, find their use blocks
+ if possibleConst(ctl) {
+ t.defBlock[ctl] = append(t.defBlock[ctl], block)
+ }
+ }
+ }
+}
+
+// addUses finds all uses of value and appends them into work list for further process
+func (t *worklist) addUses(val *Value) {
+ for _, use := range t.defUse[val] {
+ if val == use {
+ // Phi may refer to itself as uses, ignore them to avoid re-visiting phi
+ // for performance reason
+ continue
+ }
+ t.uses = append(t.uses, use)
+ }
+ for _, block := range t.defBlock[val] {
+ if t.visitedBlock[block.ID] {
+ t.propagate(block)
+ }
+ }
+}
+
+// meet meets all of phi arguments and computes result lattice
+func (t *worklist) meet(val *Value) lattice {
+ optimisticLt := lattice{top, nil}
+ for i := 0; i < len(val.Args); i++ {
+ edge := Edge{val.Block, i}
+ // If incoming edge for phi is not visited, assume top optimistically.
+ // According to rules of meet:
+ // Top ∩ any = any
+ // Top participates in meet() but does not affect the result, so here
+ // we will ignore Top and only take other lattices into consideration.
+ if _, exist := t.visited[edge]; exist {
+ lt := t.getLatticeCell(val.Args[i])
+ if lt.tag == constant {
+ if optimisticLt.tag == top {
+ optimisticLt = lt
+ } else {
+ if !equals(optimisticLt, lt) {
+ // ConstantA ∩ ConstantB = Bottom
+ return lattice{bottom, nil}
+ }
+ }
+ } else if lt.tag == bottom {
+ // Bottom ∩ any = Bottom
+ return lattice{bottom, nil}
+ } else {
+ // Top ∩ any = any
+ }
+ } else {
+ // Top ∩ any = any
+ }
+ }
+
+ // ConstantA ∩ ConstantA = ConstantA or Top ∩ any = any
+ return optimisticLt
+}
+
+func computeLattice(f *Func, val *Value, args ...*Value) lattice {
+ // In general, we need to perform constant evaluation based on constant args:
+ //
+ // res := lattice{constant, nil}
+ // switch op {
+ // case OpAdd16:
+ // res.val = newConst(argLt1.val.AuxInt16() + argLt2.val.AuxInt16())
+ // case OpAdd32:
+ // res.val = newConst(argLt1.val.AuxInt32() + argLt2.val.AuxInt32())
+ // case OpDiv8:
+ // if !isDivideByZero(argLt2.val.AuxInt8()) {
+ // res.val = newConst(argLt1.val.AuxInt8() / argLt2.val.AuxInt8())
+ // }
+ // ...
+ // }
+ //
+ // However, this would create a huge switch for all opcodes that can be
+ // evaluated during compile time. Moreover, some operations can be evaluated
+ // only if its arguments satisfy additional conditions(e.g. divide by zero).
+ // It's fragile and error prone. We did a trick by reusing the existing rules
+ // in generic rules for compile-time evaluation. But generic rules rewrite
+ // original value, this behavior is undesired, because the lattice of values
+ // may change multiple times, once it was rewritten, we lose the opportunity
+ // to change it permanently, which can lead to errors. For example, We cannot
+ // change its value immediately after visiting Phi, because some of its input
+ // edges may still not be visited at this moment.
+ constValue := f.newValue(val.Op, val.Type, f.Entry, val.Pos)
+ constValue.AddArgs(args...)
+ matched := rewriteValuegeneric(constValue)
+ if matched {
+ if isConst(constValue) {
+ return lattice{constant, constValue}
+ }
+ }
+ // Either we can not match generic rules for given value or it does not
+ // satisfy additional constraints(e.g. divide by zero), in these cases, clean
+ // up temporary value immediately in case they are not dominated by their args.
+ constValue.reset(OpInvalid)
+ return lattice{bottom, nil}
+}
+
+func (t *worklist) visitValue(val *Value) {
+ if !possibleConst(val) {
+ // fast fail for always worst Values, i.e. there is no lowering happen
+ // on them, their lattices must be initially worse Bottom.
+ return
+ }
+
+ oldLt := t.getLatticeCell(val)
+ defer func() {
+ // re-visit all uses of value if its lattice is changed
+ newLt := t.getLatticeCell(val)
+ if !equals(newLt, oldLt) {
+ if int8(oldLt.tag) > int8(newLt.tag) {
+ t.f.Fatalf("Must lower lattice\n")
+ }
+ t.addUses(val)
+ }
+ }()
+
+ switch val.Op {
+ // they are constant values, aren't they?
+ case OpConst64, OpConst32, OpConst16, OpConst8,
+ OpConstBool, OpConst32F, OpConst64F: //TODO: support ConstNil ConstString etc
+ t.latticeCells[val] = lattice{constant, val}
+ // lattice value of copy(x) actually means lattice value of (x)
+ case OpCopy:
+ t.latticeCells[val] = t.getLatticeCell(val.Args[0])
+ // phi should be processed specially
+ case OpPhi:
+ t.latticeCells[val] = t.meet(val)
+ // fold 1-input operations:
+ case
+ // negate
+ OpNeg8, OpNeg16, OpNeg32, OpNeg64, OpNeg32F, OpNeg64F,
+ OpCom8, OpCom16, OpCom32, OpCom64,
+ // math
+ OpFloor, OpCeil, OpTrunc, OpRoundToEven, OpSqrt,
+ // conversion
+ OpTrunc16to8, OpTrunc32to8, OpTrunc32to16, OpTrunc64to8,
+ OpTrunc64to16, OpTrunc64to32, OpCvt32to32F, OpCvt32to64F,
+ OpCvt64to32F, OpCvt64to64F, OpCvt32Fto32, OpCvt32Fto64,
+ OpCvt64Fto32, OpCvt64Fto64, OpCvt32Fto64F, OpCvt64Fto32F,
+ OpCvtBoolToUint8,
+ OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32,
+ OpZeroExt16to64, OpZeroExt32to64, OpSignExt8to16, OpSignExt8to32,
+ OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64,
+ // bit
+ OpCtz8, OpCtz16, OpCtz32, OpCtz64,
+ // mask
+ OpSlicemask,
+ // safety check
+ OpIsNonNil,
+ // not
+ OpNot:
+ lt1 := t.getLatticeCell(val.Args[0])
+
+ if lt1.tag == constant {
+ // here we take a shortcut by reusing generic rules to fold constants
+ t.latticeCells[val] = computeLattice(t.f, val, lt1.val)
+ } else {
+ t.latticeCells[val] = lattice{lt1.tag, nil}
+ }
+ // fold 2-input operations
+ case
+ // add
+ OpAdd64, OpAdd32, OpAdd16, OpAdd8,
+ OpAdd32F, OpAdd64F,
+ // sub
+ OpSub64, OpSub32, OpSub16, OpSub8,
+ OpSub32F, OpSub64F,
+ // mul
+ OpMul64, OpMul32, OpMul16, OpMul8,
+ OpMul32F, OpMul64F,
+ // div
+ OpDiv32F, OpDiv64F,
+ OpDiv8, OpDiv16, OpDiv32, OpDiv64,
+ OpDiv8u, OpDiv16u, OpDiv32u, OpDiv64u, //TODO: support div128u
+ // mod
+ OpMod8, OpMod16, OpMod32, OpMod64,
+ OpMod8u, OpMod16u, OpMod32u, OpMod64u,
+ // compare
+ OpEq64, OpEq32, OpEq16, OpEq8,
+ OpEq32F, OpEq64F,
+ OpLess64, OpLess32, OpLess16, OpLess8,
+ OpLess64U, OpLess32U, OpLess16U, OpLess8U,
+ OpLess32F, OpLess64F,
+ OpLeq64, OpLeq32, OpLeq16, OpLeq8,
+ OpLeq64U, OpLeq32U, OpLeq16U, OpLeq8U,
+ OpLeq32F, OpLeq64F,
+ OpEqB, OpNeqB,
+ // shift
+ OpLsh64x64, OpRsh64x64, OpRsh64Ux64, OpLsh32x64,
+ OpRsh32x64, OpRsh32Ux64, OpLsh16x64, OpRsh16x64,
+ OpRsh16Ux64, OpLsh8x64, OpRsh8x64, OpRsh8Ux64,
+ // safety check
+ OpIsInBounds, OpIsSliceInBounds,
+ // bit
+ OpAnd8, OpAnd16, OpAnd32, OpAnd64,
+ OpOr8, OpOr16, OpOr32, OpOr64,
+ OpXor8, OpXor16, OpXor32, OpXor64:
+ lt1 := t.getLatticeCell(val.Args[0])
+ lt2 := t.getLatticeCell(val.Args[1])
+
+ if lt1.tag == constant && lt2.tag == constant {
+ // here we take a shortcut by reusing generic rules to fold constants
+ t.latticeCells[val] = computeLattice(t.f, val, lt1.val, lt2.val)
+ } else {
+ if lt1.tag == bottom || lt2.tag == bottom {
+ t.latticeCells[val] = lattice{bottom, nil}
+ } else {
+ t.latticeCells[val] = lattice{top, nil}
+ }
+ }
+ default:
+ // Any other type of value cannot be a constant, they are always worst(Bottom)
+ }
+}
+
+// propagate propagates constants facts through CFG. If the block has single successor,
+// add the successor anyway. If the block has multiple successors, only add the
+// branch destination corresponding to lattice value of condition value.
+func (t *worklist) propagate(block *Block) {
+ switch block.Kind {
+ case BlockExit, BlockRet, BlockRetJmp, BlockInvalid:
+ // control flow ends, do nothing then
+ break
+ case BlockDefer:
+ // we know nothing about control flow, add all branch destinations
+ t.edges = append(t.edges, block.Succs...)
+ case BlockFirst:
+ fallthrough // always takes the first branch
+ case BlockPlain:
+ t.edges = append(t.edges, block.Succs[0])
+ case BlockIf, BlockJumpTable:
+ cond := block.ControlValues()[0]
+ condLattice := t.getLatticeCell(cond)
+ if condLattice.tag == bottom {
+ // we know nothing about control flow, add all branch destinations
+ t.edges = append(t.edges, block.Succs...)
+ } else if condLattice.tag == constant {
+ // add branchIdx destinations depends on its condition
+ var branchIdx int64
+ if block.Kind == BlockIf {
+ branchIdx = 1 - condLattice.val.AuxInt
+ } else {
+ branchIdx = condLattice.val.AuxInt
+ }
+ t.edges = append(t.edges, block.Succs[branchIdx])
+ } else {
+ // condition value is not visited yet, don't propagate it now
+ }
+ default:
+ t.f.Fatalf("All kind of block should be processed above.")
+ }
+}
+
+// rewireSuccessor rewires corresponding successors according to constant value
+// discovered by previous analysis. As the result, some successors become unreachable
+// and thus can be removed in further deadcode phase
+func rewireSuccessor(block *Block, constVal *Value) bool {
+ switch block.Kind {
+ case BlockIf:
+ block.removeEdge(int(constVal.AuxInt))
+ block.Kind = BlockPlain
+ block.Likely = BranchUnknown
+ block.ResetControls()
+ return true
+ case BlockJumpTable:
+ // Remove everything but the known taken branch.
+ idx := int(constVal.AuxInt)
+ if idx < 0 || idx >= len(block.Succs) {
+ // This can only happen in unreachable code,
+ // as an invariant of jump tables is that their
+ // input index is in range.
+ // See issue 64826.
+ return false
+ }
+ block.swapSuccessorsByIdx(0, idx)
+ for len(block.Succs) > 1 {
+ block.removeEdge(1)
+ }
+ block.Kind = BlockPlain
+ block.Likely = BranchUnknown
+ block.ResetControls()
+ return true
+ default:
+ return false
+ }
+}
+
+// replaceConst will replace non-constant values that have been proven by sccp
+// to be constants.
+func (t *worklist) replaceConst() (int, int) {
+ constCnt, rewireCnt := 0, 0
+ for val, lt := range t.latticeCells {
+ if lt.tag == constant {
+ if !isConst(val) {
+ if t.f.pass.debug > 0 {
+ fmt.Printf("Replace %v with %v\n", val.LongString(), lt.val.LongString())
+ }
+ val.reset(lt.val.Op)
+ val.AuxInt = lt.val.AuxInt
+ constCnt++
+ }
+ // If const value controls this block, rewires successors according to its value
+ ctrlBlock := t.defBlock[val]
+ for _, block := range ctrlBlock {
+ if rewireSuccessor(block, lt.val) {
+ rewireCnt++
+ if t.f.pass.debug > 0 {
+ fmt.Printf("Rewire %v %v successors\n", block.Kind, block)
+ }
+ }
+ }
+ }
+ }
+ return constCnt, rewireCnt
+}