diff options
Diffstat (limited to 'src/cmd/compile/internal/ssa/prove.go')
-rw-r--r-- | src/cmd/compile/internal/ssa/prove.go | 1426 |
1 files changed, 1426 insertions, 0 deletions
diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go new file mode 100644 index 0000000..8a2e7c0 --- /dev/null +++ b/src/cmd/compile/internal/ssa/prove.go @@ -0,0 +1,1426 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssa + +import ( + "cmd/internal/src" + "fmt" + "math" +) + +type branch int + +const ( + unknown branch = iota + positive + negative +) + +// relation represents the set of possible relations between +// pairs of variables (v, w). Without a priori knowledge the +// mask is lt | eq | gt meaning v can be less than, equal to or +// greater than w. When the execution path branches on the condition +// `v op w` the set of relations is updated to exclude any +// relation not possible due to `v op w` being true (or false). +// +// E.g. +// +// r := relation(...) +// +// if v < w { +// newR := r & lt +// } +// if v >= w { +// newR := r & (eq|gt) +// } +// if v != w { +// newR := r & (lt|gt) +// } +type relation uint + +const ( + lt relation = 1 << iota + eq + gt +) + +var relationStrings = [...]string{ + 0: "none", lt: "<", eq: "==", lt | eq: "<=", + gt: ">", gt | lt: "!=", gt | eq: ">=", gt | eq | lt: "any", +} + +func (r relation) String() string { + if r < relation(len(relationStrings)) { + return relationStrings[r] + } + return fmt.Sprintf("relation(%d)", uint(r)) +} + +// domain represents the domain of a variable pair in which a set +// of relations is known. For example, relations learned for unsigned +// pairs cannot be transferred to signed pairs because the same bit +// representation can mean something else. +type domain uint + +const ( + signed domain = 1 << iota + unsigned + pointer + boolean +) + +var domainStrings = [...]string{ + "signed", "unsigned", "pointer", "boolean", +} + +func (d domain) String() string { + s := "" + for i, ds := range domainStrings { + if d&(1<<uint(i)) != 0 { + if len(s) != 0 { + s += "|" + } + s += ds + d &^= 1 << uint(i) + } + } + if d != 0 { + if len(s) != 0 { + s += "|" + } + s += fmt.Sprintf("0x%x", uint(d)) + } + return s +} + +type pair struct { + v, w *Value // a pair of values, ordered by ID. + // v can be nil, to mean the zero value. + // for booleans the zero value (v == nil) is false. + d domain +} + +// fact is a pair plus a relation for that pair. +type fact struct { + p pair + r relation +} + +// a limit records known upper and lower bounds for a value. +type limit struct { + min, max int64 // min <= value <= max, signed + umin, umax uint64 // umin <= value <= umax, unsigned +} + +func (l limit) String() string { + return fmt.Sprintf("sm,SM,um,UM=%d,%d,%d,%d", l.min, l.max, l.umin, l.umax) +} + +func (l limit) intersect(l2 limit) limit { + if l.min < l2.min { + l.min = l2.min + } + if l.umin < l2.umin { + l.umin = l2.umin + } + if l.max > l2.max { + l.max = l2.max + } + if l.umax > l2.umax { + l.umax = l2.umax + } + return l +} + +var noLimit = limit{math.MinInt64, math.MaxInt64, 0, math.MaxUint64} + +// a limitFact is a limit known for a particular value. +type limitFact struct { + vid ID + limit limit +} + +// factsTable keeps track of relations between pairs of values. +// +// The fact table logic is sound, but incomplete. Outside of a few +// special cases, it performs no deduction or arithmetic. While there +// are known decision procedures for this, the ad hoc approach taken +// by the facts table is effective for real code while remaining very +// efficient. +type factsTable struct { + // unsat is true if facts contains a contradiction. + // + // Note that the factsTable logic is incomplete, so if unsat + // is false, the assertions in factsTable could be satisfiable + // *or* unsatisfiable. + unsat bool // true if facts contains a contradiction + unsatDepth int // number of unsat checkpoints + + facts map[pair]relation // current known set of relation + stack []fact // previous sets of relations + + // order is a couple of partial order sets that record information + // about relations between SSA values in the signed and unsigned + // domain. + orderS *poset + orderU *poset + + // known lower and upper bounds on individual values. + limits map[ID]limit + limitStack []limitFact // previous entries + + // For each slice s, a map from s to a len(s)/cap(s) value (if any) + // TODO: check if there are cases that matter where we have + // more than one len(s) for a slice. We could keep a list if necessary. + lens map[ID]*Value + caps map[ID]*Value + + // zero is a zero-valued constant + zero *Value +} + +// checkpointFact is an invalid value used for checkpointing +// and restoring factsTable. +var checkpointFact = fact{} +var checkpointBound = limitFact{} + +func newFactsTable(f *Func) *factsTable { + ft := &factsTable{} + ft.orderS = f.newPoset() + ft.orderU = f.newPoset() + ft.orderS.SetUnsigned(false) + ft.orderU.SetUnsigned(true) + ft.facts = make(map[pair]relation) + ft.stack = make([]fact, 4) + ft.limits = make(map[ID]limit) + ft.limitStack = make([]limitFact, 4) + ft.zero = f.ConstInt64(f.Config.Types.Int64, 0) + return ft +} + +// update updates the set of relations between v and w in domain d +// restricting it to r. +func (ft *factsTable) update(parent *Block, v, w *Value, d domain, r relation) { + if parent.Func.pass.debug > 2 { + parent.Func.Warnl(parent.Pos, "parent=%s, update %s %s %s", parent, v, w, r) + } + // No need to do anything else if we already found unsat. + if ft.unsat { + return + } + + // Self-fact. It's wasteful to register it into the facts + // table, so just note whether it's satisfiable + if v == w { + if r&eq == 0 { + ft.unsat = true + } + return + } + + if d == signed || d == unsigned { + var ok bool + order := ft.orderS + if d == unsigned { + order = ft.orderU + } + switch r { + case lt: + ok = order.SetOrder(v, w) + case gt: + ok = order.SetOrder(w, v) + case lt | eq: + ok = order.SetOrderOrEqual(v, w) + case gt | eq: + ok = order.SetOrderOrEqual(w, v) + case eq: + ok = order.SetEqual(v, w) + case lt | gt: + ok = order.SetNonEqual(v, w) + default: + panic("unknown relation") + } + if !ok { + if parent.Func.pass.debug > 2 { + parent.Func.Warnl(parent.Pos, "unsat %s %s %s", v, w, r) + } + ft.unsat = true + return + } + } else { + if lessByID(w, v) { + v, w = w, v + r = reverseBits[r] + } + + p := pair{v, w, d} + oldR, ok := ft.facts[p] + if !ok { + if v == w { + oldR = eq + } else { + oldR = lt | eq | gt + } + } + // No changes compared to information already in facts table. + if oldR == r { + return + } + ft.stack = append(ft.stack, fact{p, oldR}) + ft.facts[p] = oldR & r + // If this relation is not satisfiable, mark it and exit right away + if oldR&r == 0 { + if parent.Func.pass.debug > 2 { + parent.Func.Warnl(parent.Pos, "unsat %s %s %s", v, w, r) + } + ft.unsat = true + return + } + } + + // Extract bounds when comparing against constants + if v.isGenericIntConst() { + v, w = w, v + r = reverseBits[r] + } + if v != nil && w.isGenericIntConst() { + // Note: all the +1/-1 below could overflow/underflow. Either will + // still generate correct results, it will just lead to imprecision. + // In fact if there is overflow/underflow, the corresponding + // code is unreachable because the known range is outside the range + // of the value's type. + old, ok := ft.limits[v.ID] + if !ok { + old = noLimit + if v.isGenericIntConst() { + switch d { + case signed: + old.min, old.max = v.AuxInt, v.AuxInt + if v.AuxInt >= 0 { + old.umin, old.umax = uint64(v.AuxInt), uint64(v.AuxInt) + } + case unsigned: + old.umin = v.AuxUnsigned() + old.umax = old.umin + if int64(old.umin) >= 0 { + old.min, old.max = int64(old.umin), int64(old.umin) + } + } + } + } + lim := noLimit + switch d { + case signed: + c := w.AuxInt + switch r { + case lt: + lim.max = c - 1 + case lt | eq: + lim.max = c + case gt | eq: + lim.min = c + case gt: + lim.min = c + 1 + case lt | gt: + lim = old + if c == lim.min { + lim.min++ + } + if c == lim.max { + lim.max-- + } + case eq: + lim.min = c + lim.max = c + } + if lim.min >= 0 { + // int(x) >= 0 && int(x) >= N ⇒ uint(x) >= N + lim.umin = uint64(lim.min) + } + if lim.max != noLimit.max && old.min >= 0 && lim.max >= 0 { + // 0 <= int(x) <= N ⇒ 0 <= uint(x) <= N + // This is for a max update, so the lower bound + // comes from what we already know (old). + lim.umax = uint64(lim.max) + } + case unsigned: + uc := w.AuxUnsigned() + switch r { + case lt: + lim.umax = uc - 1 + case lt | eq: + lim.umax = uc + case gt | eq: + lim.umin = uc + case gt: + lim.umin = uc + 1 + case lt | gt: + lim = old + if uc == lim.umin { + lim.umin++ + } + if uc == lim.umax { + lim.umax-- + } + case eq: + lim.umin = uc + lim.umax = uc + } + // We could use the contrapositives of the + // signed implications to derive signed facts, + // but it turns out not to matter. + } + ft.limitStack = append(ft.limitStack, limitFact{v.ID, old}) + lim = old.intersect(lim) + ft.limits[v.ID] = lim + if v.Block.Func.pass.debug > 2 { + v.Block.Func.Warnl(parent.Pos, "parent=%s, new limits %s %s %s %s", parent, v, w, r, lim.String()) + } + if lim.min > lim.max || lim.umin > lim.umax { + ft.unsat = true + return + } + } + + // Derived facts below here are only about numbers. + if d != signed && d != unsigned { + return + } + + // Additional facts we know given the relationship between len and cap. + // + // TODO: Since prove now derives transitive relations, it + // should be sufficient to learn that len(w) <= cap(w) at the + // beginning of prove where we look for all len/cap ops. + if v.Op == OpSliceLen && r< == 0 && ft.caps[v.Args[0].ID] != nil { + // len(s) > w implies cap(s) > w + // len(s) >= w implies cap(s) >= w + // len(s) == w implies cap(s) >= w + ft.update(parent, ft.caps[v.Args[0].ID], w, d, r|gt) + } + if w.Op == OpSliceLen && r> == 0 && ft.caps[w.Args[0].ID] != nil { + // same, length on the RHS. + ft.update(parent, v, ft.caps[w.Args[0].ID], d, r|lt) + } + if v.Op == OpSliceCap && r> == 0 && ft.lens[v.Args[0].ID] != nil { + // cap(s) < w implies len(s) < w + // cap(s) <= w implies len(s) <= w + // cap(s) == w implies len(s) <= w + ft.update(parent, ft.lens[v.Args[0].ID], w, d, r|lt) + } + if w.Op == OpSliceCap && r< == 0 && ft.lens[w.Args[0].ID] != nil { + // same, capacity on the RHS. + ft.update(parent, v, ft.lens[w.Args[0].ID], d, r|gt) + } + + // Process fence-post implications. + // + // First, make the condition > or >=. + if r == lt || r == lt|eq { + v, w = w, v + r = reverseBits[r] + } + switch r { + case gt: + if x, delta := isConstDelta(v); x != nil && delta == 1 { + // x+1 > w ⇒ x >= w + // + // This is useful for eliminating the + // growslice branch of append. + ft.update(parent, x, w, d, gt|eq) + } else if x, delta := isConstDelta(w); x != nil && delta == -1 { + // v > x-1 ⇒ v >= x + ft.update(parent, v, x, d, gt|eq) + } + case gt | eq: + if x, delta := isConstDelta(v); x != nil && delta == -1 { + // x-1 >= w && x > min ⇒ x > w + // + // Useful for i > 0; s[i-1]. + lim, ok := ft.limits[x.ID] + if ok && ((d == signed && lim.min > opMin[v.Op]) || (d == unsigned && lim.umin > 0)) { + ft.update(parent, x, w, d, gt) + } + } else if x, delta := isConstDelta(w); x != nil && delta == 1 { + // v >= x+1 && x < max ⇒ v > x + lim, ok := ft.limits[x.ID] + if ok && ((d == signed && lim.max < opMax[w.Op]) || (d == unsigned && lim.umax < opUMax[w.Op])) { + ft.update(parent, v, x, d, gt) + } + } + } + + // Process: x+delta > w (with delta constant) + // Only signed domain for now (useful for accesses to slices in loops). + if r == gt || r == gt|eq { + if x, delta := isConstDelta(v); x != nil && d == signed { + if parent.Func.pass.debug > 1 { + parent.Func.Warnl(parent.Pos, "x+d %s w; x:%v %v delta:%v w:%v d:%v", r, x, parent.String(), delta, w.AuxInt, d) + } + if !w.isGenericIntConst() { + // If we know that x+delta > w but w is not constant, we can derive: + // if delta < 0 and x > MinInt - delta, then x > w (because x+delta cannot underflow) + // This is useful for loops with bounds "len(slice)-K" (delta = -K) + if l, has := ft.limits[x.ID]; has && delta < 0 { + if (x.Type.Size() == 8 && l.min >= math.MinInt64-delta) || + (x.Type.Size() == 4 && l.min >= math.MinInt32-delta) { + ft.update(parent, x, w, signed, r) + } + } + } else { + // With w,delta constants, we want to derive: x+delta > w ⇒ x > w-delta + // + // We compute (using integers of the correct size): + // min = w - delta + // max = MaxInt - delta + // + // And we prove that: + // if min<max: min < x AND x <= max + // if min>max: min < x OR x <= max + // + // This is always correct, even in case of overflow. + // + // If the initial fact is x+delta >= w instead, the derived conditions are: + // if min<max: min <= x AND x <= max + // if min>max: min <= x OR x <= max + // + // Notice the conditions for max are still <=, as they handle overflows. + var min, max int64 + var vmin, vmax *Value + switch x.Type.Size() { + case 8: + min = w.AuxInt - delta + max = int64(^uint64(0)>>1) - delta + + vmin = parent.NewValue0I(parent.Pos, OpConst64, parent.Func.Config.Types.Int64, min) + vmax = parent.NewValue0I(parent.Pos, OpConst64, parent.Func.Config.Types.Int64, max) + + case 4: + min = int64(int32(w.AuxInt) - int32(delta)) + max = int64(int32(^uint32(0)>>1) - int32(delta)) + + vmin = parent.NewValue0I(parent.Pos, OpConst32, parent.Func.Config.Types.Int32, min) + vmax = parent.NewValue0I(parent.Pos, OpConst32, parent.Func.Config.Types.Int32, max) + + default: + panic("unimplemented") + } + + if min < max { + // Record that x > min and max >= x + ft.update(parent, x, vmin, d, r) + ft.update(parent, vmax, x, d, r|eq) + } else { + // We know that either x>min OR x<=max. factsTable cannot record OR conditions, + // so let's see if we can already prove that one of them is false, in which case + // the other must be true + if l, has := ft.limits[x.ID]; has { + if l.max <= min { + if r&eq == 0 || l.max < min { + // x>min (x>=min) is impossible, so it must be x<=max + ft.update(parent, vmax, x, d, r|eq) + } + } else if l.min > max { + // x<=max is impossible, so it must be x>min + ft.update(parent, x, vmin, d, r) + } + } + } + } + } + } + + // Look through value-preserving extensions. + // If the domain is appropriate for the pre-extension Type, + // repeat the update with the pre-extension Value. + if isCleanExt(v) { + switch { + case d == signed && v.Args[0].Type.IsSigned(): + fallthrough + case d == unsigned && !v.Args[0].Type.IsSigned(): + ft.update(parent, v.Args[0], w, d, r) + } + } + if isCleanExt(w) { + switch { + case d == signed && w.Args[0].Type.IsSigned(): + fallthrough + case d == unsigned && !w.Args[0].Type.IsSigned(): + ft.update(parent, v, w.Args[0], d, r) + } + } +} + +var opMin = map[Op]int64{ + OpAdd64: math.MinInt64, OpSub64: math.MinInt64, + OpAdd32: math.MinInt32, OpSub32: math.MinInt32, +} + +var opMax = map[Op]int64{ + OpAdd64: math.MaxInt64, OpSub64: math.MaxInt64, + OpAdd32: math.MaxInt32, OpSub32: math.MaxInt32, +} + +var opUMax = map[Op]uint64{ + OpAdd64: math.MaxUint64, OpSub64: math.MaxUint64, + OpAdd32: math.MaxUint32, OpSub32: math.MaxUint32, +} + +// isNonNegative reports whether v is known to be non-negative. +func (ft *factsTable) isNonNegative(v *Value) bool { + if isNonNegative(v) { + return true + } + + var max int64 + switch v.Type.Size() { + case 1: + max = math.MaxInt8 + case 2: + max = math.MaxInt16 + case 4: + max = math.MaxInt32 + case 8: + max = math.MaxInt64 + default: + panic("unexpected integer size") + } + + // Check if the recorded limits can prove that the value is positive + + if l, has := ft.limits[v.ID]; has && (l.min >= 0 || l.umax <= uint64(max)) { + return true + } + + // Check if v = x+delta, and we can use x's limits to prove that it's positive + if x, delta := isConstDelta(v); x != nil { + if l, has := ft.limits[x.ID]; has { + if delta > 0 && l.min >= -delta && l.max <= max-delta { + return true + } + if delta < 0 && l.min >= -delta { + return true + } + } + } + + // Check if v is a value-preserving extension of a non-negative value. + if isCleanExt(v) && ft.isNonNegative(v.Args[0]) { + return true + } + + // Check if the signed poset can prove that the value is >= 0 + return ft.orderS.OrderedOrEqual(ft.zero, v) +} + +// checkpoint saves the current state of known relations. +// Called when descending on a branch. +func (ft *factsTable) checkpoint() { + if ft.unsat { + ft.unsatDepth++ + } + ft.stack = append(ft.stack, checkpointFact) + ft.limitStack = append(ft.limitStack, checkpointBound) + ft.orderS.Checkpoint() + ft.orderU.Checkpoint() +} + +// restore restores known relation to the state just +// before the previous checkpoint. +// Called when backing up on a branch. +func (ft *factsTable) restore() { + if ft.unsatDepth > 0 { + ft.unsatDepth-- + } else { + ft.unsat = false + } + for { + old := ft.stack[len(ft.stack)-1] + ft.stack = ft.stack[:len(ft.stack)-1] + if old == checkpointFact { + break + } + if old.r == lt|eq|gt { + delete(ft.facts, old.p) + } else { + ft.facts[old.p] = old.r + } + } + for { + old := ft.limitStack[len(ft.limitStack)-1] + ft.limitStack = ft.limitStack[:len(ft.limitStack)-1] + if old.vid == 0 { // checkpointBound + break + } + if old.limit == noLimit { + delete(ft.limits, old.vid) + } else { + ft.limits[old.vid] = old.limit + } + } + ft.orderS.Undo() + ft.orderU.Undo() +} + +func lessByID(v, w *Value) bool { + if v == nil && w == nil { + // Should not happen, but just in case. + return false + } + if v == nil { + return true + } + return w != nil && v.ID < w.ID +} + +var ( + reverseBits = [...]relation{0, 4, 2, 6, 1, 5, 3, 7} + + // maps what we learn when the positive branch is taken. + // For example: + // OpLess8: {signed, lt}, + // v1 = (OpLess8 v2 v3). + // If v1 branch is taken then we learn that the rangeMask + // can be at most lt. + domainRelationTable = map[Op]struct { + d domain + r relation + }{ + OpEq8: {signed | unsigned, eq}, + OpEq16: {signed | unsigned, eq}, + OpEq32: {signed | unsigned, eq}, + OpEq64: {signed | unsigned, eq}, + OpEqPtr: {pointer, eq}, + + OpNeq8: {signed | unsigned, lt | gt}, + OpNeq16: {signed | unsigned, lt | gt}, + OpNeq32: {signed | unsigned, lt | gt}, + OpNeq64: {signed | unsigned, lt | gt}, + OpNeqPtr: {pointer, lt | gt}, + + OpLess8: {signed, lt}, + OpLess8U: {unsigned, lt}, + OpLess16: {signed, lt}, + OpLess16U: {unsigned, lt}, + OpLess32: {signed, lt}, + OpLess32U: {unsigned, lt}, + OpLess64: {signed, lt}, + OpLess64U: {unsigned, lt}, + + OpLeq8: {signed, lt | eq}, + OpLeq8U: {unsigned, lt | eq}, + OpLeq16: {signed, lt | eq}, + OpLeq16U: {unsigned, lt | eq}, + OpLeq32: {signed, lt | eq}, + OpLeq32U: {unsigned, lt | eq}, + OpLeq64: {signed, lt | eq}, + OpLeq64U: {unsigned, lt | eq}, + + // For these ops, the negative branch is different: we can only + // prove signed/GE (signed/GT) if we can prove that arg0 is non-negative. + // See the special case in addBranchRestrictions. + OpIsInBounds: {signed | unsigned, lt}, // 0 <= arg0 < arg1 + OpIsSliceInBounds: {signed | unsigned, lt | eq}, // 0 <= arg0 <= arg1 + } +) + +// prove removes redundant BlockIf branches that can be inferred +// from previous dominating comparisons. +// +// By far, the most common redundant pair are generated by bounds checking. +// For example for the code: +// +// a[i] = 4 +// foo(a[i]) +// +// The compiler will generate the following code: +// +// if i >= len(a) { +// panic("not in bounds") +// } +// a[i] = 4 +// if i >= len(a) { +// panic("not in bounds") +// } +// foo(a[i]) +// +// The second comparison i >= len(a) is clearly redundant because if the +// else branch of the first comparison is executed, we already know that i < len(a). +// The code for the second panic can be removed. +// +// prove works by finding contradictions and trimming branches whose +// conditions are unsatisfiable given the branches leading up to them. +// It tracks a "fact table" of branch conditions. For each branching +// block, it asserts the branch conditions that uniquely dominate that +// block, and then separately asserts the block's branch condition and +// its negation. If either leads to a contradiction, it can trim that +// successor. +func prove(f *Func) { + ft := newFactsTable(f) + ft.checkpoint() + + var lensVars map[*Block][]*Value + + // Find length and capacity ops. + for _, b := range f.Blocks { + for _, v := range b.Values { + if v.Uses == 0 { + // We don't care about dead values. + // (There can be some that are CSEd but not removed yet.) + continue + } + switch v.Op { + case OpStringLen: + ft.update(b, v, ft.zero, signed, gt|eq) + case OpSliceLen: + if ft.lens == nil { + ft.lens = map[ID]*Value{} + } + ft.lens[v.Args[0].ID] = v + ft.update(b, v, ft.zero, signed, gt|eq) + if v.Args[0].Op == OpSliceMake { + if lensVars == nil { + lensVars = make(map[*Block][]*Value) + } + lensVars[b] = append(lensVars[b], v) + } + case OpSliceCap: + if ft.caps == nil { + ft.caps = map[ID]*Value{} + } + ft.caps[v.Args[0].ID] = v + ft.update(b, v, ft.zero, signed, gt|eq) + if v.Args[0].Op == OpSliceMake { + if lensVars == nil { + lensVars = make(map[*Block][]*Value) + } + lensVars[b] = append(lensVars[b], v) + } + } + } + } + + // Find induction variables. Currently, findIndVars + // is limited to one induction variable per block. + var indVars map[*Block]indVar + for _, v := range findIndVar(f) { + if indVars == nil { + indVars = make(map[*Block]indVar) + } + indVars[v.entry] = v + } + + // current node state + type walkState int + const ( + descend walkState = iota + simplify + ) + // work maintains the DFS stack. + type bp struct { + block *Block // current handled block + state walkState // what's to do + } + work := make([]bp, 0, 256) + work = append(work, bp{ + block: f.Entry, + state: descend, + }) + + idom := f.Idom() + sdom := f.Sdom() + + // DFS on the dominator tree. + // + // For efficiency, we consider only the dominator tree rather + // than the entire flow graph. On the way down, we consider + // incoming branches and accumulate conditions that uniquely + // dominate the current block. If we discover a contradiction, + // we can eliminate the entire block and all of its children. + // On the way back up, we consider outgoing branches that + // haven't already been considered. This way we consider each + // branch condition only once. + for len(work) > 0 { + node := work[len(work)-1] + work = work[:len(work)-1] + parent := idom[node.block.ID] + branch := getBranch(sdom, parent, node.block) + + switch node.state { + case descend: + ft.checkpoint() + + // Entering the block, add the block-depending facts that we collected + // at the beginning: induction variables and lens/caps of slices. + if iv, ok := indVars[node.block]; ok { + addIndVarRestrictions(ft, parent, iv) + } + if lens, ok := lensVars[node.block]; ok { + for _, v := range lens { + switch v.Op { + case OpSliceLen: + ft.update(node.block, v, v.Args[0].Args[1], signed, eq) + case OpSliceCap: + ft.update(node.block, v, v.Args[0].Args[2], signed, eq) + } + } + } + + if branch != unknown { + addBranchRestrictions(ft, parent, branch) + if ft.unsat { + // node.block is unreachable. + // Remove it and don't visit + // its children. + removeBranch(parent, branch) + ft.restore() + break + } + // Otherwise, we can now commit to + // taking this branch. We'll restore + // ft when we unwind. + } + + // Add inductive facts for phis in this block. + addLocalInductiveFacts(ft, node.block) + + work = append(work, bp{ + block: node.block, + state: simplify, + }) + for s := sdom.Child(node.block); s != nil; s = sdom.Sibling(s) { + work = append(work, bp{ + block: s, + state: descend, + }) + } + + case simplify: + simplifyBlock(sdom, ft, node.block) + ft.restore() + } + } + + ft.restore() + + // Return the posets to the free list + for _, po := range []*poset{ft.orderS, ft.orderU} { + // Make sure it's empty as it should be. A non-empty poset + // might cause errors and miscompilations if reused. + if checkEnabled { + if err := po.CheckEmpty(); err != nil { + f.Fatalf("prove poset not empty after function %s: %v", f.Name, err) + } + } + f.retPoset(po) + } +} + +// getBranch returns the range restrictions added by p +// when reaching b. p is the immediate dominator of b. +func getBranch(sdom SparseTree, p *Block, b *Block) branch { + if p == nil || p.Kind != BlockIf { + return unknown + } + // If p and p.Succs[0] are dominators it means that every path + // from entry to b passes through p and p.Succs[0]. We care that + // no path from entry to b passes through p.Succs[1]. If p.Succs[0] + // has one predecessor then (apart from the degenerate case), + // there is no path from entry that can reach b through p.Succs[1]. + // TODO: how about p->yes->b->yes, i.e. a loop in yes. + if sdom.IsAncestorEq(p.Succs[0].b, b) && len(p.Succs[0].b.Preds) == 1 { + return positive + } + if sdom.IsAncestorEq(p.Succs[1].b, b) && len(p.Succs[1].b.Preds) == 1 { + return negative + } + return unknown +} + +// addIndVarRestrictions updates the factsTables ft with the facts +// learned from the induction variable indVar which drives the loop +// starting in Block b. +func addIndVarRestrictions(ft *factsTable, b *Block, iv indVar) { + d := signed + if ft.isNonNegative(iv.min) && ft.isNonNegative(iv.max) { + d |= unsigned + } + + if iv.flags&indVarMinExc == 0 { + addRestrictions(b, ft, d, iv.min, iv.ind, lt|eq) + } else { + addRestrictions(b, ft, d, iv.min, iv.ind, lt) + } + + if iv.flags&indVarMaxInc == 0 { + addRestrictions(b, ft, d, iv.ind, iv.max, lt) + } else { + addRestrictions(b, ft, d, iv.ind, iv.max, lt|eq) + } +} + +// addBranchRestrictions updates the factsTables ft with the facts learned when +// branching from Block b in direction br. +func addBranchRestrictions(ft *factsTable, b *Block, br branch) { + c := b.Controls[0] + switch br { + case negative: + addRestrictions(b, ft, boolean, nil, c, eq) + case positive: + addRestrictions(b, ft, boolean, nil, c, lt|gt) + default: + panic("unknown branch") + } + if tr, has := domainRelationTable[c.Op]; has { + // When we branched from parent we learned a new set of + // restrictions. Update the factsTable accordingly. + d := tr.d + if d == signed && ft.isNonNegative(c.Args[0]) && ft.isNonNegative(c.Args[1]) { + d |= unsigned + } + switch c.Op { + case OpIsInBounds, OpIsSliceInBounds: + // 0 <= a0 < a1 (or 0 <= a0 <= a1) + // + // On the positive branch, we learn: + // signed: 0 <= a0 < a1 (or 0 <= a0 <= a1) + // unsigned: a0 < a1 (or a0 <= a1) + // + // On the negative branch, we learn (0 > a0 || + // a0 >= a1). In the unsigned domain, this is + // simply a0 >= a1 (which is the reverse of the + // positive branch, so nothing surprising). + // But in the signed domain, we can't express the || + // condition, so check if a0 is non-negative instead, + // to be able to learn something. + switch br { + case negative: + d = unsigned + if ft.isNonNegative(c.Args[0]) { + d |= signed + } + addRestrictions(b, ft, d, c.Args[0], c.Args[1], tr.r^(lt|gt|eq)) + case positive: + addRestrictions(b, ft, signed, ft.zero, c.Args[0], lt|eq) + addRestrictions(b, ft, d, c.Args[0], c.Args[1], tr.r) + } + default: + switch br { + case negative: + addRestrictions(b, ft, d, c.Args[0], c.Args[1], tr.r^(lt|gt|eq)) + case positive: + addRestrictions(b, ft, d, c.Args[0], c.Args[1], tr.r) + } + } + + } +} + +// addRestrictions updates restrictions from the immediate +// dominating block (p) using r. +func addRestrictions(parent *Block, ft *factsTable, t domain, v, w *Value, r relation) { + if t == 0 { + // Trivial case: nothing to do. + // Shoult not happen, but just in case. + return + } + for i := domain(1); i <= t; i <<= 1 { + if t&i == 0 { + continue + } + ft.update(parent, v, w, i, r) + } +} + +// addLocalInductiveFacts adds inductive facts when visiting b, where +// b is a join point in a loop. In contrast with findIndVar, this +// depends on facts established for b, which is why it happens when +// visiting b. addLocalInductiveFacts specifically targets the pattern +// created by OFORUNTIL, which isn't detected by findIndVar. +// +// TODO: It would be nice to combine this with findIndVar. +func addLocalInductiveFacts(ft *factsTable, b *Block) { + // This looks for a specific pattern of induction: + // + // 1. i1 = OpPhi(min, i2) in b + // 2. i2 = i1 + 1 + // 3. i2 < max at exit from b.Preds[1] + // 4. min < max + // + // If all of these conditions are true, then i1 < max and i1 >= min. + + // To ensure this is a loop header node. + if len(b.Preds) != 2 { + return + } + + for _, i1 := range b.Values { + if i1.Op != OpPhi { + continue + } + + // Check for conditions 1 and 2. This is easy to do + // and will throw out most phis. + min, i2 := i1.Args[0], i1.Args[1] + if i1q, delta := isConstDelta(i2); i1q != i1 || delta != 1 { + continue + } + + // Try to prove condition 3. We can't just query the + // fact table for this because we don't know what the + // facts of b.Preds[1] are (in general, b.Preds[1] is + // a loop-back edge, so we haven't even been there + // yet). As a conservative approximation, we look for + // this condition in the predecessor chain until we + // hit a join point. + uniquePred := func(b *Block) *Block { + if len(b.Preds) == 1 { + return b.Preds[0].b + } + return nil + } + pred, child := b.Preds[1].b, b + for ; pred != nil; pred, child = uniquePred(pred), pred { + if pred.Kind != BlockIf { + continue + } + control := pred.Controls[0] + + br := unknown + if pred.Succs[0].b == child { + br = positive + } + if pred.Succs[1].b == child { + if br != unknown { + continue + } + br = negative + } + if br == unknown { + continue + } + + tr, has := domainRelationTable[control.Op] + if !has { + continue + } + r := tr.r + if br == negative { + // Negative branch taken to reach b. + // Complement the relations. + r = (lt | eq | gt) ^ r + } + + // Check for i2 < max or max > i2. + var max *Value + if r == lt && control.Args[0] == i2 { + max = control.Args[1] + } else if r == gt && control.Args[1] == i2 { + max = control.Args[0] + } else { + continue + } + + // Check condition 4 now that we have a + // candidate max. For this we can query the + // fact table. We "prove" min < max by showing + // that min >= max is unsat. (This may simply + // compare two constants; that's fine.) + ft.checkpoint() + ft.update(b, min, max, tr.d, gt|eq) + proved := ft.unsat + ft.restore() + + if proved { + // We know that min <= i1 < max. + if b.Func.pass.debug > 0 { + printIndVar(b, i1, min, max, 1, 0) + } + ft.update(b, min, i1, tr.d, lt|eq) + ft.update(b, i1, max, tr.d, lt) + } + } + } +} + +var ctzNonZeroOp = map[Op]Op{OpCtz8: OpCtz8NonZero, OpCtz16: OpCtz16NonZero, OpCtz32: OpCtz32NonZero, OpCtz64: OpCtz64NonZero} +var mostNegativeDividend = map[Op]int64{ + OpDiv16: -1 << 15, + OpMod16: -1 << 15, + OpDiv32: -1 << 31, + OpMod32: -1 << 31, + OpDiv64: -1 << 63, + OpMod64: -1 << 63} + +// simplifyBlock simplifies some constant values in b and evaluates +// branches to non-uniquely dominated successors of b. +func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) { + for _, v := range b.Values { + switch v.Op { + case OpSlicemask: + // Replace OpSlicemask operations in b with constants where possible. + x, delta := isConstDelta(v.Args[0]) + if x == nil { + continue + } + // slicemask(x + y) + // if x is larger than -y (y is negative), then slicemask is -1. + lim, ok := ft.limits[x.ID] + if !ok { + continue + } + if lim.umin > uint64(-delta) { + if v.Args[0].Op == OpAdd64 { + v.reset(OpConst64) + } else { + v.reset(OpConst32) + } + if b.Func.pass.debug > 0 { + b.Func.Warnl(v.Pos, "Proved slicemask not needed") + } + v.AuxInt = -1 + } + case OpCtz8, OpCtz16, OpCtz32, OpCtz64: + // On some architectures, notably amd64, we can generate much better + // code for CtzNN if we know that the argument is non-zero. + // Capture that information here for use in arch-specific optimizations. + x := v.Args[0] + lim, ok := ft.limits[x.ID] + if !ok { + continue + } + if lim.umin > 0 || lim.min > 0 || lim.max < 0 { + if b.Func.pass.debug > 0 { + b.Func.Warnl(v.Pos, "Proved %v non-zero", v.Op) + } + v.Op = ctzNonZeroOp[v.Op] + } + case OpRsh8x8, OpRsh8x16, OpRsh8x32, OpRsh8x64, + OpRsh16x8, OpRsh16x16, OpRsh16x32, OpRsh16x64, + OpRsh32x8, OpRsh32x16, OpRsh32x32, OpRsh32x64, + OpRsh64x8, OpRsh64x16, OpRsh64x32, OpRsh64x64: + // Check whether, for a >> b, we know that a is non-negative + // and b is all of a's bits except the MSB. If so, a is shifted to zero. + bits := 8 * v.Type.Size() + if v.Args[1].isGenericIntConst() && v.Args[1].AuxInt >= bits-1 && ft.isNonNegative(v.Args[0]) { + if b.Func.pass.debug > 0 { + b.Func.Warnl(v.Pos, "Proved %v shifts to zero", v.Op) + } + switch bits { + case 64: + v.reset(OpConst64) + case 32: + v.reset(OpConst32) + case 16: + v.reset(OpConst16) + case 8: + v.reset(OpConst8) + default: + panic("unexpected integer size") + } + v.AuxInt = 0 + continue // Be sure not to fallthrough - this is no longer OpRsh. + } + // If the Rsh hasn't been replaced with 0, still check if it is bounded. + fallthrough + case OpLsh8x8, OpLsh8x16, OpLsh8x32, OpLsh8x64, + OpLsh16x8, OpLsh16x16, OpLsh16x32, OpLsh16x64, + OpLsh32x8, OpLsh32x16, OpLsh32x32, OpLsh32x64, + OpLsh64x8, OpLsh64x16, OpLsh64x32, OpLsh64x64, + OpRsh8Ux8, OpRsh8Ux16, OpRsh8Ux32, OpRsh8Ux64, + OpRsh16Ux8, OpRsh16Ux16, OpRsh16Ux32, OpRsh16Ux64, + OpRsh32Ux8, OpRsh32Ux16, OpRsh32Ux32, OpRsh32Ux64, + OpRsh64Ux8, OpRsh64Ux16, OpRsh64Ux32, OpRsh64Ux64: + // Check whether, for a << b, we know that b + // is strictly less than the number of bits in a. + by := v.Args[1] + lim, ok := ft.limits[by.ID] + if !ok { + continue + } + bits := 8 * v.Args[0].Type.Size() + if lim.umax < uint64(bits) || (lim.max < bits && ft.isNonNegative(by)) { + v.AuxInt = 1 // see shiftIsBounded + if b.Func.pass.debug > 0 { + b.Func.Warnl(v.Pos, "Proved %v bounded", v.Op) + } + } + case OpDiv16, OpDiv32, OpDiv64, OpMod16, OpMod32, OpMod64: + // On amd64 and 386 fix-up code can be avoided if we know + // the divisor is not -1 or the dividend > MinIntNN. + // Don't modify AuxInt on other architectures, + // as that can interfere with CSE. + // TODO: add other architectures? + if b.Func.Config.arch != "386" && b.Func.Config.arch != "amd64" { + break + } + divr := v.Args[1] + divrLim, divrLimok := ft.limits[divr.ID] + divd := v.Args[0] + divdLim, divdLimok := ft.limits[divd.ID] + if (divrLimok && (divrLim.max < -1 || divrLim.min > -1)) || + (divdLimok && divdLim.min > mostNegativeDividend[v.Op]) { + // See DivisionNeedsFixUp in rewrite.go. + // v.AuxInt = 1 means we have proved both that the divisor is not -1 + // and that the dividend is not the most negative integer, + // so we do not need to add fix-up code. + v.AuxInt = 1 + if b.Func.pass.debug > 0 { + b.Func.Warnl(v.Pos, "Proved %v does not need fix-up", v.Op) + } + } + } + } + + if b.Kind != BlockIf { + return + } + + // Consider outgoing edges from this block. + parent := b + for i, branch := range [...]branch{positive, negative} { + child := parent.Succs[i].b + if getBranch(sdom, parent, child) != unknown { + // For edges to uniquely dominated blocks, we + // already did this when we visited the child. + continue + } + // For edges to other blocks, this can trim a branch + // even if we couldn't get rid of the child itself. + ft.checkpoint() + addBranchRestrictions(ft, parent, branch) + unsat := ft.unsat + ft.restore() + if unsat { + // This branch is impossible, so remove it + // from the block. + removeBranch(parent, branch) + // No point in considering the other branch. + // (It *is* possible for both to be + // unsatisfiable since the fact table is + // incomplete. We could turn this into a + // BlockExit, but it doesn't seem worth it.) + break + } + } +} + +func removeBranch(b *Block, branch branch) { + c := b.Controls[0] + if b.Func.pass.debug > 0 { + verb := "Proved" + if branch == positive { + verb = "Disproved" + } + if b.Func.pass.debug > 1 { + b.Func.Warnl(b.Pos, "%s %s (%s)", verb, c.Op, c) + } else { + b.Func.Warnl(b.Pos, "%s %s", verb, c.Op) + } + } + if c != nil && c.Pos.IsStmt() == src.PosIsStmt && c.Pos.SameFileAndLine(b.Pos) { + // attempt to preserve statement marker. + b.Pos = b.Pos.WithIsStmt() + } + b.Kind = BlockFirst + b.ResetControls() + if branch == positive { + b.swapSuccessors() + } +} + +// isNonNegative reports whether v is known to be greater or equal to zero. +func isNonNegative(v *Value) bool { + if !v.Type.IsInteger() { + v.Fatalf("isNonNegative bad type: %v", v.Type) + } + // TODO: return true if !v.Type.IsSigned() + // SSA isn't type-safe enough to do that now (issue 37753). + // The checks below depend only on the pattern of bits. + + switch v.Op { + case OpConst64: + return v.AuxInt >= 0 + + case OpConst32: + return int32(v.AuxInt) >= 0 + + case OpConst16: + return int16(v.AuxInt) >= 0 + + case OpConst8: + return int8(v.AuxInt) >= 0 + + case OpStringLen, OpSliceLen, OpSliceCap, + OpZeroExt8to64, OpZeroExt16to64, OpZeroExt32to64, + OpZeroExt8to32, OpZeroExt16to32, OpZeroExt8to16, + OpCtz64, OpCtz32, OpCtz16, OpCtz8: + return true + + case OpRsh64Ux64, OpRsh32Ux64: + by := v.Args[1] + return by.Op == OpConst64 && by.AuxInt > 0 + + case OpRsh64x64, OpRsh32x64, OpRsh8x64, OpRsh16x64, OpRsh32x32, OpRsh64x32, + OpSignExt32to64, OpSignExt16to64, OpSignExt8to64, OpSignExt16to32, OpSignExt8to32: + return isNonNegative(v.Args[0]) + + case OpAnd64, OpAnd32, OpAnd16, OpAnd8: + return isNonNegative(v.Args[0]) || isNonNegative(v.Args[1]) + + case OpMod64, OpMod32, OpMod16, OpMod8, + OpDiv64, OpDiv32, OpDiv16, OpDiv8, + OpOr64, OpOr32, OpOr16, OpOr8, + OpXor64, OpXor32, OpXor16, OpXor8: + return isNonNegative(v.Args[0]) && isNonNegative(v.Args[1]) + + // We could handle OpPhi here, but the improvements from doing + // so are very minor, and it is neither simple nor cheap. + } + return false +} + +// isConstDelta returns non-nil if v is equivalent to w+delta (signed). +func isConstDelta(v *Value) (w *Value, delta int64) { + cop := OpConst64 + switch v.Op { + case OpAdd32, OpSub32: + cop = OpConst32 + } + switch v.Op { + case OpAdd64, OpAdd32: + if v.Args[0].Op == cop { + return v.Args[1], v.Args[0].AuxInt + } + if v.Args[1].Op == cop { + return v.Args[0], v.Args[1].AuxInt + } + case OpSub64, OpSub32: + if v.Args[1].Op == cop { + aux := v.Args[1].AuxInt + if aux != -aux { // Overflow; too bad + return v.Args[0], -aux + } + } + } + return nil, 0 +} + +// isCleanExt reports whether v is the result of a value-preserving +// sign or zero extension +func isCleanExt(v *Value) bool { + switch v.Op { + case OpSignExt8to16, OpSignExt8to32, OpSignExt8to64, + OpSignExt16to32, OpSignExt16to64, OpSignExt32to64: + // signed -> signed is the only value-preserving sign extension + return v.Args[0].Type.IsSigned() && v.Type.IsSigned() + + case OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, + OpZeroExt16to32, OpZeroExt16to64, OpZeroExt32to64: + // unsigned -> signed/unsigned are value-preserving zero extensions + return !v.Args[0].Type.IsSigned() + } + return false +} |