diff options
Diffstat (limited to 'src/cmd/compile/internal/devirtualize/pgo.go')
-rw-r--r-- | src/cmd/compile/internal/devirtualize/pgo.go | 542 |
1 files changed, 542 insertions, 0 deletions
diff --git a/src/cmd/compile/internal/devirtualize/pgo.go b/src/cmd/compile/internal/devirtualize/pgo.go new file mode 100644 index 0000000..068e0ef --- /dev/null +++ b/src/cmd/compile/internal/devirtualize/pgo.go @@ -0,0 +1,542 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package devirtualize + +import ( + "cmd/compile/internal/base" + "cmd/compile/internal/inline" + "cmd/compile/internal/ir" + "cmd/compile/internal/logopt" + "cmd/compile/internal/pgo" + "cmd/compile/internal/typecheck" + "cmd/compile/internal/types" + "encoding/json" + "fmt" + "os" + "strings" +) + +// CallStat summarizes a single call site. +// +// This is used only for debug logging. +type CallStat struct { + Pkg string // base.Ctxt.Pkgpath + Pos string // file:line:col of call. + + Caller string // Linker symbol name of calling function. + + // Direct or indirect call. + Direct bool + + // For indirect calls, interface call or other indirect function call. + Interface bool + + // Total edge weight from this call site. + Weight int64 + + // Hottest callee from this call site, regardless of type + // compatibility. + Hottest string + HottestWeight int64 + + // Devirtualized callee if != "". + // + // Note that this may be different than Hottest because we apply + // type-check restrictions, which helps distinguish multiple calls on + // the same line. + Devirtualized string + DevirtualizedWeight int64 +} + +// ProfileGuided performs call devirtualization of indirect calls based on +// profile information. +// +// Specifically, it performs conditional devirtualization of interface calls +// for the hottest callee. That is, it performs a transformation like: +// +// type Iface interface { +// Foo() +// } +// +// type Concrete struct{} +// +// func (Concrete) Foo() {} +// +// func foo(i Iface) { +// i.Foo() +// } +// +// to: +// +// func foo(i Iface) { +// if c, ok := i.(Concrete); ok { +// c.Foo() +// } else { +// i.Foo() +// } +// } +// +// The primary benefit of this transformation is enabling inlining of the +// direct call. +func ProfileGuided(fn *ir.Func, p *pgo.Profile) { + ir.CurFunc = fn + + name := ir.LinkFuncName(fn) + + // Can't devirtualize go/defer calls. See comment in Static. + goDeferCall := make(map[*ir.CallExpr]bool) + + var jsonW *json.Encoder + if base.Debug.PGODebug >= 3 { + jsonW = json.NewEncoder(os.Stdout) + } + + var edit func(n ir.Node) ir.Node + edit = func(n ir.Node) ir.Node { + if n == nil { + return n + } + + if gds, ok := n.(*ir.GoDeferStmt); ok { + if call, ok := gds.Call.(*ir.CallExpr); ok { + goDeferCall[call] = true + } + } + + ir.EditChildren(n, edit) + + call, ok := n.(*ir.CallExpr) + if !ok { + return n + } + + var stat *CallStat + if base.Debug.PGODebug >= 3 { + // Statistics about every single call. Handy for external data analysis. + // + // TODO(prattmic): Log via logopt? + stat = constructCallStat(p, fn, name, call) + if stat != nil { + defer func() { + jsonW.Encode(&stat) + }() + } + } + + if call.Op() != ir.OCALLINTER { + return n + } + + if base.Debug.PGODebug >= 2 { + fmt.Printf("%v: PGO devirtualize considering call %v\n", ir.Line(call), call) + } + + if goDeferCall[call] { + if base.Debug.PGODebug >= 2 { + fmt.Printf("%v: can't PGO devirtualize go/defer call %v\n", ir.Line(call), call) + } + return n + } + + // Bail if we do not have a hot callee. + callee, weight := findHotConcreteCallee(p, fn, call) + if callee == nil { + return n + } + // Bail if we do not have a Type node for the hot callee. + ctyp := methodRecvType(callee) + if ctyp == nil { + return n + } + // Bail if we know for sure it won't inline. + if !shouldPGODevirt(callee) { + return n + } + + if stat != nil { + stat.Devirtualized = ir.LinkFuncName(callee) + stat.DevirtualizedWeight = weight + } + + return rewriteCondCall(call, fn, callee, ctyp) + } + + ir.EditChildren(fn, edit) +} + +// shouldPGODevirt checks if we should perform PGO devirtualization to the +// target function. +// +// PGO devirtualization is most valuable when the callee is inlined, so if it +// won't inline we can skip devirtualizing. +func shouldPGODevirt(fn *ir.Func) bool { + var reason string + if base.Flag.LowerM > 1 || logopt.Enabled() { + defer func() { + if reason != "" { + if base.Flag.LowerM > 1 { + fmt.Printf("%v: should not PGO devirtualize %v: %s\n", ir.Line(fn), ir.FuncName(fn), reason) + } + if logopt.Enabled() { + logopt.LogOpt(fn.Pos(), ": should not PGO devirtualize function", "pgo-devirtualize", ir.FuncName(fn), reason) + } + } + }() + } + + reason = inline.InlineImpossible(fn) + if reason != "" { + return false + } + + // TODO(prattmic): checking only InlineImpossible is very conservative, + // primarily excluding only functions with pragmas. We probably want to + // move in either direction. Either: + // + // 1. Don't even bother to check InlineImpossible, as it affects so few + // functions. + // + // 2. Or consider the function body (notably cost) to better determine + // if the function will actually inline. + + return true +} + +// constructCallStat builds an initial CallStat describing this call, for +// logging. If the call is devirtualized, the devirtualization fields should be +// updated. +func constructCallStat(p *pgo.Profile, fn *ir.Func, name string, call *ir.CallExpr) *CallStat { + switch call.Op() { + case ir.OCALLFUNC, ir.OCALLINTER, ir.OCALLMETH: + default: + // We don't care about logging builtin functions. + return nil + } + + stat := CallStat{ + Pkg: base.Ctxt.Pkgpath, + Pos: ir.Line(call), + Caller: name, + } + + offset := pgo.NodeLineOffset(call, fn) + + // Sum of all edges from this callsite, regardless of callee. + // For direct calls, this should be the same as the single edge + // weight (except for multiple calls on one line, which we + // can't distinguish). + callerNode := p.WeightedCG.IRNodes[name] + for _, edge := range callerNode.OutEdges { + if edge.CallSiteOffset != offset { + continue + } + stat.Weight += edge.Weight + if edge.Weight > stat.HottestWeight { + stat.HottestWeight = edge.Weight + stat.Hottest = edge.Dst.Name() + } + } + + switch call.Op() { + case ir.OCALLFUNC: + stat.Interface = false + + callee := pgo.DirectCallee(call.X) + if callee != nil { + stat.Direct = true + if stat.Hottest == "" { + stat.Hottest = ir.LinkFuncName(callee) + } + } else { + stat.Direct = false + } + case ir.OCALLINTER: + stat.Direct = false + stat.Interface = true + case ir.OCALLMETH: + base.FatalfAt(call.Pos(), "OCALLMETH missed by typecheck") + } + + return &stat +} + +// rewriteCondCall devirtualizes the given call using a direct method call to +// concretetyp. +func rewriteCondCall(call *ir.CallExpr, curfn, callee *ir.Func, concretetyp *types.Type) ir.Node { + if base.Flag.LowerM != 0 { + fmt.Printf("%v: PGO devirtualizing %v to %v\n", ir.Line(call), call.X, callee) + } + + // We generate an OINCALL of: + // + // var recv Iface + // + // var arg1 A1 + // var argN AN + // + // var ret1 R1 + // var retN RN + // + // recv, arg1, argN = recv expr, arg1 expr, argN expr + // + // t, ok := recv.(Concrete) + // if ok { + // ret1, retN = t.Method(arg1, ... argN) + // } else { + // ret1, retN = recv.Method(arg1, ... argN) + // } + // + // OINCALL retvars: ret1, ... retN + // + // This isn't really an inlined call of course, but InlinedCallExpr + // makes handling reassignment of return values easier. + // + // TODO(prattmic): This increases the size of the AST in the caller, + // making it less like to inline. We may want to compensate for this + // somehow. + + var retvars []ir.Node + + sig := call.X.Type() + + for _, ret := range sig.Results().FieldSlice() { + retvars = append(retvars, typecheck.Temp(ret.Type)) + } + + sel := call.X.(*ir.SelectorExpr) + method := sel.Sel + pos := call.Pos() + init := ir.TakeInit(call) + + // Evaluate receiver and argument expressions. The receiver is used + // twice but we don't want to cause side effects twice. The arguments + // are used in two different calls and we can't trivially copy them. + // + // recv must be first in the assignment list as its side effects must + // be ordered before argument side effects. + var lhs, rhs []ir.Node + recv := typecheck.Temp(sel.X.Type()) + lhs = append(lhs, recv) + rhs = append(rhs, sel.X) + + // Move arguments to assignments prior to the if statement. We cannot + // simply copy the args' IR, as some IR constructs cannot be copied, + // such as labels (possible in InlinedCall nodes). + args := call.Args.Take() + for _, arg := range args { + argvar := typecheck.Temp(arg.Type()) + + lhs = append(lhs, argvar) + rhs = append(rhs, arg) + } + + asList := ir.NewAssignListStmt(pos, ir.OAS2, lhs, rhs) + init.Append(typecheck.Stmt(asList)) + + // Copy slice so edits in one location don't affect another. + argvars := append([]ir.Node(nil), lhs[1:]...) + call.Args = argvars + + tmpnode := typecheck.Temp(concretetyp) + tmpok := typecheck.Temp(types.Types[types.TBOOL]) + + assert := ir.NewTypeAssertExpr(pos, recv, concretetyp) + + assertAsList := ir.NewAssignListStmt(pos, ir.OAS2, []ir.Node{tmpnode, tmpok}, []ir.Node{typecheck.Expr(assert)}) + init.Append(typecheck.Stmt(assertAsList)) + + concreteCallee := typecheck.Callee(ir.NewSelectorExpr(pos, ir.OXDOT, tmpnode, method)) + // Copy slice so edits in one location don't affect another. + argvars = append([]ir.Node(nil), argvars...) + concreteCall := typecheck.Call(pos, concreteCallee, argvars, call.IsDDD) + + var thenBlock, elseBlock ir.Nodes + if len(retvars) == 0 { + thenBlock.Append(concreteCall) + elseBlock.Append(call) + } else { + // Copy slice so edits in one location don't affect another. + thenRet := append([]ir.Node(nil), retvars...) + thenAsList := ir.NewAssignListStmt(pos, ir.OAS2, thenRet, []ir.Node{concreteCall}) + thenBlock.Append(typecheck.Stmt(thenAsList)) + + elseRet := append([]ir.Node(nil), retvars...) + elseAsList := ir.NewAssignListStmt(pos, ir.OAS2, elseRet, []ir.Node{call}) + elseBlock.Append(typecheck.Stmt(elseAsList)) + } + + cond := ir.NewIfStmt(pos, nil, nil, nil) + cond.SetInit(init) + cond.Cond = tmpok + cond.Body = thenBlock + cond.Else = elseBlock + cond.Likely = true + + body := []ir.Node{typecheck.Stmt(cond)} + + res := ir.NewInlinedCallExpr(pos, body, retvars) + res.SetType(call.Type()) + res.SetTypecheck(1) + + if base.Debug.PGODebug >= 3 { + fmt.Printf("PGO devirtualizing call to %+v. After: %+v\n", concretetyp, res) + } + + return res +} + +// methodRecvType returns the type containing method fn. Returns nil if fn +// is not a method. +func methodRecvType(fn *ir.Func) *types.Type { + recv := fn.Nname.Type().Recv() + if recv == nil { + return nil + } + return recv.Type +} + +// interfaceCallRecvTypeAndMethod returns the type and the method of the interface +// used in an interface call. +func interfaceCallRecvTypeAndMethod(call *ir.CallExpr) (*types.Type, *types.Sym) { + if call.Op() != ir.OCALLINTER { + base.Fatalf("Call isn't OCALLINTER: %+v", call) + } + + sel, ok := call.X.(*ir.SelectorExpr) + if !ok { + base.Fatalf("OCALLINTER doesn't contain SelectorExpr: %+v", call) + } + + return sel.X.Type(), sel.Sel +} + +// findHotConcreteCallee returns the *ir.Func of the hottest callee of an +// indirect call, if available, and its edge weight. +func findHotConcreteCallee(p *pgo.Profile, caller *ir.Func, call *ir.CallExpr) (*ir.Func, int64) { + callerName := ir.LinkFuncName(caller) + callerNode := p.WeightedCG.IRNodes[callerName] + callOffset := pgo.NodeLineOffset(call, caller) + + inter, method := interfaceCallRecvTypeAndMethod(call) + + var hottest *pgo.IREdge + + // Returns true if e is hotter than hottest. + // + // Naively this is just e.Weight > hottest.Weight, but because OutEdges + // has arbitrary iteration order, we need to apply additional sort + // criteria when e.Weight == hottest.Weight to ensure we have stable + // selection. + hotter := func(e *pgo.IREdge) bool { + if hottest == nil { + return true + } + if e.Weight != hottest.Weight { + return e.Weight > hottest.Weight + } + + // Now e.Weight == hottest.Weight, we must select on other + // criteria. + + if hottest.Dst.AST == nil && e.Dst.AST != nil { + // Prefer the edge with IR available. + return true + } + + // Arbitrary, but the callee names will always differ. Select + // the lexicographically first callee. + return e.Dst.Name() < hottest.Dst.Name() + } + + for _, e := range callerNode.OutEdges { + if e.CallSiteOffset != callOffset { + continue + } + + if !hotter(e) { + // TODO(prattmic): consider total caller weight? i.e., + // if the hottest callee is only 10% of the weight, + // maybe don't devirtualize? Similarly, if this is call + // is globally very cold, there is not much value in + // devirtualizing. + if base.Debug.PGODebug >= 2 { + fmt.Printf("%v: edge %s:%d -> %s (weight %d): too cold (hottest %d)\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight, hottest.Weight) + } + continue + } + + if e.Dst.AST == nil { + // Destination isn't visible from this package + // compilation. + // + // We must assume it implements the interface. + // + // We still record this as the hottest callee so far + // because we only want to return the #1 hottest + // callee. If we skip this then we'd return the #2 + // hottest callee. + if base.Debug.PGODebug >= 2 { + fmt.Printf("%v: edge %s:%d -> %s (weight %d) (missing IR): hottest so far\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight) + } + hottest = e + continue + } + + ctyp := methodRecvType(e.Dst.AST) + if ctyp == nil { + // Not a method. + // TODO(prattmic): Support non-interface indirect calls. + if base.Debug.PGODebug >= 2 { + fmt.Printf("%v: edge %s:%d -> %s (weight %d): callee not a method\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight) + } + continue + } + + // If ctyp doesn't implement inter it is most likely from a + // different call on the same line + if !typecheck.Implements(ctyp, inter) { + // TODO(prattmic): this is overly strict. Consider if + // ctyp is a partial implementation of an interface + // that gets embedded in types that complete the + // interface. It would still be OK to devirtualize a + // call to this method. + // + // What we'd need to do is check that the function + // pointer in the itab matches the method we want, + // rather than doing a full type assertion. + if base.Debug.PGODebug >= 2 { + why := typecheck.ImplementsExplain(ctyp, inter) + fmt.Printf("%v: edge %s:%d -> %s (weight %d): %v doesn't implement %v (%s)\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight, ctyp, inter, why) + } + continue + } + + // If the method name is different it is most likely from a + // different call on the same line + if !strings.HasSuffix(e.Dst.Name(), "."+method.Name) { + if base.Debug.PGODebug >= 2 { + fmt.Printf("%v: edge %s:%d -> %s (weight %d): callee is a different method\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight) + } + continue + } + + if base.Debug.PGODebug >= 2 { + fmt.Printf("%v: edge %s:%d -> %s (weight %d): hottest so far\n", ir.Line(call), callerName, callOffset, e.Dst.Name(), e.Weight) + } + hottest = e + } + + if hottest == nil { + if base.Debug.PGODebug >= 2 { + fmt.Printf("%v: call %s:%d: no hot callee\n", ir.Line(call), callerName, callOffset) + } + return nil, 0 + } + + if base.Debug.PGODebug >= 2 { + fmt.Printf("%v call %s:%d: hottest callee %s (weight %d)\n", ir.Line(call), callerName, callOffset, hottest.Dst.Name(), hottest.Weight) + } + return hottest.Dst.AST, hottest.Weight +} |