summaryrefslogtreecommitdiffstats
path: root/src/cmd/compile/internal/walk/compare.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/cmd/compile/internal/walk/compare.go')
-rw-r--r--src/cmd/compile/internal/walk/compare.go537
1 files changed, 537 insertions, 0 deletions
diff --git a/src/cmd/compile/internal/walk/compare.go b/src/cmd/compile/internal/walk/compare.go
new file mode 100644
index 0000000..fe9c5d8
--- /dev/null
+++ b/src/cmd/compile/internal/walk/compare.go
@@ -0,0 +1,537 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package walk
+
+import (
+ "encoding/binary"
+ "fmt"
+ "go/constant"
+ "hash/fnv"
+ "io"
+
+ "cmd/compile/internal/base"
+ "cmd/compile/internal/compare"
+ "cmd/compile/internal/ir"
+ "cmd/compile/internal/reflectdata"
+ "cmd/compile/internal/ssagen"
+ "cmd/compile/internal/typecheck"
+ "cmd/compile/internal/types"
+)
+
+func fakePC(n ir.Node) ir.Node {
+ // In order to get deterministic IDs, we include the package path, absolute filename, line number, column number
+ // in the calculation of the fakePC for the IR node.
+ hash := fnv.New32()
+ // We ignore the errors here because the `io.Writer` in the `hash.Hash` interface never returns an error.
+ io.WriteString(hash, base.Ctxt.Pkgpath)
+ io.WriteString(hash, base.Ctxt.PosTable.Pos(n.Pos()).AbsFilename())
+ binary.Write(hash, binary.LittleEndian, int64(n.Pos().Line()))
+ binary.Write(hash, binary.LittleEndian, int64(n.Pos().Col()))
+ // We also include the string representation of the node to distinguish autogenerated expression since
+ // those get the same `src.XPos`
+ io.WriteString(hash, fmt.Sprintf("%v", n))
+
+ return ir.NewInt(int64(hash.Sum32()))
+}
+
+// The result of walkCompare MUST be assigned back to n, e.g.
+//
+// n.Left = walkCompare(n.Left, init)
+func walkCompare(n *ir.BinaryExpr, init *ir.Nodes) ir.Node {
+ if n.X.Type().IsInterface() && n.Y.Type().IsInterface() && n.X.Op() != ir.ONIL && n.Y.Op() != ir.ONIL {
+ return walkCompareInterface(n, init)
+ }
+
+ if n.X.Type().IsString() && n.Y.Type().IsString() {
+ return walkCompareString(n, init)
+ }
+
+ n.X = walkExpr(n.X, init)
+ n.Y = walkExpr(n.Y, init)
+
+ // Given mixed interface/concrete comparison,
+ // rewrite into types-equal && data-equal.
+ // This is efficient, avoids allocations, and avoids runtime calls.
+ //
+ // TODO(mdempsky): It would be more general and probably overall
+ // simpler to just extend walkCompareInterface to optimize when one
+ // operand is an OCONVIFACE.
+ if n.X.Type().IsInterface() != n.Y.Type().IsInterface() {
+ // Preserve side-effects in case of short-circuiting; see #32187.
+ l := cheapExpr(n.X, init)
+ r := cheapExpr(n.Y, init)
+ // Swap so that l is the interface value and r is the concrete value.
+ if n.Y.Type().IsInterface() {
+ l, r = r, l
+ }
+
+ // Handle both == and !=.
+ eq := n.Op()
+ andor := ir.OOROR
+ if eq == ir.OEQ {
+ andor = ir.OANDAND
+ }
+ // Check for types equal.
+ // For empty interface, this is:
+ // l.tab == type(r)
+ // For non-empty interface, this is:
+ // l.tab != nil && l.tab._type == type(r)
+ //
+ // TODO(mdempsky): For non-empty interface comparisons, just
+ // compare against the itab address directly?
+ var eqtype ir.Node
+ tab := ir.NewUnaryExpr(base.Pos, ir.OITAB, l)
+ rtyp := reflectdata.CompareRType(base.Pos, n)
+ if l.Type().IsEmptyInterface() {
+ tab.SetType(types.NewPtr(types.Types[types.TUINT8]))
+ tab.SetTypecheck(1)
+ eqtype = ir.NewBinaryExpr(base.Pos, eq, tab, rtyp)
+ } else {
+ nonnil := ir.NewBinaryExpr(base.Pos, brcom(eq), typecheck.NodNil(), tab)
+ match := ir.NewBinaryExpr(base.Pos, eq, itabType(tab), rtyp)
+ eqtype = ir.NewLogicalExpr(base.Pos, andor, nonnil, match)
+ }
+ // Check for data equal.
+ eqdata := ir.NewBinaryExpr(base.Pos, eq, ifaceData(n.Pos(), l, r.Type()), r)
+ // Put it all together.
+ expr := ir.NewLogicalExpr(base.Pos, andor, eqtype, eqdata)
+ return finishCompare(n, expr, init)
+ }
+
+ // Must be comparison of array or struct.
+ // Otherwise back end handles it.
+ // While we're here, decide whether to
+ // inline or call an eq alg.
+ t := n.X.Type()
+ var inline bool
+
+ maxcmpsize := int64(4)
+ unalignedLoad := ssagen.Arch.LinkArch.CanMergeLoads
+ if unalignedLoad {
+ // Keep this low enough to generate less code than a function call.
+ maxcmpsize = 2 * int64(ssagen.Arch.LinkArch.RegSize)
+ }
+
+ switch t.Kind() {
+ default:
+ if base.Debug.Libfuzzer != 0 && t.IsInteger() && (n.X.Name() == nil || !n.X.Name().Libfuzzer8BitCounter()) {
+ n.X = cheapExpr(n.X, init)
+ n.Y = cheapExpr(n.Y, init)
+
+ // If exactly one comparison operand is
+ // constant, invoke the constcmp functions
+ // instead, and arrange for the constant
+ // operand to be the first argument.
+ l, r := n.X, n.Y
+ if r.Op() == ir.OLITERAL {
+ l, r = r, l
+ }
+ constcmp := l.Op() == ir.OLITERAL && r.Op() != ir.OLITERAL
+
+ var fn string
+ var paramType *types.Type
+ switch t.Size() {
+ case 1:
+ fn = "libfuzzerTraceCmp1"
+ if constcmp {
+ fn = "libfuzzerTraceConstCmp1"
+ }
+ paramType = types.Types[types.TUINT8]
+ case 2:
+ fn = "libfuzzerTraceCmp2"
+ if constcmp {
+ fn = "libfuzzerTraceConstCmp2"
+ }
+ paramType = types.Types[types.TUINT16]
+ case 4:
+ fn = "libfuzzerTraceCmp4"
+ if constcmp {
+ fn = "libfuzzerTraceConstCmp4"
+ }
+ paramType = types.Types[types.TUINT32]
+ case 8:
+ fn = "libfuzzerTraceCmp8"
+ if constcmp {
+ fn = "libfuzzerTraceConstCmp8"
+ }
+ paramType = types.Types[types.TUINT64]
+ default:
+ base.Fatalf("unexpected integer size %d for %v", t.Size(), t)
+ }
+ init.Append(mkcall(fn, nil, init, tracecmpArg(l, paramType, init), tracecmpArg(r, paramType, init), fakePC(n)))
+ }
+ return n
+ case types.TARRAY:
+ // We can compare several elements at once with 2/4/8 byte integer compares
+ inline = t.NumElem() <= 1 || (types.IsSimple[t.Elem().Kind()] && (t.NumElem() <= 4 || t.Elem().Size()*t.NumElem() <= maxcmpsize))
+ case types.TSTRUCT:
+ inline = compare.EqStructCost(t) <= 4
+ }
+
+ cmpl := n.X
+ for cmpl != nil && cmpl.Op() == ir.OCONVNOP {
+ cmpl = cmpl.(*ir.ConvExpr).X
+ }
+ cmpr := n.Y
+ for cmpr != nil && cmpr.Op() == ir.OCONVNOP {
+ cmpr = cmpr.(*ir.ConvExpr).X
+ }
+
+ // Chose not to inline. Call equality function directly.
+ if !inline {
+ // eq algs take pointers; cmpl and cmpr must be addressable
+ if !ir.IsAddressable(cmpl) || !ir.IsAddressable(cmpr) {
+ base.Fatalf("arguments of comparison must be lvalues - %v %v", cmpl, cmpr)
+ }
+
+ fn, needsize := eqFor(t)
+ call := ir.NewCallExpr(base.Pos, ir.OCALL, fn, nil)
+ call.Args.Append(typecheck.NodAddr(cmpl))
+ call.Args.Append(typecheck.NodAddr(cmpr))
+ if needsize {
+ call.Args.Append(ir.NewInt(t.Size()))
+ }
+ res := ir.Node(call)
+ if n.Op() != ir.OEQ {
+ res = ir.NewUnaryExpr(base.Pos, ir.ONOT, res)
+ }
+ return finishCompare(n, res, init)
+ }
+
+ // inline: build boolean expression comparing element by element
+ andor := ir.OANDAND
+ if n.Op() == ir.ONE {
+ andor = ir.OOROR
+ }
+ var expr ir.Node
+ comp := func(el, er ir.Node) {
+ a := ir.NewBinaryExpr(base.Pos, n.Op(), el, er)
+ if expr == nil {
+ expr = a
+ } else {
+ expr = ir.NewLogicalExpr(base.Pos, andor, expr, a)
+ }
+ }
+ and := func(cond ir.Node) {
+ if expr == nil {
+ expr = cond
+ } else {
+ expr = ir.NewLogicalExpr(base.Pos, andor, expr, cond)
+ }
+ }
+ cmpl = safeExpr(cmpl, init)
+ cmpr = safeExpr(cmpr, init)
+ if t.IsStruct() {
+ conds := compare.EqStruct(t, cmpl, cmpr)
+ if n.Op() == ir.OEQ {
+ for _, cond := range conds {
+ and(cond)
+ }
+ } else {
+ for _, cond := range conds {
+ notCond := ir.NewUnaryExpr(base.Pos, ir.ONOT, cond)
+ and(notCond)
+ }
+ }
+ } else {
+ step := int64(1)
+ remains := t.NumElem() * t.Elem().Size()
+ combine64bit := unalignedLoad && types.RegSize == 8 && t.Elem().Size() <= 4 && t.Elem().IsInteger()
+ combine32bit := unalignedLoad && t.Elem().Size() <= 2 && t.Elem().IsInteger()
+ combine16bit := unalignedLoad && t.Elem().Size() == 1 && t.Elem().IsInteger()
+ for i := int64(0); remains > 0; {
+ var convType *types.Type
+ switch {
+ case remains >= 8 && combine64bit:
+ convType = types.Types[types.TINT64]
+ step = 8 / t.Elem().Size()
+ case remains >= 4 && combine32bit:
+ convType = types.Types[types.TUINT32]
+ step = 4 / t.Elem().Size()
+ case remains >= 2 && combine16bit:
+ convType = types.Types[types.TUINT16]
+ step = 2 / t.Elem().Size()
+ default:
+ step = 1
+ }
+ if step == 1 {
+ comp(
+ ir.NewIndexExpr(base.Pos, cmpl, ir.NewInt(i)),
+ ir.NewIndexExpr(base.Pos, cmpr, ir.NewInt(i)),
+ )
+ i++
+ remains -= t.Elem().Size()
+ } else {
+ elemType := t.Elem().ToUnsigned()
+ cmplw := ir.Node(ir.NewIndexExpr(base.Pos, cmpl, ir.NewInt(i)))
+ cmplw = typecheck.Conv(cmplw, elemType) // convert to unsigned
+ cmplw = typecheck.Conv(cmplw, convType) // widen
+ cmprw := ir.Node(ir.NewIndexExpr(base.Pos, cmpr, ir.NewInt(i)))
+ cmprw = typecheck.Conv(cmprw, elemType)
+ cmprw = typecheck.Conv(cmprw, convType)
+ // For code like this: uint32(s[0]) | uint32(s[1])<<8 | uint32(s[2])<<16 ...
+ // ssa will generate a single large load.
+ for offset := int64(1); offset < step; offset++ {
+ lb := ir.Node(ir.NewIndexExpr(base.Pos, cmpl, ir.NewInt(i+offset)))
+ lb = typecheck.Conv(lb, elemType)
+ lb = typecheck.Conv(lb, convType)
+ lb = ir.NewBinaryExpr(base.Pos, ir.OLSH, lb, ir.NewInt(8*t.Elem().Size()*offset))
+ cmplw = ir.NewBinaryExpr(base.Pos, ir.OOR, cmplw, lb)
+ rb := ir.Node(ir.NewIndexExpr(base.Pos, cmpr, ir.NewInt(i+offset)))
+ rb = typecheck.Conv(rb, elemType)
+ rb = typecheck.Conv(rb, convType)
+ rb = ir.NewBinaryExpr(base.Pos, ir.OLSH, rb, ir.NewInt(8*t.Elem().Size()*offset))
+ cmprw = ir.NewBinaryExpr(base.Pos, ir.OOR, cmprw, rb)
+ }
+ comp(cmplw, cmprw)
+ i += step
+ remains -= step * t.Elem().Size()
+ }
+ }
+ }
+ if expr == nil {
+ expr = ir.NewBool(n.Op() == ir.OEQ)
+ // We still need to use cmpl and cmpr, in case they contain
+ // an expression which might panic. See issue 23837.
+ a1 := typecheck.Stmt(ir.NewAssignStmt(base.Pos, ir.BlankNode, cmpl))
+ a2 := typecheck.Stmt(ir.NewAssignStmt(base.Pos, ir.BlankNode, cmpr))
+ init.Append(a1, a2)
+ }
+ return finishCompare(n, expr, init)
+}
+
+func walkCompareInterface(n *ir.BinaryExpr, init *ir.Nodes) ir.Node {
+ n.Y = cheapExpr(n.Y, init)
+ n.X = cheapExpr(n.X, init)
+ eqtab, eqdata := compare.EqInterface(n.X, n.Y)
+ var cmp ir.Node
+ if n.Op() == ir.OEQ {
+ cmp = ir.NewLogicalExpr(base.Pos, ir.OANDAND, eqtab, eqdata)
+ } else {
+ eqtab.SetOp(ir.ONE)
+ cmp = ir.NewLogicalExpr(base.Pos, ir.OOROR, eqtab, ir.NewUnaryExpr(base.Pos, ir.ONOT, eqdata))
+ }
+ return finishCompare(n, cmp, init)
+}
+
+func walkCompareString(n *ir.BinaryExpr, init *ir.Nodes) ir.Node {
+ if base.Debug.Libfuzzer != 0 {
+ if !ir.IsConst(n.X, constant.String) || !ir.IsConst(n.Y, constant.String) {
+ fn := "libfuzzerHookStrCmp"
+ n.X = cheapExpr(n.X, init)
+ n.Y = cheapExpr(n.Y, init)
+ paramType := types.Types[types.TSTRING]
+ init.Append(mkcall(fn, nil, init, tracecmpArg(n.X, paramType, init), tracecmpArg(n.Y, paramType, init), fakePC(n)))
+ }
+ }
+ // Rewrite comparisons to short constant strings as length+byte-wise comparisons.
+ var cs, ncs ir.Node // const string, non-const string
+ switch {
+ case ir.IsConst(n.X, constant.String) && ir.IsConst(n.Y, constant.String):
+ // ignore; will be constant evaluated
+ case ir.IsConst(n.X, constant.String):
+ cs = n.X
+ ncs = n.Y
+ case ir.IsConst(n.Y, constant.String):
+ cs = n.Y
+ ncs = n.X
+ }
+ if cs != nil {
+ cmp := n.Op()
+ // Our comparison below assumes that the non-constant string
+ // is on the left hand side, so rewrite "" cmp x to x cmp "".
+ // See issue 24817.
+ if ir.IsConst(n.X, constant.String) {
+ cmp = brrev(cmp)
+ }
+
+ // maxRewriteLen was chosen empirically.
+ // It is the value that minimizes cmd/go file size
+ // across most architectures.
+ // See the commit description for CL 26758 for details.
+ maxRewriteLen := 6
+ // Some architectures can load unaligned byte sequence as 1 word.
+ // So we can cover longer strings with the same amount of code.
+ canCombineLoads := ssagen.Arch.LinkArch.CanMergeLoads
+ combine64bit := false
+ if canCombineLoads {
+ // Keep this low enough to generate less code than a function call.
+ maxRewriteLen = 2 * ssagen.Arch.LinkArch.RegSize
+ combine64bit = ssagen.Arch.LinkArch.RegSize >= 8
+ }
+
+ var and ir.Op
+ switch cmp {
+ case ir.OEQ:
+ and = ir.OANDAND
+ case ir.ONE:
+ and = ir.OOROR
+ default:
+ // Don't do byte-wise comparisons for <, <=, etc.
+ // They're fairly complicated.
+ // Length-only checks are ok, though.
+ maxRewriteLen = 0
+ }
+ if s := ir.StringVal(cs); len(s) <= maxRewriteLen {
+ if len(s) > 0 {
+ ncs = safeExpr(ncs, init)
+ }
+ r := ir.Node(ir.NewBinaryExpr(base.Pos, cmp, ir.NewUnaryExpr(base.Pos, ir.OLEN, ncs), ir.NewInt(int64(len(s)))))
+ remains := len(s)
+ for i := 0; remains > 0; {
+ if remains == 1 || !canCombineLoads {
+ cb := ir.NewInt(int64(s[i]))
+ ncb := ir.NewIndexExpr(base.Pos, ncs, ir.NewInt(int64(i)))
+ r = ir.NewLogicalExpr(base.Pos, and, r, ir.NewBinaryExpr(base.Pos, cmp, ncb, cb))
+ remains--
+ i++
+ continue
+ }
+ var step int
+ var convType *types.Type
+ switch {
+ case remains >= 8 && combine64bit:
+ convType = types.Types[types.TINT64]
+ step = 8
+ case remains >= 4:
+ convType = types.Types[types.TUINT32]
+ step = 4
+ case remains >= 2:
+ convType = types.Types[types.TUINT16]
+ step = 2
+ }
+ ncsubstr := typecheck.Conv(ir.NewIndexExpr(base.Pos, ncs, ir.NewInt(int64(i))), convType)
+ csubstr := int64(s[i])
+ // Calculate large constant from bytes as sequence of shifts and ors.
+ // Like this: uint32(s[0]) | uint32(s[1])<<8 | uint32(s[2])<<16 ...
+ // ssa will combine this into a single large load.
+ for offset := 1; offset < step; offset++ {
+ b := typecheck.Conv(ir.NewIndexExpr(base.Pos, ncs, ir.NewInt(int64(i+offset))), convType)
+ b = ir.NewBinaryExpr(base.Pos, ir.OLSH, b, ir.NewInt(int64(8*offset)))
+ ncsubstr = ir.NewBinaryExpr(base.Pos, ir.OOR, ncsubstr, b)
+ csubstr |= int64(s[i+offset]) << uint8(8*offset)
+ }
+ csubstrPart := ir.NewInt(csubstr)
+ // Compare "step" bytes as once
+ r = ir.NewLogicalExpr(base.Pos, and, r, ir.NewBinaryExpr(base.Pos, cmp, csubstrPart, ncsubstr))
+ remains -= step
+ i += step
+ }
+ return finishCompare(n, r, init)
+ }
+ }
+
+ var r ir.Node
+ if n.Op() == ir.OEQ || n.Op() == ir.ONE {
+ // prepare for rewrite below
+ n.X = cheapExpr(n.X, init)
+ n.Y = cheapExpr(n.Y, init)
+ eqlen, eqmem := compare.EqString(n.X, n.Y)
+ // quick check of len before full compare for == or !=.
+ // memequal then tests equality up to length len.
+ if n.Op() == ir.OEQ {
+ // len(left) == len(right) && memequal(left, right, len)
+ r = ir.NewLogicalExpr(base.Pos, ir.OANDAND, eqlen, eqmem)
+ } else {
+ // len(left) != len(right) || !memequal(left, right, len)
+ eqlen.SetOp(ir.ONE)
+ r = ir.NewLogicalExpr(base.Pos, ir.OOROR, eqlen, ir.NewUnaryExpr(base.Pos, ir.ONOT, eqmem))
+ }
+ } else {
+ // sys_cmpstring(s1, s2) :: 0
+ r = mkcall("cmpstring", types.Types[types.TINT], init, typecheck.Conv(n.X, types.Types[types.TSTRING]), typecheck.Conv(n.Y, types.Types[types.TSTRING]))
+ r = ir.NewBinaryExpr(base.Pos, n.Op(), r, ir.NewInt(0))
+ }
+
+ return finishCompare(n, r, init)
+}
+
+// The result of finishCompare MUST be assigned back to n, e.g.
+//
+// n.Left = finishCompare(n.Left, x, r, init)
+func finishCompare(n *ir.BinaryExpr, r ir.Node, init *ir.Nodes) ir.Node {
+ r = typecheck.Expr(r)
+ r = typecheck.Conv(r, n.Type())
+ r = walkExpr(r, init)
+ return r
+}
+
+func eqFor(t *types.Type) (n ir.Node, needsize bool) {
+ // Should only arrive here with large memory or
+ // a struct/array containing a non-memory field/element.
+ // Small memory is handled inline, and single non-memory
+ // is handled by walkCompare.
+ switch a, _ := types.AlgType(t); a {
+ case types.AMEM:
+ n := typecheck.LookupRuntime("memequal")
+ n = typecheck.SubstArgTypes(n, t, t)
+ return n, true
+ case types.ASPECIAL:
+ sym := reflectdata.TypeSymPrefix(".eq", t)
+ // TODO(austin): This creates an ir.Name with a nil Func.
+ n := typecheck.NewName(sym)
+ ir.MarkFunc(n)
+ n.SetType(types.NewSignature(types.NoPkg, nil, nil, []*types.Field{
+ types.NewField(base.Pos, nil, types.NewPtr(t)),
+ types.NewField(base.Pos, nil, types.NewPtr(t)),
+ }, []*types.Field{
+ types.NewField(base.Pos, nil, types.Types[types.TBOOL]),
+ }))
+ return n, false
+ }
+ base.Fatalf("eqFor %v", t)
+ return nil, false
+}
+
+// brcom returns !(op).
+// For example, brcom(==) is !=.
+func brcom(op ir.Op) ir.Op {
+ switch op {
+ case ir.OEQ:
+ return ir.ONE
+ case ir.ONE:
+ return ir.OEQ
+ case ir.OLT:
+ return ir.OGE
+ case ir.OGT:
+ return ir.OLE
+ case ir.OLE:
+ return ir.OGT
+ case ir.OGE:
+ return ir.OLT
+ }
+ base.Fatalf("brcom: no com for %v\n", op)
+ return op
+}
+
+// brrev returns reverse(op).
+// For example, Brrev(<) is >.
+func brrev(op ir.Op) ir.Op {
+ switch op {
+ case ir.OEQ:
+ return ir.OEQ
+ case ir.ONE:
+ return ir.ONE
+ case ir.OLT:
+ return ir.OGT
+ case ir.OGT:
+ return ir.OLT
+ case ir.OLE:
+ return ir.OGE
+ case ir.OGE:
+ return ir.OLE
+ }
+ base.Fatalf("brrev: no rev for %v\n", op)
+ return op
+}
+
+func tracecmpArg(n ir.Node, t *types.Type, init *ir.Nodes) ir.Node {
+ // Ugly hack to avoid "constant -1 overflows uintptr" errors, etc.
+ if n.Op() == ir.OLITERAL && n.Type().IsSigned() && ir.Int64Val(n) < 0 {
+ n = copyExpr(n, n.Type(), init)
+ }
+
+ return typecheck.Conv(n, t)
+}