diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 13:14:23 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-28 13:14:23 +0000 |
commit | 73df946d56c74384511a194dd01dbe099584fd1a (patch) | |
tree | fd0bcea490dd81327ddfbb31e215439672c9a068 /src/cmd/fix | |
parent | Initial commit. (diff) | |
download | golang-1.16-73df946d56c74384511a194dd01dbe099584fd1a.tar.xz golang-1.16-73df946d56c74384511a194dd01dbe099584fd1a.zip |
Adding upstream version 1.16.10.upstream/1.16.10upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to '')
-rw-r--r-- | src/cmd/fix/cftype.go | 143 | ||||
-rw-r--r-- | src/cmd/fix/cftype_test.go | 219 | ||||
-rw-r--r-- | src/cmd/fix/context.go | 25 | ||||
-rw-r--r-- | src/cmd/fix/context_test.go | 42 | ||||
-rw-r--r-- | src/cmd/fix/doc.go | 36 | ||||
-rw-r--r-- | src/cmd/fix/egltype.go | 52 | ||||
-rw-r--r-- | src/cmd/fix/egltype_test.go | 196 | ||||
-rw-r--r-- | src/cmd/fix/fix.go | 557 | ||||
-rw-r--r-- | src/cmd/fix/gotypes.go | 75 | ||||
-rw-r--r-- | src/cmd/fix/gotypes_test.go | 89 | ||||
-rw-r--r-- | src/cmd/fix/import_test.go | 458 | ||||
-rw-r--r-- | src/cmd/fix/jnitype.go | 65 | ||||
-rw-r--r-- | src/cmd/fix/jnitype_test.go | 185 | ||||
-rw-r--r-- | src/cmd/fix/main.go | 253 | ||||
-rw-r--r-- | src/cmd/fix/main_test.go | 135 | ||||
-rw-r--r-- | src/cmd/fix/netipv6zone.go | 68 | ||||
-rw-r--r-- | src/cmd/fix/netipv6zone_test.go | 43 | ||||
-rw-r--r-- | src/cmd/fix/printerconfig.go | 61 | ||||
-rw-r--r-- | src/cmd/fix/printerconfig_test.go | 37 | ||||
-rw-r--r-- | src/cmd/fix/typecheck.go | 800 |
20 files changed, 3539 insertions, 0 deletions
diff --git a/src/cmd/fix/cftype.go b/src/cmd/fix/cftype.go new file mode 100644 index 0000000..b47b066 --- /dev/null +++ b/src/cmd/fix/cftype.go @@ -0,0 +1,143 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "go/ast" + "go/token" + "reflect" + "strings" +) + +func init() { + register(cftypeFix) +} + +var cftypeFix = fix{ + name: "cftype", + date: "2017-09-27", + f: cftypefix, + desc: `Fixes initializers and casts of C.*Ref and JNI types`, + disabled: false, +} + +// Old state: +// type CFTypeRef unsafe.Pointer +// New state: +// type CFTypeRef uintptr +// and similar for other *Ref types. +// This fix finds nils initializing these types and replaces the nils with 0s. +func cftypefix(f *ast.File) bool { + return typefix(f, func(s string) bool { + return strings.HasPrefix(s, "C.") && strings.HasSuffix(s, "Ref") && s != "C.CFAllocatorRef" + }) +} + +// typefix replaces nil with 0 for all nils whose type, when passed to badType, returns true. +func typefix(f *ast.File, badType func(string) bool) bool { + if !imports(f, "C") { + return false + } + typeof, _ := typecheck(&TypeConfig{}, f) + changed := false + + // step 1: Find all the nils with the offending types. + // Compute their replacement. + badNils := map[interface{}]ast.Expr{} + walk(f, func(n interface{}) { + if i, ok := n.(*ast.Ident); ok && i.Name == "nil" && badType(typeof[n]) { + badNils[n] = &ast.BasicLit{ValuePos: i.NamePos, Kind: token.INT, Value: "0"} + } + }) + + // step 2: find all uses of the bad nils, replace them with 0. + // There's no easy way to map from an ast.Expr to all the places that use them, so + // we use reflect to find all such references. + if len(badNils) > 0 { + exprType := reflect.TypeOf((*ast.Expr)(nil)).Elem() + exprSliceType := reflect.TypeOf(([]ast.Expr)(nil)) + walk(f, func(n interface{}) { + if n == nil { + return + } + v := reflect.ValueOf(n) + if v.Type().Kind() != reflect.Ptr { + return + } + if v.IsNil() { + return + } + v = v.Elem() + if v.Type().Kind() != reflect.Struct { + return + } + for i := 0; i < v.NumField(); i++ { + f := v.Field(i) + if f.Type() == exprType { + if r := badNils[f.Interface()]; r != nil { + f.Set(reflect.ValueOf(r)) + changed = true + } + } + if f.Type() == exprSliceType { + for j := 0; j < f.Len(); j++ { + e := f.Index(j) + if r := badNils[e.Interface()]; r != nil { + e.Set(reflect.ValueOf(r)) + changed = true + } + } + } + } + }) + } + + // step 3: fix up invalid casts. + // It used to be ok to cast between *unsafe.Pointer and *C.CFTypeRef in a single step. + // Now we need unsafe.Pointer as an intermediate cast. + // (*unsafe.Pointer)(x) where x is type *bad -> (*unsafe.Pointer)(unsafe.Pointer(x)) + // (*bad.type)(x) where x is type *unsafe.Pointer -> (*bad.type)(unsafe.Pointer(x)) + walk(f, func(n interface{}) { + if n == nil { + return + } + // Find pattern like (*a.b)(x) + c, ok := n.(*ast.CallExpr) + if !ok { + return + } + if len(c.Args) != 1 { + return + } + p, ok := c.Fun.(*ast.ParenExpr) + if !ok { + return + } + s, ok := p.X.(*ast.StarExpr) + if !ok { + return + } + t, ok := s.X.(*ast.SelectorExpr) + if !ok { + return + } + pkg, ok := t.X.(*ast.Ident) + if !ok { + return + } + dst := pkg.Name + "." + t.Sel.Name + src := typeof[c.Args[0]] + if badType(dst) && src == "*unsafe.Pointer" || + dst == "unsafe.Pointer" && strings.HasPrefix(src, "*") && badType(src[1:]) { + c.Args[0] = &ast.CallExpr{ + Fun: &ast.SelectorExpr{X: &ast.Ident{Name: "unsafe"}, Sel: &ast.Ident{Name: "Pointer"}}, + Args: []ast.Expr{c.Args[0]}, + } + changed = true + } + }) + + return changed +} diff --git a/src/cmd/fix/cftype_test.go b/src/cmd/fix/cftype_test.go new file mode 100644 index 0000000..a18eb25 --- /dev/null +++ b/src/cmd/fix/cftype_test.go @@ -0,0 +1,219 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +func init() { + addTestCases(cftypeTests, cftypefix) +} + +var cftypeTests = []testCase{ + { + Name: "cftype.localVariable", + In: `package main + +import "C" + +func f() { + var x C.CFTypeRef = nil + x = nil + x, x = nil, nil +} +`, + Out: `package main + +import "C" + +func f() { + var x C.CFTypeRef = 0 + x = 0 + x, x = 0, 0 +} +`, + }, + { + Name: "cftype.globalVariable", + In: `package main + +import "C" + +var x C.CFTypeRef = nil + +func f() { + x = nil +} +`, + Out: `package main + +import "C" + +var x C.CFTypeRef = 0 + +func f() { + x = 0 +} +`, + }, + { + Name: "cftype.EqualArgument", + In: `package main + +import "C" + +var x C.CFTypeRef +var y = x == nil +var z = x != nil +`, + Out: `package main + +import "C" + +var x C.CFTypeRef +var y = x == 0 +var z = x != 0 +`, + }, + { + Name: "cftype.StructField", + In: `package main + +import "C" + +type T struct { + x C.CFTypeRef +} + +var t = T{x: nil} +`, + Out: `package main + +import "C" + +type T struct { + x C.CFTypeRef +} + +var t = T{x: 0} +`, + }, + { + Name: "cftype.FunctionArgument", + In: `package main + +import "C" + +func f(x C.CFTypeRef) { +} + +func g() { + f(nil) +} +`, + Out: `package main + +import "C" + +func f(x C.CFTypeRef) { +} + +func g() { + f(0) +} +`, + }, + { + Name: "cftype.ArrayElement", + In: `package main + +import "C" + +var x = [3]C.CFTypeRef{nil, nil, nil} +`, + Out: `package main + +import "C" + +var x = [3]C.CFTypeRef{0, 0, 0} +`, + }, + { + Name: "cftype.SliceElement", + In: `package main + +import "C" + +var x = []C.CFTypeRef{nil, nil, nil} +`, + Out: `package main + +import "C" + +var x = []C.CFTypeRef{0, 0, 0} +`, + }, + { + Name: "cftype.MapKey", + In: `package main + +import "C" + +var x = map[C.CFTypeRef]int{nil: 0} +`, + Out: `package main + +import "C" + +var x = map[C.CFTypeRef]int{0: 0} +`, + }, + { + Name: "cftype.MapValue", + In: `package main + +import "C" + +var x = map[int]C.CFTypeRef{0: nil} +`, + Out: `package main + +import "C" + +var x = map[int]C.CFTypeRef{0: 0} +`, + }, + { + Name: "cftype.Conversion1", + In: `package main + +import "C" + +var x C.CFTypeRef +var y = (*unsafe.Pointer)(&x) +`, + Out: `package main + +import "C" + +var x C.CFTypeRef +var y = (*unsafe.Pointer)(unsafe.Pointer(&x)) +`, + }, + { + Name: "cftype.Conversion2", + In: `package main + +import "C" + +var x unsafe.Pointer +var y = (*C.CFTypeRef)(&x) +`, + Out: `package main + +import "C" + +var x unsafe.Pointer +var y = (*C.CFTypeRef)(unsafe.Pointer(&x)) +`, + }, +} diff --git a/src/cmd/fix/context.go b/src/cmd/fix/context.go new file mode 100644 index 0000000..1107f4d --- /dev/null +++ b/src/cmd/fix/context.go @@ -0,0 +1,25 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "go/ast" +) + +func init() { + register(contextFix) +} + +var contextFix = fix{ + name: "context", + date: "2016-09-09", + f: ctxfix, + desc: `Change imports of golang.org/x/net/context to context`, + disabled: false, +} + +func ctxfix(f *ast.File) bool { + return rewriteImport(f, "golang.org/x/net/context", "context") +} diff --git a/src/cmd/fix/context_test.go b/src/cmd/fix/context_test.go new file mode 100644 index 0000000..935d0d7 --- /dev/null +++ b/src/cmd/fix/context_test.go @@ -0,0 +1,42 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +func init() { + addTestCases(contextTests, ctxfix) +} + +var contextTests = []testCase{ + { + Name: "context.0", + In: `package main + +import "golang.org/x/net/context" + +var _ = "golang.org/x/net/context" +`, + Out: `package main + +import "context" + +var _ = "golang.org/x/net/context" +`, + }, + { + Name: "context.1", + In: `package main + +import ctx "golang.org/x/net/context" + +var _ = ctx.Background() +`, + Out: `package main + +import ctx "context" + +var _ = ctx.Background() +`, + }, +} diff --git a/src/cmd/fix/doc.go b/src/cmd/fix/doc.go new file mode 100644 index 0000000..0570169 --- /dev/null +++ b/src/cmd/fix/doc.go @@ -0,0 +1,36 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Fix finds Go programs that use old APIs and rewrites them to use +newer ones. After you update to a new Go release, fix helps make +the necessary changes to your programs. + +Usage: + go tool fix [-r name,...] [path ...] + +Without an explicit path, fix reads standard input and writes the +result to standard output. + +If the named path is a file, fix rewrites the named files in place. +If the named path is a directory, fix rewrites all .go files in that +directory tree. When fix rewrites a file, it prints a line to standard +error giving the name of the file and the rewrite applied. + +If the -diff flag is set, no files are rewritten. Instead fix prints +the differences a rewrite would introduce. + +The -r flag restricts the set of rewrites considered to those in the +named list. By default fix considers all known rewrites. Fix's +rewrites are idempotent, so that it is safe to apply fix to updated +or partially updated code even without using the -r flag. + +Fix prints the full list of fixes it can apply in its help output; +to see them, run go tool fix -help. + +Fix does not make backup copies of the files that it edits. +Instead, use a version control system's ``diff'' functionality to inspect +the changes that fix makes before committing them. +*/ +package main diff --git a/src/cmd/fix/egltype.go b/src/cmd/fix/egltype.go new file mode 100644 index 0000000..cb0f7a7 --- /dev/null +++ b/src/cmd/fix/egltype.go @@ -0,0 +1,52 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "go/ast" +) + +func init() { + register(eglFixDisplay) + register(eglFixConfig) +} + +var eglFixDisplay = fix{ + name: "egl", + date: "2018-12-15", + f: eglfixDisp, + desc: `Fixes initializers of EGLDisplay`, + disabled: false, +} + +// Old state: +// type EGLDisplay unsafe.Pointer +// New state: +// type EGLDisplay uintptr +// This fix finds nils initializing these types and replaces the nils with 0s. +func eglfixDisp(f *ast.File) bool { + return typefix(f, func(s string) bool { + return s == "C.EGLDisplay" + }) +} + +var eglFixConfig = fix{ + name: "eglconf", + date: "2020-05-30", + f: eglfixConfig, + desc: `Fixes initializers of EGLConfig`, + disabled: false, +} + +// Old state: +// type EGLConfig unsafe.Pointer +// New state: +// type EGLConfig uintptr +// This fix finds nils initializing these types and replaces the nils with 0s. +func eglfixConfig(f *ast.File) bool { + return typefix(f, func(s string) bool { + return s == "C.EGLConfig" + }) +} diff --git a/src/cmd/fix/egltype_test.go b/src/cmd/fix/egltype_test.go new file mode 100644 index 0000000..9b64a7c --- /dev/null +++ b/src/cmd/fix/egltype_test.go @@ -0,0 +1,196 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import "strings" + +func init() { + addTestCases(eglTestsFor("EGLDisplay"), eglfixDisp) + addTestCases(eglTestsFor("EGLConfig"), eglfixConfig) +} + +func eglTestsFor(tname string) []testCase { + var eglTests = []testCase{ + { + Name: "egl.localVariable", + In: `package main + +import "C" + +func f() { + var x C.$EGLTYPE = nil + x = nil + x, x = nil, nil +} +`, + Out: `package main + +import "C" + +func f() { + var x C.$EGLTYPE = 0 + x = 0 + x, x = 0, 0 +} +`, + }, + { + Name: "egl.globalVariable", + In: `package main + +import "C" + +var x C.$EGLTYPE = nil + +func f() { + x = nil +} +`, + Out: `package main + +import "C" + +var x C.$EGLTYPE = 0 + +func f() { + x = 0 +} +`, + }, + { + Name: "egl.EqualArgument", + In: `package main + +import "C" + +var x C.$EGLTYPE +var y = x == nil +var z = x != nil +`, + Out: `package main + +import "C" + +var x C.$EGLTYPE +var y = x == 0 +var z = x != 0 +`, + }, + { + Name: "egl.StructField", + In: `package main + +import "C" + +type T struct { + x C.$EGLTYPE +} + +var t = T{x: nil} +`, + Out: `package main + +import "C" + +type T struct { + x C.$EGLTYPE +} + +var t = T{x: 0} +`, + }, + { + Name: "egl.FunctionArgument", + In: `package main + +import "C" + +func f(x C.$EGLTYPE) { +} + +func g() { + f(nil) +} +`, + Out: `package main + +import "C" + +func f(x C.$EGLTYPE) { +} + +func g() { + f(0) +} +`, + }, + { + Name: "egl.ArrayElement", + In: `package main + +import "C" + +var x = [3]C.$EGLTYPE{nil, nil, nil} +`, + Out: `package main + +import "C" + +var x = [3]C.$EGLTYPE{0, 0, 0} +`, + }, + { + Name: "egl.SliceElement", + In: `package main + +import "C" + +var x = []C.$EGLTYPE{nil, nil, nil} +`, + Out: `package main + +import "C" + +var x = []C.$EGLTYPE{0, 0, 0} +`, + }, + { + Name: "egl.MapKey", + In: `package main + +import "C" + +var x = map[C.$EGLTYPE]int{nil: 0} +`, + Out: `package main + +import "C" + +var x = map[C.$EGLTYPE]int{0: 0} +`, + }, + { + Name: "egl.MapValue", + In: `package main + +import "C" + +var x = map[int]C.$EGLTYPE{0: nil} +`, + Out: `package main + +import "C" + +var x = map[int]C.$EGLTYPE{0: 0} +`, + }, + } + for i := range eglTests { + t := &eglTests[i] + t.In = strings.ReplaceAll(t.In, "$EGLTYPE", tname) + t.Out = strings.ReplaceAll(t.Out, "$EGLTYPE", tname) + } + return eglTests +} diff --git a/src/cmd/fix/fix.go b/src/cmd/fix/fix.go new file mode 100644 index 0000000..b49db37 --- /dev/null +++ b/src/cmd/fix/fix.go @@ -0,0 +1,557 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "go/ast" + "go/token" + "path" + "strconv" +) + +type fix struct { + name string + date string // date that fix was introduced, in YYYY-MM-DD format + f func(*ast.File) bool + desc string + disabled bool // whether this fix should be disabled by default +} + +// main runs sort.Sort(byName(fixes)) before printing list of fixes. +type byName []fix + +func (f byName) Len() int { return len(f) } +func (f byName) Swap(i, j int) { f[i], f[j] = f[j], f[i] } +func (f byName) Less(i, j int) bool { return f[i].name < f[j].name } + +// main runs sort.Sort(byDate(fixes)) before applying fixes. +type byDate []fix + +func (f byDate) Len() int { return len(f) } +func (f byDate) Swap(i, j int) { f[i], f[j] = f[j], f[i] } +func (f byDate) Less(i, j int) bool { return f[i].date < f[j].date } + +var fixes []fix + +func register(f fix) { + fixes = append(fixes, f) +} + +// walk traverses the AST x, calling visit(y) for each node y in the tree but +// also with a pointer to each ast.Expr, ast.Stmt, and *ast.BlockStmt, +// in a bottom-up traversal. +func walk(x interface{}, visit func(interface{})) { + walkBeforeAfter(x, nop, visit) +} + +func nop(interface{}) {} + +// walkBeforeAfter is like walk but calls before(x) before traversing +// x's children and after(x) afterward. +func walkBeforeAfter(x interface{}, before, after func(interface{})) { + before(x) + + switch n := x.(type) { + default: + panic(fmt.Errorf("unexpected type %T in walkBeforeAfter", x)) + + case nil: + + // pointers to interfaces + case *ast.Decl: + walkBeforeAfter(*n, before, after) + case *ast.Expr: + walkBeforeAfter(*n, before, after) + case *ast.Spec: + walkBeforeAfter(*n, before, after) + case *ast.Stmt: + walkBeforeAfter(*n, before, after) + + // pointers to struct pointers + case **ast.BlockStmt: + walkBeforeAfter(*n, before, after) + case **ast.CallExpr: + walkBeforeAfter(*n, before, after) + case **ast.FieldList: + walkBeforeAfter(*n, before, after) + case **ast.FuncType: + walkBeforeAfter(*n, before, after) + case **ast.Ident: + walkBeforeAfter(*n, before, after) + case **ast.BasicLit: + walkBeforeAfter(*n, before, after) + + // pointers to slices + case *[]ast.Decl: + walkBeforeAfter(*n, before, after) + case *[]ast.Expr: + walkBeforeAfter(*n, before, after) + case *[]*ast.File: + walkBeforeAfter(*n, before, after) + case *[]*ast.Ident: + walkBeforeAfter(*n, before, after) + case *[]ast.Spec: + walkBeforeAfter(*n, before, after) + case *[]ast.Stmt: + walkBeforeAfter(*n, before, after) + + // These are ordered and grouped to match ../../go/ast/ast.go + case *ast.Field: + walkBeforeAfter(&n.Names, before, after) + walkBeforeAfter(&n.Type, before, after) + walkBeforeAfter(&n.Tag, before, after) + case *ast.FieldList: + for _, field := range n.List { + walkBeforeAfter(field, before, after) + } + case *ast.BadExpr: + case *ast.Ident: + case *ast.Ellipsis: + walkBeforeAfter(&n.Elt, before, after) + case *ast.BasicLit: + case *ast.FuncLit: + walkBeforeAfter(&n.Type, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.CompositeLit: + walkBeforeAfter(&n.Type, before, after) + walkBeforeAfter(&n.Elts, before, after) + case *ast.ParenExpr: + walkBeforeAfter(&n.X, before, after) + case *ast.SelectorExpr: + walkBeforeAfter(&n.X, before, after) + case *ast.IndexExpr: + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Index, before, after) + case *ast.SliceExpr: + walkBeforeAfter(&n.X, before, after) + if n.Low != nil { + walkBeforeAfter(&n.Low, before, after) + } + if n.High != nil { + walkBeforeAfter(&n.High, before, after) + } + case *ast.TypeAssertExpr: + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Type, before, after) + case *ast.CallExpr: + walkBeforeAfter(&n.Fun, before, after) + walkBeforeAfter(&n.Args, before, after) + case *ast.StarExpr: + walkBeforeAfter(&n.X, before, after) + case *ast.UnaryExpr: + walkBeforeAfter(&n.X, before, after) + case *ast.BinaryExpr: + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Y, before, after) + case *ast.KeyValueExpr: + walkBeforeAfter(&n.Key, before, after) + walkBeforeAfter(&n.Value, before, after) + + case *ast.ArrayType: + walkBeforeAfter(&n.Len, before, after) + walkBeforeAfter(&n.Elt, before, after) + case *ast.StructType: + walkBeforeAfter(&n.Fields, before, after) + case *ast.FuncType: + walkBeforeAfter(&n.Params, before, after) + if n.Results != nil { + walkBeforeAfter(&n.Results, before, after) + } + case *ast.InterfaceType: + walkBeforeAfter(&n.Methods, before, after) + case *ast.MapType: + walkBeforeAfter(&n.Key, before, after) + walkBeforeAfter(&n.Value, before, after) + case *ast.ChanType: + walkBeforeAfter(&n.Value, before, after) + + case *ast.BadStmt: + case *ast.DeclStmt: + walkBeforeAfter(&n.Decl, before, after) + case *ast.EmptyStmt: + case *ast.LabeledStmt: + walkBeforeAfter(&n.Stmt, before, after) + case *ast.ExprStmt: + walkBeforeAfter(&n.X, before, after) + case *ast.SendStmt: + walkBeforeAfter(&n.Chan, before, after) + walkBeforeAfter(&n.Value, before, after) + case *ast.IncDecStmt: + walkBeforeAfter(&n.X, before, after) + case *ast.AssignStmt: + walkBeforeAfter(&n.Lhs, before, after) + walkBeforeAfter(&n.Rhs, before, after) + case *ast.GoStmt: + walkBeforeAfter(&n.Call, before, after) + case *ast.DeferStmt: + walkBeforeAfter(&n.Call, before, after) + case *ast.ReturnStmt: + walkBeforeAfter(&n.Results, before, after) + case *ast.BranchStmt: + case *ast.BlockStmt: + walkBeforeAfter(&n.List, before, after) + case *ast.IfStmt: + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Cond, before, after) + walkBeforeAfter(&n.Body, before, after) + walkBeforeAfter(&n.Else, before, after) + case *ast.CaseClause: + walkBeforeAfter(&n.List, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.SwitchStmt: + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Tag, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.TypeSwitchStmt: + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Assign, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.CommClause: + walkBeforeAfter(&n.Comm, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.SelectStmt: + walkBeforeAfter(&n.Body, before, after) + case *ast.ForStmt: + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Cond, before, after) + walkBeforeAfter(&n.Post, before, after) + walkBeforeAfter(&n.Body, before, after) + case *ast.RangeStmt: + walkBeforeAfter(&n.Key, before, after) + walkBeforeAfter(&n.Value, before, after) + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Body, before, after) + + case *ast.ImportSpec: + case *ast.ValueSpec: + walkBeforeAfter(&n.Type, before, after) + walkBeforeAfter(&n.Values, before, after) + walkBeforeAfter(&n.Names, before, after) + case *ast.TypeSpec: + walkBeforeAfter(&n.Type, before, after) + + case *ast.BadDecl: + case *ast.GenDecl: + walkBeforeAfter(&n.Specs, before, after) + case *ast.FuncDecl: + if n.Recv != nil { + walkBeforeAfter(&n.Recv, before, after) + } + walkBeforeAfter(&n.Type, before, after) + if n.Body != nil { + walkBeforeAfter(&n.Body, before, after) + } + + case *ast.File: + walkBeforeAfter(&n.Decls, before, after) + + case *ast.Package: + walkBeforeAfter(&n.Files, before, after) + + case []*ast.File: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + case []ast.Decl: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + case []ast.Expr: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + case []*ast.Ident: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + case []ast.Stmt: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + case []ast.Spec: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } + } + after(x) +} + +// imports reports whether f imports path. +func imports(f *ast.File, path string) bool { + return importSpec(f, path) != nil +} + +// importSpec returns the import spec if f imports path, +// or nil otherwise. +func importSpec(f *ast.File, path string) *ast.ImportSpec { + for _, s := range f.Imports { + if importPath(s) == path { + return s + } + } + return nil +} + +// importPath returns the unquoted import path of s, +// or "" if the path is not properly quoted. +func importPath(s *ast.ImportSpec) string { + t, err := strconv.Unquote(s.Path.Value) + if err == nil { + return t + } + return "" +} + +// declImports reports whether gen contains an import of path. +func declImports(gen *ast.GenDecl, path string) bool { + if gen.Tok != token.IMPORT { + return false + } + for _, spec := range gen.Specs { + impspec := spec.(*ast.ImportSpec) + if importPath(impspec) == path { + return true + } + } + return false +} + +// isTopName reports whether n is a top-level unresolved identifier with the given name. +func isTopName(n ast.Expr, name string) bool { + id, ok := n.(*ast.Ident) + return ok && id.Name == name && id.Obj == nil +} + +// renameTop renames all references to the top-level name old. +// It reports whether it makes any changes. +func renameTop(f *ast.File, old, new string) bool { + var fixed bool + + // Rename any conflicting imports + // (assuming package name is last element of path). + for _, s := range f.Imports { + if s.Name != nil { + if s.Name.Name == old { + s.Name.Name = new + fixed = true + } + } else { + _, thisName := path.Split(importPath(s)) + if thisName == old { + s.Name = ast.NewIdent(new) + fixed = true + } + } + } + + // Rename any top-level declarations. + for _, d := range f.Decls { + switch d := d.(type) { + case *ast.FuncDecl: + if d.Recv == nil && d.Name.Name == old { + d.Name.Name = new + d.Name.Obj.Name = new + fixed = true + } + case *ast.GenDecl: + for _, s := range d.Specs { + switch s := s.(type) { + case *ast.TypeSpec: + if s.Name.Name == old { + s.Name.Name = new + s.Name.Obj.Name = new + fixed = true + } + case *ast.ValueSpec: + for _, n := range s.Names { + if n.Name == old { + n.Name = new + n.Obj.Name = new + fixed = true + } + } + } + } + } + } + + // Rename top-level old to new, both unresolved names + // (probably defined in another file) and names that resolve + // to a declaration we renamed. + walk(f, func(n interface{}) { + id, ok := n.(*ast.Ident) + if ok && isTopName(id, old) { + id.Name = new + fixed = true + } + if ok && id.Obj != nil && id.Name == old && id.Obj.Name == new { + id.Name = id.Obj.Name + fixed = true + } + }) + + return fixed +} + +// matchLen returns the length of the longest prefix shared by x and y. +func matchLen(x, y string) int { + i := 0 + for i < len(x) && i < len(y) && x[i] == y[i] { + i++ + } + return i +} + +// addImport adds the import path to the file f, if absent. +func addImport(f *ast.File, ipath string) (added bool) { + if imports(f, ipath) { + return false + } + + // Determine name of import. + // Assume added imports follow convention of using last element. + _, name := path.Split(ipath) + + // Rename any conflicting top-level references from name to name_. + renameTop(f, name, name+"_") + + newImport := &ast.ImportSpec{ + Path: &ast.BasicLit{ + Kind: token.STRING, + Value: strconv.Quote(ipath), + }, + } + + // Find an import decl to add to. + var ( + bestMatch = -1 + lastImport = -1 + impDecl *ast.GenDecl + impIndex = -1 + ) + for i, decl := range f.Decls { + gen, ok := decl.(*ast.GenDecl) + if ok && gen.Tok == token.IMPORT { + lastImport = i + // Do not add to import "C", to avoid disrupting the + // association with its doc comment, breaking cgo. + if declImports(gen, "C") { + continue + } + + // Compute longest shared prefix with imports in this block. + for j, spec := range gen.Specs { + impspec := spec.(*ast.ImportSpec) + n := matchLen(importPath(impspec), ipath) + if n > bestMatch { + bestMatch = n + impDecl = gen + impIndex = j + } + } + } + } + + // If no import decl found, add one after the last import. + if impDecl == nil { + impDecl = &ast.GenDecl{ + Tok: token.IMPORT, + } + f.Decls = append(f.Decls, nil) + copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:]) + f.Decls[lastImport+1] = impDecl + } + + // Ensure the import decl has parentheses, if needed. + if len(impDecl.Specs) > 0 && !impDecl.Lparen.IsValid() { + impDecl.Lparen = impDecl.Pos() + } + + insertAt := impIndex + 1 + if insertAt == 0 { + insertAt = len(impDecl.Specs) + } + impDecl.Specs = append(impDecl.Specs, nil) + copy(impDecl.Specs[insertAt+1:], impDecl.Specs[insertAt:]) + impDecl.Specs[insertAt] = newImport + if insertAt > 0 { + // Assign same position as the previous import, + // so that the sorter sees it as being in the same block. + prev := impDecl.Specs[insertAt-1] + newImport.Path.ValuePos = prev.Pos() + newImport.EndPos = prev.Pos() + } + + f.Imports = append(f.Imports, newImport) + return true +} + +// deleteImport deletes the import path from the file f, if present. +func deleteImport(f *ast.File, path string) (deleted bool) { + oldImport := importSpec(f, path) + + // Find the import node that imports path, if any. + for i, decl := range f.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || gen.Tok != token.IMPORT { + continue + } + for j, spec := range gen.Specs { + impspec := spec.(*ast.ImportSpec) + if oldImport != impspec { + continue + } + + // We found an import spec that imports path. + // Delete it. + deleted = true + copy(gen.Specs[j:], gen.Specs[j+1:]) + gen.Specs = gen.Specs[:len(gen.Specs)-1] + + // If this was the last import spec in this decl, + // delete the decl, too. + if len(gen.Specs) == 0 { + copy(f.Decls[i:], f.Decls[i+1:]) + f.Decls = f.Decls[:len(f.Decls)-1] + } else if len(gen.Specs) == 1 { + gen.Lparen = token.NoPos // drop parens + } + if j > 0 { + // We deleted an entry but now there will be + // a blank line-sized hole where the import was. + // Close the hole by making the previous + // import appear to "end" where this one did. + gen.Specs[j-1].(*ast.ImportSpec).EndPos = impspec.End() + } + break + } + } + + // Delete it from f.Imports. + for i, imp := range f.Imports { + if imp == oldImport { + copy(f.Imports[i:], f.Imports[i+1:]) + f.Imports = f.Imports[:len(f.Imports)-1] + break + } + } + + return +} + +// rewriteImport rewrites any import of path oldPath to path newPath. +func rewriteImport(f *ast.File, oldPath, newPath string) (rewrote bool) { + for _, imp := range f.Imports { + if importPath(imp) == oldPath { + rewrote = true + // record old End, because the default is to compute + // it using the length of imp.Path.Value. + imp.EndPos = imp.End() + imp.Path.Value = strconv.Quote(newPath) + } + } + return +} diff --git a/src/cmd/fix/gotypes.go b/src/cmd/fix/gotypes.go new file mode 100644 index 0000000..031f85c --- /dev/null +++ b/src/cmd/fix/gotypes.go @@ -0,0 +1,75 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "go/ast" + "strconv" +) + +func init() { + register(gotypesFix) +} + +var gotypesFix = fix{ + name: "gotypes", + date: "2015-07-16", + f: gotypes, + desc: `Change imports of golang.org/x/tools/go/{exact,types} to go/{constant,types}`, +} + +func gotypes(f *ast.File) bool { + fixed := fixGoTypes(f) + if fixGoExact(f) { + fixed = true + } + return fixed +} + +func fixGoTypes(f *ast.File) bool { + return rewriteImport(f, "golang.org/x/tools/go/types", "go/types") +} + +func fixGoExact(f *ast.File) bool { + // This one is harder because the import name changes. + // First find the import spec. + var importSpec *ast.ImportSpec + walk(f, func(n interface{}) { + if importSpec != nil { + return + } + spec, ok := n.(*ast.ImportSpec) + if !ok { + return + } + path, err := strconv.Unquote(spec.Path.Value) + if err != nil { + return + } + if path == "golang.org/x/tools/go/exact" { + importSpec = spec + } + + }) + if importSpec == nil { + return false + } + + // We are about to rename exact.* to constant.*, but constant is a common + // name. See if it will conflict. This is a hack but it is effective. + exists := renameTop(f, "constant", "constant") + suffix := "" + if exists { + suffix = "_" + } + // Now we need to rename all the uses of the import. RewriteImport + // affects renameTop, but not vice versa, so do them in this order. + renameTop(f, "exact", "constant"+suffix) + rewriteImport(f, "golang.org/x/tools/go/exact", "go/constant") + // renameTop will also rewrite the imported package name. Fix that; + // we know it should be missing. + importSpec.Name = nil + return true +} diff --git a/src/cmd/fix/gotypes_test.go b/src/cmd/fix/gotypes_test.go new file mode 100644 index 0000000..9248fff --- /dev/null +++ b/src/cmd/fix/gotypes_test.go @@ -0,0 +1,89 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +func init() { + addTestCases(gotypesTests, gotypes) +} + +var gotypesTests = []testCase{ + { + Name: "gotypes.0", + In: `package main + +import "golang.org/x/tools/go/types" +import "golang.org/x/tools/go/exact" + +var _ = exact.Kind + +func f() { + _ = exact.MakeBool(true) +} +`, + Out: `package main + +import "go/types" +import "go/constant" + +var _ = constant.Kind + +func f() { + _ = constant.MakeBool(true) +} +`, + }, + { + Name: "gotypes.1", + In: `package main + +import "golang.org/x/tools/go/types" +import foo "golang.org/x/tools/go/exact" + +var _ = foo.Kind + +func f() { + _ = foo.MakeBool(true) +} +`, + Out: `package main + +import "go/types" +import "go/constant" + +var _ = foo.Kind + +func f() { + _ = foo.MakeBool(true) +} +`, + }, + { + Name: "gotypes.0", + In: `package main + +import "golang.org/x/tools/go/types" +import "golang.org/x/tools/go/exact" + +var _ = exact.Kind +var constant = 23 // Use of new package name. + +func f() { + _ = exact.MakeBool(true) +} +`, + Out: `package main + +import "go/types" +import "go/constant" + +var _ = constant_.Kind +var constant = 23 // Use of new package name. + +func f() { + _ = constant_.MakeBool(true) +} +`, + }, +} diff --git a/src/cmd/fix/import_test.go b/src/cmd/fix/import_test.go new file mode 100644 index 0000000..8644e28 --- /dev/null +++ b/src/cmd/fix/import_test.go @@ -0,0 +1,458 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import "go/ast" + +func init() { + addTestCases(importTests, nil) +} + +var importTests = []testCase{ + { + Name: "import.0", + Fn: addImportFn("os"), + In: `package main + +import ( + "os" +) +`, + Out: `package main + +import ( + "os" +) +`, + }, + { + Name: "import.1", + Fn: addImportFn("os"), + In: `package main +`, + Out: `package main + +import "os" +`, + }, + { + Name: "import.2", + Fn: addImportFn("os"), + In: `package main + +// Comment +import "C" +`, + Out: `package main + +// Comment +import "C" +import "os" +`, + }, + { + Name: "import.3", + Fn: addImportFn("os"), + In: `package main + +// Comment +import "C" + +import ( + "io" + "utf8" +) +`, + Out: `package main + +// Comment +import "C" + +import ( + "io" + "os" + "utf8" +) +`, + }, + { + Name: "import.4", + Fn: deleteImportFn("os"), + In: `package main + +import ( + "os" +) +`, + Out: `package main +`, + }, + { + Name: "import.5", + Fn: deleteImportFn("os"), + In: `package main + +// Comment +import "C" +import "os" +`, + Out: `package main + +// Comment +import "C" +`, + }, + { + Name: "import.6", + Fn: deleteImportFn("os"), + In: `package main + +// Comment +import "C" + +import ( + "io" + "os" + "utf8" +) +`, + Out: `package main + +// Comment +import "C" + +import ( + "io" + "utf8" +) +`, + }, + { + Name: "import.7", + Fn: deleteImportFn("io"), + In: `package main + +import ( + "io" // a + "os" // b + "utf8" // c +) +`, + Out: `package main + +import ( + // a + "os" // b + "utf8" // c +) +`, + }, + { + Name: "import.8", + Fn: deleteImportFn("os"), + In: `package main + +import ( + "io" // a + "os" // b + "utf8" // c +) +`, + Out: `package main + +import ( + "io" // a + // b + "utf8" // c +) +`, + }, + { + Name: "import.9", + Fn: deleteImportFn("utf8"), + In: `package main + +import ( + "io" // a + "os" // b + "utf8" // c +) +`, + Out: `package main + +import ( + "io" // a + "os" // b + // c +) +`, + }, + { + Name: "import.10", + Fn: deleteImportFn("io"), + In: `package main + +import ( + "io" + "os" + "utf8" +) +`, + Out: `package main + +import ( + "os" + "utf8" +) +`, + }, + { + Name: "import.11", + Fn: deleteImportFn("os"), + In: `package main + +import ( + "io" + "os" + "utf8" +) +`, + Out: `package main + +import ( + "io" + "utf8" +) +`, + }, + { + Name: "import.12", + Fn: deleteImportFn("utf8"), + In: `package main + +import ( + "io" + "os" + "utf8" +) +`, + Out: `package main + +import ( + "io" + "os" +) +`, + }, + { + Name: "import.13", + Fn: rewriteImportFn("utf8", "encoding/utf8"), + In: `package main + +import ( + "io" + "os" + "utf8" // thanks ken +) +`, + Out: `package main + +import ( + "encoding/utf8" // thanks ken + "io" + "os" +) +`, + }, + { + Name: "import.14", + Fn: rewriteImportFn("asn1", "encoding/asn1"), + In: `package main + +import ( + "asn1" + "crypto" + "crypto/rsa" + _ "crypto/sha1" + "crypto/x509" + "crypto/x509/pkix" + "time" +) + +var x = 1 +`, + Out: `package main + +import ( + "crypto" + "crypto/rsa" + _ "crypto/sha1" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "time" +) + +var x = 1 +`, + }, + { + Name: "import.15", + Fn: rewriteImportFn("url", "net/url"), + In: `package main + +import ( + "bufio" + "net" + "path" + "url" +) + +var x = 1 // comment on x, not on url +`, + Out: `package main + +import ( + "bufio" + "net" + "net/url" + "path" +) + +var x = 1 // comment on x, not on url +`, + }, + { + Name: "import.16", + Fn: rewriteImportFn("http", "net/http", "template", "text/template"), + In: `package main + +import ( + "flag" + "http" + "log" + "template" +) + +var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18 +`, + Out: `package main + +import ( + "flag" + "log" + "net/http" + "text/template" +) + +var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18 +`, + }, + { + Name: "import.17", + Fn: addImportFn("x/y/z", "x/a/c"), + In: `package main + +// Comment +import "C" + +import ( + "a" + "b" + + "x/w" + + "d/f" +) +`, + Out: `package main + +// Comment +import "C" + +import ( + "a" + "b" + + "x/a/c" + "x/w" + "x/y/z" + + "d/f" +) +`, + }, + { + Name: "import.18", + Fn: addDelImportFn("e", "o"), + In: `package main + +import ( + "f" + "o" + "z" +) +`, + Out: `package main + +import ( + "e" + "f" + "z" +) +`, + }, +} + +func addImportFn(path ...string) func(*ast.File) bool { + return func(f *ast.File) bool { + fixed := false + for _, p := range path { + if !imports(f, p) { + addImport(f, p) + fixed = true + } + } + return fixed + } +} + +func deleteImportFn(path string) func(*ast.File) bool { + return func(f *ast.File) bool { + if imports(f, path) { + deleteImport(f, path) + return true + } + return false + } +} + +func addDelImportFn(p1 string, p2 string) func(*ast.File) bool { + return func(f *ast.File) bool { + fixed := false + if !imports(f, p1) { + addImport(f, p1) + fixed = true + } + if imports(f, p2) { + deleteImport(f, p2) + fixed = true + } + return fixed + } +} + +func rewriteImportFn(oldnew ...string) func(*ast.File) bool { + return func(f *ast.File) bool { + fixed := false + for i := 0; i < len(oldnew); i += 2 { + if imports(f, oldnew[i]) { + rewriteImport(f, oldnew[i], oldnew[i+1]) + fixed = true + } + } + return fixed + } +} diff --git a/src/cmd/fix/jnitype.go b/src/cmd/fix/jnitype.go new file mode 100644 index 0000000..29abe0f --- /dev/null +++ b/src/cmd/fix/jnitype.go @@ -0,0 +1,65 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "go/ast" +) + +func init() { + register(jniFix) +} + +var jniFix = fix{ + name: "jni", + date: "2017-12-04", + f: jnifix, + desc: `Fixes initializers of JNI's jobject and subtypes`, + disabled: false, +} + +// Old state: +// type jobject *_jobject +// New state: +// type jobject uintptr +// and similar for subtypes of jobject. +// This fix finds nils initializing these types and replaces the nils with 0s. +func jnifix(f *ast.File) bool { + return typefix(f, func(s string) bool { + switch s { + case "C.jobject": + return true + case "C.jclass": + return true + case "C.jthrowable": + return true + case "C.jstring": + return true + case "C.jarray": + return true + case "C.jbooleanArray": + return true + case "C.jbyteArray": + return true + case "C.jcharArray": + return true + case "C.jshortArray": + return true + case "C.jintArray": + return true + case "C.jlongArray": + return true + case "C.jfloatArray": + return true + case "C.jdoubleArray": + return true + case "C.jobjectArray": + return true + case "C.jweak": + return true + } + return false + }) +} diff --git a/src/cmd/fix/jnitype_test.go b/src/cmd/fix/jnitype_test.go new file mode 100644 index 0000000..a6420f7 --- /dev/null +++ b/src/cmd/fix/jnitype_test.go @@ -0,0 +1,185 @@ +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +func init() { + addTestCases(jniTests, jnifix) +} + +var jniTests = []testCase{ + { + Name: "jni.localVariable", + In: `package main + +import "C" + +func f() { + var x C.jobject = nil + x = nil + x, x = nil, nil +} +`, + Out: `package main + +import "C" + +func f() { + var x C.jobject = 0 + x = 0 + x, x = 0, 0 +} +`, + }, + { + Name: "jni.globalVariable", + In: `package main + +import "C" + +var x C.jobject = nil + +func f() { + x = nil +} +`, + Out: `package main + +import "C" + +var x C.jobject = 0 + +func f() { + x = 0 +} +`, + }, + { + Name: "jni.EqualArgument", + In: `package main + +import "C" + +var x C.jobject +var y = x == nil +var z = x != nil +`, + Out: `package main + +import "C" + +var x C.jobject +var y = x == 0 +var z = x != 0 +`, + }, + { + Name: "jni.StructField", + In: `package main + +import "C" + +type T struct { + x C.jobject +} + +var t = T{x: nil} +`, + Out: `package main + +import "C" + +type T struct { + x C.jobject +} + +var t = T{x: 0} +`, + }, + { + Name: "jni.FunctionArgument", + In: `package main + +import "C" + +func f(x C.jobject) { +} + +func g() { + f(nil) +} +`, + Out: `package main + +import "C" + +func f(x C.jobject) { +} + +func g() { + f(0) +} +`, + }, + { + Name: "jni.ArrayElement", + In: `package main + +import "C" + +var x = [3]C.jobject{nil, nil, nil} +`, + Out: `package main + +import "C" + +var x = [3]C.jobject{0, 0, 0} +`, + }, + { + Name: "jni.SliceElement", + In: `package main + +import "C" + +var x = []C.jobject{nil, nil, nil} +`, + Out: `package main + +import "C" + +var x = []C.jobject{0, 0, 0} +`, + }, + { + Name: "jni.MapKey", + In: `package main + +import "C" + +var x = map[C.jobject]int{nil: 0} +`, + Out: `package main + +import "C" + +var x = map[C.jobject]int{0: 0} +`, + }, + { + Name: "jni.MapValue", + In: `package main + +import "C" + +var x = map[int]C.jobject{0: nil} +`, + Out: `package main + +import "C" + +var x = map[int]C.jobject{0: 0} +`, + }, +} diff --git a/src/cmd/fix/main.go b/src/cmd/fix/main.go new file mode 100644 index 0000000..d055929 --- /dev/null +++ b/src/cmd/fix/main.go @@ -0,0 +1,253 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "flag" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/scanner" + "go/token" + "io" + "io/fs" + "os" + "path/filepath" + "sort" + "strings" + + "cmd/internal/diff" +) + +var ( + fset = token.NewFileSet() + exitCode = 0 +) + +var allowedRewrites = flag.String("r", "", + "restrict the rewrites to this comma-separated list") + +var forceRewrites = flag.String("force", "", + "force these fixes to run even if the code looks updated") + +var allowed, force map[string]bool + +var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files") + +// enable for debugging fix failures +const debug = false // display incorrectly reformatted source and exit + +func usage() { + fmt.Fprintf(os.Stderr, "usage: go tool fix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n") + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n") + sort.Sort(byName(fixes)) + for _, f := range fixes { + if f.disabled { + fmt.Fprintf(os.Stderr, "\n%s (disabled)\n", f.name) + } else { + fmt.Fprintf(os.Stderr, "\n%s\n", f.name) + } + desc := strings.TrimSpace(f.desc) + desc = strings.ReplaceAll(desc, "\n", "\n\t") + fmt.Fprintf(os.Stderr, "\t%s\n", desc) + } + os.Exit(2) +} + +func main() { + flag.Usage = usage + flag.Parse() + + sort.Sort(byDate(fixes)) + + if *allowedRewrites != "" { + allowed = make(map[string]bool) + for _, f := range strings.Split(*allowedRewrites, ",") { + allowed[f] = true + } + } + + if *forceRewrites != "" { + force = make(map[string]bool) + for _, f := range strings.Split(*forceRewrites, ",") { + force[f] = true + } + } + + if flag.NArg() == 0 { + if err := processFile("standard input", true); err != nil { + report(err) + } + os.Exit(exitCode) + } + + for i := 0; i < flag.NArg(); i++ { + path := flag.Arg(i) + switch dir, err := os.Stat(path); { + case err != nil: + report(err) + case dir.IsDir(): + walkDir(path) + default: + if err := processFile(path, false); err != nil { + report(err) + } + } + } + + os.Exit(exitCode) +} + +const parserMode = parser.ParseComments + +func gofmtFile(f *ast.File) ([]byte, error) { + var buf bytes.Buffer + if err := format.Node(&buf, fset, f); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +func processFile(filename string, useStdin bool) error { + var f *os.File + var err error + var fixlog bytes.Buffer + + if useStdin { + f = os.Stdin + } else { + f, err = os.Open(filename) + if err != nil { + return err + } + defer f.Close() + } + + src, err := io.ReadAll(f) + if err != nil { + return err + } + + file, err := parser.ParseFile(fset, filename, src, parserMode) + if err != nil { + return err + } + + // Make sure file is in canonical format. + // This "fmt" pseudo-fix cannot be disabled. + newSrc, err := gofmtFile(file) + if err != nil { + return err + } + if !bytes.Equal(newSrc, src) { + newFile, err := parser.ParseFile(fset, filename, newSrc, parserMode) + if err != nil { + return err + } + file = newFile + fmt.Fprintf(&fixlog, " fmt") + } + + // Apply all fixes to file. + newFile := file + fixed := false + for _, fix := range fixes { + if allowed != nil && !allowed[fix.name] { + continue + } + if fix.disabled && !force[fix.name] { + continue + } + if fix.f(newFile) { + fixed = true + fmt.Fprintf(&fixlog, " %s", fix.name) + + // AST changed. + // Print and parse, to update any missing scoping + // or position information for subsequent fixers. + newSrc, err := gofmtFile(newFile) + if err != nil { + return err + } + newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode) + if err != nil { + if debug { + fmt.Printf("%s", newSrc) + report(err) + os.Exit(exitCode) + } + return err + } + } + } + if !fixed { + return nil + } + fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:]) + + // Print AST. We did that after each fix, so this appears + // redundant, but it is necessary to generate gofmt-compatible + // source code in a few cases. The official gofmt style is the + // output of the printer run on a standard AST generated by the parser, + // but the source we generated inside the loop above is the + // output of the printer run on a mangled AST generated by a fixer. + newSrc, err = gofmtFile(newFile) + if err != nil { + return err + } + + if *doDiff { + data, err := diff.Diff("go-fix", src, newSrc) + if err != nil { + return fmt.Errorf("computing diff: %s", err) + } + fmt.Printf("diff %s fixed/%s\n", filename, filename) + os.Stdout.Write(data) + return nil + } + + if useStdin { + os.Stdout.Write(newSrc) + return nil + } + + return os.WriteFile(f.Name(), newSrc, 0) +} + +func gofmt(n interface{}) string { + var gofmtBuf bytes.Buffer + if err := format.Node(&gofmtBuf, fset, n); err != nil { + return "<" + err.Error() + ">" + } + return gofmtBuf.String() +} + +func report(err error) { + scanner.PrintError(os.Stderr, err) + exitCode = 2 +} + +func walkDir(path string) { + filepath.WalkDir(path, visitFile) +} + +func visitFile(path string, f fs.DirEntry, err error) error { + if err == nil && isGoFile(f) { + err = processFile(path, false) + } + if err != nil { + report(err) + } + return nil +} + +func isGoFile(f fs.DirEntry) bool { + // ignore non-Go files + name := f.Name() + return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go") +} diff --git a/src/cmd/fix/main_test.go b/src/cmd/fix/main_test.go new file mode 100644 index 0000000..af16bca --- /dev/null +++ b/src/cmd/fix/main_test.go @@ -0,0 +1,135 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "go/ast" + "go/parser" + "strings" + "testing" + + "cmd/internal/diff" +) + +type testCase struct { + Name string + Fn func(*ast.File) bool + In string + Out string +} + +var testCases []testCase + +func addTestCases(t []testCase, fn func(*ast.File) bool) { + // Fill in fn to avoid repetition in definitions. + if fn != nil { + for i := range t { + if t[i].Fn == nil { + t[i].Fn = fn + } + } + } + testCases = append(testCases, t...) +} + +func fnop(*ast.File) bool { return false } + +func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string, mustBeGofmt bool) (out string, fixed, ok bool) { + file, err := parser.ParseFile(fset, desc, in, parserMode) + if err != nil { + t.Errorf("parsing: %v", err) + return + } + + outb, err := gofmtFile(file) + if err != nil { + t.Errorf("printing: %v", err) + return + } + if s := string(outb); in != s && mustBeGofmt { + t.Errorf("not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s", + desc, in, desc, s) + tdiff(t, in, s) + return + } + + if fn == nil { + for _, fix := range fixes { + if fix.f(file) { + fixed = true + } + } + } else { + fixed = fn(file) + } + + outb, err = gofmtFile(file) + if err != nil { + t.Errorf("printing: %v", err) + return + } + + return string(outb), fixed, true +} + +func TestRewrite(t *testing.T) { + for _, tt := range testCases { + tt := tt + t.Run(tt.Name, func(t *testing.T) { + t.Parallel() + // Apply fix: should get tt.Out. + out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In, true) + if !ok { + return + } + + // reformat to get printing right + out, _, ok = parseFixPrint(t, fnop, tt.Name, out, false) + if !ok { + return + } + + if out != tt.Out { + t.Errorf("incorrect output.\n") + if !strings.HasPrefix(tt.Name, "testdata/") { + t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out) + } + tdiff(t, out, tt.Out) + return + } + + if changed := out != tt.In; changed != fixed { + t.Errorf("changed=%v != fixed=%v", changed, fixed) + return + } + + // Should not change if run again. + out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out, true) + if !ok { + return + } + + if fixed2 { + t.Errorf("applied fixes during second round") + return + } + + if out2 != out { + t.Errorf("changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s", + out, out2) + tdiff(t, out, out2) + } + }) + } +} + +func tdiff(t *testing.T, a, b string) { + data, err := diff.Diff("go-fix-test", []byte(a), []byte(b)) + if err != nil { + t.Error(err) + return + } + t.Error(string(data)) +} diff --git a/src/cmd/fix/netipv6zone.go b/src/cmd/fix/netipv6zone.go new file mode 100644 index 0000000..3e502bd --- /dev/null +++ b/src/cmd/fix/netipv6zone.go @@ -0,0 +1,68 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import "go/ast" + +func init() { + register(netipv6zoneFix) +} + +var netipv6zoneFix = fix{ + name: "netipv6zone", + date: "2012-11-26", + f: netipv6zone, + desc: `Adapt element key to IPAddr, UDPAddr or TCPAddr composite literals. + +https://codereview.appspot.com/6849045/ +`, +} + +func netipv6zone(f *ast.File) bool { + if !imports(f, "net") { + return false + } + + fixed := false + walk(f, func(n interface{}) { + cl, ok := n.(*ast.CompositeLit) + if !ok { + return + } + se, ok := cl.Type.(*ast.SelectorExpr) + if !ok { + return + } + if !isTopName(se.X, "net") || se.Sel == nil { + return + } + switch ss := se.Sel.String(); ss { + case "IPAddr", "UDPAddr", "TCPAddr": + for i, e := range cl.Elts { + if _, ok := e.(*ast.KeyValueExpr); ok { + break + } + switch i { + case 0: + cl.Elts[i] = &ast.KeyValueExpr{ + Key: ast.NewIdent("IP"), + Value: e, + } + case 1: + if elit, ok := e.(*ast.BasicLit); ok && elit.Value == "0" { + cl.Elts = append(cl.Elts[:i], cl.Elts[i+1:]...) + } else { + cl.Elts[i] = &ast.KeyValueExpr{ + Key: ast.NewIdent("Port"), + Value: e, + } + } + } + fixed = true + } + } + }) + return fixed +} diff --git a/src/cmd/fix/netipv6zone_test.go b/src/cmd/fix/netipv6zone_test.go new file mode 100644 index 0000000..5b8d964 --- /dev/null +++ b/src/cmd/fix/netipv6zone_test.go @@ -0,0 +1,43 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +func init() { + addTestCases(netipv6zoneTests, netipv6zone) +} + +var netipv6zoneTests = []testCase{ + { + Name: "netipv6zone.0", + In: `package main + +import "net" + +func f() net.Addr { + a := &net.IPAddr{ip1} + sub(&net.UDPAddr{ip2, 12345}) + c := &net.TCPAddr{IP: ip3, Port: 54321} + d := &net.TCPAddr{ip4, 0} + p := 1234 + e := &net.TCPAddr{ip4, p} + return &net.TCPAddr{ip5}, nil +} +`, + Out: `package main + +import "net" + +func f() net.Addr { + a := &net.IPAddr{IP: ip1} + sub(&net.UDPAddr{IP: ip2, Port: 12345}) + c := &net.TCPAddr{IP: ip3, Port: 54321} + d := &net.TCPAddr{IP: ip4} + p := 1234 + e := &net.TCPAddr{IP: ip4, Port: p} + return &net.TCPAddr{IP: ip5}, nil +} +`, + }, +} diff --git a/src/cmd/fix/printerconfig.go b/src/cmd/fix/printerconfig.go new file mode 100644 index 0000000..6d93996 --- /dev/null +++ b/src/cmd/fix/printerconfig.go @@ -0,0 +1,61 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import "go/ast" + +func init() { + register(printerconfigFix) +} + +var printerconfigFix = fix{ + name: "printerconfig", + date: "2012-12-11", + f: printerconfig, + desc: `Add element keys to Config composite literals.`, +} + +func printerconfig(f *ast.File) bool { + if !imports(f, "go/printer") { + return false + } + + fixed := false + walk(f, func(n interface{}) { + cl, ok := n.(*ast.CompositeLit) + if !ok { + return + } + se, ok := cl.Type.(*ast.SelectorExpr) + if !ok { + return + } + if !isTopName(se.X, "printer") || se.Sel == nil { + return + } + + if ss := se.Sel.String(); ss == "Config" { + for i, e := range cl.Elts { + if _, ok := e.(*ast.KeyValueExpr); ok { + break + } + switch i { + case 0: + cl.Elts[i] = &ast.KeyValueExpr{ + Key: ast.NewIdent("Mode"), + Value: e, + } + case 1: + cl.Elts[i] = &ast.KeyValueExpr{ + Key: ast.NewIdent("Tabwidth"), + Value: e, + } + } + fixed = true + } + } + }) + return fixed +} diff --git a/src/cmd/fix/printerconfig_test.go b/src/cmd/fix/printerconfig_test.go new file mode 100644 index 0000000..e485c13 --- /dev/null +++ b/src/cmd/fix/printerconfig_test.go @@ -0,0 +1,37 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +func init() { + addTestCases(printerconfigTests, printerconfig) +} + +var printerconfigTests = []testCase{ + { + Name: "printerconfig.0", + In: `package main + +import "go/printer" + +func f() printer.Config { + b := printer.Config{0, 8} + c := &printer.Config{0} + d := &printer.Config{Tabwidth: 8, Mode: 0} + return printer.Config{0, 8} +} +`, + Out: `package main + +import "go/printer" + +func f() printer.Config { + b := printer.Config{Mode: 0, Tabwidth: 8} + c := &printer.Config{Mode: 0} + d := &printer.Config{Tabwidth: 8, Mode: 0} + return printer.Config{Mode: 0, Tabwidth: 8} +} +`, + }, +} diff --git a/src/cmd/fix/typecheck.go b/src/cmd/fix/typecheck.go new file mode 100644 index 0000000..39a5378 --- /dev/null +++ b/src/cmd/fix/typecheck.go @@ -0,0 +1,800 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "go/ast" + "go/parser" + "go/token" + exec "internal/execabs" + "os" + "path/filepath" + "reflect" + "runtime" + "strings" +) + +// Partial type checker. +// +// The fact that it is partial is very important: the input is +// an AST and a description of some type information to +// assume about one or more packages, but not all the +// packages that the program imports. The checker is +// expected to do as much as it can with what it has been +// given. There is not enough information supplied to do +// a full type check, but the type checker is expected to +// apply information that can be derived from variable +// declarations, function and method returns, and type switches +// as far as it can, so that the caller can still tell the types +// of expression relevant to a particular fix. +// +// TODO(rsc,gri): Replace with go/typechecker. +// Doing that could be an interesting test case for go/typechecker: +// the constraints about working with partial information will +// likely exercise it in interesting ways. The ideal interface would +// be to pass typecheck a map from importpath to package API text +// (Go source code), but for now we use data structures (TypeConfig, Type). +// +// The strings mostly use gofmt form. +// +// A Field or FieldList has as its type a comma-separated list +// of the types of the fields. For example, the field list +// x, y, z int +// has type "int, int, int". + +// The prefix "type " is the type of a type. +// For example, given +// var x int +// type T int +// x's type is "int" but T's type is "type int". +// mkType inserts the "type " prefix. +// getType removes it. +// isType tests for it. + +func mkType(t string) string { + return "type " + t +} + +func getType(t string) string { + if !isType(t) { + return "" + } + return t[len("type "):] +} + +func isType(t string) bool { + return strings.HasPrefix(t, "type ") +} + +// TypeConfig describes the universe of relevant types. +// For ease of creation, the types are all referred to by string +// name (e.g., "reflect.Value"). TypeByName is the only place +// where the strings are resolved. + +type TypeConfig struct { + Type map[string]*Type + Var map[string]string + Func map[string]string + + // External maps from a name to its type. + // It provides additional typings not present in the Go source itself. + // For now, the only additional typings are those generated by cgo. + External map[string]string +} + +// typeof returns the type of the given name, which may be of +// the form "x" or "p.X". +func (cfg *TypeConfig) typeof(name string) string { + if cfg.Var != nil { + if t := cfg.Var[name]; t != "" { + return t + } + } + if cfg.Func != nil { + if t := cfg.Func[name]; t != "" { + return "func()" + t + } + } + return "" +} + +// Type describes the Fields and Methods of a type. +// If the field or method cannot be found there, it is next +// looked for in the Embed list. +type Type struct { + Field map[string]string // map field name to type + Method map[string]string // map method name to comma-separated return types (should start with "func ") + Embed []string // list of types this type embeds (for extra methods) + Def string // definition of named type +} + +// dot returns the type of "typ.name", making its decision +// using the type information in cfg. +func (typ *Type) dot(cfg *TypeConfig, name string) string { + if typ.Field != nil { + if t := typ.Field[name]; t != "" { + return t + } + } + if typ.Method != nil { + if t := typ.Method[name]; t != "" { + return t + } + } + + for _, e := range typ.Embed { + etyp := cfg.Type[e] + if etyp != nil { + if t := etyp.dot(cfg, name); t != "" { + return t + } + } + } + + return "" +} + +// typecheck type checks the AST f assuming the information in cfg. +// It returns two maps with type information: +// typeof maps AST nodes to type information in gofmt string form. +// assign maps type strings to lists of expressions that were assigned +// to values of another type that were assigned to that type. +func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[interface{}]string, assign map[string][]interface{}) { + typeof = make(map[interface{}]string) + assign = make(map[string][]interface{}) + cfg1 := &TypeConfig{} + *cfg1 = *cfg // make copy so we can add locally + copied := false + + // If we import "C", add types of cgo objects. + cfg.External = map[string]string{} + cfg1.External = cfg.External + if imports(f, "C") { + // Run cgo on gofmtFile(f) + // Parse, extract decls from _cgo_gotypes.go + // Map _Ctype_* types to C.* types. + err := func() error { + txt, err := gofmtFile(f) + if err != nil { + return err + } + dir, err := os.MkdirTemp(os.TempDir(), "fix_cgo_typecheck") + if err != nil { + return err + } + defer os.RemoveAll(dir) + err = os.WriteFile(filepath.Join(dir, "in.go"), txt, 0600) + if err != nil { + return err + } + cmd := exec.Command(filepath.Join(runtime.GOROOT(), "bin", "go"), "tool", "cgo", "-objdir", dir, "-srcdir", dir, "in.go") + err = cmd.Run() + if err != nil { + return err + } + out, err := os.ReadFile(filepath.Join(dir, "_cgo_gotypes.go")) + if err != nil { + return err + } + cgo, err := parser.ParseFile(token.NewFileSet(), "cgo.go", out, 0) + if err != nil { + return err + } + for _, decl := range cgo.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + if strings.HasPrefix(fn.Name.Name, "_Cfunc_") { + var params, results []string + for _, p := range fn.Type.Params.List { + t := gofmt(p.Type) + t = strings.ReplaceAll(t, "_Ctype_", "C.") + params = append(params, t) + } + for _, r := range fn.Type.Results.List { + t := gofmt(r.Type) + t = strings.ReplaceAll(t, "_Ctype_", "C.") + results = append(results, t) + } + cfg.External["C."+fn.Name.Name[7:]] = joinFunc(params, results) + } + } + return nil + }() + if err != nil { + fmt.Fprintf(os.Stderr, "go fix: warning: no cgo types: %s\n", err) + } + } + + // gather function declarations + for _, decl := range f.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + typecheck1(cfg, fn.Type, typeof, assign) + t := typeof[fn.Type] + if fn.Recv != nil { + // The receiver must be a type. + rcvr := typeof[fn.Recv] + if !isType(rcvr) { + if len(fn.Recv.List) != 1 { + continue + } + rcvr = mkType(gofmt(fn.Recv.List[0].Type)) + typeof[fn.Recv.List[0].Type] = rcvr + } + rcvr = getType(rcvr) + if rcvr != "" && rcvr[0] == '*' { + rcvr = rcvr[1:] + } + typeof[rcvr+"."+fn.Name.Name] = t + } else { + if isType(t) { + t = getType(t) + } else { + t = gofmt(fn.Type) + } + typeof[fn.Name] = t + + // Record typeof[fn.Name.Obj] for future references to fn.Name. + typeof[fn.Name.Obj] = t + } + } + + // gather struct declarations + for _, decl := range f.Decls { + d, ok := decl.(*ast.GenDecl) + if ok { + for _, s := range d.Specs { + switch s := s.(type) { + case *ast.TypeSpec: + if cfg1.Type[s.Name.Name] != nil { + break + } + if !copied { + copied = true + // Copy map lazily: it's time. + cfg1.Type = make(map[string]*Type) + for k, v := range cfg.Type { + cfg1.Type[k] = v + } + } + t := &Type{Field: map[string]string{}} + cfg1.Type[s.Name.Name] = t + switch st := s.Type.(type) { + case *ast.StructType: + for _, f := range st.Fields.List { + for _, n := range f.Names { + t.Field[n.Name] = gofmt(f.Type) + } + } + case *ast.ArrayType, *ast.StarExpr, *ast.MapType: + t.Def = gofmt(st) + } + } + } + } + } + + typecheck1(cfg1, f, typeof, assign) + return typeof, assign +} + +func makeExprList(a []*ast.Ident) []ast.Expr { + var b []ast.Expr + for _, x := range a { + b = append(b, x) + } + return b +} + +// Typecheck1 is the recursive form of typecheck. +// It is like typecheck but adds to the information in typeof +// instead of allocating a new map. +func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string, assign map[string][]interface{}) { + // set sets the type of n to typ. + // If isDecl is true, n is being declared. + set := func(n ast.Expr, typ string, isDecl bool) { + if typeof[n] != "" || typ == "" { + if typeof[n] != typ { + assign[typ] = append(assign[typ], n) + } + return + } + typeof[n] = typ + + // If we obtained typ from the declaration of x + // propagate the type to all the uses. + // The !isDecl case is a cheat here, but it makes + // up in some cases for not paying attention to + // struct fields. The real type checker will be + // more accurate so we won't need the cheat. + if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") { + typeof[id.Obj] = typ + } + } + + // Type-check an assignment lhs = rhs. + // If isDecl is true, this is := so we can update + // the types of the objects that lhs refers to. + typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) { + if len(lhs) > 1 && len(rhs) == 1 { + if _, ok := rhs[0].(*ast.CallExpr); ok { + t := split(typeof[rhs[0]]) + // Lists should have same length but may not; pair what can be paired. + for i := 0; i < len(lhs) && i < len(t); i++ { + set(lhs[i], t[i], isDecl) + } + return + } + } + if len(lhs) == 1 && len(rhs) == 2 { + // x = y, ok + rhs = rhs[:1] + } else if len(lhs) == 2 && len(rhs) == 1 { + // x, ok = y + lhs = lhs[:1] + } + + // Match as much as we can. + for i := 0; i < len(lhs) && i < len(rhs); i++ { + x, y := lhs[i], rhs[i] + if typeof[y] != "" { + set(x, typeof[y], isDecl) + } else { + set(y, typeof[x], false) + } + } + } + + expand := func(s string) string { + typ := cfg.Type[s] + if typ != nil && typ.Def != "" { + return typ.Def + } + return s + } + + // The main type check is a recursive algorithm implemented + // by walkBeforeAfter(n, before, after). + // Most of it is bottom-up, but in a few places we need + // to know the type of the function we are checking. + // The before function records that information on + // the curfn stack. + var curfn []*ast.FuncType + + before := func(n interface{}) { + // push function type on stack + switch n := n.(type) { + case *ast.FuncDecl: + curfn = append(curfn, n.Type) + case *ast.FuncLit: + curfn = append(curfn, n.Type) + } + } + + // After is the real type checker. + after := func(n interface{}) { + if n == nil { + return + } + if false && reflect.TypeOf(n).Kind() == reflect.Ptr { // debugging trace + defer func() { + if t := typeof[n]; t != "" { + pos := fset.Position(n.(ast.Node).Pos()) + fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos, gofmt(n), t) + } + }() + } + + switch n := n.(type) { + case *ast.FuncDecl, *ast.FuncLit: + // pop function type off stack + curfn = curfn[:len(curfn)-1] + + case *ast.FuncType: + typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results]))) + + case *ast.FieldList: + // Field list is concatenation of sub-lists. + t := "" + for _, field := range n.List { + if t != "" { + t += ", " + } + t += typeof[field] + } + typeof[n] = t + + case *ast.Field: + // Field is one instance of the type per name. + all := "" + t := typeof[n.Type] + if !isType(t) { + // Create a type, because it is typically *T or *p.T + // and we might care about that type. + t = mkType(gofmt(n.Type)) + typeof[n.Type] = t + } + t = getType(t) + if len(n.Names) == 0 { + all = t + } else { + for _, id := range n.Names { + if all != "" { + all += ", " + } + all += t + typeof[id.Obj] = t + typeof[id] = t + } + } + typeof[n] = all + + case *ast.ValueSpec: + // var declaration. Use type if present. + if n.Type != nil { + t := typeof[n.Type] + if !isType(t) { + t = mkType(gofmt(n.Type)) + typeof[n.Type] = t + } + t = getType(t) + for _, id := range n.Names { + set(id, t, true) + } + } + // Now treat same as assignment. + typecheckAssign(makeExprList(n.Names), n.Values, true) + + case *ast.AssignStmt: + typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE) + + case *ast.Ident: + // Identifier can take its type from underlying object. + if t := typeof[n.Obj]; t != "" { + typeof[n] = t + } + + case *ast.SelectorExpr: + // Field or method. + name := n.Sel.Name + if t := typeof[n.X]; t != "" { + t = strings.TrimPrefix(t, "*") // implicit * + if typ := cfg.Type[t]; typ != nil { + if t := typ.dot(cfg, name); t != "" { + typeof[n] = t + return + } + } + tt := typeof[t+"."+name] + if isType(tt) { + typeof[n] = getType(tt) + return + } + } + // Package selector. + if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil { + str := x.Name + "." + name + if cfg.Type[str] != nil { + typeof[n] = mkType(str) + return + } + if t := cfg.typeof(x.Name + "." + name); t != "" { + typeof[n] = t + return + } + } + + case *ast.CallExpr: + // make(T) has type T. + if isTopName(n.Fun, "make") && len(n.Args) >= 1 { + typeof[n] = gofmt(n.Args[0]) + return + } + // new(T) has type *T + if isTopName(n.Fun, "new") && len(n.Args) == 1 { + typeof[n] = "*" + gofmt(n.Args[0]) + return + } + // Otherwise, use type of function to determine arguments. + t := typeof[n.Fun] + if t == "" { + t = cfg.External[gofmt(n.Fun)] + } + in, out := splitFunc(t) + if in == nil && out == nil { + return + } + typeof[n] = join(out) + for i, arg := range n.Args { + if i >= len(in) { + break + } + if typeof[arg] == "" { + typeof[arg] = in[i] + } + } + + case *ast.TypeAssertExpr: + // x.(type) has type of x. + if n.Type == nil { + typeof[n] = typeof[n.X] + return + } + // x.(T) has type T. + if t := typeof[n.Type]; isType(t) { + typeof[n] = getType(t) + } else { + typeof[n] = gofmt(n.Type) + } + + case *ast.SliceExpr: + // x[i:j] has type of x. + typeof[n] = typeof[n.X] + + case *ast.IndexExpr: + // x[i] has key type of x's type. + t := expand(typeof[n.X]) + if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") { + // Lazy: assume there are no nested [] in the array + // length or map key type. + if i := strings.Index(t, "]"); i >= 0 { + typeof[n] = t[i+1:] + } + } + + case *ast.StarExpr: + // *x for x of type *T has type T when x is an expr. + // We don't use the result when *x is a type, but + // compute it anyway. + t := expand(typeof[n.X]) + if isType(t) { + typeof[n] = "type *" + getType(t) + } else if strings.HasPrefix(t, "*") { + typeof[n] = t[len("*"):] + } + + case *ast.UnaryExpr: + // &x for x of type T has type *T. + t := typeof[n.X] + if t != "" && n.Op == token.AND { + typeof[n] = "*" + t + } + + case *ast.CompositeLit: + // T{...} has type T. + typeof[n] = gofmt(n.Type) + + // Propagate types down to values used in the composite literal. + t := expand(typeof[n]) + if strings.HasPrefix(t, "[") { // array or slice + // Lazy: assume there are no nested [] in the array length. + if i := strings.Index(t, "]"); i >= 0 { + et := t[i+1:] + for _, e := range n.Elts { + if kv, ok := e.(*ast.KeyValueExpr); ok { + e = kv.Value + } + if typeof[e] == "" { + typeof[e] = et + } + } + } + } + if strings.HasPrefix(t, "map[") { // map + // Lazy: assume there are no nested [] in the map key type. + if i := strings.Index(t, "]"); i >= 0 { + kt, vt := t[4:i], t[i+1:] + for _, e := range n.Elts { + if kv, ok := e.(*ast.KeyValueExpr); ok { + if typeof[kv.Key] == "" { + typeof[kv.Key] = kt + } + if typeof[kv.Value] == "" { + typeof[kv.Value] = vt + } + } + } + } + } + if typ := cfg.Type[t]; typ != nil && len(typ.Field) > 0 { // struct + for _, e := range n.Elts { + if kv, ok := e.(*ast.KeyValueExpr); ok { + if ft := typ.Field[fmt.Sprintf("%s", kv.Key)]; ft != "" { + if typeof[kv.Value] == "" { + typeof[kv.Value] = ft + } + } + } + } + } + + case *ast.ParenExpr: + // (x) has type of x. + typeof[n] = typeof[n.X] + + case *ast.RangeStmt: + t := expand(typeof[n.X]) + if t == "" { + return + } + var key, value string + if t == "string" { + key, value = "int", "rune" + } else if strings.HasPrefix(t, "[") { + key = "int" + if i := strings.Index(t, "]"); i >= 0 { + value = t[i+1:] + } + } else if strings.HasPrefix(t, "map[") { + if i := strings.Index(t, "]"); i >= 0 { + key, value = t[4:i], t[i+1:] + } + } + changed := false + if n.Key != nil && key != "" { + changed = true + set(n.Key, key, n.Tok == token.DEFINE) + } + if n.Value != nil && value != "" { + changed = true + set(n.Value, value, n.Tok == token.DEFINE) + } + // Ugly failure of vision: already type-checked body. + // Do it again now that we have that type info. + if changed { + typecheck1(cfg, n.Body, typeof, assign) + } + + case *ast.TypeSwitchStmt: + // Type of variable changes for each case in type switch, + // but go/parser generates just one variable. + // Repeat type check for each case with more precise + // type information. + as, ok := n.Assign.(*ast.AssignStmt) + if !ok { + return + } + varx, ok := as.Lhs[0].(*ast.Ident) + if !ok { + return + } + t := typeof[varx] + for _, cas := range n.Body.List { + cas := cas.(*ast.CaseClause) + if len(cas.List) == 1 { + // Variable has specific type only when there is + // exactly one type in the case list. + if tt := typeof[cas.List[0]]; isType(tt) { + tt = getType(tt) + typeof[varx] = tt + typeof[varx.Obj] = tt + typecheck1(cfg, cas.Body, typeof, assign) + } + } + } + // Restore t. + typeof[varx] = t + typeof[varx.Obj] = t + + case *ast.ReturnStmt: + if len(curfn) == 0 { + // Probably can't happen. + return + } + f := curfn[len(curfn)-1] + res := n.Results + if f.Results != nil { + t := split(typeof[f.Results]) + for i := 0; i < len(res) && i < len(t); i++ { + set(res[i], t[i], false) + } + } + + case *ast.BinaryExpr: + // Propagate types across binary ops that require two args of the same type. + switch n.Op { + case token.EQL, token.NEQ: // TODO: more cases. This is enough for the cftype fix. + if typeof[n.X] != "" && typeof[n.Y] == "" { + typeof[n.Y] = typeof[n.X] + } + if typeof[n.X] == "" && typeof[n.Y] != "" { + typeof[n.X] = typeof[n.Y] + } + } + } + } + walkBeforeAfter(f, before, after) +} + +// Convert between function type strings and lists of types. +// Using strings makes this a little harder, but it makes +// a lot of the rest of the code easier. This will all go away +// when we can use go/typechecker directly. + +// splitFunc splits "func(x,y,z) (a,b,c)" into ["x", "y", "z"] and ["a", "b", "c"]. +func splitFunc(s string) (in, out []string) { + if !strings.HasPrefix(s, "func(") { + return nil, nil + } + + i := len("func(") // index of beginning of 'in' arguments + nparen := 0 + for j := i; j < len(s); j++ { + switch s[j] { + case '(': + nparen++ + case ')': + nparen-- + if nparen < 0 { + // found end of parameter list + out := strings.TrimSpace(s[j+1:]) + if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' { + out = out[1 : len(out)-1] + } + return split(s[i:j]), split(out) + } + } + } + return nil, nil +} + +// joinFunc is the inverse of splitFunc. +func joinFunc(in, out []string) string { + outs := "" + if len(out) == 1 { + outs = " " + out[0] + } else if len(out) > 1 { + outs = " (" + join(out) + ")" + } + return "func(" + join(in) + ")" + outs +} + +// split splits "int, float" into ["int", "float"] and splits "" into []. +func split(s string) []string { + out := []string{} + i := 0 // current type being scanned is s[i:j]. + nparen := 0 + for j := 0; j < len(s); j++ { + switch s[j] { + case ' ': + if i == j { + i++ + } + case '(': + nparen++ + case ')': + nparen-- + if nparen < 0 { + // probably can't happen + return nil + } + case ',': + if nparen == 0 { + if i < j { + out = append(out, s[i:j]) + } + i = j + 1 + } + } + } + if nparen != 0 { + // probably can't happen + return nil + } + if i < len(s) { + out = append(out, s[i:]) + } + return out +} + +// join is the inverse of split. +func join(x []string) string { + return strings.Join(x, ", ") +} |