summaryrefslogtreecommitdiffstats
path: root/src/cmd/compile/internal/ssa/memcombine.go
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 19:19:13 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 19:19:13 +0000
commitccd992355df7192993c666236047820244914598 (patch)
treef00fea65147227b7743083c6148396f74cd66935 /src/cmd/compile/internal/ssa/memcombine.go
parentInitial commit. (diff)
downloadgolang-1.21-ccd992355df7192993c666236047820244914598.tar.xz
golang-1.21-ccd992355df7192993c666236047820244914598.zip
Adding upstream version 1.21.8.upstream/1.21.8
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/cmd/compile/internal/ssa/memcombine.go')
-rw-r--r--src/cmd/compile/internal/ssa/memcombine.go731
1 files changed, 731 insertions, 0 deletions
diff --git a/src/cmd/compile/internal/ssa/memcombine.go b/src/cmd/compile/internal/ssa/memcombine.go
new file mode 100644
index 0000000..26fb3f5
--- /dev/null
+++ b/src/cmd/compile/internal/ssa/memcombine.go
@@ -0,0 +1,731 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssa
+
+import (
+ "cmd/compile/internal/base"
+ "cmd/compile/internal/types"
+ "cmd/internal/src"
+ "sort"
+)
+
+// memcombine combines smaller loads and stores into larger ones.
+// We ensure this generates good code for encoding/binary operations.
+// It may help other cases also.
+func memcombine(f *Func) {
+ // This optimization requires that the architecture has
+ // unaligned loads and unaligned stores.
+ if !f.Config.unalignedOK {
+ return
+ }
+
+ memcombineLoads(f)
+ memcombineStores(f)
+}
+
+func memcombineLoads(f *Func) {
+ // Find "OR trees" to start with.
+ mark := f.newSparseSet(f.NumValues())
+ defer f.retSparseSet(mark)
+ var order []*Value
+
+ // Mark all values that are the argument of an OR.
+ for _, b := range f.Blocks {
+ for _, v := range b.Values {
+ if v.Op == OpOr16 || v.Op == OpOr32 || v.Op == OpOr64 {
+ mark.add(v.Args[0].ID)
+ mark.add(v.Args[1].ID)
+ }
+ }
+ }
+ for _, b := range f.Blocks {
+ order = order[:0]
+ for _, v := range b.Values {
+ if v.Op != OpOr16 && v.Op != OpOr32 && v.Op != OpOr64 {
+ continue
+ }
+ if mark.contains(v.ID) {
+ // marked - means it is not the root of an OR tree
+ continue
+ }
+ // Add the OR tree rooted at v to the order.
+ // We use BFS here, but any walk that puts roots before leaves would work.
+ i := len(order)
+ order = append(order, v)
+ for ; i < len(order); i++ {
+ x := order[i]
+ for j := 0; j < 2; j++ {
+ a := x.Args[j]
+ if a.Op == OpOr16 || a.Op == OpOr32 || a.Op == OpOr64 {
+ order = append(order, a)
+ }
+ }
+ }
+ }
+ for _, v := range order {
+ max := f.Config.RegSize
+ switch v.Op {
+ case OpOr64:
+ case OpOr32:
+ max = 4
+ case OpOr16:
+ max = 2
+ default:
+ continue
+ }
+ for n := max; n > 1; n /= 2 {
+ if combineLoads(v, n) {
+ break
+ }
+ }
+ }
+ }
+}
+
+// A BaseAddress represents the address ptr+idx, where
+// ptr is a pointer type and idx is an integer type.
+// idx may be nil, in which case it is treated as 0.
+type BaseAddress struct {
+ ptr *Value
+ idx *Value
+}
+
+// splitPtr returns the base address of ptr and any
+// constant offset from that base.
+// BaseAddress{ptr,nil},0 is always a valid result, but splitPtr
+// tries to peel away as many constants into off as possible.
+func splitPtr(ptr *Value) (BaseAddress, int64) {
+ var idx *Value
+ var off int64
+ for {
+ if ptr.Op == OpOffPtr {
+ off += ptr.AuxInt
+ ptr = ptr.Args[0]
+ } else if ptr.Op == OpAddPtr {
+ if idx != nil {
+ // We have two or more indexing values.
+ // Pick the first one we found.
+ return BaseAddress{ptr: ptr, idx: idx}, off
+ }
+ idx = ptr.Args[1]
+ if idx.Op == OpAdd32 || idx.Op == OpAdd64 {
+ if idx.Args[0].Op == OpConst32 || idx.Args[0].Op == OpConst64 {
+ off += idx.Args[0].AuxInt
+ idx = idx.Args[1]
+ } else if idx.Args[1].Op == OpConst32 || idx.Args[1].Op == OpConst64 {
+ off += idx.Args[1].AuxInt
+ idx = idx.Args[0]
+ }
+ }
+ ptr = ptr.Args[0]
+ } else {
+ return BaseAddress{ptr: ptr, idx: idx}, off
+ }
+ }
+}
+
+func combineLoads(root *Value, n int64) bool {
+ orOp := root.Op
+ var shiftOp Op
+ switch orOp {
+ case OpOr64:
+ shiftOp = OpLsh64x64
+ case OpOr32:
+ shiftOp = OpLsh32x64
+ case OpOr16:
+ shiftOp = OpLsh16x64
+ default:
+ return false
+ }
+
+ // Find n values that are ORed together with the above op.
+ a := make([]*Value, 0, 8)
+ a = append(a, root)
+ for i := 0; i < len(a) && int64(len(a)) < n; i++ {
+ v := a[i]
+ if v.Uses != 1 && v != root {
+ // Something in this subtree is used somewhere else.
+ return false
+ }
+ if v.Op == orOp {
+ a[i] = v.Args[0]
+ a = append(a, v.Args[1])
+ i--
+ }
+ }
+ if int64(len(a)) != n {
+ return false
+ }
+
+ // Check that the first entry to see what ops we're looking for.
+ // All the entries should be of the form shift(extend(load)), maybe with no shift.
+ v := a[0]
+ if v.Op == shiftOp {
+ v = v.Args[0]
+ }
+ var extOp Op
+ if orOp == OpOr64 && (v.Op == OpZeroExt8to64 || v.Op == OpZeroExt16to64 || v.Op == OpZeroExt32to64) ||
+ orOp == OpOr32 && (v.Op == OpZeroExt8to32 || v.Op == OpZeroExt16to32) ||
+ orOp == OpOr16 && v.Op == OpZeroExt8to16 {
+ extOp = v.Op
+ v = v.Args[0]
+ } else {
+ return false
+ }
+ if v.Op != OpLoad {
+ return false
+ }
+ base, _ := splitPtr(v.Args[0])
+ mem := v.Args[1]
+ size := v.Type.Size()
+
+ if root.Block.Func.Config.arch == "S390X" {
+ // s390x can't handle unaligned accesses to global variables.
+ if base.ptr.Op == OpAddr {
+ return false
+ }
+ }
+
+ // Check all the entries, extract useful info.
+ type LoadRecord struct {
+ load *Value
+ offset int64 // offset of load address from base
+ shift int64
+ }
+ r := make([]LoadRecord, n, 8)
+ for i := int64(0); i < n; i++ {
+ v := a[i]
+ if v.Uses != 1 {
+ return false
+ }
+ shift := int64(0)
+ if v.Op == shiftOp {
+ if v.Args[1].Op != OpConst64 {
+ return false
+ }
+ shift = v.Args[1].AuxInt
+ v = v.Args[0]
+ if v.Uses != 1 {
+ return false
+ }
+ }
+ if v.Op != extOp {
+ return false
+ }
+ load := v.Args[0]
+ if load.Op != OpLoad {
+ return false
+ }
+ if load.Uses != 1 {
+ return false
+ }
+ if load.Args[1] != mem {
+ return false
+ }
+ p, off := splitPtr(load.Args[0])
+ if p != base {
+ return false
+ }
+ r[i] = LoadRecord{load: load, offset: off, shift: shift}
+ }
+
+ // Sort in memory address order.
+ sort.Slice(r, func(i, j int) bool {
+ return r[i].offset < r[j].offset
+ })
+
+ // Check that we have contiguous offsets.
+ for i := int64(0); i < n; i++ {
+ if r[i].offset != r[0].offset+i*size {
+ return false
+ }
+ }
+
+ // Check for reads in little-endian or big-endian order.
+ shift0 := r[0].shift
+ isLittleEndian := true
+ for i := int64(0); i < n; i++ {
+ if r[i].shift != shift0+i*size*8 {
+ isLittleEndian = false
+ break
+ }
+ }
+ isBigEndian := true
+ for i := int64(0); i < n; i++ {
+ if r[i].shift != shift0-i*size*8 {
+ isBigEndian = false
+ break
+ }
+ }
+ if !isLittleEndian && !isBigEndian {
+ return false
+ }
+
+ // Find a place to put the new load.
+ // This is tricky, because it has to be at a point where
+ // its memory argument is live. We can't just put it in root.Block.
+ // We use the block of the latest load.
+ loads := make([]*Value, n, 8)
+ for i := int64(0); i < n; i++ {
+ loads[i] = r[i].load
+ }
+ loadBlock := mergePoint(root.Block, loads...)
+ if loadBlock == nil {
+ return false
+ }
+ // Find a source position to use.
+ pos := src.NoXPos
+ for _, load := range loads {
+ if load.Block == loadBlock {
+ pos = load.Pos
+ break
+ }
+ }
+ if pos == src.NoXPos {
+ return false
+ }
+
+ // Check to see if we need byte swap before storing.
+ needSwap := isLittleEndian && root.Block.Func.Config.BigEndian ||
+ isBigEndian && !root.Block.Func.Config.BigEndian
+ if needSwap && (size != 1 || !root.Block.Func.Config.haveByteSwap(n)) {
+ return false
+ }
+
+ // This is the commit point.
+
+ // First, issue load at lowest address.
+ v = loadBlock.NewValue2(pos, OpLoad, sizeType(n*size), r[0].load.Args[0], mem)
+
+ // Byte swap if needed,
+ if needSwap {
+ v = byteSwap(loadBlock, pos, v)
+ }
+
+ // Extend if needed.
+ if n*size < root.Type.Size() {
+ v = zeroExtend(loadBlock, pos, v, n*size, root.Type.Size())
+ }
+
+ // Shift if needed.
+ if isLittleEndian && shift0 != 0 {
+ v = leftShift(loadBlock, pos, v, shift0)
+ }
+ if isBigEndian && shift0-(n-1)*size*8 != 0 {
+ v = leftShift(loadBlock, pos, v, shift0-(n-1)*size*8)
+ }
+
+ // Install with (Copy v).
+ root.reset(OpCopy)
+ root.AddArg(v)
+
+ // Clobber the loads, just to prevent additional work being done on
+ // subtrees (which are now unreachable).
+ for i := int64(0); i < n; i++ {
+ clobber(r[i].load)
+ }
+ return true
+}
+
+func memcombineStores(f *Func) {
+ mark := f.newSparseSet(f.NumValues())
+ defer f.retSparseSet(mark)
+ var order []*Value
+
+ for _, b := range f.Blocks {
+ // Mark all stores which are not last in a store sequence.
+ mark.clear()
+ for _, v := range b.Values {
+ if v.Op == OpStore {
+ mark.add(v.MemoryArg().ID)
+ }
+ }
+
+ // pick an order for visiting stores such that
+ // later stores come earlier in the ordering.
+ order = order[:0]
+ for _, v := range b.Values {
+ if v.Op != OpStore {
+ continue
+ }
+ if mark.contains(v.ID) {
+ continue // not last in a chain of stores
+ }
+ for {
+ order = append(order, v)
+ v = v.Args[2]
+ if v.Block != b || v.Op != OpStore {
+ break
+ }
+ }
+ }
+
+ // Look for combining opportunities at each store in queue order.
+ for _, v := range order {
+ if v.Op != OpStore { // already rewritten
+ continue
+ }
+
+ size := v.Aux.(*types.Type).Size()
+ if size >= f.Config.RegSize || size == 0 {
+ continue
+ }
+
+ for n := f.Config.RegSize / size; n > 1; n /= 2 {
+ if combineStores(v, n) {
+ continue
+ }
+ }
+ }
+ }
+}
+
+// Try to combine the n stores ending in root.
+// Returns true if successful.
+func combineStores(root *Value, n int64) bool {
+ // Helper functions.
+ type StoreRecord struct {
+ store *Value
+ offset int64
+ }
+ getShiftBase := func(a []StoreRecord) *Value {
+ x := a[0].store.Args[1]
+ y := a[1].store.Args[1]
+ switch x.Op {
+ case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
+ x = x.Args[0]
+ default:
+ return nil
+ }
+ switch y.Op {
+ case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
+ y = y.Args[0]
+ default:
+ return nil
+ }
+ var x2 *Value
+ switch x.Op {
+ case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
+ x2 = x.Args[0]
+ default:
+ }
+ var y2 *Value
+ switch y.Op {
+ case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
+ y2 = y.Args[0]
+ default:
+ }
+ if y2 == x {
+ // a shift of x and x itself.
+ return x
+ }
+ if x2 == y {
+ // a shift of y and y itself.
+ return y
+ }
+ if x2 == y2 {
+ // 2 shifts both of the same argument.
+ return x2
+ }
+ return nil
+ }
+ isShiftBase := func(v, base *Value) bool {
+ val := v.Args[1]
+ switch val.Op {
+ case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
+ val = val.Args[0]
+ default:
+ return false
+ }
+ if val == base {
+ return true
+ }
+ switch val.Op {
+ case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
+ val = val.Args[0]
+ default:
+ return false
+ }
+ return val == base
+ }
+ shift := func(v, base *Value) int64 {
+ val := v.Args[1]
+ switch val.Op {
+ case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8:
+ val = val.Args[0]
+ default:
+ return -1
+ }
+ if val == base {
+ return 0
+ }
+ switch val.Op {
+ case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64:
+ val = val.Args[1]
+ default:
+ return -1
+ }
+ if val.Op != OpConst64 {
+ return -1
+ }
+ return val.AuxInt
+ }
+
+ // Element size of the individual stores.
+ size := root.Aux.(*types.Type).Size()
+ if size*n > root.Block.Func.Config.RegSize {
+ return false
+ }
+
+ // Gather n stores to look at. Check easy conditions we require.
+ a := make([]StoreRecord, 0, 8)
+ rbase, roff := splitPtr(root.Args[0])
+ if root.Block.Func.Config.arch == "S390X" {
+ // s390x can't handle unaligned accesses to global variables.
+ if rbase.ptr.Op == OpAddr {
+ return false
+ }
+ }
+ a = append(a, StoreRecord{root, roff})
+ for i, x := int64(1), root.Args[2]; i < n; i, x = i+1, x.Args[2] {
+ if x.Op != OpStore {
+ return false
+ }
+ if x.Block != root.Block {
+ return false
+ }
+ if x.Uses != 1 { // Note: root can have more than one use.
+ return false
+ }
+ if x.Aux.(*types.Type).Size() != size {
+ return false
+ }
+ base, off := splitPtr(x.Args[0])
+ if base != rbase {
+ return false
+ }
+ a = append(a, StoreRecord{x, off})
+ }
+ // Before we sort, grab the memory arg the result should have.
+ mem := a[n-1].store.Args[2]
+
+ // Sort stores in increasing address order.
+ sort.Slice(a, func(i, j int) bool {
+ return a[i].offset < a[j].offset
+ })
+
+ // Check that everything is written to sequential locations.
+ for i := int64(0); i < n; i++ {
+ if a[i].offset != a[0].offset+i*size {
+ return false
+ }
+ }
+
+ // Memory location we're going to write at (the lowest one).
+ ptr := a[0].store.Args[0]
+
+ // Check for constant stores
+ isConst := true
+ for i := int64(0); i < n; i++ {
+ switch a[i].store.Args[1].Op {
+ case OpConst32, OpConst16, OpConst8:
+ default:
+ isConst = false
+ break
+ }
+ }
+ if isConst {
+ // Modify root to do all the stores.
+ var c int64
+ mask := int64(1)<<(8*size) - 1
+ for i := int64(0); i < n; i++ {
+ s := 8 * size * int64(i)
+ if root.Block.Func.Config.BigEndian {
+ s = 8*size*(n-1) - s
+ }
+ c |= (a[i].store.Args[1].AuxInt & mask) << s
+ }
+ var cv *Value
+ switch size * n {
+ case 2:
+ cv = root.Block.Func.ConstInt16(types.Types[types.TUINT16], int16(c))
+ case 4:
+ cv = root.Block.Func.ConstInt32(types.Types[types.TUINT32], int32(c))
+ case 8:
+ cv = root.Block.Func.ConstInt64(types.Types[types.TUINT64], c)
+ }
+
+ // Move all the stores to the root.
+ for i := int64(0); i < n; i++ {
+ v := a[i].store
+ if v == root {
+ v.Aux = cv.Type // widen store type
+ v.SetArg(0, ptr)
+ v.SetArg(1, cv)
+ v.SetArg(2, mem)
+ } else {
+ clobber(v)
+ v.Type = types.Types[types.TBOOL] // erase memory type
+ }
+ }
+ return true
+ }
+
+ // Check that all the shift/trunc are of the same base value.
+ shiftBase := getShiftBase(a)
+ if shiftBase == nil {
+ return false
+ }
+ for i := int64(0); i < n; i++ {
+ if !isShiftBase(a[i].store, shiftBase) {
+ return false
+ }
+ }
+
+ // Check for writes in little-endian or big-endian order.
+ isLittleEndian := true
+ shift0 := shift(a[0].store, shiftBase)
+ for i := int64(1); i < n; i++ {
+ if shift(a[i].store, shiftBase) != shift0+i*size*8 {
+ isLittleEndian = false
+ break
+ }
+ }
+ isBigEndian := true
+ for i := int64(1); i < n; i++ {
+ if shift(a[i].store, shiftBase) != shift0-i*size*8 {
+ isBigEndian = false
+ break
+ }
+ }
+ if !isLittleEndian && !isBigEndian {
+ return false
+ }
+
+ // Check to see if we need byte swap before storing.
+ needSwap := isLittleEndian && root.Block.Func.Config.BigEndian ||
+ isBigEndian && !root.Block.Func.Config.BigEndian
+ if needSwap && (size != 1 || !root.Block.Func.Config.haveByteSwap(n)) {
+ return false
+ }
+
+ // This is the commit point.
+
+ // Modify root to do all the stores.
+ sv := shiftBase
+ if isLittleEndian && shift0 != 0 {
+ sv = rightShift(root.Block, root.Pos, sv, shift0)
+ }
+ if isBigEndian && shift0-(n-1)*size*8 != 0 {
+ sv = rightShift(root.Block, root.Pos, sv, shift0-(n-1)*size*8)
+ }
+ if sv.Type.Size() > size*n {
+ sv = truncate(root.Block, root.Pos, sv, sv.Type.Size(), size*n)
+ }
+ if needSwap {
+ sv = byteSwap(root.Block, root.Pos, sv)
+ }
+
+ // Move all the stores to the root.
+ for i := int64(0); i < n; i++ {
+ v := a[i].store
+ if v == root {
+ v.Aux = sv.Type // widen store type
+ v.SetArg(0, ptr)
+ v.SetArg(1, sv)
+ v.SetArg(2, mem)
+ } else {
+ clobber(v)
+ v.Type = types.Types[types.TBOOL] // erase memory type
+ }
+ }
+ return true
+}
+
+func sizeType(size int64) *types.Type {
+ switch size {
+ case 8:
+ return types.Types[types.TUINT64]
+ case 4:
+ return types.Types[types.TUINT32]
+ case 2:
+ return types.Types[types.TUINT16]
+ default:
+ base.Fatalf("bad size %d\n", size)
+ return nil
+ }
+}
+
+func truncate(b *Block, pos src.XPos, v *Value, from, to int64) *Value {
+ switch from*10 + to {
+ case 82:
+ return b.NewValue1(pos, OpTrunc64to16, types.Types[types.TUINT16], v)
+ case 84:
+ return b.NewValue1(pos, OpTrunc64to32, types.Types[types.TUINT32], v)
+ case 42:
+ return b.NewValue1(pos, OpTrunc32to16, types.Types[types.TUINT16], v)
+ default:
+ base.Fatalf("bad sizes %d %d\n", from, to)
+ return nil
+ }
+}
+func zeroExtend(b *Block, pos src.XPos, v *Value, from, to int64) *Value {
+ switch from*10 + to {
+ case 24:
+ return b.NewValue1(pos, OpZeroExt16to32, types.Types[types.TUINT32], v)
+ case 28:
+ return b.NewValue1(pos, OpZeroExt16to64, types.Types[types.TUINT64], v)
+ case 48:
+ return b.NewValue1(pos, OpZeroExt32to64, types.Types[types.TUINT64], v)
+ default:
+ base.Fatalf("bad sizes %d %d\n", from, to)
+ return nil
+ }
+}
+
+func leftShift(b *Block, pos src.XPos, v *Value, shift int64) *Value {
+ s := b.Func.ConstInt64(types.Types[types.TUINT64], shift)
+ size := v.Type.Size()
+ switch size {
+ case 8:
+ return b.NewValue2(pos, OpLsh64x64, v.Type, v, s)
+ case 4:
+ return b.NewValue2(pos, OpLsh32x64, v.Type, v, s)
+ case 2:
+ return b.NewValue2(pos, OpLsh16x64, v.Type, v, s)
+ default:
+ base.Fatalf("bad size %d\n", size)
+ return nil
+ }
+}
+func rightShift(b *Block, pos src.XPos, v *Value, shift int64) *Value {
+ s := b.Func.ConstInt64(types.Types[types.TUINT64], shift)
+ size := v.Type.Size()
+ switch size {
+ case 8:
+ return b.NewValue2(pos, OpRsh64Ux64, v.Type, v, s)
+ case 4:
+ return b.NewValue2(pos, OpRsh32Ux64, v.Type, v, s)
+ case 2:
+ return b.NewValue2(pos, OpRsh16Ux64, v.Type, v, s)
+ default:
+ base.Fatalf("bad size %d\n", size)
+ return nil
+ }
+}
+func byteSwap(b *Block, pos src.XPos, v *Value) *Value {
+ switch v.Type.Size() {
+ case 8:
+ return b.NewValue1(pos, OpBswap64, v.Type, v)
+ case 4:
+ return b.NewValue1(pos, OpBswap32, v.Type, v)
+ case 2:
+ return b.NewValue1(pos, OpBswap16, v.Type, v)
+
+ default:
+ v.Fatalf("bad size %d\n", v.Type.Size())
+ return nil
+ }
+}