diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-16 19:19:13 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-16 19:19:13 +0000 |
commit | ccd992355df7192993c666236047820244914598 (patch) | |
tree | f00fea65147227b7743083c6148396f74cd66935 /src/cmd/compile/internal/ssa/memcombine.go | |
parent | Initial commit. (diff) | |
download | golang-1.21-ccd992355df7192993c666236047820244914598.tar.xz golang-1.21-ccd992355df7192993c666236047820244914598.zip |
Adding upstream version 1.21.8.upstream/1.21.8
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/cmd/compile/internal/ssa/memcombine.go')
-rw-r--r-- | src/cmd/compile/internal/ssa/memcombine.go | 731 |
1 files changed, 731 insertions, 0 deletions
diff --git a/src/cmd/compile/internal/ssa/memcombine.go b/src/cmd/compile/internal/ssa/memcombine.go new file mode 100644 index 0000000..26fb3f5 --- /dev/null +++ b/src/cmd/compile/internal/ssa/memcombine.go @@ -0,0 +1,731 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssa + +import ( + "cmd/compile/internal/base" + "cmd/compile/internal/types" + "cmd/internal/src" + "sort" +) + +// memcombine combines smaller loads and stores into larger ones. +// We ensure this generates good code for encoding/binary operations. +// It may help other cases also. +func memcombine(f *Func) { + // This optimization requires that the architecture has + // unaligned loads and unaligned stores. + if !f.Config.unalignedOK { + return + } + + memcombineLoads(f) + memcombineStores(f) +} + +func memcombineLoads(f *Func) { + // Find "OR trees" to start with. + mark := f.newSparseSet(f.NumValues()) + defer f.retSparseSet(mark) + var order []*Value + + // Mark all values that are the argument of an OR. + for _, b := range f.Blocks { + for _, v := range b.Values { + if v.Op == OpOr16 || v.Op == OpOr32 || v.Op == OpOr64 { + mark.add(v.Args[0].ID) + mark.add(v.Args[1].ID) + } + } + } + for _, b := range f.Blocks { + order = order[:0] + for _, v := range b.Values { + if v.Op != OpOr16 && v.Op != OpOr32 && v.Op != OpOr64 { + continue + } + if mark.contains(v.ID) { + // marked - means it is not the root of an OR tree + continue + } + // Add the OR tree rooted at v to the order. + // We use BFS here, but any walk that puts roots before leaves would work. + i := len(order) + order = append(order, v) + for ; i < len(order); i++ { + x := order[i] + for j := 0; j < 2; j++ { + a := x.Args[j] + if a.Op == OpOr16 || a.Op == OpOr32 || a.Op == OpOr64 { + order = append(order, a) + } + } + } + } + for _, v := range order { + max := f.Config.RegSize + switch v.Op { + case OpOr64: + case OpOr32: + max = 4 + case OpOr16: + max = 2 + default: + continue + } + for n := max; n > 1; n /= 2 { + if combineLoads(v, n) { + break + } + } + } + } +} + +// A BaseAddress represents the address ptr+idx, where +// ptr is a pointer type and idx is an integer type. +// idx may be nil, in which case it is treated as 0. +type BaseAddress struct { + ptr *Value + idx *Value +} + +// splitPtr returns the base address of ptr and any +// constant offset from that base. +// BaseAddress{ptr,nil},0 is always a valid result, but splitPtr +// tries to peel away as many constants into off as possible. +func splitPtr(ptr *Value) (BaseAddress, int64) { + var idx *Value + var off int64 + for { + if ptr.Op == OpOffPtr { + off += ptr.AuxInt + ptr = ptr.Args[0] + } else if ptr.Op == OpAddPtr { + if idx != nil { + // We have two or more indexing values. + // Pick the first one we found. + return BaseAddress{ptr: ptr, idx: idx}, off + } + idx = ptr.Args[1] + if idx.Op == OpAdd32 || idx.Op == OpAdd64 { + if idx.Args[0].Op == OpConst32 || idx.Args[0].Op == OpConst64 { + off += idx.Args[0].AuxInt + idx = idx.Args[1] + } else if idx.Args[1].Op == OpConst32 || idx.Args[1].Op == OpConst64 { + off += idx.Args[1].AuxInt + idx = idx.Args[0] + } + } + ptr = ptr.Args[0] + } else { + return BaseAddress{ptr: ptr, idx: idx}, off + } + } +} + +func combineLoads(root *Value, n int64) bool { + orOp := root.Op + var shiftOp Op + switch orOp { + case OpOr64: + shiftOp = OpLsh64x64 + case OpOr32: + shiftOp = OpLsh32x64 + case OpOr16: + shiftOp = OpLsh16x64 + default: + return false + } + + // Find n values that are ORed together with the above op. + a := make([]*Value, 0, 8) + a = append(a, root) + for i := 0; i < len(a) && int64(len(a)) < n; i++ { + v := a[i] + if v.Uses != 1 && v != root { + // Something in this subtree is used somewhere else. + return false + } + if v.Op == orOp { + a[i] = v.Args[0] + a = append(a, v.Args[1]) + i-- + } + } + if int64(len(a)) != n { + return false + } + + // Check that the first entry to see what ops we're looking for. + // All the entries should be of the form shift(extend(load)), maybe with no shift. + v := a[0] + if v.Op == shiftOp { + v = v.Args[0] + } + var extOp Op + if orOp == OpOr64 && (v.Op == OpZeroExt8to64 || v.Op == OpZeroExt16to64 || v.Op == OpZeroExt32to64) || + orOp == OpOr32 && (v.Op == OpZeroExt8to32 || v.Op == OpZeroExt16to32) || + orOp == OpOr16 && v.Op == OpZeroExt8to16 { + extOp = v.Op + v = v.Args[0] + } else { + return false + } + if v.Op != OpLoad { + return false + } + base, _ := splitPtr(v.Args[0]) + mem := v.Args[1] + size := v.Type.Size() + + if root.Block.Func.Config.arch == "S390X" { + // s390x can't handle unaligned accesses to global variables. + if base.ptr.Op == OpAddr { + return false + } + } + + // Check all the entries, extract useful info. + type LoadRecord struct { + load *Value + offset int64 // offset of load address from base + shift int64 + } + r := make([]LoadRecord, n, 8) + for i := int64(0); i < n; i++ { + v := a[i] + if v.Uses != 1 { + return false + } + shift := int64(0) + if v.Op == shiftOp { + if v.Args[1].Op != OpConst64 { + return false + } + shift = v.Args[1].AuxInt + v = v.Args[0] + if v.Uses != 1 { + return false + } + } + if v.Op != extOp { + return false + } + load := v.Args[0] + if load.Op != OpLoad { + return false + } + if load.Uses != 1 { + return false + } + if load.Args[1] != mem { + return false + } + p, off := splitPtr(load.Args[0]) + if p != base { + return false + } + r[i] = LoadRecord{load: load, offset: off, shift: shift} + } + + // Sort in memory address order. + sort.Slice(r, func(i, j int) bool { + return r[i].offset < r[j].offset + }) + + // Check that we have contiguous offsets. + for i := int64(0); i < n; i++ { + if r[i].offset != r[0].offset+i*size { + return false + } + } + + // Check for reads in little-endian or big-endian order. + shift0 := r[0].shift + isLittleEndian := true + for i := int64(0); i < n; i++ { + if r[i].shift != shift0+i*size*8 { + isLittleEndian = false + break + } + } + isBigEndian := true + for i := int64(0); i < n; i++ { + if r[i].shift != shift0-i*size*8 { + isBigEndian = false + break + } + } + if !isLittleEndian && !isBigEndian { + return false + } + + // Find a place to put the new load. + // This is tricky, because it has to be at a point where + // its memory argument is live. We can't just put it in root.Block. + // We use the block of the latest load. + loads := make([]*Value, n, 8) + for i := int64(0); i < n; i++ { + loads[i] = r[i].load + } + loadBlock := mergePoint(root.Block, loads...) + if loadBlock == nil { + return false + } + // Find a source position to use. + pos := src.NoXPos + for _, load := range loads { + if load.Block == loadBlock { + pos = load.Pos + break + } + } + if pos == src.NoXPos { + return false + } + + // Check to see if we need byte swap before storing. + needSwap := isLittleEndian && root.Block.Func.Config.BigEndian || + isBigEndian && !root.Block.Func.Config.BigEndian + if needSwap && (size != 1 || !root.Block.Func.Config.haveByteSwap(n)) { + return false + } + + // This is the commit point. + + // First, issue load at lowest address. + v = loadBlock.NewValue2(pos, OpLoad, sizeType(n*size), r[0].load.Args[0], mem) + + // Byte swap if needed, + if needSwap { + v = byteSwap(loadBlock, pos, v) + } + + // Extend if needed. + if n*size < root.Type.Size() { + v = zeroExtend(loadBlock, pos, v, n*size, root.Type.Size()) + } + + // Shift if needed. + if isLittleEndian && shift0 != 0 { + v = leftShift(loadBlock, pos, v, shift0) + } + if isBigEndian && shift0-(n-1)*size*8 != 0 { + v = leftShift(loadBlock, pos, v, shift0-(n-1)*size*8) + } + + // Install with (Copy v). + root.reset(OpCopy) + root.AddArg(v) + + // Clobber the loads, just to prevent additional work being done on + // subtrees (which are now unreachable). + for i := int64(0); i < n; i++ { + clobber(r[i].load) + } + return true +} + +func memcombineStores(f *Func) { + mark := f.newSparseSet(f.NumValues()) + defer f.retSparseSet(mark) + var order []*Value + + for _, b := range f.Blocks { + // Mark all stores which are not last in a store sequence. + mark.clear() + for _, v := range b.Values { + if v.Op == OpStore { + mark.add(v.MemoryArg().ID) + } + } + + // pick an order for visiting stores such that + // later stores come earlier in the ordering. + order = order[:0] + for _, v := range b.Values { + if v.Op != OpStore { + continue + } + if mark.contains(v.ID) { + continue // not last in a chain of stores + } + for { + order = append(order, v) + v = v.Args[2] + if v.Block != b || v.Op != OpStore { + break + } + } + } + + // Look for combining opportunities at each store in queue order. + for _, v := range order { + if v.Op != OpStore { // already rewritten + continue + } + + size := v.Aux.(*types.Type).Size() + if size >= f.Config.RegSize || size == 0 { + continue + } + + for n := f.Config.RegSize / size; n > 1; n /= 2 { + if combineStores(v, n) { + continue + } + } + } + } +} + +// Try to combine the n stores ending in root. +// Returns true if successful. +func combineStores(root *Value, n int64) bool { + // Helper functions. + type StoreRecord struct { + store *Value + offset int64 + } + getShiftBase := func(a []StoreRecord) *Value { + x := a[0].store.Args[1] + y := a[1].store.Args[1] + switch x.Op { + case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8: + x = x.Args[0] + default: + return nil + } + switch y.Op { + case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8: + y = y.Args[0] + default: + return nil + } + var x2 *Value + switch x.Op { + case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64: + x2 = x.Args[0] + default: + } + var y2 *Value + switch y.Op { + case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64: + y2 = y.Args[0] + default: + } + if y2 == x { + // a shift of x and x itself. + return x + } + if x2 == y { + // a shift of y and y itself. + return y + } + if x2 == y2 { + // 2 shifts both of the same argument. + return x2 + } + return nil + } + isShiftBase := func(v, base *Value) bool { + val := v.Args[1] + switch val.Op { + case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8: + val = val.Args[0] + default: + return false + } + if val == base { + return true + } + switch val.Op { + case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64: + val = val.Args[0] + default: + return false + } + return val == base + } + shift := func(v, base *Value) int64 { + val := v.Args[1] + switch val.Op { + case OpTrunc64to8, OpTrunc64to16, OpTrunc64to32, OpTrunc32to8, OpTrunc32to16, OpTrunc16to8: + val = val.Args[0] + default: + return -1 + } + if val == base { + return 0 + } + switch val.Op { + case OpRsh64Ux64, OpRsh32Ux64, OpRsh16Ux64: + val = val.Args[1] + default: + return -1 + } + if val.Op != OpConst64 { + return -1 + } + return val.AuxInt + } + + // Element size of the individual stores. + size := root.Aux.(*types.Type).Size() + if size*n > root.Block.Func.Config.RegSize { + return false + } + + // Gather n stores to look at. Check easy conditions we require. + a := make([]StoreRecord, 0, 8) + rbase, roff := splitPtr(root.Args[0]) + if root.Block.Func.Config.arch == "S390X" { + // s390x can't handle unaligned accesses to global variables. + if rbase.ptr.Op == OpAddr { + return false + } + } + a = append(a, StoreRecord{root, roff}) + for i, x := int64(1), root.Args[2]; i < n; i, x = i+1, x.Args[2] { + if x.Op != OpStore { + return false + } + if x.Block != root.Block { + return false + } + if x.Uses != 1 { // Note: root can have more than one use. + return false + } + if x.Aux.(*types.Type).Size() != size { + return false + } + base, off := splitPtr(x.Args[0]) + if base != rbase { + return false + } + a = append(a, StoreRecord{x, off}) + } + // Before we sort, grab the memory arg the result should have. + mem := a[n-1].store.Args[2] + + // Sort stores in increasing address order. + sort.Slice(a, func(i, j int) bool { + return a[i].offset < a[j].offset + }) + + // Check that everything is written to sequential locations. + for i := int64(0); i < n; i++ { + if a[i].offset != a[0].offset+i*size { + return false + } + } + + // Memory location we're going to write at (the lowest one). + ptr := a[0].store.Args[0] + + // Check for constant stores + isConst := true + for i := int64(0); i < n; i++ { + switch a[i].store.Args[1].Op { + case OpConst32, OpConst16, OpConst8: + default: + isConst = false + break + } + } + if isConst { + // Modify root to do all the stores. + var c int64 + mask := int64(1)<<(8*size) - 1 + for i := int64(0); i < n; i++ { + s := 8 * size * int64(i) + if root.Block.Func.Config.BigEndian { + s = 8*size*(n-1) - s + } + c |= (a[i].store.Args[1].AuxInt & mask) << s + } + var cv *Value + switch size * n { + case 2: + cv = root.Block.Func.ConstInt16(types.Types[types.TUINT16], int16(c)) + case 4: + cv = root.Block.Func.ConstInt32(types.Types[types.TUINT32], int32(c)) + case 8: + cv = root.Block.Func.ConstInt64(types.Types[types.TUINT64], c) + } + + // Move all the stores to the root. + for i := int64(0); i < n; i++ { + v := a[i].store + if v == root { + v.Aux = cv.Type // widen store type + v.SetArg(0, ptr) + v.SetArg(1, cv) + v.SetArg(2, mem) + } else { + clobber(v) + v.Type = types.Types[types.TBOOL] // erase memory type + } + } + return true + } + + // Check that all the shift/trunc are of the same base value. + shiftBase := getShiftBase(a) + if shiftBase == nil { + return false + } + for i := int64(0); i < n; i++ { + if !isShiftBase(a[i].store, shiftBase) { + return false + } + } + + // Check for writes in little-endian or big-endian order. + isLittleEndian := true + shift0 := shift(a[0].store, shiftBase) + for i := int64(1); i < n; i++ { + if shift(a[i].store, shiftBase) != shift0+i*size*8 { + isLittleEndian = false + break + } + } + isBigEndian := true + for i := int64(1); i < n; i++ { + if shift(a[i].store, shiftBase) != shift0-i*size*8 { + isBigEndian = false + break + } + } + if !isLittleEndian && !isBigEndian { + return false + } + + // Check to see if we need byte swap before storing. + needSwap := isLittleEndian && root.Block.Func.Config.BigEndian || + isBigEndian && !root.Block.Func.Config.BigEndian + if needSwap && (size != 1 || !root.Block.Func.Config.haveByteSwap(n)) { + return false + } + + // This is the commit point. + + // Modify root to do all the stores. + sv := shiftBase + if isLittleEndian && shift0 != 0 { + sv = rightShift(root.Block, root.Pos, sv, shift0) + } + if isBigEndian && shift0-(n-1)*size*8 != 0 { + sv = rightShift(root.Block, root.Pos, sv, shift0-(n-1)*size*8) + } + if sv.Type.Size() > size*n { + sv = truncate(root.Block, root.Pos, sv, sv.Type.Size(), size*n) + } + if needSwap { + sv = byteSwap(root.Block, root.Pos, sv) + } + + // Move all the stores to the root. + for i := int64(0); i < n; i++ { + v := a[i].store + if v == root { + v.Aux = sv.Type // widen store type + v.SetArg(0, ptr) + v.SetArg(1, sv) + v.SetArg(2, mem) + } else { + clobber(v) + v.Type = types.Types[types.TBOOL] // erase memory type + } + } + return true +} + +func sizeType(size int64) *types.Type { + switch size { + case 8: + return types.Types[types.TUINT64] + case 4: + return types.Types[types.TUINT32] + case 2: + return types.Types[types.TUINT16] + default: + base.Fatalf("bad size %d\n", size) + return nil + } +} + +func truncate(b *Block, pos src.XPos, v *Value, from, to int64) *Value { + switch from*10 + to { + case 82: + return b.NewValue1(pos, OpTrunc64to16, types.Types[types.TUINT16], v) + case 84: + return b.NewValue1(pos, OpTrunc64to32, types.Types[types.TUINT32], v) + case 42: + return b.NewValue1(pos, OpTrunc32to16, types.Types[types.TUINT16], v) + default: + base.Fatalf("bad sizes %d %d\n", from, to) + return nil + } +} +func zeroExtend(b *Block, pos src.XPos, v *Value, from, to int64) *Value { + switch from*10 + to { + case 24: + return b.NewValue1(pos, OpZeroExt16to32, types.Types[types.TUINT32], v) + case 28: + return b.NewValue1(pos, OpZeroExt16to64, types.Types[types.TUINT64], v) + case 48: + return b.NewValue1(pos, OpZeroExt32to64, types.Types[types.TUINT64], v) + default: + base.Fatalf("bad sizes %d %d\n", from, to) + return nil + } +} + +func leftShift(b *Block, pos src.XPos, v *Value, shift int64) *Value { + s := b.Func.ConstInt64(types.Types[types.TUINT64], shift) + size := v.Type.Size() + switch size { + case 8: + return b.NewValue2(pos, OpLsh64x64, v.Type, v, s) + case 4: + return b.NewValue2(pos, OpLsh32x64, v.Type, v, s) + case 2: + return b.NewValue2(pos, OpLsh16x64, v.Type, v, s) + default: + base.Fatalf("bad size %d\n", size) + return nil + } +} +func rightShift(b *Block, pos src.XPos, v *Value, shift int64) *Value { + s := b.Func.ConstInt64(types.Types[types.TUINT64], shift) + size := v.Type.Size() + switch size { + case 8: + return b.NewValue2(pos, OpRsh64Ux64, v.Type, v, s) + case 4: + return b.NewValue2(pos, OpRsh32Ux64, v.Type, v, s) + case 2: + return b.NewValue2(pos, OpRsh16Ux64, v.Type, v, s) + default: + base.Fatalf("bad size %d\n", size) + return nil + } +} +func byteSwap(b *Block, pos src.XPos, v *Value) *Value { + switch v.Type.Size() { + case 8: + return b.NewValue1(pos, OpBswap64, v.Type, v) + case 4: + return b.NewValue1(pos, OpBswap32, v.Type, v) + case 2: + return b.NewValue1(pos, OpBswap16, v.Type, v) + + default: + v.Fatalf("bad size %d\n", v.Type.Size()) + return nil + } +} |