diff options
Diffstat (limited to 'src/cmd/compile/internal/ssa/sccp.go')
-rw-r--r-- | src/cmd/compile/internal/ssa/sccp.go | 585 |
1 files changed, 585 insertions, 0 deletions
diff --git a/src/cmd/compile/internal/ssa/sccp.go b/src/cmd/compile/internal/ssa/sccp.go new file mode 100644 index 0000000..77a6f50 --- /dev/null +++ b/src/cmd/compile/internal/ssa/sccp.go @@ -0,0 +1,585 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssa + +import ( + "fmt" +) + +// ---------------------------------------------------------------------------- +// Sparse Conditional Constant Propagation +// +// Described in +// Mark N. Wegman, F. Kenneth Zadeck: Constant Propagation with Conditional Branches. +// TOPLAS 1991. +// +// This algorithm uses three level lattice for SSA value +// +// Top undefined +// / | \ +// .. 1 2 3 .. constant +// \ | / +// Bottom not constant +// +// It starts with optimistically assuming that all SSA values are initially Top +// and then propagates constant facts only along reachable control flow paths. +// Since some basic blocks are not visited yet, corresponding inputs of phi become +// Top, we use the meet(phi) to compute its lattice. +// +// Top ∩ any = any +// Bottom ∩ any = Bottom +// ConstantA ∩ ConstantA = ConstantA +// ConstantA ∩ ConstantB = Bottom +// +// Each lattice value is lowered most twice(Top to Constant, Constant to Bottom) +// due to lattice depth, resulting in a fast convergence speed of the algorithm. +// In this way, sccp can discover optimization opportunities that cannot be found +// by just combining constant folding and constant propagation and dead code +// elimination separately. + +// Three level lattice holds compile time knowledge about SSA value +const ( + top int8 = iota // undefined + constant // constant + bottom // not a constant +) + +type lattice struct { + tag int8 // lattice type + val *Value // constant value +} + +type worklist struct { + f *Func // the target function to be optimized out + edges []Edge // propagate constant facts through edges + uses []*Value // re-visiting set + visited map[Edge]bool // visited edges + latticeCells map[*Value]lattice // constant lattices + defUse map[*Value][]*Value // def-use chains for some values + defBlock map[*Value][]*Block // use blocks of def + visitedBlock []bool // visited block +} + +// sccp stands for sparse conditional constant propagation, it propagates constants +// through CFG conditionally and applies constant folding, constant replacement and +// dead code elimination all together. +func sccp(f *Func) { + var t worklist + t.f = f + t.edges = make([]Edge, 0) + t.visited = make(map[Edge]bool) + t.edges = append(t.edges, Edge{f.Entry, 0}) + t.defUse = make(map[*Value][]*Value) + t.defBlock = make(map[*Value][]*Block) + t.latticeCells = make(map[*Value]lattice) + t.visitedBlock = f.Cache.allocBoolSlice(f.NumBlocks()) + defer f.Cache.freeBoolSlice(t.visitedBlock) + + // build it early since we rely heavily on the def-use chain later + t.buildDefUses() + + // pick up either an edge or SSA value from worklilst, process it + for { + if len(t.edges) > 0 { + edge := t.edges[0] + t.edges = t.edges[1:] + if _, exist := t.visited[edge]; !exist { + dest := edge.b + destVisited := t.visitedBlock[dest.ID] + + // mark edge as visited + t.visited[edge] = true + t.visitedBlock[dest.ID] = true + for _, val := range dest.Values { + if val.Op == OpPhi || !destVisited { + t.visitValue(val) + } + } + // propagates constants facts through CFG, taking condition test + // into account + if !destVisited { + t.propagate(dest) + } + } + continue + } + if len(t.uses) > 0 { + use := t.uses[0] + t.uses = t.uses[1:] + t.visitValue(use) + continue + } + break + } + + // apply optimizations based on discovered constants + constCnt, rewireCnt := t.replaceConst() + if f.pass.debug > 0 { + if constCnt > 0 || rewireCnt > 0 { + fmt.Printf("Phase SCCP for %v : %v constants, %v dce\n", f.Name, constCnt, rewireCnt) + } + } +} + +func equals(a, b lattice) bool { + if a == b { + // fast path + return true + } + if a.tag != b.tag { + return false + } + if a.tag == constant { + // The same content of const value may be different, we should + // compare with auxInt instead + v1 := a.val + v2 := b.val + if v1.Op == v2.Op && v1.AuxInt == v2.AuxInt { + return true + } else { + return false + } + } + return true +} + +// possibleConst checks if Value can be fold to const. For those Values that can +// never become constants(e.g. StaticCall), we don't make futile efforts. +func possibleConst(val *Value) bool { + if isConst(val) { + return true + } + switch val.Op { + case OpCopy: + return true + case OpPhi: + return true + case + // negate + OpNeg8, OpNeg16, OpNeg32, OpNeg64, OpNeg32F, OpNeg64F, + OpCom8, OpCom16, OpCom32, OpCom64, + // math + OpFloor, OpCeil, OpTrunc, OpRoundToEven, OpSqrt, + // conversion + OpTrunc16to8, OpTrunc32to8, OpTrunc32to16, OpTrunc64to8, + OpTrunc64to16, OpTrunc64to32, OpCvt32to32F, OpCvt32to64F, + OpCvt64to32F, OpCvt64to64F, OpCvt32Fto32, OpCvt32Fto64, + OpCvt64Fto32, OpCvt64Fto64, OpCvt32Fto64F, OpCvt64Fto32F, + OpCvtBoolToUint8, + OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32, + OpZeroExt16to64, OpZeroExt32to64, OpSignExt8to16, OpSignExt8to32, + OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64, + // bit + OpCtz8, OpCtz16, OpCtz32, OpCtz64, + // mask + OpSlicemask, + // safety check + OpIsNonNil, + // not + OpNot: + return true + case + // add + OpAdd64, OpAdd32, OpAdd16, OpAdd8, + OpAdd32F, OpAdd64F, + // sub + OpSub64, OpSub32, OpSub16, OpSub8, + OpSub32F, OpSub64F, + // mul + OpMul64, OpMul32, OpMul16, OpMul8, + OpMul32F, OpMul64F, + // div + OpDiv32F, OpDiv64F, + OpDiv8, OpDiv16, OpDiv32, OpDiv64, + OpDiv8u, OpDiv16u, OpDiv32u, OpDiv64u, + OpMod8, OpMod16, OpMod32, OpMod64, + OpMod8u, OpMod16u, OpMod32u, OpMod64u, + // compare + OpEq64, OpEq32, OpEq16, OpEq8, + OpEq32F, OpEq64F, + OpLess64, OpLess32, OpLess16, OpLess8, + OpLess64U, OpLess32U, OpLess16U, OpLess8U, + OpLess32F, OpLess64F, + OpLeq64, OpLeq32, OpLeq16, OpLeq8, + OpLeq64U, OpLeq32U, OpLeq16U, OpLeq8U, + OpLeq32F, OpLeq64F, + OpEqB, OpNeqB, + // shift + OpLsh64x64, OpRsh64x64, OpRsh64Ux64, OpLsh32x64, + OpRsh32x64, OpRsh32Ux64, OpLsh16x64, OpRsh16x64, + OpRsh16Ux64, OpLsh8x64, OpRsh8x64, OpRsh8Ux64, + // safety check + OpIsInBounds, OpIsSliceInBounds, + // bit + OpAnd8, OpAnd16, OpAnd32, OpAnd64, + OpOr8, OpOr16, OpOr32, OpOr64, + OpXor8, OpXor16, OpXor32, OpXor64: + return true + default: + return false + } +} + +func (t *worklist) getLatticeCell(val *Value) lattice { + if !possibleConst(val) { + // they are always worst + return lattice{bottom, nil} + } + lt, exist := t.latticeCells[val] + if !exist { + return lattice{top, nil} // optimistically for un-visited value + } + return lt +} + +func isConst(val *Value) bool { + switch val.Op { + case OpConst64, OpConst32, OpConst16, OpConst8, + OpConstBool, OpConst32F, OpConst64F: + return true + default: + return false + } +} + +// buildDefUses builds def-use chain for some values early, because once the +// lattice of a value is changed, we need to update lattices of use. But we don't +// need all uses of it, only uses that can become constants would be added into +// re-visit worklist since no matter how many times they are revisited, uses which +// can't become constants lattice remains unchanged, i.e. Bottom. +func (t *worklist) buildDefUses() { + for _, block := range t.f.Blocks { + for _, val := range block.Values { + for _, arg := range val.Args { + // find its uses, only uses that can become constants take into account + if possibleConst(arg) && possibleConst(val) { + if _, exist := t.defUse[arg]; !exist { + t.defUse[arg] = make([]*Value, 0, arg.Uses) + } + t.defUse[arg] = append(t.defUse[arg], val) + } + } + } + for _, ctl := range block.ControlValues() { + // for control values that can become constants, find their use blocks + if possibleConst(ctl) { + t.defBlock[ctl] = append(t.defBlock[ctl], block) + } + } + } +} + +// addUses finds all uses of value and appends them into work list for further process +func (t *worklist) addUses(val *Value) { + for _, use := range t.defUse[val] { + if val == use { + // Phi may refer to itself as uses, ignore them to avoid re-visiting phi + // for performance reason + continue + } + t.uses = append(t.uses, use) + } + for _, block := range t.defBlock[val] { + if t.visitedBlock[block.ID] { + t.propagate(block) + } + } +} + +// meet meets all of phi arguments and computes result lattice +func (t *worklist) meet(val *Value) lattice { + optimisticLt := lattice{top, nil} + for i := 0; i < len(val.Args); i++ { + edge := Edge{val.Block, i} + // If incoming edge for phi is not visited, assume top optimistically. + // According to rules of meet: + // Top ∩ any = any + // Top participates in meet() but does not affect the result, so here + // we will ignore Top and only take other lattices into consideration. + if _, exist := t.visited[edge]; exist { + lt := t.getLatticeCell(val.Args[i]) + if lt.tag == constant { + if optimisticLt.tag == top { + optimisticLt = lt + } else { + if !equals(optimisticLt, lt) { + // ConstantA ∩ ConstantB = Bottom + return lattice{bottom, nil} + } + } + } else if lt.tag == bottom { + // Bottom ∩ any = Bottom + return lattice{bottom, nil} + } else { + // Top ∩ any = any + } + } else { + // Top ∩ any = any + } + } + + // ConstantA ∩ ConstantA = ConstantA or Top ∩ any = any + return optimisticLt +} + +func computeLattice(f *Func, val *Value, args ...*Value) lattice { + // In general, we need to perform constant evaluation based on constant args: + // + // res := lattice{constant, nil} + // switch op { + // case OpAdd16: + // res.val = newConst(argLt1.val.AuxInt16() + argLt2.val.AuxInt16()) + // case OpAdd32: + // res.val = newConst(argLt1.val.AuxInt32() + argLt2.val.AuxInt32()) + // case OpDiv8: + // if !isDivideByZero(argLt2.val.AuxInt8()) { + // res.val = newConst(argLt1.val.AuxInt8() / argLt2.val.AuxInt8()) + // } + // ... + // } + // + // However, this would create a huge switch for all opcodes that can be + // evaluated during compile time. Moreover, some operations can be evaluated + // only if its arguments satisfy additional conditions(e.g. divide by zero). + // It's fragile and error prone. We did a trick by reusing the existing rules + // in generic rules for compile-time evaluation. But generic rules rewrite + // original value, this behavior is undesired, because the lattice of values + // may change multiple times, once it was rewritten, we lose the opportunity + // to change it permanently, which can lead to errors. For example, We cannot + // change its value immediately after visiting Phi, because some of its input + // edges may still not be visited at this moment. + constValue := f.newValue(val.Op, val.Type, f.Entry, val.Pos) + constValue.AddArgs(args...) + matched := rewriteValuegeneric(constValue) + if matched { + if isConst(constValue) { + return lattice{constant, constValue} + } + } + // Either we can not match generic rules for given value or it does not + // satisfy additional constraints(e.g. divide by zero), in these cases, clean + // up temporary value immediately in case they are not dominated by their args. + constValue.reset(OpInvalid) + return lattice{bottom, nil} +} + +func (t *worklist) visitValue(val *Value) { + if !possibleConst(val) { + // fast fail for always worst Values, i.e. there is no lowering happen + // on them, their lattices must be initially worse Bottom. + return + } + + oldLt := t.getLatticeCell(val) + defer func() { + // re-visit all uses of value if its lattice is changed + newLt := t.getLatticeCell(val) + if !equals(newLt, oldLt) { + if int8(oldLt.tag) > int8(newLt.tag) { + t.f.Fatalf("Must lower lattice\n") + } + t.addUses(val) + } + }() + + switch val.Op { + // they are constant values, aren't they? + case OpConst64, OpConst32, OpConst16, OpConst8, + OpConstBool, OpConst32F, OpConst64F: //TODO: support ConstNil ConstString etc + t.latticeCells[val] = lattice{constant, val} + // lattice value of copy(x) actually means lattice value of (x) + case OpCopy: + t.latticeCells[val] = t.getLatticeCell(val.Args[0]) + // phi should be processed specially + case OpPhi: + t.latticeCells[val] = t.meet(val) + // fold 1-input operations: + case + // negate + OpNeg8, OpNeg16, OpNeg32, OpNeg64, OpNeg32F, OpNeg64F, + OpCom8, OpCom16, OpCom32, OpCom64, + // math + OpFloor, OpCeil, OpTrunc, OpRoundToEven, OpSqrt, + // conversion + OpTrunc16to8, OpTrunc32to8, OpTrunc32to16, OpTrunc64to8, + OpTrunc64to16, OpTrunc64to32, OpCvt32to32F, OpCvt32to64F, + OpCvt64to32F, OpCvt64to64F, OpCvt32Fto32, OpCvt32Fto64, + OpCvt64Fto32, OpCvt64Fto64, OpCvt32Fto64F, OpCvt64Fto32F, + OpCvtBoolToUint8, + OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32, + OpZeroExt16to64, OpZeroExt32to64, OpSignExt8to16, OpSignExt8to32, + OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64, + // bit + OpCtz8, OpCtz16, OpCtz32, OpCtz64, + // mask + OpSlicemask, + // safety check + OpIsNonNil, + // not + OpNot: + lt1 := t.getLatticeCell(val.Args[0]) + + if lt1.tag == constant { + // here we take a shortcut by reusing generic rules to fold constants + t.latticeCells[val] = computeLattice(t.f, val, lt1.val) + } else { + t.latticeCells[val] = lattice{lt1.tag, nil} + } + // fold 2-input operations + case + // add + OpAdd64, OpAdd32, OpAdd16, OpAdd8, + OpAdd32F, OpAdd64F, + // sub + OpSub64, OpSub32, OpSub16, OpSub8, + OpSub32F, OpSub64F, + // mul + OpMul64, OpMul32, OpMul16, OpMul8, + OpMul32F, OpMul64F, + // div + OpDiv32F, OpDiv64F, + OpDiv8, OpDiv16, OpDiv32, OpDiv64, + OpDiv8u, OpDiv16u, OpDiv32u, OpDiv64u, //TODO: support div128u + // mod + OpMod8, OpMod16, OpMod32, OpMod64, + OpMod8u, OpMod16u, OpMod32u, OpMod64u, + // compare + OpEq64, OpEq32, OpEq16, OpEq8, + OpEq32F, OpEq64F, + OpLess64, OpLess32, OpLess16, OpLess8, + OpLess64U, OpLess32U, OpLess16U, OpLess8U, + OpLess32F, OpLess64F, + OpLeq64, OpLeq32, OpLeq16, OpLeq8, + OpLeq64U, OpLeq32U, OpLeq16U, OpLeq8U, + OpLeq32F, OpLeq64F, + OpEqB, OpNeqB, + // shift + OpLsh64x64, OpRsh64x64, OpRsh64Ux64, OpLsh32x64, + OpRsh32x64, OpRsh32Ux64, OpLsh16x64, OpRsh16x64, + OpRsh16Ux64, OpLsh8x64, OpRsh8x64, OpRsh8Ux64, + // safety check + OpIsInBounds, OpIsSliceInBounds, + // bit + OpAnd8, OpAnd16, OpAnd32, OpAnd64, + OpOr8, OpOr16, OpOr32, OpOr64, + OpXor8, OpXor16, OpXor32, OpXor64: + lt1 := t.getLatticeCell(val.Args[0]) + lt2 := t.getLatticeCell(val.Args[1]) + + if lt1.tag == constant && lt2.tag == constant { + // here we take a shortcut by reusing generic rules to fold constants + t.latticeCells[val] = computeLattice(t.f, val, lt1.val, lt2.val) + } else { + if lt1.tag == bottom || lt2.tag == bottom { + t.latticeCells[val] = lattice{bottom, nil} + } else { + t.latticeCells[val] = lattice{top, nil} + } + } + default: + // Any other type of value cannot be a constant, they are always worst(Bottom) + } +} + +// propagate propagates constants facts through CFG. If the block has single successor, +// add the successor anyway. If the block has multiple successors, only add the +// branch destination corresponding to lattice value of condition value. +func (t *worklist) propagate(block *Block) { + switch block.Kind { + case BlockExit, BlockRet, BlockRetJmp, BlockInvalid: + // control flow ends, do nothing then + break + case BlockDefer: + // we know nothing about control flow, add all branch destinations + t.edges = append(t.edges, block.Succs...) + case BlockFirst: + fallthrough // always takes the first branch + case BlockPlain: + t.edges = append(t.edges, block.Succs[0]) + case BlockIf, BlockJumpTable: + cond := block.ControlValues()[0] + condLattice := t.getLatticeCell(cond) + if condLattice.tag == bottom { + // we know nothing about control flow, add all branch destinations + t.edges = append(t.edges, block.Succs...) + } else if condLattice.tag == constant { + // add branchIdx destinations depends on its condition + var branchIdx int64 + if block.Kind == BlockIf { + branchIdx = 1 - condLattice.val.AuxInt + } else { + branchIdx = condLattice.val.AuxInt + } + t.edges = append(t.edges, block.Succs[branchIdx]) + } else { + // condition value is not visited yet, don't propagate it now + } + default: + t.f.Fatalf("All kind of block should be processed above.") + } +} + +// rewireSuccessor rewires corresponding successors according to constant value +// discovered by previous analysis. As the result, some successors become unreachable +// and thus can be removed in further deadcode phase +func rewireSuccessor(block *Block, constVal *Value) bool { + switch block.Kind { + case BlockIf: + block.removeEdge(int(constVal.AuxInt)) + block.Kind = BlockPlain + block.Likely = BranchUnknown + block.ResetControls() + return true + case BlockJumpTable: + // Remove everything but the known taken branch. + idx := int(constVal.AuxInt) + if idx < 0 || idx >= len(block.Succs) { + // This can only happen in unreachable code, + // as an invariant of jump tables is that their + // input index is in range. + // See issue 64826. + return false + } + block.swapSuccessorsByIdx(0, idx) + for len(block.Succs) > 1 { + block.removeEdge(1) + } + block.Kind = BlockPlain + block.Likely = BranchUnknown + block.ResetControls() + return true + default: + return false + } +} + +// replaceConst will replace non-constant values that have been proven by sccp +// to be constants. +func (t *worklist) replaceConst() (int, int) { + constCnt, rewireCnt := 0, 0 + for val, lt := range t.latticeCells { + if lt.tag == constant { + if !isConst(val) { + if t.f.pass.debug > 0 { + fmt.Printf("Replace %v with %v\n", val.LongString(), lt.val.LongString()) + } + val.reset(lt.val.Op) + val.AuxInt = lt.val.AuxInt + constCnt++ + } + // If const value controls this block, rewires successors according to its value + ctrlBlock := t.defBlock[val] + for _, block := range ctrlBlock { + if rewireSuccessor(block, lt.val) { + rewireCnt++ + if t.f.pass.debug > 0 { + fmt.Printf("Rewire %v %v successors\n", block.Kind, block) + } + } + } + } + } + return constCnt, rewireCnt +} |