diff options
Diffstat (limited to 'src/database/sql')
-rw-r--r-- | src/database/sql/convert.go | 591 | ||||
-rw-r--r-- | src/database/sql/convert_test.go | 597 | ||||
-rw-r--r-- | src/database/sql/ctxutil.go | 146 | ||||
-rw-r--r-- | src/database/sql/doc.txt | 46 | ||||
-rw-r--r-- | src/database/sql/driver/driver.go | 552 | ||||
-rw-r--r-- | src/database/sql/driver/types.go | 297 | ||||
-rw-r--r-- | src/database/sql/driver/types_test.go | 95 | ||||
-rw-r--r-- | src/database/sql/example_cli_test.go | 84 | ||||
-rw-r--r-- | src/database/sql/example_service_test.go | 158 | ||||
-rw-r--r-- | src/database/sql/example_test.go | 369 | ||||
-rw-r--r-- | src/database/sql/fakedb_test.go | 1272 | ||||
-rw-r--r-- | src/database/sql/sql.go | 3406 | ||||
-rw-r--r-- | src/database/sql/sql_test.go | 4585 |
13 files changed, 12198 insertions, 0 deletions
diff --git a/src/database/sql/convert.go b/src/database/sql/convert.go new file mode 100644 index 0000000..32941cb --- /dev/null +++ b/src/database/sql/convert.go @@ -0,0 +1,591 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Type conversions for Scan. + +package sql + +import ( + "bytes" + "database/sql/driver" + "errors" + "fmt" + "reflect" + "strconv" + "time" + "unicode" + "unicode/utf8" +) + +var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error + +func describeNamedValue(nv *driver.NamedValue) string { + if len(nv.Name) == 0 { + return fmt.Sprintf("$%d", nv.Ordinal) + } + return fmt.Sprintf("with name %q", nv.Name) +} + +func validateNamedValueName(name string) error { + if len(name) == 0 { + return nil + } + r, _ := utf8.DecodeRuneInString(name) + if unicode.IsLetter(r) { + return nil + } + return fmt.Errorf("name %q does not begin with a letter", name) +} + +// ccChecker wraps the driver.ColumnConverter and allows it to be used +// as if it were a NamedValueChecker. If the driver ColumnConverter +// is not present then the NamedValueChecker will return driver.ErrSkip. +type ccChecker struct { + cci driver.ColumnConverter + want int +} + +func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error { + if c.cci == nil { + return driver.ErrSkip + } + // The column converter shouldn't be called on any index + // it isn't expecting. The final error will be thrown + // in the argument converter loop. + index := nv.Ordinal - 1 + if c.want <= index { + return nil + } + + // First, see if the value itself knows how to convert + // itself to a driver type. For example, a NullString + // struct changing into a string or nil. + if vr, ok := nv.Value.(driver.Valuer); ok { + sv, err := callValuerValue(vr) + if err != nil { + return err + } + if !driver.IsValue(sv) { + return fmt.Errorf("non-subset type %T returned from Value", sv) + } + nv.Value = sv + } + + // Second, ask the column to sanity check itself. For + // example, drivers might use this to make sure that + // an int64 values being inserted into a 16-bit + // integer field is in range (before getting + // truncated), or that a nil can't go into a NOT NULL + // column before going across the network to get the + // same error. + var err error + arg := nv.Value + nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg) + if err != nil { + return err + } + if !driver.IsValue(nv.Value) { + return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value) + } + return nil +} + +// defaultCheckNamedValue wraps the default ColumnConverter to have the same +// function signature as the CheckNamedValue in the driver.NamedValueChecker +// interface. +func defaultCheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value) + return err +} + +// driverArgsConnLocked converts arguments from callers of Stmt.Exec and +// Stmt.Query into driver Values. +// +// The statement ds may be nil, if no statement is available. +// +// ci must be locked. +func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []any) ([]driver.NamedValue, error) { + nvargs := make([]driver.NamedValue, len(args)) + + // -1 means the driver doesn't know how to count the number of + // placeholders, so we won't sanity check input here and instead let the + // driver deal with errors. + want := -1 + + var si driver.Stmt + var cc ccChecker + if ds != nil { + si = ds.si + want = ds.si.NumInput() + cc.want = want + } + + // Check all types of interfaces from the start. + // Drivers may opt to use the NamedValueChecker for special + // argument types, then return driver.ErrSkip to pass it along + // to the column converter. + nvc, ok := si.(driver.NamedValueChecker) + if !ok { + nvc, ok = ci.(driver.NamedValueChecker) + } + cci, ok := si.(driver.ColumnConverter) + if ok { + cc.cci = cci + } + + // Loop through all the arguments, checking each one. + // If no error is returned simply increment the index + // and continue. However if driver.ErrRemoveArgument + // is returned the argument is not included in the query + // argument list. + var err error + var n int + for _, arg := range args { + nv := &nvargs[n] + if np, ok := arg.(NamedArg); ok { + if err = validateNamedValueName(np.Name); err != nil { + return nil, err + } + arg = np.Value + nv.Name = np.Name + } + nv.Ordinal = n + 1 + nv.Value = arg + + // Checking sequence has four routes: + // A: 1. Default + // B: 1. NamedValueChecker 2. Column Converter 3. Default + // C: 1. NamedValueChecker 3. Default + // D: 1. Column Converter 2. Default + // + // The only time a Column Converter is called is first + // or after NamedValueConverter. If first it is handled before + // the nextCheck label. Thus for repeats tries only when the + // NamedValueConverter is selected should the Column Converter + // be used in the retry. + checker := defaultCheckNamedValue + nextCC := false + switch { + case nvc != nil: + nextCC = cci != nil + checker = nvc.CheckNamedValue + case cci != nil: + checker = cc.CheckNamedValue + } + + nextCheck: + err = checker(nv) + switch err { + case nil: + n++ + continue + case driver.ErrRemoveArgument: + nvargs = nvargs[:len(nvargs)-1] + continue + case driver.ErrSkip: + if nextCC { + nextCC = false + checker = cc.CheckNamedValue + } else { + checker = defaultCheckNamedValue + } + goto nextCheck + default: + return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err) + } + } + + // Check the length of arguments after conversion to allow for omitted + // arguments. + if want != -1 && len(nvargs) != want { + return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs)) + } + + return nvargs, nil + +} + +// convertAssign is the same as convertAssignRows, but without the optional +// rows argument. +func convertAssign(dest, src any) error { + return convertAssignRows(dest, src, nil) +} + +// convertAssignRows copies to dest the value in src, converting it if possible. +// An error is returned if the copy would result in loss of information. +// dest should be a pointer type. If rows is passed in, the rows will +// be used as the parent for any cursor values converted from a +// driver.Rows to a *Rows. +func convertAssignRows(dest, src any, rows *Rows) error { + // Common cases, without reflect. + switch s := src.(type) { + case string: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = s + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s) + return nil + case *RawBytes: + if d == nil { + return errNilPtr + } + *d = append((*d)[:0], s...) + return nil + } + case []byte: + switch d := dest.(type) { + case *string: + if d == nil { + return errNilPtr + } + *d = string(s) + return nil + case *any: + if d == nil { + return errNilPtr + } + *d = bytes.Clone(s) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = bytes.Clone(s) + return nil + case *RawBytes: + if d == nil { + return errNilPtr + } + *d = s + return nil + } + case time.Time: + switch d := dest.(type) { + case *time.Time: + *d = s + return nil + case *string: + *d = s.Format(time.RFC3339Nano) + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = []byte(s.Format(time.RFC3339Nano)) + return nil + case *RawBytes: + if d == nil { + return errNilPtr + } + *d = s.AppendFormat((*d)[:0], time.RFC3339Nano) + return nil + } + case decimalDecompose: + switch d := dest.(type) { + case decimalCompose: + return d.Compose(s.Decompose(nil)) + } + case nil: + switch d := dest.(type) { + case *any: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *[]byte: + if d == nil { + return errNilPtr + } + *d = nil + return nil + case *RawBytes: + if d == nil { + return errNilPtr + } + *d = nil + return nil + } + // The driver is returning a cursor the client may iterate over. + case driver.Rows: + switch d := dest.(type) { + case *Rows: + if d == nil { + return errNilPtr + } + if rows == nil { + return errors.New("invalid context to convert cursor rows, missing parent *Rows") + } + rows.closemu.Lock() + *d = Rows{ + dc: rows.dc, + releaseConn: func(error) {}, + rowsi: s, + } + // Chain the cancel function. + parentCancel := rows.cancel + rows.cancel = func() { + // When Rows.cancel is called, the closemu will be locked as well. + // So we can access rs.lasterr. + d.close(rows.lasterr) + if parentCancel != nil { + parentCancel() + } + } + rows.closemu.Unlock() + return nil + } + } + + var sv reflect.Value + + switch d := dest.(type) { + case *string: + sv = reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Bool, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + *d = asString(src) + return nil + } + case *[]byte: + sv = reflect.ValueOf(src) + if b, ok := asBytes(nil, sv); ok { + *d = b + return nil + } + case *RawBytes: + sv = reflect.ValueOf(src) + if b, ok := asBytes([]byte(*d)[:0], sv); ok { + *d = RawBytes(b) + return nil + } + case *bool: + bv, err := driver.Bool.ConvertValue(src) + if err == nil { + *d = bv.(bool) + } + return err + case *any: + *d = src + return nil + } + + if scanner, ok := dest.(Scanner); ok { + return scanner.Scan(src) + } + + dpv := reflect.ValueOf(dest) + if dpv.Kind() != reflect.Pointer { + return errors.New("destination not a pointer") + } + if dpv.IsNil() { + return errNilPtr + } + + if !sv.IsValid() { + sv = reflect.ValueOf(src) + } + + dv := reflect.Indirect(dpv) + if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) { + switch b := src.(type) { + case []byte: + dv.Set(reflect.ValueOf(bytes.Clone(b))) + default: + dv.Set(sv) + } + return nil + } + + if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) { + dv.Set(sv.Convert(dv.Type())) + return nil + } + + // The following conversions use a string value as an intermediate representation + // to convert between various numeric types. + // + // This also allows scanning into user defined types such as "type Int int64". + // For symmetry, also check for string destination types. + switch dv.Kind() { + case reflect.Pointer: + if src == nil { + dv.Set(reflect.Zero(dv.Type())) + return nil + } + dv.Set(reflect.New(dv.Type().Elem())) + return convertAssignRows(dv.Interface(), src, rows) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetInt(i64) + return nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + u64, err := strconv.ParseUint(s, 10, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetUint(u64) + return nil + case reflect.Float32, reflect.Float64: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + s := asString(src) + f64, err := strconv.ParseFloat(s, dv.Type().Bits()) + if err != nil { + err = strconvErr(err) + return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err) + } + dv.SetFloat(f64) + return nil + case reflect.String: + if src == nil { + return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind()) + } + switch v := src.(type) { + case string: + dv.SetString(v) + return nil + case []byte: + dv.SetString(string(v)) + return nil + } + } + + return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest) +} + +func strconvErr(err error) error { + if ne, ok := err.(*strconv.NumError); ok { + return ne.Err + } + return err +} + +func asString(src any) string { + switch v := src.(type) { + case string: + return v + case []byte: + return string(v) + } + rv := reflect.ValueOf(src) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(rv.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.FormatUint(rv.Uint(), 10) + case reflect.Float64: + return strconv.FormatFloat(rv.Float(), 'g', -1, 64) + case reflect.Float32: + return strconv.FormatFloat(rv.Float(), 'g', -1, 32) + case reflect.Bool: + return strconv.FormatBool(rv.Bool()) + } + return fmt.Sprintf("%v", src) +} + +func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) { + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.AppendInt(buf, rv.Int(), 10), true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return strconv.AppendUint(buf, rv.Uint(), 10), true + case reflect.Float32: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true + case reflect.Float64: + return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true + case reflect.Bool: + return strconv.AppendBool(buf, rv.Bool()), true + case reflect.String: + s := rv.String() + return append(buf, s...), true + } + return +} + +var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +// callValuerValue returns vr.Value(), with one exception: +// If vr.Value is an auto-generated method on a pointer type and the +// pointer is nil, it would panic at runtime in the panicwrap +// method. Treat it like nil instead. +// Issue 8415. +// +// This is so people can implement driver.Value on value types and +// still use nil pointers to those types to mean nil/NULL, just like +// string/*string. +// +// This function is mirrored in the database/sql/driver package. +func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { + if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer && + rv.IsNil() && + rv.Type().Elem().Implements(valuerReflectType) { + return nil, nil + } + return vr.Value() +} + +// decimal composes or decomposes a decimal value to and from individual parts. +// There are four parts: a boolean negative flag, a form byte with three possible states +// (finite=0, infinite=1, NaN=2), a base-2 big-endian integer +// coefficient (also known as a significand) as a []byte, and an int32 exponent. +// These are composed into a final value as "decimal = (neg) (form=finite) coefficient * 10 ^ exponent". +// A zero length coefficient is a zero value. +// The big-endian integer coefficient stores the most significant byte first (at coefficient[0]). +// If the form is not finite the coefficient and exponent should be ignored. +// The negative parameter may be set to true for any form, although implementations are not required +// to respect the negative parameter in the non-finite form. +// +// Implementations may choose to set the negative parameter to true on a zero or NaN value, +// but implementations that do not differentiate between negative and positive +// zero or NaN values should ignore the negative parameter without error. +// If an implementation does not support Infinity it may be converted into a NaN without error. +// If a value is set that is larger than what is supported by an implementation, +// an error must be returned. +// Implementations must return an error if a NaN or Infinity is attempted to be set while neither +// are supported. +// +// NOTE(kardianos): This is an experimental interface. See https://golang.org/issue/30870 +type decimal interface { + decimalDecompose + decimalCompose +} + +type decimalDecompose interface { + // Decompose returns the internal decimal state in parts. + // If the provided buf has sufficient capacity, buf may be returned as the coefficient with + // the value set and length set as appropriate. + Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) +} + +type decimalCompose interface { + // Compose sets the internal decimal value from parts. If the value cannot be + // represented then an error should be returned. + Compose(form byte, negative bool, coefficient []byte, exponent int32) error +} diff --git a/src/database/sql/convert_test.go b/src/database/sql/convert_test.go new file mode 100644 index 0000000..6d09fa1 --- /dev/null +++ b/src/database/sql/convert_test.go @@ -0,0 +1,597 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "database/sql/driver" + "fmt" + "reflect" + "runtime" + "strings" + "sync" + "testing" + "time" +) + +var someTime = time.Unix(123, 0) +var answer int64 = 42 + +type ( + userDefined float64 + userDefinedSlice []int + userDefinedString string +) + +type conversionTest struct { + s, d any // source and destination + + // following are used if they're non-zero + wantint int64 + wantuint uint64 + wantstr string + wantbytes []byte + wantraw RawBytes + wantf32 float32 + wantf64 float64 + wanttime time.Time + wantbool bool // used if d is of type *bool + wanterr string + wantiface any + wantptr *int64 // if non-nil, *d's pointed value must be equal to *wantptr + wantnil bool // if true, *d must be *int64(nil) + wantusrdef userDefined + wantusrstr userDefinedString +} + +// Target variables for scanning into. +var ( + scanstr string + scanbytes []byte + scanraw RawBytes + scanint int + scanuint8 uint8 + scanuint16 uint16 + scanbool bool + scanf32 float32 + scanf64 float64 + scantime time.Time + scanptr *int64 + scaniface any +) + +func conversionTests() []conversionTest { + // Return a fresh instance to test so "go test -count 2" works correctly. + return []conversionTest{ + // Exact conversions (destination pointer type matches source type) + {s: "foo", d: &scanstr, wantstr: "foo"}, + {s: 123, d: &scanint, wantint: 123}, + {s: someTime, d: &scantime, wanttime: someTime}, + + // To strings + {s: "string", d: &scanstr, wantstr: "string"}, + {s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"}, + {s: 123, d: &scanstr, wantstr: "123"}, + {s: int8(123), d: &scanstr, wantstr: "123"}, + {s: int64(123), d: &scanstr, wantstr: "123"}, + {s: uint8(123), d: &scanstr, wantstr: "123"}, + {s: uint16(123), d: &scanstr, wantstr: "123"}, + {s: uint32(123), d: &scanstr, wantstr: "123"}, + {s: uint64(123), d: &scanstr, wantstr: "123"}, + {s: 1.5, d: &scanstr, wantstr: "1.5"}, + + // From time.Time: + {s: time.Unix(1, 0).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01Z"}, + {s: time.Unix(1453874597, 0).In(time.FixedZone("here", -3600*8)), d: &scanstr, wantstr: "2016-01-26T22:03:17-08:00"}, + {s: time.Unix(1, 2).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01.000000002Z"}, + {s: time.Time{}, d: &scanstr, wantstr: "0001-01-01T00:00:00Z"}, + {s: time.Unix(1, 2).UTC(), d: &scanbytes, wantbytes: []byte("1970-01-01T00:00:01.000000002Z")}, + {s: time.Unix(1, 2).UTC(), d: &scaniface, wantiface: time.Unix(1, 2).UTC()}, + + // To []byte + {s: nil, d: &scanbytes, wantbytes: nil}, + {s: "string", d: &scanbytes, wantbytes: []byte("string")}, + {s: []byte("byteslice"), d: &scanbytes, wantbytes: []byte("byteslice")}, + {s: 123, d: &scanbytes, wantbytes: []byte("123")}, + {s: int8(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: int64(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint8(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint16(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint32(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: uint64(123), d: &scanbytes, wantbytes: []byte("123")}, + {s: 1.5, d: &scanbytes, wantbytes: []byte("1.5")}, + + // To RawBytes + {s: nil, d: &scanraw, wantraw: nil}, + {s: []byte("byteslice"), d: &scanraw, wantraw: RawBytes("byteslice")}, + {s: "string", d: &scanraw, wantraw: RawBytes("string")}, + {s: 123, d: &scanraw, wantraw: RawBytes("123")}, + {s: int8(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: int64(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: uint8(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: uint16(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: uint32(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: uint64(123), d: &scanraw, wantraw: RawBytes("123")}, + {s: 1.5, d: &scanraw, wantraw: RawBytes("1.5")}, + // time.Time has been placed here to check that the RawBytes slice gets + // correctly reset when calling time.Time.AppendFormat. + {s: time.Unix(2, 5).UTC(), d: &scanraw, wantraw: RawBytes("1970-01-01T00:00:02.000000005Z")}, + + // Strings to integers + {s: "255", d: &scanuint8, wantuint: 255}, + {s: "256", d: &scanuint8, wanterr: "converting driver.Value type string (\"256\") to a uint8: value out of range"}, + {s: "256", d: &scanuint16, wantuint: 256}, + {s: "-1", d: &scanint, wantint: -1}, + {s: "foo", d: &scanint, wanterr: "converting driver.Value type string (\"foo\") to a int: invalid syntax"}, + + // int64 to smaller integers + {s: int64(5), d: &scanuint8, wantuint: 5}, + {s: int64(256), d: &scanuint8, wanterr: "converting driver.Value type int64 (\"256\") to a uint8: value out of range"}, + {s: int64(256), d: &scanuint16, wantuint: 256}, + {s: int64(65536), d: &scanuint16, wanterr: "converting driver.Value type int64 (\"65536\") to a uint16: value out of range"}, + + // True bools + {s: true, d: &scanbool, wantbool: true}, + {s: "True", d: &scanbool, wantbool: true}, + {s: "TRUE", d: &scanbool, wantbool: true}, + {s: "1", d: &scanbool, wantbool: true}, + {s: 1, d: &scanbool, wantbool: true}, + {s: int64(1), d: &scanbool, wantbool: true}, + {s: uint16(1), d: &scanbool, wantbool: true}, + + // False bools + {s: false, d: &scanbool, wantbool: false}, + {s: "false", d: &scanbool, wantbool: false}, + {s: "FALSE", d: &scanbool, wantbool: false}, + {s: "0", d: &scanbool, wantbool: false}, + {s: 0, d: &scanbool, wantbool: false}, + {s: int64(0), d: &scanbool, wantbool: false}, + {s: uint16(0), d: &scanbool, wantbool: false}, + + // Not bools + {s: "yup", d: &scanbool, wanterr: `sql/driver: couldn't convert "yup" into type bool`}, + {s: 2, d: &scanbool, wanterr: `sql/driver: couldn't convert 2 into type bool`}, + + // Floats + {s: float64(1.5), d: &scanf64, wantf64: float64(1.5)}, + {s: int64(1), d: &scanf64, wantf64: float64(1)}, + {s: float64(1.5), d: &scanf32, wantf32: float32(1.5)}, + {s: "1.5", d: &scanf32, wantf32: float32(1.5)}, + {s: "1.5", d: &scanf64, wantf64: float64(1.5)}, + + // Pointers + {s: any(nil), d: &scanptr, wantnil: true}, + {s: int64(42), d: &scanptr, wantptr: &answer}, + + // To interface{} + {s: float64(1.5), d: &scaniface, wantiface: float64(1.5)}, + {s: int64(1), d: &scaniface, wantiface: int64(1)}, + {s: "str", d: &scaniface, wantiface: "str"}, + {s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")}, + {s: true, d: &scaniface, wantiface: true}, + {s: nil, d: &scaniface}, + {s: []byte(nil), d: &scaniface, wantiface: []byte(nil)}, + + // To a user-defined type + {s: 1.5, d: new(userDefined), wantusrdef: 1.5}, + {s: int64(123), d: new(userDefined), wantusrdef: 123}, + {s: "1.5", d: new(userDefined), wantusrdef: 1.5}, + {s: []byte{1, 2, 3}, d: new(userDefinedSlice), wanterr: `unsupported Scan, storing driver.Value type []uint8 into type *sql.userDefinedSlice`}, + {s: "str", d: new(userDefinedString), wantusrstr: "str"}, + + // Other errors + {s: complex(1, 2), d: &scanstr, wanterr: `unsupported Scan, storing driver.Value type complex128 into type *string`}, + } +} + +func intPtrValue(intptr any) any { + return reflect.Indirect(reflect.Indirect(reflect.ValueOf(intptr))).Int() +} + +func intValue(intptr any) int64 { + return reflect.Indirect(reflect.ValueOf(intptr)).Int() +} + +func uintValue(intptr any) uint64 { + return reflect.Indirect(reflect.ValueOf(intptr)).Uint() +} + +func float64Value(ptr any) float64 { + return *(ptr.(*float64)) +} + +func float32Value(ptr any) float32 { + return *(ptr.(*float32)) +} + +func timeValue(ptr any) time.Time { + return *(ptr.(*time.Time)) +} + +func TestConversions(t *testing.T) { + for n, ct := range conversionTests() { + err := convertAssign(ct.d, ct.s) + errstr := "" + if err != nil { + errstr = err.Error() + } + errf := func(format string, args ...any) { + base := fmt.Sprintf("convertAssign #%d: for %v (%T) -> %T, ", n, ct.s, ct.s, ct.d) + t.Errorf(base+format, args...) + } + if errstr != ct.wanterr { + errf("got error %q, want error %q", errstr, ct.wanterr) + } + if ct.wantstr != "" && ct.wantstr != scanstr { + errf("want string %q, got %q", ct.wantstr, scanstr) + } + if ct.wantbytes != nil && string(ct.wantbytes) != string(scanbytes) { + errf("want byte %q, got %q", ct.wantbytes, scanbytes) + } + if ct.wantraw != nil && string(ct.wantraw) != string(scanraw) { + errf("want RawBytes %q, got %q", ct.wantraw, scanraw) + } + if ct.wantint != 0 && ct.wantint != intValue(ct.d) { + errf("want int %d, got %d", ct.wantint, intValue(ct.d)) + } + if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) { + errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d)) + } + if ct.wantf32 != 0 && ct.wantf32 != float32Value(ct.d) { + errf("want float32 %v, got %v", ct.wantf32, float32Value(ct.d)) + } + if ct.wantf64 != 0 && ct.wantf64 != float64Value(ct.d) { + errf("want float32 %v, got %v", ct.wantf64, float64Value(ct.d)) + } + if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" { + errf("want bool %v, got %v", ct.wantbool, *bp) + } + if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) { + errf("want time %v, got %v", ct.wanttime, timeValue(ct.d)) + } + if ct.wantnil && *ct.d.(**int64) != nil { + errf("want nil, got %v", intPtrValue(ct.d)) + } + if ct.wantptr != nil { + if *ct.d.(**int64) == nil { + errf("want pointer to %v, got nil", *ct.wantptr) + } else if *ct.wantptr != intPtrValue(ct.d) { + errf("want pointer to %v, got %v", *ct.wantptr, intPtrValue(ct.d)) + } + } + if ifptr, ok := ct.d.(*any); ok { + if !reflect.DeepEqual(ct.wantiface, scaniface) { + errf("want interface %#v, got %#v", ct.wantiface, scaniface) + continue + } + if srcBytes, ok := ct.s.([]byte); ok { + dstBytes := (*ifptr).([]byte) + if len(srcBytes) > 0 && &dstBytes[0] == &srcBytes[0] { + errf("copy into interface{} didn't copy []byte data") + } + } + } + if ct.wantusrdef != 0 && ct.wantusrdef != *ct.d.(*userDefined) { + errf("want userDefined %f, got %f", ct.wantusrdef, *ct.d.(*userDefined)) + } + if len(ct.wantusrstr) != 0 && ct.wantusrstr != *ct.d.(*userDefinedString) { + errf("want userDefined %q, got %q", ct.wantusrstr, *ct.d.(*userDefinedString)) + } + } +} + +func TestNullString(t *testing.T) { + var ns NullString + convertAssign(&ns, []byte("foo")) + if !ns.Valid { + t.Errorf("expecting not null") + } + if ns.String != "foo" { + t.Errorf("expecting foo; got %q", ns.String) + } + convertAssign(&ns, nil) + if ns.Valid { + t.Errorf("expecting null on nil") + } + if ns.String != "" { + t.Errorf("expecting blank on nil; got %q", ns.String) + } +} + +type valueConverterTest struct { + c driver.ValueConverter + in, out any + err string +} + +var valueConverterTests = []valueConverterTest{ + {driver.DefaultParameterConverter, NullString{"hi", true}, "hi", ""}, + {driver.DefaultParameterConverter, NullString{"", false}, nil, ""}, +} + +func TestValueConverters(t *testing.T) { + for i, tt := range valueConverterTests { + out, err := tt.c.ConvertValue(tt.in) + goterr := "" + if err != nil { + goterr = err.Error() + } + if goterr != tt.err { + t.Errorf("test %d: %T(%T(%v)) error = %q; want error = %q", + i, tt.c, tt.in, tt.in, goterr, tt.err) + } + if tt.err != "" { + continue + } + if !reflect.DeepEqual(out, tt.out) { + t.Errorf("test %d: %T(%T(%v)) = %v (%T); want %v (%T)", + i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out) + } + } +} + +// Tests that assigning to RawBytes doesn't allocate (and also works). +func TestRawBytesAllocs(t *testing.T) { + var tests = []struct { + name string + in any + want string + }{ + {"uint64", uint64(12345678), "12345678"}, + {"uint32", uint32(1234), "1234"}, + {"uint16", uint16(12), "12"}, + {"uint8", uint8(1), "1"}, + {"uint", uint(123), "123"}, + {"int", int(123), "123"}, + {"int8", int8(1), "1"}, + {"int16", int16(12), "12"}, + {"int32", int32(1234), "1234"}, + {"int64", int64(12345678), "12345678"}, + {"float32", float32(1.5), "1.5"}, + {"float64", float64(64), "64"}, + {"bool", false, "false"}, + {"time", time.Unix(2, 5).UTC(), "1970-01-01T00:00:02.000000005Z"}, + } + + buf := make(RawBytes, 10) + test := func(name string, in any, want string) { + if err := convertAssign(&buf, in); err != nil { + t.Fatalf("%s: convertAssign = %v", name, err) + } + match := len(buf) == len(want) + if match { + for i, b := range buf { + if want[i] != b { + match = false + break + } + } + } + if !match { + t.Fatalf("%s: got %q (len %d); want %q (len %d)", name, buf, len(buf), want, len(want)) + } + } + + n := testing.AllocsPerRun(100, func() { + for _, tt := range tests { + test(tt.name, tt.in, tt.want) + } + }) + + // The numbers below are only valid for 64-bit interface word sizes, + // and gc. With 32-bit words there are more convT2E allocs, and + // with gccgo, only pointers currently go in interface data. + // So only care on amd64 gc for now. + measureAllocs := runtime.GOARCH == "amd64" && runtime.Compiler == "gc" + + if n > 0.5 && measureAllocs { + t.Fatalf("allocs = %v; want 0", n) + } + + // This one involves a convT2E allocation, string -> interface{} + n = testing.AllocsPerRun(100, func() { + test("string", "foo", "foo") + }) + if n > 1.5 && measureAllocs { + t.Fatalf("allocs = %v; want max 1", n) + } +} + +// https://golang.org/issues/13905 +func TestUserDefinedBytes(t *testing.T) { + type userDefinedBytes []byte + var u userDefinedBytes + v := []byte("foo") + + convertAssign(&u, v) + if &u[0] == &v[0] { + t.Fatal("userDefinedBytes got potentially dirty driver memory") + } +} + +type Valuer_V string + +func (v Valuer_V) Value() (driver.Value, error) { + return strings.ToUpper(string(v)), nil +} + +type Valuer_P string + +func (p *Valuer_P) Value() (driver.Value, error) { + if p == nil { + return "nil-to-str", nil + } + return strings.ToUpper(string(*p)), nil +} + +func TestDriverArgs(t *testing.T) { + var nilValuerVPtr *Valuer_V + var nilValuerPPtr *Valuer_P + var nilStrPtr *string + tests := []struct { + args []any + want []driver.NamedValue + }{ + 0: { + args: []any{Valuer_V("foo")}, + want: []driver.NamedValue{ + { + Ordinal: 1, + Value: "FOO", + }, + }, + }, + 1: { + args: []any{nilValuerVPtr}, + want: []driver.NamedValue{ + { + Ordinal: 1, + Value: nil, + }, + }, + }, + 2: { + args: []any{nilValuerPPtr}, + want: []driver.NamedValue{ + { + Ordinal: 1, + Value: "nil-to-str", + }, + }, + }, + 3: { + args: []any{"plain-str"}, + want: []driver.NamedValue{ + { + Ordinal: 1, + Value: "plain-str", + }, + }, + }, + 4: { + args: []any{nilStrPtr}, + want: []driver.NamedValue{ + { + Ordinal: 1, + Value: nil, + }, + }, + }, + } + for i, tt := range tests { + ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}} + got, err := driverArgsConnLocked(nil, ds, tt.args) + if err != nil { + t.Errorf("test[%d]: %v", i, err) + continue + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("test[%d]: got %v, want %v", i, got, tt.want) + } + } +} + +type dec struct { + form byte + neg bool + coefficient [16]byte + exponent int32 +} + +func (d dec) Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) { + coef := make([]byte, 16) + copy(coef, d.coefficient[:]) + return d.form, d.neg, coef, d.exponent +} + +func (d *dec) Compose(form byte, negative bool, coefficient []byte, exponent int32) error { + switch form { + default: + return fmt.Errorf("unknown form %d", form) + case 1, 2: + d.form = form + d.neg = negative + return nil + case 0: + } + d.form = form + d.neg = negative + d.exponent = exponent + + // This isn't strictly correct, as the extra bytes could be all zero, + // ignore this for this test. + if len(coefficient) > 16 { + return fmt.Errorf("coefficient too large") + } + copy(d.coefficient[:], coefficient) + + return nil +} + +type decFinite struct { + neg bool + coefficient [16]byte + exponent int32 +} + +func (d decFinite) Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) { + coef := make([]byte, 16) + copy(coef, d.coefficient[:]) + return 0, d.neg, coef, d.exponent +} + +func (d *decFinite) Compose(form byte, negative bool, coefficient []byte, exponent int32) error { + switch form { + default: + return fmt.Errorf("unknown form %d", form) + case 1, 2: + return fmt.Errorf("unsupported form %d", form) + case 0: + } + d.neg = negative + d.exponent = exponent + + // This isn't strictly correct, as the extra bytes could be all zero, + // ignore this for this test. + if len(coefficient) > 16 { + return fmt.Errorf("coefficient too large") + } + copy(d.coefficient[:], coefficient) + + return nil +} + +func TestDecimal(t *testing.T) { + list := []struct { + name string + in decimalDecompose + out dec + err bool + }{ + {name: "same", in: dec{exponent: -6}, out: dec{exponent: -6}}, + + // Ensure reflection is not used to assign the value by using different types. + {name: "diff", in: decFinite{exponent: -6}, out: dec{exponent: -6}}, + + {name: "bad-form", in: dec{form: 200}, err: true}, + } + for _, item := range list { + t.Run(item.name, func(t *testing.T) { + out := dec{} + err := convertAssign(&out, item.in) + if item.err { + if err == nil { + t.Fatalf("unexpected nil error") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !reflect.DeepEqual(out, item.out) { + t.Fatalf("got %#v want %#v", out, item.out) + } + }) + } +} diff --git a/src/database/sql/ctxutil.go b/src/database/sql/ctxutil.go new file mode 100644 index 0000000..4dbe6af --- /dev/null +++ b/src/database/sql/ctxutil.go @@ -0,0 +1,146 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "context" + "database/sql/driver" + "errors" +) + +func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) { + if ciCtx, is := ci.(driver.ConnPrepareContext); is { + return ciCtx.PrepareContext(ctx, query) + } + si, err := ci.Prepare(query) + if err == nil { + select { + default: + case <-ctx.Done(): + si.Close() + return nil, ctx.Err() + } + } + return si, err +} + +func ctxDriverExec(ctx context.Context, execerCtx driver.ExecerContext, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) { + if execerCtx != nil { + return execerCtx.ExecContext(ctx, query, nvdargs) + } + dargs, err := namedValueToValue(nvdargs) + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + return execer.Exec(query, dargs) +} + +func ctxDriverQuery(ctx context.Context, queryerCtx driver.QueryerContext, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) { + if queryerCtx != nil { + return queryerCtx.QueryContext(ctx, query, nvdargs) + } + dargs, err := namedValueToValue(nvdargs) + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + return queryer.Query(query, dargs) +} + +func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) { + if siCtx, is := si.(driver.StmtExecContext); is { + return siCtx.ExecContext(ctx, nvdargs) + } + dargs, err := namedValueToValue(nvdargs) + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + return si.Exec(dargs) +} + +func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) { + if siCtx, is := si.(driver.StmtQueryContext); is { + return siCtx.QueryContext(ctx, nvdargs) + } + dargs, err := namedValueToValue(nvdargs) + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + return si.Query(dargs) +} + +func ctxDriverBegin(ctx context.Context, opts *TxOptions, ci driver.Conn) (driver.Tx, error) { + if ciCtx, is := ci.(driver.ConnBeginTx); is { + dopts := driver.TxOptions{} + if opts != nil { + dopts.Isolation = driver.IsolationLevel(opts.Isolation) + dopts.ReadOnly = opts.ReadOnly + } + return ciCtx.BeginTx(ctx, dopts) + } + + if opts != nil { + // Check the transaction level. If the transaction level is non-default + // then return an error here as the BeginTx driver value is not supported. + if opts.Isolation != LevelDefault { + return nil, errors.New("sql: driver does not support non-default isolation level") + } + + // If a read-only transaction is requested return an error as the + // BeginTx driver value is not supported. + if opts.ReadOnly { + return nil, errors.New("sql: driver does not support read-only transactions") + } + } + + if ctx.Done() == nil { + return ci.Begin() + } + + txi, err := ci.Begin() + if err == nil { + select { + default: + case <-ctx.Done(): + txi.Rollback() + return nil, ctx.Err() + } + } + return txi, err +} + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + return nil, errors.New("sql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} diff --git a/src/database/sql/doc.txt b/src/database/sql/doc.txt new file mode 100644 index 0000000..1341b57 --- /dev/null +++ b/src/database/sql/doc.txt @@ -0,0 +1,46 @@ +Goals of the sql and sql/driver packages: + +* Provide a generic database API for a variety of SQL or SQL-like + databases. There currently exist Go libraries for SQLite, MySQL, + and Postgres, but all with a very different feel, and often + a non-Go-like feel. + +* Feel like Go. + +* Care mostly about the common cases. Common SQL should be portable. + SQL edge cases or db-specific extensions can be detected and + conditionally used by the application. It is a non-goal to care + about every particular db's extension or quirk. + +* Separate out the basic implementation of a database driver + (implementing the sql/driver interfaces) vs the implementation + of all the user-level types and convenience methods. + In a nutshell: + + User Code ---> sql package (concrete types) ---> sql/driver (interfaces) + Database Driver -> sql (to register) + sql/driver (implement interfaces) + +* Make type casting/conversions consistent between all drivers. To + achieve this, most of the conversions are done in the sql package, + not in each driver. The drivers then only have to deal with a + smaller set of types. + +* Be flexible with type conversions, but be paranoid about silent + truncation or other loss of precision. + +* Handle concurrency well. Users shouldn't need to care about the + database's per-connection thread safety issues (or lack thereof), + and shouldn't have to maintain their own free pools of connections. + The 'sql' package should deal with that bookkeeping as needed. Given + an *sql.DB, it should be possible to share that instance between + multiple goroutines, without any extra synchronization. + +* Push complexity, where necessary, down into the sql+driver packages, + rather than exposing it to users. Said otherwise, the sql package + should expose an ideal database that's not finnicky about how it's + accessed, even if that's not true. + +* Provide optional interfaces in sql/driver for drivers to implement + for special cases or fastpaths. But the only party that knows about + those is the sql package. To user code, some stuff just might start + working or start working slightly faster. diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go new file mode 100644 index 0000000..daf282b --- /dev/null +++ b/src/database/sql/driver/driver.go @@ -0,0 +1,552 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package driver defines interfaces to be implemented by database +// drivers as used by package sql. +// +// Most code should use package sql. +// +// The driver interface has evolved over time. Drivers should implement +// Connector and DriverContext interfaces. +// The Connector.Connect and Driver.Open methods should never return ErrBadConn. +// ErrBadConn should only be returned from Validator, SessionResetter, or +// a query method if the connection is already in an invalid (e.g. closed) state. +// +// All Conn implementations should implement the following interfaces: +// Pinger, SessionResetter, and Validator. +// +// If named parameters or context are supported, the driver's Conn should implement: +// ExecerContext, QueryerContext, ConnPrepareContext, and ConnBeginTx. +// +// To support custom data types, implement NamedValueChecker. NamedValueChecker +// also allows queries to accept per-query options as a parameter by returning +// ErrRemoveArgument from CheckNamedValue. +// +// If multiple result sets are supported, Rows should implement RowsNextResultSet. +// If the driver knows how to describe the types present in the returned result +// it should implement the following interfaces: RowsColumnTypeScanType, +// RowsColumnTypeDatabaseTypeName, RowsColumnTypeLength, RowsColumnTypeNullable, +// and RowsColumnTypePrecisionScale. A given row value may also return a Rows +// type, which may represent a database cursor value. +// +// Before a connection is returned to the connection pool after use, IsValid is +// called if implemented. Before a connection is reused for another query, +// ResetSession is called if implemented. If a connection is never returned to the +// connection pool but immediately reused, then ResetSession is called prior to +// reuse but IsValid is not called. +package driver + +import ( + "context" + "errors" + "reflect" +) + +// Value is a value that drivers must be able to handle. +// It is either nil, a type handled by a database driver's NamedValueChecker +// interface, or an instance of one of these types: +// +// int64 +// float64 +// bool +// []byte +// string +// time.Time +// +// If the driver supports cursors, a returned Value may also implement the Rows interface +// in this package. This is used, for example, when a user selects a cursor +// such as "select cursor(select * from my_table) from dual". If the Rows +// from the select is closed, the cursor Rows will also be closed. +type Value any + +// NamedValue holds both the value name and value. +type NamedValue struct { + // If the Name is not empty it should be used for the parameter identifier and + // not the ordinal position. + // + // Name will not have a symbol prefix. + Name string + + // Ordinal position of the parameter starting from one and is always set. + Ordinal int + + // Value is the parameter value. + Value Value +} + +// Driver is the interface that must be implemented by a database +// driver. +// +// Database drivers may implement DriverContext for access +// to contexts and to parse the name only once for a pool of connections, +// instead of once per connection. +type Driver interface { + // Open returns a new connection to the database. + // The name is a string in a driver-specific format. + // + // Open may return a cached connection (one previously + // closed), but doing so is unnecessary; the sql package + // maintains a pool of idle connections for efficient re-use. + // + // The returned connection is only used by one goroutine at a + // time. + Open(name string) (Conn, error) +} + +// If a Driver implements DriverContext, then sql.DB will call +// OpenConnector to obtain a Connector and then invoke +// that Connector's Connect method to obtain each needed connection, +// instead of invoking the Driver's Open method for each connection. +// The two-step sequence allows drivers to parse the name just once +// and also provides access to per-Conn contexts. +type DriverContext interface { + // OpenConnector must parse the name in the same format that Driver.Open + // parses the name parameter. + OpenConnector(name string) (Connector, error) +} + +// A Connector represents a driver in a fixed configuration +// and can create any number of equivalent Conns for use +// by multiple goroutines. +// +// A Connector can be passed to sql.OpenDB, to allow drivers +// to implement their own sql.DB constructors, or returned by +// DriverContext's OpenConnector method, to allow drivers +// access to context and to avoid repeated parsing of driver +// configuration. +// +// If a Connector implements io.Closer, the sql package's DB.Close +// method will call Close and return error (if any). +type Connector interface { + // Connect returns a connection to the database. + // Connect may return a cached connection (one previously + // closed), but doing so is unnecessary; the sql package + // maintains a pool of idle connections for efficient re-use. + // + // The provided context.Context is for dialing purposes only + // (see net.DialContext) and should not be stored or used for + // other purposes. A default timeout should still be used + // when dialing as a connection pool may call Connect + // asynchronously to any query. + // + // The returned connection is only used by one goroutine at a + // time. + Connect(context.Context) (Conn, error) + + // Driver returns the underlying Driver of the Connector, + // mainly to maintain compatibility with the Driver method + // on sql.DB. + Driver() Driver +} + +// ErrSkip may be returned by some optional interfaces' methods to +// indicate at runtime that the fast path is unavailable and the sql +// package should continue as if the optional interface was not +// implemented. ErrSkip is only supported where explicitly +// documented. +var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented") + +// ErrBadConn should be returned by a driver to signal to the sql +// package that a driver.Conn is in a bad state (such as the server +// having earlier closed the connection) and the sql package should +// retry on a new connection. +// +// To prevent duplicate operations, ErrBadConn should NOT be returned +// if there's a possibility that the database server might have +// performed the operation. Even if the server sends back an error, +// you shouldn't return ErrBadConn. +// +// Errors will be checked using errors.Is. An error may +// wrap ErrBadConn or implement the Is(error) bool method. +var ErrBadConn = errors.New("driver: bad connection") + +// Pinger is an optional interface that may be implemented by a Conn. +// +// If a Conn does not implement Pinger, the sql package's DB.Ping and +// DB.PingContext will check if there is at least one Conn available. +// +// If Conn.Ping returns ErrBadConn, DB.Ping and DB.PingContext will remove +// the Conn from pool. +type Pinger interface { + Ping(ctx context.Context) error +} + +// Execer is an optional interface that may be implemented by a Conn. +// +// If a Conn implements neither ExecerContext nor Execer, +// the sql package's DB.Exec will first prepare a query, execute the statement, +// and then close the statement. +// +// Exec may return ErrSkip. +// +// Deprecated: Drivers should implement ExecerContext instead. +type Execer interface { + Exec(query string, args []Value) (Result, error) +} + +// ExecerContext is an optional interface that may be implemented by a Conn. +// +// If a Conn does not implement ExecerContext, the sql package's DB.Exec +// will fall back to Execer; if the Conn does not implement Execer either, +// DB.Exec will first prepare a query, execute the statement, and then +// close the statement. +// +// ExecContext may return ErrSkip. +// +// ExecContext must honor the context timeout and return when the context is canceled. +type ExecerContext interface { + ExecContext(ctx context.Context, query string, args []NamedValue) (Result, error) +} + +// Queryer is an optional interface that may be implemented by a Conn. +// +// If a Conn implements neither QueryerContext nor Queryer, +// the sql package's DB.Query will first prepare a query, execute the statement, +// and then close the statement. +// +// Query may return ErrSkip. +// +// Deprecated: Drivers should implement QueryerContext instead. +type Queryer interface { + Query(query string, args []Value) (Rows, error) +} + +// QueryerContext is an optional interface that may be implemented by a Conn. +// +// If a Conn does not implement QueryerContext, the sql package's DB.Query +// will fall back to Queryer; if the Conn does not implement Queryer either, +// DB.Query will first prepare a query, execute the statement, and then +// close the statement. +// +// QueryContext may return ErrSkip. +// +// QueryContext must honor the context timeout and return when the context is canceled. +type QueryerContext interface { + QueryContext(ctx context.Context, query string, args []NamedValue) (Rows, error) +} + +// Conn is a connection to a database. It is not used concurrently +// by multiple goroutines. +// +// Conn is assumed to be stateful. +type Conn interface { + // Prepare returns a prepared statement, bound to this connection. + Prepare(query string) (Stmt, error) + + // Close invalidates and potentially stops any current + // prepared statements and transactions, marking this + // connection as no longer in use. + // + // Because the sql package maintains a free pool of + // connections and only calls Close when there's a surplus of + // idle connections, it shouldn't be necessary for drivers to + // do their own connection caching. + // + // Drivers must ensure all network calls made by Close + // do not block indefinitely (e.g. apply a timeout). + Close() error + + // Begin starts and returns a new transaction. + // + // Deprecated: Drivers should implement ConnBeginTx instead (or additionally). + Begin() (Tx, error) +} + +// ConnPrepareContext enhances the Conn interface with context. +type ConnPrepareContext interface { + // PrepareContext returns a prepared statement, bound to this connection. + // context is for the preparation of the statement, + // it must not store the context within the statement itself. + PrepareContext(ctx context.Context, query string) (Stmt, error) +} + +// IsolationLevel is the transaction isolation level stored in TxOptions. +// +// This type should be considered identical to sql.IsolationLevel along +// with any values defined on it. +type IsolationLevel int + +// TxOptions holds the transaction options. +// +// This type should be considered identical to sql.TxOptions. +type TxOptions struct { + Isolation IsolationLevel + ReadOnly bool +} + +// ConnBeginTx enhances the Conn interface with context and TxOptions. +type ConnBeginTx interface { + // BeginTx starts and returns a new transaction. + // If the context is canceled by the user the sql package will + // call Tx.Rollback before discarding and closing the connection. + // + // This must check opts.Isolation to determine if there is a set + // isolation level. If the driver does not support a non-default + // level and one is set or if there is a non-default isolation level + // that is not supported, an error must be returned. + // + // This must also check opts.ReadOnly to determine if the read-only + // value is true to either set the read-only transaction property if supported + // or return an error if it is not supported. + BeginTx(ctx context.Context, opts TxOptions) (Tx, error) +} + +// SessionResetter may be implemented by Conn to allow drivers to reset the +// session state associated with the connection and to signal a bad connection. +type SessionResetter interface { + // ResetSession is called prior to executing a query on the connection + // if the connection has been used before. If the driver returns ErrBadConn + // the connection is discarded. + ResetSession(ctx context.Context) error +} + +// Validator may be implemented by Conn to allow drivers to +// signal if a connection is valid or if it should be discarded. +// +// If implemented, drivers may return the underlying error from queries, +// even if the connection should be discarded by the connection pool. +type Validator interface { + // IsValid is called prior to placing the connection into the + // connection pool. The connection will be discarded if false is returned. + IsValid() bool +} + +// Result is the result of a query execution. +type Result interface { + // LastInsertId returns the database's auto-generated ID + // after, for example, an INSERT into a table with primary + // key. + LastInsertId() (int64, error) + + // RowsAffected returns the number of rows affected by the + // query. + RowsAffected() (int64, error) +} + +// Stmt is a prepared statement. It is bound to a Conn and not +// used by multiple goroutines concurrently. +type Stmt interface { + // Close closes the statement. + // + // As of Go 1.1, a Stmt will not be closed if it's in use + // by any queries. + // + // Drivers must ensure all network calls made by Close + // do not block indefinitely (e.g. apply a timeout). + Close() error + + // NumInput returns the number of placeholder parameters. + // + // If NumInput returns >= 0, the sql package will sanity check + // argument counts from callers and return errors to the caller + // before the statement's Exec or Query methods are called. + // + // NumInput may also return -1, if the driver doesn't know + // its number of placeholders. In that case, the sql package + // will not sanity check Exec or Query argument counts. + NumInput() int + + // Exec executes a query that doesn't return rows, such + // as an INSERT or UPDATE. + // + // Deprecated: Drivers should implement StmtExecContext instead (or additionally). + Exec(args []Value) (Result, error) + + // Query executes a query that may return rows, such as a + // SELECT. + // + // Deprecated: Drivers should implement StmtQueryContext instead (or additionally). + Query(args []Value) (Rows, error) +} + +// StmtExecContext enhances the Stmt interface by providing Exec with context. +type StmtExecContext interface { + // ExecContext executes a query that doesn't return rows, such + // as an INSERT or UPDATE. + // + // ExecContext must honor the context timeout and return when it is canceled. + ExecContext(ctx context.Context, args []NamedValue) (Result, error) +} + +// StmtQueryContext enhances the Stmt interface by providing Query with context. +type StmtQueryContext interface { + // QueryContext executes a query that may return rows, such as a + // SELECT. + // + // QueryContext must honor the context timeout and return when it is canceled. + QueryContext(ctx context.Context, args []NamedValue) (Rows, error) +} + +// ErrRemoveArgument may be returned from NamedValueChecker to instruct the +// sql package to not pass the argument to the driver query interface. +// Return when accepting query specific options or structures that aren't +// SQL query arguments. +var ErrRemoveArgument = errors.New("driver: remove argument from query") + +// NamedValueChecker may be optionally implemented by Conn or Stmt. It provides +// the driver more control to handle Go and database types beyond the default +// Values types allowed. +// +// The sql package checks for value checkers in the following order, +// stopping at the first found match: Stmt.NamedValueChecker, Conn.NamedValueChecker, +// Stmt.ColumnConverter, DefaultParameterConverter. +// +// If CheckNamedValue returns ErrRemoveArgument, the NamedValue will not be included in +// the final query arguments. This may be used to pass special options to +// the query itself. +// +// If ErrSkip is returned the column converter error checking +// path is used for the argument. Drivers may wish to return ErrSkip after +// they have exhausted their own special cases. +type NamedValueChecker interface { + // CheckNamedValue is called before passing arguments to the driver + // and is called in place of any ColumnConverter. CheckNamedValue must do type + // validation and conversion as appropriate for the driver. + CheckNamedValue(*NamedValue) error +} + +// ColumnConverter may be optionally implemented by Stmt if the +// statement is aware of its own columns' types and can convert from +// any type to a driver Value. +// +// Deprecated: Drivers should implement NamedValueChecker. +type ColumnConverter interface { + // ColumnConverter returns a ValueConverter for the provided + // column index. If the type of a specific column isn't known + // or shouldn't be handled specially, DefaultValueConverter + // can be returned. + ColumnConverter(idx int) ValueConverter +} + +// Rows is an iterator over an executed query's results. +type Rows interface { + // Columns returns the names of the columns. The number of + // columns of the result is inferred from the length of the + // slice. If a particular column name isn't known, an empty + // string should be returned for that entry. + Columns() []string + + // Close closes the rows iterator. + Close() error + + // Next is called to populate the next row of data into + // the provided slice. The provided slice will be the same + // size as the Columns() are wide. + // + // Next should return io.EOF when there are no more rows. + // + // The dest should not be written to outside of Next. Care + // should be taken when closing Rows not to modify + // a buffer held in dest. + Next(dest []Value) error +} + +// RowsNextResultSet extends the Rows interface by providing a way to signal +// the driver to advance to the next result set. +type RowsNextResultSet interface { + Rows + + // HasNextResultSet is called at the end of the current result set and + // reports whether there is another result set after the current one. + HasNextResultSet() bool + + // NextResultSet advances the driver to the next result set even + // if there are remaining rows in the current result set. + // + // NextResultSet should return io.EOF when there are no more result sets. + NextResultSet() error +} + +// RowsColumnTypeScanType may be implemented by Rows. It should return +// the value type that can be used to scan types into. For example, the database +// column type "bigint" this should return "reflect.TypeOf(int64(0))". +type RowsColumnTypeScanType interface { + Rows + ColumnTypeScanType(index int) reflect.Type +} + +// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the +// database system type name without the length. Type names should be uppercase. +// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT", +// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML", +// "TIMESTAMP". +type RowsColumnTypeDatabaseTypeName interface { + Rows + ColumnTypeDatabaseTypeName(index int) string +} + +// RowsColumnTypeLength may be implemented by Rows. It should return the length +// of the column type if the column is a variable length type. If the column is +// not a variable length type ok should return false. +// If length is not limited other than system limits, it should return math.MaxInt64. +// The following are examples of returned values for various types: +// +// TEXT (math.MaxInt64, true) +// varchar(10) (10, true) +// nvarchar(10) (10, true) +// decimal (0, false) +// int (0, false) +// bytea(30) (30, true) +type RowsColumnTypeLength interface { + Rows + ColumnTypeLength(index int) (length int64, ok bool) +} + +// RowsColumnTypeNullable may be implemented by Rows. The nullable value should +// be true if it is known the column may be null, or false if the column is known +// to be not nullable. +// If the column nullability is unknown, ok should be false. +type RowsColumnTypeNullable interface { + Rows + ColumnTypeNullable(index int) (nullable, ok bool) +} + +// RowsColumnTypePrecisionScale may be implemented by Rows. It should return +// the precision and scale for decimal types. If not applicable, ok should be false. +// The following are examples of returned values for various types: +// +// decimal(38, 4) (38, 4, true) +// int (0, 0, false) +// decimal (math.MaxInt64, math.MaxInt64, true) +type RowsColumnTypePrecisionScale interface { + Rows + ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) +} + +// Tx is a transaction. +type Tx interface { + Commit() error + Rollback() error +} + +// RowsAffected implements Result for an INSERT or UPDATE operation +// which mutates a number of rows. +type RowsAffected int64 + +var _ Result = RowsAffected(0) + +func (RowsAffected) LastInsertId() (int64, error) { + return 0, errors.New("LastInsertId is not supported by this driver") +} + +func (v RowsAffected) RowsAffected() (int64, error) { + return int64(v), nil +} + +// ResultNoRows is a pre-defined Result for drivers to return when a DDL +// command (such as a CREATE TABLE) succeeds. It returns an error for both +// LastInsertId and RowsAffected. +var ResultNoRows noRows + +type noRows struct{} + +var _ Result = noRows{} + +func (noRows) LastInsertId() (int64, error) { + return 0, errors.New("no LastInsertId available after DDL statement") +} + +func (noRows) RowsAffected() (int64, error) { + return 0, errors.New("no RowsAffected available after DDL statement") +} diff --git a/src/database/sql/driver/types.go b/src/database/sql/driver/types.go new file mode 100644 index 0000000..fa98df7 --- /dev/null +++ b/src/database/sql/driver/types.go @@ -0,0 +1,297 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package driver + +import ( + "fmt" + "reflect" + "strconv" + "time" +) + +// ValueConverter is the interface providing the ConvertValue method. +// +// Various implementations of ValueConverter are provided by the +// driver package to provide consistent implementations of conversions +// between drivers. The ValueConverters have several uses: +// +// - converting from the Value types as provided by the sql package +// into a database table's specific column type and making sure it +// fits, such as making sure a particular int64 fits in a +// table's uint16 column. +// +// - converting a value as given from the database into one of the +// driver Value types. +// +// - by the sql package, for converting from a driver's Value type +// to a user's type in a scan. +type ValueConverter interface { + // ConvertValue converts a value to a driver Value. + ConvertValue(v any) (Value, error) +} + +// Valuer is the interface providing the Value method. +// +// Types implementing Valuer interface are able to convert +// themselves to a driver Value. +type Valuer interface { + // Value returns a driver Value. + // Value must not panic. + Value() (Value, error) +} + +// Bool is a ValueConverter that converts input values to bools. +// +// The conversion rules are: +// - booleans are returned unchanged +// - for integer types, +// 1 is true +// 0 is false, +// other integers are an error +// - for strings and []byte, same rules as strconv.ParseBool +// - all other types are an error +var Bool boolType + +type boolType struct{} + +var _ ValueConverter = boolType{} + +func (boolType) String() string { return "Bool" } + +func (boolType) ConvertValue(src any) (Value, error) { + switch s := src.(type) { + case bool: + return s, nil + case string: + b, err := strconv.ParseBool(s) + if err != nil { + return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s) + } + return b, nil + case []byte: + b, err := strconv.ParseBool(string(s)) + if err != nil { + return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s) + } + return b, nil + } + + sv := reflect.ValueOf(src) + switch sv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + iv := sv.Int() + if iv == 1 || iv == 0 { + return iv == 1, nil + } + return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", iv) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + uv := sv.Uint() + if uv == 1 || uv == 0 { + return uv == 1, nil + } + return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", uv) + } + + return nil, fmt.Errorf("sql/driver: couldn't convert %v (%T) into type bool", src, src) +} + +// Int32 is a ValueConverter that converts input values to int64, +// respecting the limits of an int32 value. +var Int32 int32Type + +type int32Type struct{} + +var _ ValueConverter = int32Type{} + +func (int32Type) ConvertValue(v any) (Value, error) { + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i64 := rv.Int() + if i64 > (1<<31)-1 || i64 < -(1<<31) { + return nil, fmt.Errorf("sql/driver: value %d overflows int32", v) + } + return i64, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + u64 := rv.Uint() + if u64 > (1<<31)-1 { + return nil, fmt.Errorf("sql/driver: value %d overflows int32", v) + } + return int64(u64), nil + case reflect.String: + i, err := strconv.Atoi(rv.String()) + if err != nil { + return nil, fmt.Errorf("sql/driver: value %q can't be converted to int32", v) + } + return int64(i), nil + } + return nil, fmt.Errorf("sql/driver: unsupported value %v (type %T) converting to int32", v, v) +} + +// String is a ValueConverter that converts its input to a string. +// If the value is already a string or []byte, it's unchanged. +// If the value is of another type, conversion to string is done +// with fmt.Sprintf("%v", v). +var String stringType + +type stringType struct{} + +func (stringType) ConvertValue(v any) (Value, error) { + switch v.(type) { + case string, []byte: + return v, nil + } + return fmt.Sprintf("%v", v), nil +} + +// Null is a type that implements ValueConverter by allowing nil +// values but otherwise delegating to another ValueConverter. +type Null struct { + Converter ValueConverter +} + +func (n Null) ConvertValue(v any) (Value, error) { + if v == nil { + return nil, nil + } + return n.Converter.ConvertValue(v) +} + +// NotNull is a type that implements ValueConverter by disallowing nil +// values but otherwise delegating to another ValueConverter. +type NotNull struct { + Converter ValueConverter +} + +func (n NotNull) ConvertValue(v any) (Value, error) { + if v == nil { + return nil, fmt.Errorf("nil value not allowed") + } + return n.Converter.ConvertValue(v) +} + +// IsValue reports whether v is a valid Value parameter type. +func IsValue(v any) bool { + if v == nil { + return true + } + switch v.(type) { + case []byte, bool, float64, int64, string, time.Time: + return true + case decimalDecompose: + return true + } + return false +} + +// IsScanValue is equivalent to IsValue. +// It exists for compatibility. +func IsScanValue(v any) bool { + return IsValue(v) +} + +// DefaultParameterConverter is the default implementation of +// ValueConverter that's used when a Stmt doesn't implement +// ColumnConverter. +// +// DefaultParameterConverter returns its argument directly if +// IsValue(arg). Otherwise, if the argument implements Valuer, its +// Value method is used to return a Value. As a fallback, the provided +// argument's underlying type is used to convert it to a Value: +// underlying integer types are converted to int64, floats to float64, +// bool, string, and []byte to themselves. If the argument is a nil +// pointer, ConvertValue returns a nil Value. If the argument is a +// non-nil pointer, it is dereferenced and ConvertValue is called +// recursively. Other types are an error. +var DefaultParameterConverter defaultConverter + +type defaultConverter struct{} + +var _ ValueConverter = defaultConverter{} + +var valuerReflectType = reflect.TypeOf((*Valuer)(nil)).Elem() + +// callValuerValue returns vr.Value(), with one exception: +// If vr.Value is an auto-generated method on a pointer type and the +// pointer is nil, it would panic at runtime in the panicwrap +// method. Treat it like nil instead. +// Issue 8415. +// +// This is so people can implement driver.Value on value types and +// still use nil pointers to those types to mean nil/NULL, just like +// string/*string. +// +// This function is mirrored in the database/sql package. +func callValuerValue(vr Valuer) (v Value, err error) { + if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer && + rv.IsNil() && + rv.Type().Elem().Implements(valuerReflectType) { + return nil, nil + } + return vr.Value() +} + +func (defaultConverter) ConvertValue(v any) (Value, error) { + if IsValue(v) { + return v, nil + } + + switch vr := v.(type) { + case Valuer: + sv, err := callValuerValue(vr) + if err != nil { + return nil, err + } + if !IsValue(sv) { + return nil, fmt.Errorf("non-Value type %T returned from Value", sv) + } + return sv, nil + + // For now, continue to prefer the Valuer interface over the decimal decompose interface. + case decimalDecompose: + return vr, nil + } + + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Pointer: + // indirect pointers + if rv.IsNil() { + return nil, nil + } else { + return defaultConverter{}.ConvertValue(rv.Elem().Interface()) + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return rv.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + return int64(rv.Uint()), nil + case reflect.Uint64: + u64 := rv.Uint() + if u64 >= 1<<63 { + return nil, fmt.Errorf("uint64 values with high bit set are not supported") + } + return int64(u64), nil + case reflect.Float32, reflect.Float64: + return rv.Float(), nil + case reflect.Bool: + return rv.Bool(), nil + case reflect.Slice: + ek := rv.Type().Elem().Kind() + if ek == reflect.Uint8 { + return rv.Bytes(), nil + } + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) + case reflect.String: + return rv.String(), nil + } + return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) +} + +type decimalDecompose interface { + // Decompose returns the internal decimal state into parts. + // If the provided buf has sufficient capacity, buf may be returned as the coefficient with + // the value set and length set as appropriate. + Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) +} diff --git a/src/database/sql/driver/types_test.go b/src/database/sql/driver/types_test.go new file mode 100644 index 0000000..80e5e05 --- /dev/null +++ b/src/database/sql/driver/types_test.go @@ -0,0 +1,95 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package driver + +import ( + "reflect" + "testing" + "time" +) + +type valueConverterTest struct { + c ValueConverter + in any + out any + err string +} + +var now = time.Now() +var answer int64 = 42 + +type ( + i int64 + f float64 + b bool + bs []byte + s string + t time.Time + is []int +) + +var valueConverterTests = []valueConverterTest{ + {Bool, "true", true, ""}, + {Bool, "True", true, ""}, + {Bool, []byte("t"), true, ""}, + {Bool, true, true, ""}, + {Bool, "1", true, ""}, + {Bool, 1, true, ""}, + {Bool, int64(1), true, ""}, + {Bool, uint16(1), true, ""}, + {Bool, "false", false, ""}, + {Bool, false, false, ""}, + {Bool, "0", false, ""}, + {Bool, 0, false, ""}, + {Bool, int64(0), false, ""}, + {Bool, uint16(0), false, ""}, + {c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"}, + {c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"}, + {DefaultParameterConverter, now, now, ""}, + {DefaultParameterConverter, (*int64)(nil), nil, ""}, + {DefaultParameterConverter, &answer, answer, ""}, + {DefaultParameterConverter, &now, now, ""}, + {DefaultParameterConverter, i(9), int64(9), ""}, + {DefaultParameterConverter, f(0.1), float64(0.1), ""}, + {DefaultParameterConverter, b(true), true, ""}, + {DefaultParameterConverter, bs{1}, []byte{1}, ""}, + {DefaultParameterConverter, s("a"), "a", ""}, + {DefaultParameterConverter, is{1}, nil, "unsupported type driver.is, a slice of int"}, + {DefaultParameterConverter, dec{exponent: -6}, dec{exponent: -6}, ""}, +} + +func TestValueConverters(t *testing.T) { + for i, tt := range valueConverterTests { + out, err := tt.c.ConvertValue(tt.in) + goterr := "" + if err != nil { + goterr = err.Error() + } + if goterr != tt.err { + t.Errorf("test %d: %T(%T(%v)) error = %q; want error = %q", + i, tt.c, tt.in, tt.in, goterr, tt.err) + } + if tt.err != "" { + continue + } + if !reflect.DeepEqual(out, tt.out) { + t.Errorf("test %d: %T(%T(%v)) = %v (%T); want %v (%T)", + i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out) + } + } +} + +type dec struct { + form byte + neg bool + coefficient [16]byte + exponent int32 +} + +func (d dec) Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) { + coef := make([]byte, 16) + copy(coef, d.coefficient[:]) + return d.form, d.neg, coef, d.exponent +} diff --git a/src/database/sql/example_cli_test.go b/src/database/sql/example_cli_test.go new file mode 100644 index 0000000..1e297af --- /dev/null +++ b/src/database/sql/example_cli_test.go @@ -0,0 +1,84 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql_test + +import ( + "context" + "database/sql" + "flag" + "log" + "os" + "os/signal" + "time" +) + +var pool *sql.DB // Database connection pool. + +func Example_openDBCLI() { + id := flag.Int64("id", 0, "person ID to find") + dsn := flag.String("dsn", os.Getenv("DSN"), "connection data source name") + flag.Parse() + + if len(*dsn) == 0 { + log.Fatal("missing dsn flag") + } + if *id == 0 { + log.Fatal("missing person ID") + } + var err error + + // Opening a driver typically will not attempt to connect to the database. + pool, err = sql.Open("driver-name", *dsn) + if err != nil { + // This will not be a connection error, but a DSN parse error or + // another initialization error. + log.Fatal("unable to use data source name", err) + } + defer pool.Close() + + pool.SetConnMaxLifetime(0) + pool.SetMaxIdleConns(3) + pool.SetMaxOpenConns(3) + + ctx, stop := context.WithCancel(context.Background()) + defer stop() + + appSignal := make(chan os.Signal, 3) + signal.Notify(appSignal, os.Interrupt) + + go func() { + <-appSignal + stop() + }() + + Ping(ctx) + + Query(ctx, *id) +} + +// Ping the database to verify DSN provided by the user is valid and the +// server accessible. If the ping fails exit the program with an error. +func Ping(ctx context.Context) { + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + + if err := pool.PingContext(ctx); err != nil { + log.Fatalf("unable to connect to database: %v", err) + } +} + +// Query the database for the information requested and prints the results. +// If the query fails exit the program with an error. +func Query(ctx context.Context, id int64) { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + var name string + err := pool.QueryRowContext(ctx, "select p.name from people as p where p.id = :id;", sql.Named("id", id)).Scan(&name) + if err != nil { + log.Fatal("unable to execute search query", err) + } + log.Println("name=", name) +} diff --git a/src/database/sql/example_service_test.go b/src/database/sql/example_service_test.go new file mode 100644 index 0000000..768307c --- /dev/null +++ b/src/database/sql/example_service_test.go @@ -0,0 +1,158 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql_test + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "time" +) + +func Example_openDBService() { + // Opening a driver typically will not attempt to connect to the database. + db, err := sql.Open("driver-name", "database=test1") + if err != nil { + // This will not be a connection error, but a DSN parse error or + // another initialization error. + log.Fatal(err) + } + db.SetConnMaxLifetime(0) + db.SetMaxIdleConns(50) + db.SetMaxOpenConns(50) + + s := &Service{db: db} + + http.ListenAndServe(":8080", s) +} + +type Service struct { + db *sql.DB +} + +func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { + db := s.db + switch r.URL.Path { + default: + http.Error(w, "not found", http.StatusNotFound) + return + case "/healthz": + ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second) + defer cancel() + + err := s.db.PingContext(ctx) + if err != nil { + http.Error(w, fmt.Sprintf("db down: %v", err), http.StatusFailedDependency) + return + } + w.WriteHeader(http.StatusOK) + return + case "/quick-action": + // This is a short SELECT. Use the request context as the base of + // the context timeout. + ctx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + defer cancel() + + id := 5 + org := 10 + var name string + err := db.QueryRowContext(ctx, ` +select + p.name +from + people as p + join organization as o on p.organization = o.id +where + p.id = :id + and o.id = :org +;`, + sql.Named("id", id), + sql.Named("org", org), + ).Scan(&name) + if err != nil { + if err == sql.ErrNoRows { + http.Error(w, "not found", http.StatusNotFound) + return + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + io.WriteString(w, name) + return + case "/long-action": + // This is a long SELECT. Use the request context as the base of + // the context timeout, but give it some time to finish. If + // the client cancels before the query is done the query will also + // be canceled. + ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second) + defer cancel() + + var names []string + rows, err := db.QueryContext(ctx, "select p.name from people as p where p.active = true;") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + for rows.Next() { + var name string + err = rows.Scan(&name) + if err != nil { + break + } + names = append(names, name) + } + // Check for errors during rows "Close". + // This may be more important if multiple statements are executed + // in a single batch and rows were written as well as read. + if closeErr := rows.Close(); closeErr != nil { + http.Error(w, closeErr.Error(), http.StatusInternalServerError) + return + } + + // Check for row scan error. + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Check for errors during row iteration. + if err = rows.Err(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + json.NewEncoder(w).Encode(names) + return + case "/async-action": + // This action has side effects that we want to preserve + // even if the client cancels the HTTP request part way through. + // For this we do not use the http request context as a base for + // the timeout. + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + var orderRef = "ABC123" + tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable}) + _, err = tx.ExecContext(ctx, "stored_proc_name", orderRef) + + if err != nil { + tx.Rollback() + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + err = tx.Commit() + if err != nil { + http.Error(w, "action in unknown state, check state before attempting again", http.StatusInternalServerError) + return + } + w.WriteHeader(http.StatusOK) + return + } +} diff --git a/src/database/sql/example_test.go b/src/database/sql/example_test.go new file mode 100644 index 0000000..aafb0e3 --- /dev/null +++ b/src/database/sql/example_test.go @@ -0,0 +1,369 @@ +// Copyright 2013 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql_test + +import ( + "context" + "database/sql" + "fmt" + "log" + "strings" + "time" +) + +var ( + ctx context.Context + db *sql.DB +) + +func ExampleDB_QueryContext() { + age := 27 + rows, err := db.QueryContext(ctx, "SELECT name FROM users WHERE age=?", age) + if err != nil { + log.Fatal(err) + } + defer rows.Close() + names := make([]string, 0) + + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + // Check for a scan error. + // Query rows will be closed with defer. + log.Fatal(err) + } + names = append(names, name) + } + // If the database is being written to ensure to check for Close + // errors that may be returned from the driver. The query may + // encounter an auto-commit error and be forced to rollback changes. + rerr := rows.Close() + if rerr != nil { + log.Fatal(rerr) + } + + // Rows.Err will report the last error encountered by Rows.Scan. + if err := rows.Err(); err != nil { + log.Fatal(err) + } + fmt.Printf("%s are %d years old", strings.Join(names, ", "), age) +} + +func ExampleDB_QueryRowContext() { + id := 123 + var username string + var created time.Time + err := db.QueryRowContext(ctx, "SELECT username, created_at FROM users WHERE id=?", id).Scan(&username, &created) + switch { + case err == sql.ErrNoRows: + log.Printf("no user with id %d\n", id) + case err != nil: + log.Fatalf("query error: %v\n", err) + default: + log.Printf("username is %q, account created on %s\n", username, created) + } +} + +func ExampleDB_ExecContext() { + id := 47 + result, err := db.ExecContext(ctx, "UPDATE balances SET balance = balance + 10 WHERE user_id = ?", id) + if err != nil { + log.Fatal(err) + } + rows, err := result.RowsAffected() + if err != nil { + log.Fatal(err) + } + if rows != 1 { + log.Fatalf("expected to affect 1 row, affected %d", rows) + } +} + +func ExampleDB_Query_multipleResultSets() { + age := 27 + q := ` +create temp table uid (id bigint); -- Create temp table for queries. +insert into uid +select id from users where age < ?; -- Populate temp table. + +-- First result set. +select + users.id, name +from + users + join uid on users.id = uid.id +; + +-- Second result set. +select + ur.user, ur.role +from + user_roles as ur + join uid on uid.id = ur.user +; + ` + rows, err := db.Query(q, age) + if err != nil { + log.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + var ( + id int64 + name string + ) + if err := rows.Scan(&id, &name); err != nil { + log.Fatal(err) + } + log.Printf("id %d name is %s\n", id, name) + } + if !rows.NextResultSet() { + log.Fatalf("expected more result sets: %v", rows.Err()) + } + var roleMap = map[int64]string{ + 1: "user", + 2: "admin", + 3: "gopher", + } + for rows.Next() { + var ( + id int64 + role int64 + ) + if err := rows.Scan(&id, &role); err != nil { + log.Fatal(err) + } + log.Printf("id %d has role %s\n", id, roleMap[role]) + } + if err := rows.Err(); err != nil { + log.Fatal(err) + } +} + +func ExampleDB_PingContext() { + // Ping and PingContext may be used to determine if communication with + // the database server is still possible. + // + // When used in a command line application Ping may be used to establish + // that further queries are possible; that the provided DSN is valid. + // + // When used in long running service Ping may be part of the health + // checking system. + + ctx, cancel := context.WithTimeout(ctx, 1*time.Second) + defer cancel() + + status := "up" + if err := db.PingContext(ctx); err != nil { + status = "down" + } + log.Println(status) +} + +func ExampleDB_Prepare() { + projects := []struct { + mascot string + release int + }{ + {"tux", 1991}, + {"duke", 1996}, + {"gopher", 2009}, + {"moby dock", 2013}, + } + + stmt, err := db.Prepare("INSERT INTO projects(id, mascot, release, category) VALUES( ?, ?, ?, ? )") + if err != nil { + log.Fatal(err) + } + defer stmt.Close() // Prepared statements take up server resources and should be closed after use. + + for id, project := range projects { + if _, err := stmt.Exec(id+1, project.mascot, project.release, "open source"); err != nil { + log.Fatal(err) + } + } +} + +func ExampleTx_Prepare() { + projects := []struct { + mascot string + release int + }{ + {"tux", 1991}, + {"duke", 1996}, + {"gopher", 2009}, + {"moby dock", 2013}, + } + + tx, err := db.Begin() + if err != nil { + log.Fatal(err) + } + defer tx.Rollback() // The rollback will be ignored if the tx has been committed later in the function. + + stmt, err := tx.Prepare("INSERT INTO projects(id, mascot, release, category) VALUES( ?, ?, ?, ? )") + if err != nil { + log.Fatal(err) + } + defer stmt.Close() // Prepared statements take up server resources and should be closed after use. + + for id, project := range projects { + if _, err := stmt.Exec(id+1, project.mascot, project.release, "open source"); err != nil { + log.Fatal(err) + } + } + if err := tx.Commit(); err != nil { + log.Fatal(err) + } +} + +func ExampleDB_BeginTx() { + tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable}) + if err != nil { + log.Fatal(err) + } + id := 37 + _, execErr := tx.Exec(`UPDATE users SET status = ? WHERE id = ?`, "paid", id) + if execErr != nil { + _ = tx.Rollback() + log.Fatal(execErr) + } + if err := tx.Commit(); err != nil { + log.Fatal(err) + } +} + +func ExampleConn_ExecContext() { + // A *DB is a pool of connections. Call Conn to reserve a connection for + // exclusive use. + conn, err := db.Conn(ctx) + if err != nil { + log.Fatal(err) + } + defer conn.Close() // Return the connection to the pool. + id := 41 + result, err := conn.ExecContext(ctx, `UPDATE balances SET balance = balance + 10 WHERE user_id = ?;`, id) + if err != nil { + log.Fatal(err) + } + rows, err := result.RowsAffected() + if err != nil { + log.Fatal(err) + } + if rows != 1 { + log.Fatalf("expected single row affected, got %d rows affected", rows) + } +} + +func ExampleTx_ExecContext() { + tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable}) + if err != nil { + log.Fatal(err) + } + id := 37 + _, execErr := tx.ExecContext(ctx, "UPDATE users SET status = ? WHERE id = ?", "paid", id) + if execErr != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + log.Fatalf("update failed: %v, unable to rollback: %v\n", execErr, rollbackErr) + } + log.Fatalf("update failed: %v", execErr) + } + if err := tx.Commit(); err != nil { + log.Fatal(err) + } +} + +func ExampleTx_Rollback() { + tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable}) + if err != nil { + log.Fatal(err) + } + id := 53 + _, err = tx.ExecContext(ctx, "UPDATE drivers SET status = ? WHERE id = ?;", "assigned", id) + if err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + log.Fatalf("update drivers: unable to rollback: %v", rollbackErr) + } + log.Fatal(err) + } + _, err = tx.ExecContext(ctx, "UPDATE pickups SET driver_id = $1;", id) + if err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + log.Fatalf("update failed: %v, unable to back: %v", err, rollbackErr) + } + log.Fatal(err) + } + if err := tx.Commit(); err != nil { + log.Fatal(err) + } +} + +func ExampleStmt() { + // In normal use, create one Stmt when your process starts. + stmt, err := db.PrepareContext(ctx, "SELECT username FROM users WHERE id = ?") + if err != nil { + log.Fatal(err) + } + defer stmt.Close() + + // Then reuse it each time you need to issue the query. + id := 43 + var username string + err = stmt.QueryRowContext(ctx, id).Scan(&username) + switch { + case err == sql.ErrNoRows: + log.Fatalf("no user with id %d", id) + case err != nil: + log.Fatal(err) + default: + log.Printf("username is %s\n", username) + } +} + +func ExampleStmt_QueryRowContext() { + // In normal use, create one Stmt when your process starts. + stmt, err := db.PrepareContext(ctx, "SELECT username FROM users WHERE id = ?") + if err != nil { + log.Fatal(err) + } + defer stmt.Close() + + // Then reuse it each time you need to issue the query. + id := 43 + var username string + err = stmt.QueryRowContext(ctx, id).Scan(&username) + switch { + case err == sql.ErrNoRows: + log.Fatalf("no user with id %d", id) + case err != nil: + log.Fatal(err) + default: + log.Printf("username is %s\n", username) + } +} + +func ExampleRows() { + age := 27 + rows, err := db.QueryContext(ctx, "SELECT name FROM users WHERE age=?", age) + if err != nil { + log.Fatal(err) + } + defer rows.Close() + + names := make([]string, 0) + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + log.Fatal(err) + } + names = append(names, name) + } + // Check for errors from iterating over rows. + if err := rows.Err(); err != nil { + log.Fatal(err) + } + log.Printf("%s are %d years old", strings.Join(names, ", "), age) +} diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go new file mode 100644 index 0000000..2fe5ea4 --- /dev/null +++ b/src/database/sql/fakedb_test.go @@ -0,0 +1,1272 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" + "io" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +// fakeDriver is a fake database that implements Go's driver.Driver +// interface, just for testing. +// +// It speaks a query language that's semantically similar to but +// syntactically different and simpler than SQL. The syntax is as +// follows: +// +// WIPE +// CREATE|<tablename>|<col>=<type>,<col>=<type>,... +// where types are: "string", [u]int{8,16,32,64}, "bool" +// INSERT|<tablename>|col=val,col2=val2,col3=? +// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=? +// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2 +// +// Any of these can be preceded by PANIC|<method>|, to cause the +// named method on fakeStmt to panic. +// +// Any of these can be proceeded by WAIT|<duration>|, to cause the +// named method on fakeStmt to sleep for the specified duration. +// +// Multiple of these can be combined when separated with a semicolon. +// +// When opening a fakeDriver's database, it starts empty with no +// tables. All tables and data are stored in memory only. +type fakeDriver struct { + mu sync.Mutex // guards 3 following fields + openCount int // conn opens + closeCount int // conn closes + waitCh chan struct{} + waitingCh chan struct{} + dbs map[string]*fakeDB +} + +type fakeConnector struct { + name string + + waiter func(context.Context) + closed bool +} + +func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) { + conn, err := fdriver.Open(c.name) + conn.(*fakeConn).waiter = c.waiter + return conn, err +} + +func (c *fakeConnector) Driver() driver.Driver { + return fdriver +} + +func (c *fakeConnector) Close() error { + if c.closed { + return errors.New("fakedb: connector is closed") + } + c.closed = true + return nil +} + +type fakeDriverCtx struct { + fakeDriver +} + +var _ driver.DriverContext = &fakeDriverCtx{} + +func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) { + return &fakeConnector{name: name}, nil +} + +type fakeDB struct { + name string + + mu sync.Mutex + tables map[string]*table + badConn bool + allowAny bool +} + +type fakeError struct { + Message string + Wrapped error +} + +func (err fakeError) Error() string { + return err.Message +} + +func (err fakeError) Unwrap() error { + return err.Wrapped +} + +type table struct { + mu sync.Mutex + colname []string + coltype []string + rows []*row +} + +func (t *table) columnIndex(name string) int { + for n, nname := range t.colname { + if name == nname { + return n + } + } + return -1 +} + +type row struct { + cols []any // must be same size as its table colname + coltype +} + +type memToucher interface { + // touchMem reads & writes some memory, to help find data races. + touchMem() +} + +type fakeConn struct { + db *fakeDB // where to return ourselves to + + currTx *fakeTx + + // Every operation writes to line to enable the race detector + // check for data races. + line int64 + + // Stats for tests: + mu sync.Mutex + stmtsMade int + stmtsClosed int + numPrepare int + + // bad connection tests; see isBad() + bad bool + stickyBad bool + + skipDirtySession bool // tests that use Conn should set this to true. + + // dirtySession tests ResetSession, true if a query has executed + // until ResetSession is called. + dirtySession bool + + // The waiter is called before each query. May be used in place of the "WAIT" + // directive. + waiter func(context.Context) +} + +func (c *fakeConn) touchMem() { + c.line++ +} + +func (c *fakeConn) incrStat(v *int) { + c.mu.Lock() + *v++ + c.mu.Unlock() +} + +type fakeTx struct { + c *fakeConn +} + +type boundCol struct { + Column string + Placeholder string + Ordinal int +} + +type fakeStmt struct { + memToucher + c *fakeConn + q string // just for debugging + + cmd string + table string + panic string + wait time.Duration + + next *fakeStmt // used for returning multiple results. + + closed bool + + colName []string // used by CREATE, INSERT, SELECT (selected columns) + colType []string // used by CREATE + colValue []any // used by INSERT (mix of strings and "?" for bound params) + placeholders int // used by INSERT/SELECT: number of ? params + + whereCol []boundCol // used by SELECT (all placeholders) + + placeholderConverter []driver.ValueConverter // used by INSERT +} + +var fdriver driver.Driver = &fakeDriver{} + +func init() { + Register("test", fdriver) +} + +func contains(list []string, y string) bool { + for _, x := range list { + if x == y { + return true + } + } + return false +} + +type Dummy struct { + driver.Driver +} + +func TestDrivers(t *testing.T) { + unregisterAllDrivers() + Register("test", fdriver) + Register("invalid", Dummy{}) + all := Drivers() + if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") { + t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all) + } +} + +// hook to simulate connection failures +var hookOpenErr struct { + sync.Mutex + fn func() error +} + +func setHookOpenErr(fn func() error) { + hookOpenErr.Lock() + defer hookOpenErr.Unlock() + hookOpenErr.fn = fn +} + +// Supports dsn forms: +// +// <dbname> +// <dbname>;<opts> (only currently supported option is `badConn`, +// which causes driver.ErrBadConn to be returned on +// every other conn.Begin()) +func (d *fakeDriver) Open(dsn string) (driver.Conn, error) { + hookOpenErr.Lock() + fn := hookOpenErr.fn + hookOpenErr.Unlock() + if fn != nil { + if err := fn(); err != nil { + return nil, err + } + } + parts := strings.Split(dsn, ";") + if len(parts) < 1 { + return nil, errors.New("fakedb: no database name") + } + name := parts[0] + + db := d.getDB(name) + + d.mu.Lock() + d.openCount++ + d.mu.Unlock() + conn := &fakeConn{db: db} + + if len(parts) >= 2 && parts[1] == "badConn" { + conn.bad = true + } + if d.waitCh != nil { + d.waitingCh <- struct{}{} + <-d.waitCh + d.waitCh = nil + d.waitingCh = nil + } + return conn, nil +} + +func (d *fakeDriver) getDB(name string) *fakeDB { + d.mu.Lock() + defer d.mu.Unlock() + if d.dbs == nil { + d.dbs = make(map[string]*fakeDB) + } + db, ok := d.dbs[name] + if !ok { + db = &fakeDB{name: name} + d.dbs[name] = db + } + return db +} + +func (db *fakeDB) wipe() { + db.mu.Lock() + defer db.mu.Unlock() + db.tables = nil +} + +func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error { + db.mu.Lock() + defer db.mu.Unlock() + if db.tables == nil { + db.tables = make(map[string]*table) + } + if _, exist := db.tables[name]; exist { + return fmt.Errorf("fakedb: table %q already exists", name) + } + if len(columnNames) != len(columnTypes) { + return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d", + name, len(columnNames), len(columnTypes)) + } + db.tables[name] = &table{colname: columnNames, coltype: columnTypes} + return nil +} + +// must be called with db.mu lock held +func (db *fakeDB) table(table string) (*table, bool) { + if db.tables == nil { + return nil, false + } + t, ok := db.tables[table] + return t, ok +} + +func (db *fakeDB) columnType(table, column string) (typ string, ok bool) { + db.mu.Lock() + defer db.mu.Unlock() + t, ok := db.table(table) + if !ok { + return + } + for n, cname := range t.colname { + if cname == column { + return t.coltype[n], true + } + } + return "", false +} + +func (c *fakeConn) isBad() bool { + if c.stickyBad { + return true + } else if c.bad { + if c.db == nil { + return false + } + // alternate between bad conn and not bad conn + c.db.badConn = !c.db.badConn + return c.db.badConn + } else { + return false + } +} + +func (c *fakeConn) isDirtyAndMark() bool { + if c.skipDirtySession { + return false + } + if c.currTx != nil { + c.dirtySession = true + return false + } + if c.dirtySession { + return true + } + c.dirtySession = true + return false +} + +func (c *fakeConn) Begin() (driver.Tx, error) { + if c.isBad() { + return nil, fakeError{Wrapped: driver.ErrBadConn} + } + if c.currTx != nil { + return nil, errors.New("fakedb: already in a transaction") + } + c.touchMem() + c.currTx = &fakeTx{c: c} + return c.currTx, nil +} + +var hookPostCloseConn struct { + sync.Mutex + fn func(*fakeConn, error) +} + +func setHookpostCloseConn(fn func(*fakeConn, error)) { + hookPostCloseConn.Lock() + defer hookPostCloseConn.Unlock() + hookPostCloseConn.fn = fn +} + +var testStrictClose *testing.T + +// setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close +// fails to close. If nil, the check is disabled. +func setStrictFakeConnClose(t *testing.T) { + testStrictClose = t +} + +func (c *fakeConn) ResetSession(ctx context.Context) error { + c.dirtySession = false + c.currTx = nil + if c.isBad() { + return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn} + } + return nil +} + +var _ driver.Validator = (*fakeConn)(nil) + +func (c *fakeConn) IsValid() bool { + return !c.isBad() +} + +func (c *fakeConn) Close() (err error) { + drv := fdriver.(*fakeDriver) + defer func() { + if err != nil && testStrictClose != nil { + testStrictClose.Errorf("failed to close a test fakeConn: %v", err) + } + hookPostCloseConn.Lock() + fn := hookPostCloseConn.fn + hookPostCloseConn.Unlock() + if fn != nil { + fn(c, err) + } + if err == nil { + drv.mu.Lock() + drv.closeCount++ + drv.mu.Unlock() + } + }() + c.touchMem() + if c.currTx != nil { + return errors.New("fakedb: can't close fakeConn; in a Transaction") + } + if c.db == nil { + return errors.New("fakedb: can't close fakeConn; already closed") + } + if c.stmtsMade > c.stmtsClosed { + return errors.New("fakedb: can't close; dangling statement(s)") + } + c.db = nil + return nil +} + +func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error { + for _, arg := range args { + switch arg.Value.(type) { + case int64, float64, bool, nil, []byte, string, time.Time: + default: + if !allowAny { + return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) + } + } + } + return nil +} + +func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) { + // Ensure that ExecContext is called if available. + panic("ExecContext was not called.") +} + +func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + // This is an optional interface, but it's implemented here + // just to check that all the args are of the proper types. + // ErrSkip is returned so the caller acts as if we didn't + // implement this at all. + err := checkSubsetTypes(c.db.allowAny, args) + if err != nil { + return nil, err + } + return nil, driver.ErrSkip +} + +func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) { + // Ensure that ExecContext is called if available. + panic("QueryContext was not called.") +} + +func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + // This is an optional interface, but it's implemented here + // just to check that all the args are of the proper types. + // ErrSkip is returned so the caller acts as if we didn't + // implement this at all. + err := checkSubsetTypes(c.db.allowAny, args) + if err != nil { + return nil, err + } + return nil, driver.ErrSkip +} + +func errf(msg string, args ...any) error { + return errors.New("fakedb: " + fmt.Sprintf(msg, args...)) +} + +// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? +// (note that where columns must always contain ? marks, +// just a limitation for fakedb) +func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { + if len(parts) != 3 { + stmt.Close() + return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) + } + stmt.table = parts[0] + + stmt.colName = strings.Split(parts[1], ",") + for n, colspec := range strings.Split(parts[2], ",") { + if colspec == "" { + continue + } + nameVal := strings.Split(colspec, "=") + if len(nameVal) != 2 { + stmt.Close() + return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) + } + column, value := nameVal[0], nameVal[1] + _, ok := c.db.columnType(stmt.table, column) + if !ok { + stmt.Close() + return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column) + } + if !strings.HasPrefix(value, "?") { + stmt.Close() + return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark", + stmt.table, column) + } + stmt.placeholders++ + stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders}) + } + return stmt, nil +} + +// parts are table|col=type,col2=type2 +func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { + if len(parts) != 2 { + stmt.Close() + return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) + } + stmt.table = parts[0] + for n, colspec := range strings.Split(parts[1], ",") { + nameType := strings.Split(colspec, "=") + if len(nameType) != 2 { + stmt.Close() + return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) + } + stmt.colName = append(stmt.colName, nameType[0]) + stmt.colType = append(stmt.colType, nameType[1]) + } + return stmt, nil +} + +// parts are table|col=?,col2=val +func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) { + if len(parts) != 2 { + stmt.Close() + return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) + } + stmt.table = parts[0] + for n, colspec := range strings.Split(parts[1], ",") { + nameVal := strings.Split(colspec, "=") + if len(nameVal) != 2 { + stmt.Close() + return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n) + } + column, value := nameVal[0], nameVal[1] + ctype, ok := c.db.columnType(stmt.table, column) + if !ok { + stmt.Close() + return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column) + } + stmt.colName = append(stmt.colName, column) + + if !strings.HasPrefix(value, "?") { + var subsetVal any + // Convert to driver subset type + switch ctype { + case "string": + subsetVal = []byte(value) + case "blob": + subsetVal = []byte(value) + case "int32": + i, err := strconv.Atoi(value) + if err != nil { + stmt.Close() + return nil, errf("invalid conversion to int32 from %q", value) + } + subsetVal = int64(i) // int64 is a subset type, but not int32 + case "table": // For testing cursor reads. + c.skipDirtySession = true + vparts := strings.Split(value, "!") + + substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ","))) + if err != nil { + return nil, err + } + cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{}) + substmt.Close() + if err != nil { + return nil, err + } + subsetVal = cursor + default: + stmt.Close() + return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) + } + stmt.colValue = append(stmt.colValue, subsetVal) + } else { + stmt.placeholders++ + stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype)) + stmt.colValue = append(stmt.colValue, value) + } + } + return stmt, nil +} + +// hook to simulate broken connections +var hookPrepareBadConn func() bool + +func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { + panic("use PrepareContext") +} + +func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + c.numPrepare++ + if c.db == nil { + panic("nil c.db; conn = " + fmt.Sprintf("%#v", c)) + } + + if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) { + return nil, fakeError{Message: "Prepare: Sticky Bad", Wrapped: driver.ErrBadConn} + } + + c.touchMem() + var firstStmt, prev *fakeStmt + for _, query := range strings.Split(query, ";") { + parts := strings.Split(query, "|") + if len(parts) < 1 { + return nil, errf("empty query") + } + stmt := &fakeStmt{q: query, c: c, memToucher: c} + if firstStmt == nil { + firstStmt = stmt + } + if len(parts) >= 3 { + switch parts[0] { + case "PANIC": + stmt.panic = parts[1] + parts = parts[2:] + case "WAIT": + wait, err := time.ParseDuration(parts[1]) + if err != nil { + return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err) + } + parts = parts[2:] + stmt.wait = wait + } + } + cmd := parts[0] + stmt.cmd = cmd + parts = parts[1:] + + if c.waiter != nil { + c.waiter(ctx) + if err := ctx.Err(); err != nil { + return nil, err + } + } + + if stmt.wait > 0 { + wait := time.NewTimer(stmt.wait) + select { + case <-wait.C: + case <-ctx.Done(): + wait.Stop() + return nil, ctx.Err() + } + } + + c.incrStat(&c.stmtsMade) + var err error + switch cmd { + case "WIPE": + // Nothing + case "SELECT": + stmt, err = c.prepareSelect(stmt, parts) + case "CREATE": + stmt, err = c.prepareCreate(stmt, parts) + case "INSERT": + stmt, err = c.prepareInsert(ctx, stmt, parts) + case "NOSERT": + // Do all the prep-work like for an INSERT but don't actually insert the row. + // Used for some of the concurrent tests. + stmt, err = c.prepareInsert(ctx, stmt, parts) + default: + stmt.Close() + return nil, errf("unsupported command type %q", cmd) + } + if err != nil { + return nil, err + } + if prev != nil { + prev.next = stmt + } + prev = stmt + } + return firstStmt, nil +} + +func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { + if s.panic == "ColumnConverter" { + panic(s.panic) + } + if len(s.placeholderConverter) == 0 { + return driver.DefaultParameterConverter + } + return s.placeholderConverter[idx] +} + +func (s *fakeStmt) Close() error { + if s.panic == "Close" { + panic(s.panic) + } + if s.c == nil { + panic("nil conn in fakeStmt.Close") + } + if s.c.db == nil { + panic("in fakeStmt.Close, conn's db is nil (already closed)") + } + s.touchMem() + if !s.closed { + s.c.incrStat(&s.c.stmtsClosed) + s.closed = true + } + if s.next != nil { + s.next.Close() + } + return nil +} + +var errClosed = errors.New("fakedb: statement has been closed") + +// hook to simulate broken connections +var hookExecBadConn func() bool + +func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) { + panic("Using ExecContext") +} + +var errFakeConnSessionDirty = errors.New("fakedb: session is dirty") + +func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + if s.panic == "Exec" { + panic(s.panic) + } + if s.closed { + return nil, errClosed + } + + if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) { + return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn} + } + if s.c.isDirtyAndMark() { + return nil, errFakeConnSessionDirty + } + + err := checkSubsetTypes(s.c.db.allowAny, args) + if err != nil { + return nil, err + } + s.touchMem() + + if s.wait > 0 { + time.Sleep(s.wait) + } + + select { + default: + case <-ctx.Done(): + return nil, ctx.Err() + } + + db := s.c.db + switch s.cmd { + case "WIPE": + db.wipe() + return driver.ResultNoRows, nil + case "CREATE": + if err := db.createTable(s.table, s.colName, s.colType); err != nil { + return nil, err + } + return driver.ResultNoRows, nil + case "INSERT": + return s.execInsert(args, true) + case "NOSERT": + // Do all the prep-work like for an INSERT but don't actually insert the row. + // Used for some of the concurrent tests. + return s.execInsert(args, false) + } + return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd) +} + +// When doInsert is true, add the row to the table. +// When doInsert is false do prep-work and error checking, but don't +// actually add the row to the table. +func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) { + db := s.c.db + if len(args) != s.placeholders { + panic("error in pkg db; should only get here if size is correct") + } + db.mu.Lock() + t, ok := db.table(s.table) + db.mu.Unlock() + if !ok { + return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) + } + + t.mu.Lock() + defer t.mu.Unlock() + + var cols []any + if doInsert { + cols = make([]any, len(t.colname)) + } + argPos := 0 + for n, colname := range s.colName { + colidx := t.columnIndex(colname) + if colidx == -1 { + return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname) + } + var val any + if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") { + if strvalue == "?" { + val = args[argPos].Value + } else { + // Assign value from argument placeholder name. + for _, a := range args { + if a.Name == strvalue[1:] { + val = a.Value + break + } + } + } + argPos++ + } else { + val = s.colValue[n] + } + if doInsert { + cols[colidx] = val + } + } + + if doInsert { + t.rows = append(t.rows, &row{cols: cols}) + } + return driver.RowsAffected(1), nil +} + +// hook to simulate broken connections +var hookQueryBadConn func() bool + +func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { + panic("Use QueryContext") +} + +func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + if s.panic == "Query" { + panic(s.panic) + } + if s.closed { + return nil, errClosed + } + + if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) { + return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn} + } + if s.c.isDirtyAndMark() { + return nil, errFakeConnSessionDirty + } + + err := checkSubsetTypes(s.c.db.allowAny, args) + if err != nil { + return nil, err + } + + s.touchMem() + db := s.c.db + if len(args) != s.placeholders { + panic("error in pkg db; should only get here if size is correct") + } + + setMRows := make([][]*row, 0, 1) + setColumns := make([][]string, 0, 1) + setColType := make([][]string, 0, 1) + + for { + db.mu.Lock() + t, ok := db.table(s.table) + db.mu.Unlock() + if !ok { + return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) + } + + if s.table == "magicquery" { + if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" { + if args[0].Value == "sleep" { + time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond) + } + } + } + if s.table == "tx_status" && s.colName[0] == "tx_status" { + txStatus := "autocommit" + if s.c.currTx != nil { + txStatus = "transaction" + } + cursor := &rowsCursor{ + parentMem: s.c, + posRow: -1, + rows: [][]*row{ + { + { + cols: []any{ + txStatus, + }, + }, + }, + }, + cols: [][]string{ + { + "tx_status", + }, + }, + colType: [][]string{ + { + "string", + }, + }, + errPos: -1, + } + return cursor, nil + } + + t.mu.Lock() + + colIdx := make(map[string]int) // select column name -> column index in table + for _, name := range s.colName { + idx := t.columnIndex(name) + if idx == -1 { + t.mu.Unlock() + return nil, fmt.Errorf("fakedb: unknown column name %q", name) + } + colIdx[name] = idx + } + + mrows := []*row{} + rows: + for _, trow := range t.rows { + // Process the where clause, skipping non-match rows. This is lazy + // and just uses fmt.Sprintf("%v") to test equality. Good enough + // for test code. + for _, wcol := range s.whereCol { + idx := t.columnIndex(wcol.Column) + if idx == -1 { + t.mu.Unlock() + return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol) + } + tcol := trow.cols[idx] + if bs, ok := tcol.([]byte); ok { + // lazy hack to avoid sprintf %v on a []byte + tcol = string(bs) + } + var argValue any + if wcol.Placeholder == "?" { + argValue = args[wcol.Ordinal-1].Value + } else { + // Assign arg value from placeholder name. + for _, a := range args { + if a.Name == wcol.Placeholder[1:] { + argValue = a.Value + break + } + } + } + if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) { + continue rows + } + } + mrow := &row{cols: make([]any, len(s.colName))} + for seli, name := range s.colName { + mrow.cols[seli] = trow.cols[colIdx[name]] + } + mrows = append(mrows, mrow) + } + + var colType []string + for _, column := range s.colName { + colType = append(colType, t.coltype[t.columnIndex(column)]) + } + + t.mu.Unlock() + + setMRows = append(setMRows, mrows) + setColumns = append(setColumns, s.colName) + setColType = append(setColType, colType) + + if s.next == nil { + break + } + s = s.next + } + + cursor := &rowsCursor{ + parentMem: s.c, + posRow: -1, + rows: setMRows, + cols: setColumns, + colType: setColType, + errPos: -1, + } + return cursor, nil +} + +func (s *fakeStmt) NumInput() int { + if s.panic == "NumInput" { + panic(s.panic) + } + return s.placeholders +} + +// hook to simulate broken connections +var hookCommitBadConn func() bool + +func (tx *fakeTx) Commit() error { + tx.c.currTx = nil + if hookCommitBadConn != nil && hookCommitBadConn() { + return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn} + } + tx.c.touchMem() + return nil +} + +// hook to simulate broken connections +var hookRollbackBadConn func() bool + +func (tx *fakeTx) Rollback() error { + tx.c.currTx = nil + if hookRollbackBadConn != nil && hookRollbackBadConn() { + return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn} + } + tx.c.touchMem() + return nil +} + +type rowsCursor struct { + parentMem memToucher + cols [][]string + colType [][]string + posSet int + posRow int + rows [][]*row + closed bool + + // errPos and err are for making Next return early with error. + errPos int + err error + + // a clone of slices to give out to clients, indexed by the + // original slice's first byte address. we clone them + // just so we're able to corrupt them on close. + bytesClone map[*byte][]byte + + // Every operation writes to line to enable the race detector + // check for data races. + // This is separate from the fakeConn.line to allow for drivers that + // can start multiple queries on the same transaction at the same time. + line int64 + + // closeErr is returned when rowsCursor.Close + closeErr error +} + +func (rc *rowsCursor) touchMem() { + rc.parentMem.touchMem() + rc.line++ +} + +func (rc *rowsCursor) Close() error { + rc.touchMem() + rc.parentMem.touchMem() + rc.closed = true + return rc.closeErr +} + +func (rc *rowsCursor) Columns() []string { + return rc.cols[rc.posSet] +} + +func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type { + return colTypeToReflectType(rc.colType[rc.posSet][index]) +} + +var rowsCursorNextHook func(dest []driver.Value) error + +func (rc *rowsCursor) Next(dest []driver.Value) error { + if rowsCursorNextHook != nil { + return rowsCursorNextHook(dest) + } + + if rc.closed { + return errors.New("fakedb: cursor is closed") + } + rc.touchMem() + rc.posRow++ + if rc.posRow == rc.errPos { + return rc.err + } + if rc.posRow >= len(rc.rows[rc.posSet]) { + return io.EOF // per interface spec + } + for i, v := range rc.rows[rc.posSet][rc.posRow].cols { + // TODO(bradfitz): convert to subset types? naah, I + // think the subset types should only be input to + // driver, but the sql package should be able to handle + // a wider range of types coming out of drivers. all + // for ease of drivers, and to prevent drivers from + // messing up conversions or doing them differently. + dest[i] = v + + if bs, ok := v.([]byte); ok { + if rc.bytesClone == nil { + rc.bytesClone = make(map[*byte][]byte) + } + clone, ok := rc.bytesClone[&bs[0]] + if !ok { + clone = make([]byte, len(bs)) + copy(clone, bs) + rc.bytesClone[&bs[0]] = clone + } + dest[i] = clone + } + } + return nil +} + +func (rc *rowsCursor) HasNextResultSet() bool { + rc.touchMem() + return rc.posSet < len(rc.rows)-1 +} + +func (rc *rowsCursor) NextResultSet() error { + rc.touchMem() + if rc.HasNextResultSet() { + rc.posSet++ + rc.posRow = -1 + return nil + } + return io.EOF // Per interface spec. +} + +// fakeDriverString is like driver.String, but indirects pointers like +// DefaultValueConverter. +// +// This could be surprising behavior to retroactively apply to +// driver.String now that Go1 is out, but this is convenient for +// our TestPointerParamsAndScans. +type fakeDriverString struct{} + +func (fakeDriverString) ConvertValue(v any) (driver.Value, error) { + switch c := v.(type) { + case string, []byte: + return v, nil + case *string: + if c == nil { + return nil, nil + } + return *c, nil + } + return fmt.Sprintf("%v", v), nil +} + +type anyTypeConverter struct{} + +func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) { + return v, nil +} + +func converterForType(typ string) driver.ValueConverter { + switch typ { + case "bool": + return driver.Bool + case "nullbool": + return driver.Null{Converter: driver.Bool} + case "byte", "int16": + return driver.NotNull{Converter: driver.DefaultParameterConverter} + case "int32": + return driver.Int32 + case "nullbyte", "nullint32", "nullint16": + return driver.Null{Converter: driver.DefaultParameterConverter} + case "string": + return driver.NotNull{Converter: fakeDriverString{}} + case "nullstring": + return driver.Null{Converter: fakeDriverString{}} + case "int64": + // TODO(coopernurse): add type-specific converter + return driver.NotNull{Converter: driver.DefaultParameterConverter} + case "nullint64": + // TODO(coopernurse): add type-specific converter + return driver.Null{Converter: driver.DefaultParameterConverter} + case "float64": + // TODO(coopernurse): add type-specific converter + return driver.NotNull{Converter: driver.DefaultParameterConverter} + case "nullfloat64": + // TODO(coopernurse): add type-specific converter + return driver.Null{Converter: driver.DefaultParameterConverter} + case "datetime": + return driver.NotNull{Converter: driver.DefaultParameterConverter} + case "nulldatetime": + return driver.Null{Converter: driver.DefaultParameterConverter} + case "any": + return anyTypeConverter{} + } + panic("invalid fakedb column type of " + typ) +} + +func colTypeToReflectType(typ string) reflect.Type { + switch typ { + case "bool": + return reflect.TypeOf(false) + case "nullbool": + return reflect.TypeOf(NullBool{}) + case "int16": + return reflect.TypeOf(int16(0)) + case "nullint16": + return reflect.TypeOf(NullInt16{}) + case "int32": + return reflect.TypeOf(int32(0)) + case "nullint32": + return reflect.TypeOf(NullInt32{}) + case "string": + return reflect.TypeOf("") + case "nullstring": + return reflect.TypeOf(NullString{}) + case "int64": + return reflect.TypeOf(int64(0)) + case "nullint64": + return reflect.TypeOf(NullInt64{}) + case "float64": + return reflect.TypeOf(float64(0)) + case "nullfloat64": + return reflect.TypeOf(NullFloat64{}) + case "datetime": + return reflect.TypeOf(time.Time{}) + case "any": + return reflect.TypeOf(new(any)).Elem() + } + panic("invalid fakedb column type of " + typ) +} diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go new file mode 100644 index 0000000..ad17eb3 --- /dev/null +++ b/src/database/sql/sql.go @@ -0,0 +1,3406 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sql provides a generic interface around SQL (or SQL-like) +// databases. +// +// The sql package must be used in conjunction with a database driver. +// See https://golang.org/s/sqldrivers for a list of drivers. +// +// Drivers that do not support context cancellation will not return until +// after the query is completed. +// +// For usage examples, see the wiki page at +// https://golang.org/s/sqlwiki. +package sql + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" + "io" + "reflect" + "runtime" + "sort" + "strconv" + "sync" + "sync/atomic" + "time" +) + +var ( + driversMu sync.RWMutex + drivers = make(map[string]driver.Driver) +) + +// nowFunc returns the current time; it's overridden in tests. +var nowFunc = time.Now + +// Register makes a database driver available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, driver driver.Driver) { + driversMu.Lock() + defer driversMu.Unlock() + if driver == nil { + panic("sql: Register driver is nil") + } + if _, dup := drivers[name]; dup { + panic("sql: Register called twice for driver " + name) + } + drivers[name] = driver +} + +func unregisterAllDrivers() { + driversMu.Lock() + defer driversMu.Unlock() + // For tests. + drivers = make(map[string]driver.Driver) +} + +// Drivers returns a sorted list of the names of the registered drivers. +func Drivers() []string { + driversMu.RLock() + defer driversMu.RUnlock() + list := make([]string, 0, len(drivers)) + for name := range drivers { + list = append(list, name) + } + sort.Strings(list) + return list +} + +// A NamedArg is a named argument. NamedArg values may be used as +// arguments to Query or Exec and bind to the corresponding named +// parameter in the SQL statement. +// +// For a more concise way to create NamedArg values, see +// the Named function. +type NamedArg struct { + _NamedFieldsRequired struct{} + + // Name is the name of the parameter placeholder. + // + // If empty, the ordinal position in the argument list will be + // used. + // + // Name must omit any symbol prefix. + Name string + + // Value is the value of the parameter. + // It may be assigned the same value types as the query + // arguments. + Value any +} + +// Named provides a more concise way to create NamedArg values. +// +// Example usage: +// +// db.ExecContext(ctx, ` +// delete from Invoice +// where +// TimeCreated < @end +// and TimeCreated >= @start;`, +// sql.Named("start", startTime), +// sql.Named("end", endTime), +// ) +func Named(name string, value any) NamedArg { + // This method exists because the go1compat promise + // doesn't guarantee that structs don't grow more fields, + // so unkeyed struct literals are a vet error. Thus, we don't + // want to allow sql.NamedArg{name, value}. + return NamedArg{Name: name, Value: value} +} + +// IsolationLevel is the transaction isolation level used in TxOptions. +type IsolationLevel int + +// Various isolation levels that drivers may support in BeginTx. +// If a driver does not support a given isolation level an error may be returned. +// +// See https://en.wikipedia.org/wiki/Isolation_(database_systems)#Isolation_levels. +const ( + LevelDefault IsolationLevel = iota + LevelReadUncommitted + LevelReadCommitted + LevelWriteCommitted + LevelRepeatableRead + LevelSnapshot + LevelSerializable + LevelLinearizable +) + +// String returns the name of the transaction isolation level. +func (i IsolationLevel) String() string { + switch i { + case LevelDefault: + return "Default" + case LevelReadUncommitted: + return "Read Uncommitted" + case LevelReadCommitted: + return "Read Committed" + case LevelWriteCommitted: + return "Write Committed" + case LevelRepeatableRead: + return "Repeatable Read" + case LevelSnapshot: + return "Snapshot" + case LevelSerializable: + return "Serializable" + case LevelLinearizable: + return "Linearizable" + default: + return "IsolationLevel(" + strconv.Itoa(int(i)) + ")" + } +} + +var _ fmt.Stringer = LevelDefault + +// TxOptions holds the transaction options to be used in DB.BeginTx. +type TxOptions struct { + // Isolation is the transaction isolation level. + // If zero, the driver or database's default level is used. + Isolation IsolationLevel + ReadOnly bool +} + +// RawBytes is a byte slice that holds a reference to memory owned by +// the database itself. After a Scan into a RawBytes, the slice is only +// valid until the next call to Next, Scan, or Close. +type RawBytes []byte + +// NullString represents a string that may be null. +// NullString implements the Scanner interface so +// it can be used as a scan destination: +// +// var s NullString +// err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&s) +// ... +// if s.Valid { +// // use s.String +// } else { +// // NULL value +// } +type NullString struct { + String string + Valid bool // Valid is true if String is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullString) Scan(value any) error { + if value == nil { + ns.String, ns.Valid = "", false + return nil + } + ns.Valid = true + return convertAssign(&ns.String, value) +} + +// Value implements the driver Valuer interface. +func (ns NullString) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return ns.String, nil +} + +// NullInt64 represents an int64 that may be null. +// NullInt64 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullInt64 struct { + Int64 int64 + Valid bool // Valid is true if Int64 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullInt64) Scan(value any) error { + if value == nil { + n.Int64, n.Valid = 0, false + return nil + } + n.Valid = true + return convertAssign(&n.Int64, value) +} + +// Value implements the driver Valuer interface. +func (n NullInt64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Int64, nil +} + +// NullInt32 represents an int32 that may be null. +// NullInt32 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullInt32 struct { + Int32 int32 + Valid bool // Valid is true if Int32 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullInt32) Scan(value any) error { + if value == nil { + n.Int32, n.Valid = 0, false + return nil + } + n.Valid = true + return convertAssign(&n.Int32, value) +} + +// Value implements the driver Valuer interface. +func (n NullInt32) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return int64(n.Int32), nil +} + +// NullInt16 represents an int16 that may be null. +// NullInt16 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullInt16 struct { + Int16 int16 + Valid bool // Valid is true if Int16 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullInt16) Scan(value any) error { + if value == nil { + n.Int16, n.Valid = 0, false + return nil + } + err := convertAssign(&n.Int16, value) + n.Valid = err == nil + return err +} + +// Value implements the driver Valuer interface. +func (n NullInt16) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return int64(n.Int16), nil +} + +// NullByte represents a byte that may be null. +// NullByte implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullByte struct { + Byte byte + Valid bool // Valid is true if Byte is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullByte) Scan(value any) error { + if value == nil { + n.Byte, n.Valid = 0, false + return nil + } + err := convertAssign(&n.Byte, value) + n.Valid = err == nil + return err +} + +// Value implements the driver Valuer interface. +func (n NullByte) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return int64(n.Byte), nil +} + +// NullFloat64 represents a float64 that may be null. +// NullFloat64 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullFloat64 struct { + Float64 float64 + Valid bool // Valid is true if Float64 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullFloat64) Scan(value any) error { + if value == nil { + n.Float64, n.Valid = 0, false + return nil + } + n.Valid = true + return convertAssign(&n.Float64, value) +} + +// Value implements the driver Valuer interface. +func (n NullFloat64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Float64, nil +} + +// NullBool represents a bool that may be null. +// NullBool implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullBool struct { + Bool bool + Valid bool // Valid is true if Bool is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullBool) Scan(value any) error { + if value == nil { + n.Bool, n.Valid = false, false + return nil + } + n.Valid = true + return convertAssign(&n.Bool, value) +} + +// Value implements the driver Valuer interface. +func (n NullBool) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Bool, nil +} + +// NullTime represents a time.Time that may be null. +// NullTime implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullTime) Scan(value any) error { + if value == nil { + n.Time, n.Valid = time.Time{}, false + return nil + } + n.Valid = true + return convertAssign(&n.Time, value) +} + +// Value implements the driver Valuer interface. +func (n NullTime) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Time, nil +} + +// Scanner is an interface used by Scan. +type Scanner interface { + // Scan assigns a value from a database driver. + // + // The src value will be of one of the following types: + // + // int64 + // float64 + // bool + // []byte + // string + // time.Time + // nil - for NULL values + // + // An error should be returned if the value cannot be stored + // without loss of information. + // + // Reference types such as []byte are only valid until the next call to Scan + // and should not be retained. Their underlying memory is owned by the driver. + // If retention is necessary, copy their values before the next call to Scan. + Scan(src any) error +} + +// Out may be used to retrieve OUTPUT value parameters from stored procedures. +// +// Not all drivers and databases support OUTPUT value parameters. +// +// Example usage: +// +// var outArg string +// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", sql.Out{Dest: &outArg})) +type Out struct { + _NamedFieldsRequired struct{} + + // Dest is a pointer to the value that will be set to the result of the + // stored procedure's OUTPUT parameter. + Dest any + + // In is whether the parameter is an INOUT parameter. If so, the input value to the stored + // procedure is the dereferenced value of Dest's pointer, which is then replaced with + // the output value. + In bool +} + +// ErrNoRows is returned by Scan when QueryRow doesn't return a +// row. In such a case, QueryRow returns a placeholder *Row value that +// defers this error until a Scan. +var ErrNoRows = errors.New("sql: no rows in result set") + +// DB is a database handle representing a pool of zero or more +// underlying connections. It's safe for concurrent use by multiple +// goroutines. +// +// The sql package creates and frees connections automatically; it +// also maintains a free pool of idle connections. If the database has +// a concept of per-connection state, such state can be reliably observed +// within a transaction (Tx) or connection (Conn). Once DB.Begin is called, the +// returned Tx is bound to a single connection. Once Commit or +// Rollback is called on the transaction, that transaction's +// connection is returned to DB's idle connection pool. The pool size +// can be controlled with SetMaxIdleConns. +type DB struct { + // Total time waited for new connections. + waitDuration atomic.Int64 + + connector driver.Connector + // numClosed is an atomic counter which represents a total number of + // closed connections. Stmt.openStmt checks it before cleaning closed + // connections in Stmt.css. + numClosed atomic.Uint64 + + mu sync.Mutex // protects following fields + freeConn []*driverConn // free connections ordered by returnedAt oldest to newest + connRequests map[uint64]chan connRequest + nextRequest uint64 // Next key to use in connRequests. + numOpen int // number of opened and pending open connections + // Used to signal the need for new connections + // a goroutine running connectionOpener() reads on this chan and + // maybeOpenNewConnections sends on the chan (one send per needed connection) + // It is closed during db.Close(). The close tells the connectionOpener + // goroutine to exit. + openerCh chan struct{} + closed bool + dep map[finalCloser]depSet + lastPut map[*driverConn]string // stacktrace of last conn's put; debug only + maxIdleCount int // zero means defaultMaxIdleConns; negative means 0 + maxOpen int // <= 0 means unlimited + maxLifetime time.Duration // maximum amount of time a connection may be reused + maxIdleTime time.Duration // maximum amount of time a connection may be idle before being closed + cleanerCh chan struct{} + waitCount int64 // Total number of connections waited for. + maxIdleClosed int64 // Total number of connections closed due to idle count. + maxIdleTimeClosed int64 // Total number of connections closed due to idle time. + maxLifetimeClosed int64 // Total number of connections closed due to max connection lifetime limit. + + stop func() // stop cancels the connection opener. +} + +// connReuseStrategy determines how (*DB).conn returns database connections. +type connReuseStrategy uint8 + +const ( + // alwaysNewConn forces a new connection to the database. + alwaysNewConn connReuseStrategy = iota + // cachedOrNewConn returns a cached connection, if available, else waits + // for one to become available (if MaxOpenConns has been reached) or + // creates a new database connection. + cachedOrNewConn +) + +// driverConn wraps a driver.Conn with a mutex, to +// be held during all calls into the Conn. (including any calls onto +// interfaces returned via that Conn, such as calls on Tx, Stmt, +// Result, Rows) +type driverConn struct { + db *DB + createdAt time.Time + + sync.Mutex // guards following + ci driver.Conn + needReset bool // The connection session should be reset before use if true. + closed bool + finalClosed bool // ci.Close has been called + openStmt map[*driverStmt]bool + + // guarded by db.mu + inUse bool + returnedAt time.Time // Time the connection was created or returned. + onPut []func() // code (with db.mu held) run when conn is next returned + dbmuClosed bool // same as closed, but guarded by db.mu, for removeClosedStmtLocked +} + +func (dc *driverConn) releaseConn(err error) { + dc.db.putConn(dc, err, true) +} + +func (dc *driverConn) removeOpenStmt(ds *driverStmt) { + dc.Lock() + defer dc.Unlock() + delete(dc.openStmt, ds) +} + +func (dc *driverConn) expired(timeout time.Duration) bool { + if timeout <= 0 { + return false + } + return dc.createdAt.Add(timeout).Before(nowFunc()) +} + +// resetSession checks if the driver connection needs the +// session to be reset and if required, resets it. +func (dc *driverConn) resetSession(ctx context.Context) error { + dc.Lock() + defer dc.Unlock() + + if !dc.needReset { + return nil + } + if cr, ok := dc.ci.(driver.SessionResetter); ok { + return cr.ResetSession(ctx) + } + return nil +} + +// validateConnection checks if the connection is valid and can +// still be used. It also marks the session for reset if required. +func (dc *driverConn) validateConnection(needsReset bool) bool { + dc.Lock() + defer dc.Unlock() + + if needsReset { + dc.needReset = true + } + if cv, ok := dc.ci.(driver.Validator); ok { + return cv.IsValid() + } + return true +} + +// prepareLocked prepares the query on dc. When cg == nil the dc must keep track of +// the prepared statements in a pool. +func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) { + si, err := ctxDriverPrepare(ctx, dc.ci, query) + if err != nil { + return nil, err + } + ds := &driverStmt{Locker: dc, si: si} + + // No need to manage open statements if there is a single connection grabber. + if cg != nil { + return ds, nil + } + + // Track each driverConn's open statements, so we can close them + // before closing the conn. + // + // Wrap all driver.Stmt is *driverStmt to ensure they are only closed once. + if dc.openStmt == nil { + dc.openStmt = make(map[*driverStmt]bool) + } + dc.openStmt[ds] = true + return ds, nil +} + +// the dc.db's Mutex is held. +func (dc *driverConn) closeDBLocked() func() error { + dc.Lock() + defer dc.Unlock() + if dc.closed { + return func() error { return errors.New("sql: duplicate driverConn close") } + } + dc.closed = true + return dc.db.removeDepLocked(dc, dc) +} + +func (dc *driverConn) Close() error { + dc.Lock() + if dc.closed { + dc.Unlock() + return errors.New("sql: duplicate driverConn close") + } + dc.closed = true + dc.Unlock() // not defer; removeDep finalClose calls may need to lock + + // And now updates that require holding dc.mu.Lock. + dc.db.mu.Lock() + dc.dbmuClosed = true + fn := dc.db.removeDepLocked(dc, dc) + dc.db.mu.Unlock() + return fn() +} + +func (dc *driverConn) finalClose() error { + var err error + + // Each *driverStmt has a lock to the dc. Copy the list out of the dc + // before calling close on each stmt. + var openStmt []*driverStmt + withLock(dc, func() { + openStmt = make([]*driverStmt, 0, len(dc.openStmt)) + for ds := range dc.openStmt { + openStmt = append(openStmt, ds) + } + dc.openStmt = nil + }) + for _, ds := range openStmt { + ds.Close() + } + withLock(dc, func() { + dc.finalClosed = true + err = dc.ci.Close() + dc.ci = nil + }) + + dc.db.mu.Lock() + dc.db.numOpen-- + dc.db.maybeOpenNewConnections() + dc.db.mu.Unlock() + + dc.db.numClosed.Add(1) + return err +} + +// driverStmt associates a driver.Stmt with the +// *driverConn from which it came, so the driverConn's lock can be +// held during calls. +type driverStmt struct { + sync.Locker // the *driverConn + si driver.Stmt + closed bool + closeErr error // return value of previous Close call +} + +// Close ensures driver.Stmt is only closed once and always returns the same +// result. +func (ds *driverStmt) Close() error { + ds.Lock() + defer ds.Unlock() + if ds.closed { + return ds.closeErr + } + ds.closed = true + ds.closeErr = ds.si.Close() + return ds.closeErr +} + +// depSet is a finalCloser's outstanding dependencies +type depSet map[any]bool // set of true bools + +// The finalCloser interface is used by (*DB).addDep and related +// dependency reference counting. +type finalCloser interface { + // finalClose is called when the reference count of an object + // goes to zero. (*DB).mu is not held while calling it. + finalClose() error +} + +// addDep notes that x now depends on dep, and x's finalClose won't be +// called until all of x's dependencies are removed with removeDep. +func (db *DB) addDep(x finalCloser, dep any) { + db.mu.Lock() + defer db.mu.Unlock() + db.addDepLocked(x, dep) +} + +func (db *DB) addDepLocked(x finalCloser, dep any) { + if db.dep == nil { + db.dep = make(map[finalCloser]depSet) + } + xdep := db.dep[x] + if xdep == nil { + xdep = make(depSet) + db.dep[x] = xdep + } + xdep[dep] = true +} + +// removeDep notes that x no longer depends on dep. +// If x still has dependencies, nil is returned. +// If x no longer has any dependencies, its finalClose method will be +// called and its error value will be returned. +func (db *DB) removeDep(x finalCloser, dep any) error { + db.mu.Lock() + fn := db.removeDepLocked(x, dep) + db.mu.Unlock() + return fn() +} + +func (db *DB) removeDepLocked(x finalCloser, dep any) func() error { + xdep, ok := db.dep[x] + if !ok { + panic(fmt.Sprintf("unpaired removeDep: no deps for %T", x)) + } + + l0 := len(xdep) + delete(xdep, dep) + + switch len(xdep) { + case l0: + // Nothing removed. Shouldn't happen. + panic(fmt.Sprintf("unpaired removeDep: no %T dep on %T", dep, x)) + case 0: + // No more dependencies. + delete(db.dep, x) + return x.finalClose + default: + // Dependencies remain. + return func() error { return nil } + } +} + +// This is the size of the connectionOpener request chan (DB.openerCh). +// This value should be larger than the maximum typical value +// used for db.maxOpen. If maxOpen is significantly larger than +// connectionRequestQueueSize then it is possible for ALL calls into the *DB +// to block until the connectionOpener can satisfy the backlog of requests. +var connectionRequestQueueSize = 1000000 + +type dsnConnector struct { + dsn string + driver driver.Driver +} + +func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) { + return t.driver.Open(t.dsn) +} + +func (t dsnConnector) Driver() driver.Driver { + return t.driver +} + +// OpenDB opens a database using a Connector, allowing drivers to +// bypass a string based data source name. +// +// Most users will open a database via a driver-specific connection +// helper function that returns a *DB. No database drivers are included +// in the Go standard library. See https://golang.org/s/sqldrivers for +// a list of third-party drivers. +// +// OpenDB may just validate its arguments without creating a connection +// to the database. To verify that the data source name is valid, call +// Ping. +// +// The returned DB is safe for concurrent use by multiple goroutines +// and maintains its own pool of idle connections. Thus, the OpenDB +// function should be called just once. It is rarely necessary to +// close a DB. +func OpenDB(c driver.Connector) *DB { + ctx, cancel := context.WithCancel(context.Background()) + db := &DB{ + connector: c, + openerCh: make(chan struct{}, connectionRequestQueueSize), + lastPut: make(map[*driverConn]string), + connRequests: make(map[uint64]chan connRequest), + stop: cancel, + } + + go db.connectionOpener(ctx) + + return db +} + +// Open opens a database specified by its database driver name and a +// driver-specific data source name, usually consisting of at least a +// database name and connection information. +// +// Most users will open a database via a driver-specific connection +// helper function that returns a *DB. No database drivers are included +// in the Go standard library. See https://golang.org/s/sqldrivers for +// a list of third-party drivers. +// +// Open may just validate its arguments without creating a connection +// to the database. To verify that the data source name is valid, call +// Ping. +// +// The returned DB is safe for concurrent use by multiple goroutines +// and maintains its own pool of idle connections. Thus, the Open +// function should be called just once. It is rarely necessary to +// close a DB. +func Open(driverName, dataSourceName string) (*DB, error) { + driversMu.RLock() + driveri, ok := drivers[driverName] + driversMu.RUnlock() + if !ok { + return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName) + } + + if driverCtx, ok := driveri.(driver.DriverContext); ok { + connector, err := driverCtx.OpenConnector(dataSourceName) + if err != nil { + return nil, err + } + return OpenDB(connector), nil + } + + return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil +} + +func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error { + var err error + if pinger, ok := dc.ci.(driver.Pinger); ok { + withLock(dc, func() { + err = pinger.Ping(ctx) + }) + } + release(err) + return err +} + +// PingContext verifies a connection to the database is still alive, +// establishing a connection if necessary. +func (db *DB) PingContext(ctx context.Context) error { + var dc *driverConn + var err error + + err = db.retry(func(strategy connReuseStrategy) error { + dc, err = db.conn(ctx, strategy) + return err + }) + + if err != nil { + return err + } + + return db.pingDC(ctx, dc, dc.releaseConn) +} + +// Ping verifies a connection to the database is still alive, +// establishing a connection if necessary. +// +// Ping uses context.Background internally; to specify the context, use +// PingContext. +func (db *DB) Ping() error { + return db.PingContext(context.Background()) +} + +// Close closes the database and prevents new queries from starting. +// Close then waits for all queries that have started processing on the server +// to finish. +// +// It is rare to Close a DB, as the DB handle is meant to be +// long-lived and shared between many goroutines. +func (db *DB) Close() error { + db.mu.Lock() + if db.closed { // Make DB.Close idempotent + db.mu.Unlock() + return nil + } + if db.cleanerCh != nil { + close(db.cleanerCh) + } + var err error + fns := make([]func() error, 0, len(db.freeConn)) + for _, dc := range db.freeConn { + fns = append(fns, dc.closeDBLocked()) + } + db.freeConn = nil + db.closed = true + for _, req := range db.connRequests { + close(req) + } + db.mu.Unlock() + for _, fn := range fns { + err1 := fn() + if err1 != nil { + err = err1 + } + } + db.stop() + if c, ok := db.connector.(io.Closer); ok { + err1 := c.Close() + if err1 != nil { + err = err1 + } + } + return err +} + +const defaultMaxIdleConns = 2 + +func (db *DB) maxIdleConnsLocked() int { + n := db.maxIdleCount + switch { + case n == 0: + // TODO(bradfitz): ask driver, if supported, for its default preference + return defaultMaxIdleConns + case n < 0: + return 0 + default: + return n + } +} + +func (db *DB) shortestIdleTimeLocked() time.Duration { + if db.maxIdleTime <= 0 { + return db.maxLifetime + } + if db.maxLifetime <= 0 { + return db.maxIdleTime + } + + min := db.maxIdleTime + if min > db.maxLifetime { + min = db.maxLifetime + } + return min +} + +// SetMaxIdleConns sets the maximum number of connections in the idle +// connection pool. +// +// If MaxOpenConns is greater than 0 but less than the new MaxIdleConns, +// then the new MaxIdleConns will be reduced to match the MaxOpenConns limit. +// +// If n <= 0, no idle connections are retained. +// +// The default max idle connections is currently 2. This may change in +// a future release. +func (db *DB) SetMaxIdleConns(n int) { + db.mu.Lock() + if n > 0 { + db.maxIdleCount = n + } else { + // No idle connections. + db.maxIdleCount = -1 + } + // Make sure maxIdle doesn't exceed maxOpen + if db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen { + db.maxIdleCount = db.maxOpen + } + var closing []*driverConn + idleCount := len(db.freeConn) + maxIdle := db.maxIdleConnsLocked() + if idleCount > maxIdle { + closing = db.freeConn[maxIdle:] + db.freeConn = db.freeConn[:maxIdle] + } + db.maxIdleClosed += int64(len(closing)) + db.mu.Unlock() + for _, c := range closing { + c.Close() + } +} + +// SetMaxOpenConns sets the maximum number of open connections to the database. +// +// If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than +// MaxIdleConns, then MaxIdleConns will be reduced to match the new +// MaxOpenConns limit. +// +// If n <= 0, then there is no limit on the number of open connections. +// The default is 0 (unlimited). +func (db *DB) SetMaxOpenConns(n int) { + db.mu.Lock() + db.maxOpen = n + if n < 0 { + db.maxOpen = 0 + } + syncMaxIdle := db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen + db.mu.Unlock() + if syncMaxIdle { + db.SetMaxIdleConns(n) + } +} + +// SetConnMaxLifetime sets the maximum amount of time a connection may be reused. +// +// Expired connections may be closed lazily before reuse. +// +// If d <= 0, connections are not closed due to a connection's age. +func (db *DB) SetConnMaxLifetime(d time.Duration) { + if d < 0 { + d = 0 + } + db.mu.Lock() + // Wake cleaner up when lifetime is shortened. + if d > 0 && d < db.maxLifetime && db.cleanerCh != nil { + select { + case db.cleanerCh <- struct{}{}: + default: + } + } + db.maxLifetime = d + db.startCleanerLocked() + db.mu.Unlock() +} + +// SetConnMaxIdleTime sets the maximum amount of time a connection may be idle. +// +// Expired connections may be closed lazily before reuse. +// +// If d <= 0, connections are not closed due to a connection's idle time. +func (db *DB) SetConnMaxIdleTime(d time.Duration) { + if d < 0 { + d = 0 + } + db.mu.Lock() + defer db.mu.Unlock() + + // Wake cleaner up when idle time is shortened. + if d > 0 && d < db.maxIdleTime && db.cleanerCh != nil { + select { + case db.cleanerCh <- struct{}{}: + default: + } + } + db.maxIdleTime = d + db.startCleanerLocked() +} + +// startCleanerLocked starts connectionCleaner if needed. +func (db *DB) startCleanerLocked() { + if (db.maxLifetime > 0 || db.maxIdleTime > 0) && db.numOpen > 0 && db.cleanerCh == nil { + db.cleanerCh = make(chan struct{}, 1) + go db.connectionCleaner(db.shortestIdleTimeLocked()) + } +} + +func (db *DB) connectionCleaner(d time.Duration) { + const minInterval = time.Second + + if d < minInterval { + d = minInterval + } + t := time.NewTimer(d) + + for { + select { + case <-t.C: + case <-db.cleanerCh: // maxLifetime was changed or db was closed. + } + + db.mu.Lock() + + d = db.shortestIdleTimeLocked() + if db.closed || db.numOpen == 0 || d <= 0 { + db.cleanerCh = nil + db.mu.Unlock() + return + } + + d, closing := db.connectionCleanerRunLocked(d) + db.mu.Unlock() + for _, c := range closing { + c.Close() + } + + if d < minInterval { + d = minInterval + } + + if !t.Stop() { + select { + case <-t.C: + default: + } + } + t.Reset(d) + } +} + +// connectionCleanerRunLocked removes connections that should be closed from +// freeConn and returns them along side an updated duration to the next check +// if a quicker check is required to ensure connections are checked appropriately. +func (db *DB) connectionCleanerRunLocked(d time.Duration) (time.Duration, []*driverConn) { + var idleClosing int64 + var closing []*driverConn + if db.maxIdleTime > 0 { + // As freeConn is ordered by returnedAt process + // in reverse order to minimise the work needed. + idleSince := nowFunc().Add(-db.maxIdleTime) + last := len(db.freeConn) - 1 + for i := last; i >= 0; i-- { + c := db.freeConn[i] + if c.returnedAt.Before(idleSince) { + i++ + closing = db.freeConn[:i:i] + db.freeConn = db.freeConn[i:] + idleClosing = int64(len(closing)) + db.maxIdleTimeClosed += idleClosing + break + } + } + + if len(db.freeConn) > 0 { + c := db.freeConn[0] + if d2 := c.returnedAt.Sub(idleSince); d2 < d { + // Ensure idle connections are cleaned up as soon as + // possible. + d = d2 + } + } + } + + if db.maxLifetime > 0 { + expiredSince := nowFunc().Add(-db.maxLifetime) + for i := 0; i < len(db.freeConn); i++ { + c := db.freeConn[i] + if c.createdAt.Before(expiredSince) { + closing = append(closing, c) + + last := len(db.freeConn) - 1 + // Use slow delete as order is required to ensure + // connections are reused least idle time first. + copy(db.freeConn[i:], db.freeConn[i+1:]) + db.freeConn[last] = nil + db.freeConn = db.freeConn[:last] + i-- + } else if d2 := c.createdAt.Sub(expiredSince); d2 < d { + // Prevent connections sitting the freeConn when they + // have expired by updating our next deadline d. + d = d2 + } + } + db.maxLifetimeClosed += int64(len(closing)) - idleClosing + } + + return d, closing +} + +// DBStats contains database statistics. +type DBStats struct { + MaxOpenConnections int // Maximum number of open connections to the database. + + // Pool Status + OpenConnections int // The number of established connections both in use and idle. + InUse int // The number of connections currently in use. + Idle int // The number of idle connections. + + // Counters + WaitCount int64 // The total number of connections waited for. + WaitDuration time.Duration // The total time blocked waiting for a new connection. + MaxIdleClosed int64 // The total number of connections closed due to SetMaxIdleConns. + MaxIdleTimeClosed int64 // The total number of connections closed due to SetConnMaxIdleTime. + MaxLifetimeClosed int64 // The total number of connections closed due to SetConnMaxLifetime. +} + +// Stats returns database statistics. +func (db *DB) Stats() DBStats { + wait := db.waitDuration.Load() + + db.mu.Lock() + defer db.mu.Unlock() + + stats := DBStats{ + MaxOpenConnections: db.maxOpen, + + Idle: len(db.freeConn), + OpenConnections: db.numOpen, + InUse: db.numOpen - len(db.freeConn), + + WaitCount: db.waitCount, + WaitDuration: time.Duration(wait), + MaxIdleClosed: db.maxIdleClosed, + MaxIdleTimeClosed: db.maxIdleTimeClosed, + MaxLifetimeClosed: db.maxLifetimeClosed, + } + return stats +} + +// Assumes db.mu is locked. +// If there are connRequests and the connection limit hasn't been reached, +// then tell the connectionOpener to open new connections. +func (db *DB) maybeOpenNewConnections() { + numRequests := len(db.connRequests) + if db.maxOpen > 0 { + numCanOpen := db.maxOpen - db.numOpen + if numRequests > numCanOpen { + numRequests = numCanOpen + } + } + for numRequests > 0 { + db.numOpen++ // optimistically + numRequests-- + if db.closed { + return + } + db.openerCh <- struct{}{} + } +} + +// Runs in a separate goroutine, opens new connections when requested. +func (db *DB) connectionOpener(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-db.openerCh: + db.openNewConnection(ctx) + } + } +} + +// Open one new connection +func (db *DB) openNewConnection(ctx context.Context) { + // maybeOpenNewConnections has already executed db.numOpen++ before it sent + // on db.openerCh. This function must execute db.numOpen-- if the + // connection fails or is closed before returning. + ci, err := db.connector.Connect(ctx) + db.mu.Lock() + defer db.mu.Unlock() + if db.closed { + if err == nil { + ci.Close() + } + db.numOpen-- + return + } + if err != nil { + db.numOpen-- + db.putConnDBLocked(nil, err) + db.maybeOpenNewConnections() + return + } + dc := &driverConn{ + db: db, + createdAt: nowFunc(), + returnedAt: nowFunc(), + ci: ci, + } + if db.putConnDBLocked(dc, err) { + db.addDepLocked(dc, dc) + } else { + db.numOpen-- + ci.Close() + } +} + +// connRequest represents one request for a new connection +// When there are no idle connections available, DB.conn will create +// a new connRequest and put it on the db.connRequests list. +type connRequest struct { + conn *driverConn + err error +} + +var errDBClosed = errors.New("sql: database is closed") + +// nextRequestKeyLocked returns the next connection request key. +// It is assumed that nextRequest will not overflow. +func (db *DB) nextRequestKeyLocked() uint64 { + next := db.nextRequest + db.nextRequest++ + return next +} + +// conn returns a newly-opened or cached *driverConn. +func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) { + db.mu.Lock() + if db.closed { + db.mu.Unlock() + return nil, errDBClosed + } + // Check if the context is expired. + select { + default: + case <-ctx.Done(): + db.mu.Unlock() + return nil, ctx.Err() + } + lifetime := db.maxLifetime + + // Prefer a free connection, if possible. + last := len(db.freeConn) - 1 + if strategy == cachedOrNewConn && last >= 0 { + // Reuse the lowest idle time connection so we can close + // connections which remain idle as soon as possible. + conn := db.freeConn[last] + db.freeConn = db.freeConn[:last] + conn.inUse = true + if conn.expired(lifetime) { + db.maxLifetimeClosed++ + db.mu.Unlock() + conn.Close() + return nil, driver.ErrBadConn + } + db.mu.Unlock() + + // Reset the session if required. + if err := conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) { + conn.Close() + return nil, err + } + + return conn, nil + } + + // Out of free connections or we were asked not to use one. If we're not + // allowed to open any more connections, make a request and wait. + if db.maxOpen > 0 && db.numOpen >= db.maxOpen { + // Make the connRequest channel. It's buffered so that the + // connectionOpener doesn't block while waiting for the req to be read. + req := make(chan connRequest, 1) + reqKey := db.nextRequestKeyLocked() + db.connRequests[reqKey] = req + db.waitCount++ + db.mu.Unlock() + + waitStart := nowFunc() + + // Timeout the connection request with the context. + select { + case <-ctx.Done(): + // Remove the connection request and ensure no value has been sent + // on it after removing. + db.mu.Lock() + delete(db.connRequests, reqKey) + db.mu.Unlock() + + db.waitDuration.Add(int64(time.Since(waitStart))) + + select { + default: + case ret, ok := <-req: + if ok && ret.conn != nil { + db.putConn(ret.conn, ret.err, false) + } + } + return nil, ctx.Err() + case ret, ok := <-req: + db.waitDuration.Add(int64(time.Since(waitStart))) + + if !ok { + return nil, errDBClosed + } + // Only check if the connection is expired if the strategy is cachedOrNewConns. + // If we require a new connection, just re-use the connection without looking + // at the expiry time. If it is expired, it will be checked when it is placed + // back into the connection pool. + // This prioritizes giving a valid connection to a client over the exact connection + // lifetime, which could expire exactly after this point anyway. + if strategy == cachedOrNewConn && ret.err == nil && ret.conn.expired(lifetime) { + db.mu.Lock() + db.maxLifetimeClosed++ + db.mu.Unlock() + ret.conn.Close() + return nil, driver.ErrBadConn + } + if ret.conn == nil { + return nil, ret.err + } + + // Reset the session if required. + if err := ret.conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) { + ret.conn.Close() + return nil, err + } + return ret.conn, ret.err + } + } + + db.numOpen++ // optimistically + db.mu.Unlock() + ci, err := db.connector.Connect(ctx) + if err != nil { + db.mu.Lock() + db.numOpen-- // correct for earlier optimism + db.maybeOpenNewConnections() + db.mu.Unlock() + return nil, err + } + db.mu.Lock() + dc := &driverConn{ + db: db, + createdAt: nowFunc(), + returnedAt: nowFunc(), + ci: ci, + inUse: true, + } + db.addDepLocked(dc, dc) + db.mu.Unlock() + return dc, nil +} + +// putConnHook is a hook for testing. +var putConnHook func(*DB, *driverConn) + +// noteUnusedDriverStatement notes that ds is no longer used and should +// be closed whenever possible (when c is next not in use), unless c is +// already closed. +func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) { + db.mu.Lock() + defer db.mu.Unlock() + if c.inUse { + c.onPut = append(c.onPut, func() { + ds.Close() + }) + } else { + c.Lock() + fc := c.finalClosed + c.Unlock() + if !fc { + ds.Close() + } + } +} + +// debugGetPut determines whether getConn & putConn calls' stack traces +// are returned for more verbose crashes. +const debugGetPut = false + +// putConn adds a connection to the db's free pool. +// err is optionally the last error that occurred on this connection. +func (db *DB) putConn(dc *driverConn, err error, resetSession bool) { + if !errors.Is(err, driver.ErrBadConn) { + if !dc.validateConnection(resetSession) { + err = driver.ErrBadConn + } + } + db.mu.Lock() + if !dc.inUse { + db.mu.Unlock() + if debugGetPut { + fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc]) + } + panic("sql: connection returned that was never out") + } + + if !errors.Is(err, driver.ErrBadConn) && dc.expired(db.maxLifetime) { + db.maxLifetimeClosed++ + err = driver.ErrBadConn + } + if debugGetPut { + db.lastPut[dc] = stack() + } + dc.inUse = false + dc.returnedAt = nowFunc() + + for _, fn := range dc.onPut { + fn() + } + dc.onPut = nil + + if errors.Is(err, driver.ErrBadConn) { + // Don't reuse bad connections. + // Since the conn is considered bad and is being discarded, treat it + // as closed. Don't decrement the open count here, finalClose will + // take care of that. + db.maybeOpenNewConnections() + db.mu.Unlock() + dc.Close() + return + } + if putConnHook != nil { + putConnHook(db, dc) + } + added := db.putConnDBLocked(dc, nil) + db.mu.Unlock() + + if !added { + dc.Close() + return + } +} + +// Satisfy a connRequest or put the driverConn in the idle pool and return true +// or return false. +// putConnDBLocked will satisfy a connRequest if there is one, or it will +// return the *driverConn to the freeConn list if err == nil and the idle +// connection limit will not be exceeded. +// If err != nil, the value of dc is ignored. +// If err == nil, then dc must not equal nil. +// If a connRequest was fulfilled or the *driverConn was placed in the +// freeConn list, then true is returned, otherwise false is returned. +func (db *DB) putConnDBLocked(dc *driverConn, err error) bool { + if db.closed { + return false + } + if db.maxOpen > 0 && db.numOpen > db.maxOpen { + return false + } + if c := len(db.connRequests); c > 0 { + var req chan connRequest + var reqKey uint64 + for reqKey, req = range db.connRequests { + break + } + delete(db.connRequests, reqKey) // Remove from pending requests. + if err == nil { + dc.inUse = true + } + req <- connRequest{ + conn: dc, + err: err, + } + return true + } else if err == nil && !db.closed { + if db.maxIdleConnsLocked() > len(db.freeConn) { + db.freeConn = append(db.freeConn, dc) + db.startCleanerLocked() + return true + } + db.maxIdleClosed++ + } + return false +} + +// maxBadConnRetries is the number of maximum retries if the driver returns +// driver.ErrBadConn to signal a broken connection before forcing a new +// connection to be opened. +const maxBadConnRetries = 2 + +func (db *DB) retry(fn func(strategy connReuseStrategy) error) error { + for i := int64(0); i < maxBadConnRetries; i++ { + err := fn(cachedOrNewConn) + // retry if err is driver.ErrBadConn + if err == nil || !errors.Is(err, driver.ErrBadConn) { + return err + } + } + + return fn(alwaysNewConn) +} + +// PrepareContext creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. +// The caller must call the statement's Close method +// when the statement is no longer needed. +// +// The provided context is used for the preparation of the statement, not for the +// execution of the statement. +func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + var stmt *Stmt + var err error + + err = db.retry(func(strategy connReuseStrategy) error { + stmt, err = db.prepare(ctx, query, strategy) + return err + }) + + return stmt, err +} + +// Prepare creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. +// The caller must call the statement's Close method +// when the statement is no longer needed. +// +// Prepare uses context.Background internally; to specify the context, use +// PrepareContext. +func (db *DB) Prepare(query string) (*Stmt, error) { + return db.PrepareContext(context.Background(), query) +} + +func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) { + // TODO: check if db.driver supports an optional + // driver.Preparer interface and call that instead, if so, + // otherwise we make a prepared statement that's bound + // to a connection, and to execute this prepared statement + // we either need to use this connection (if it's free), else + // get a new connection + re-prepare + execute on that one. + dc, err := db.conn(ctx, strategy) + if err != nil { + return nil, err + } + return db.prepareDC(ctx, dc, dc.releaseConn, nil, query) +} + +// prepareDC prepares a query on the driverConn and calls release before +// returning. When cg == nil it implies that a connection pool is used, and +// when cg != nil only a single driver connection is used. +func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) { + var ds *driverStmt + var err error + defer func() { + release(err) + }() + withLock(dc, func() { + ds, err = dc.prepareLocked(ctx, cg, query) + }) + if err != nil { + return nil, err + } + stmt := &Stmt{ + db: db, + query: query, + cg: cg, + cgds: ds, + } + + // When cg == nil this statement will need to keep track of various + // connections they are prepared on and record the stmt dependency on + // the DB. + if cg == nil { + stmt.css = []connStmt{{dc, ds}} + stmt.lastNumClosed = db.numClosed.Load() + db.addDep(stmt, stmt) + } + return stmt, nil +} + +// ExecContext executes a query without returning any rows. +// The args are for any placeholder parameters in the query. +func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error) { + var res Result + var err error + + err = db.retry(func(strategy connReuseStrategy) error { + res, err = db.exec(ctx, query, args, strategy) + return err + }) + + return res, err +} + +// Exec executes a query without returning any rows. +// The args are for any placeholder parameters in the query. +// +// Exec uses context.Background internally; to specify the context, use +// ExecContext. +func (db *DB) Exec(query string, args ...any) (Result, error) { + return db.ExecContext(context.Background(), query, args...) +} + +func (db *DB) exec(ctx context.Context, query string, args []any, strategy connReuseStrategy) (Result, error) { + dc, err := db.conn(ctx, strategy) + if err != nil { + return nil, err + } + return db.execDC(ctx, dc, dc.releaseConn, query, args) +} + +func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), query string, args []any) (res Result, err error) { + defer func() { + release(err) + }() + execerCtx, ok := dc.ci.(driver.ExecerContext) + var execer driver.Execer + if !ok { + execer, ok = dc.ci.(driver.Execer) + } + if ok { + var nvdargs []driver.NamedValue + var resi driver.Result + withLock(dc, func() { + nvdargs, err = driverArgsConnLocked(dc.ci, nil, args) + if err != nil { + return + } + resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs) + }) + if err != driver.ErrSkip { + if err != nil { + return nil, err + } + return driverResult{dc, resi}, nil + } + } + + var si driver.Stmt + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, query) + }) + if err != nil { + return nil, err + } + ds := &driverStmt{Locker: dc, si: si} + defer ds.Close() + return resultFromStatement(ctx, dc.ci, ds, args...) +} + +// QueryContext executes a query that returns rows, typically a SELECT. +// The args are for any placeholder parameters in the query. +func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) { + var rows *Rows + var err error + + err = db.retry(func(strategy connReuseStrategy) error { + rows, err = db.query(ctx, query, args, strategy) + return err + }) + + return rows, err +} + +// Query executes a query that returns rows, typically a SELECT. +// The args are for any placeholder parameters in the query. +// +// Query uses context.Background internally; to specify the context, use +// QueryContext. +func (db *DB) Query(query string, args ...any) (*Rows, error) { + return db.QueryContext(context.Background(), query, args...) +} + +func (db *DB) query(ctx context.Context, query string, args []any, strategy connReuseStrategy) (*Rows, error) { + dc, err := db.conn(ctx, strategy) + if err != nil { + return nil, err + } + + return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args) +} + +// queryDC executes a query on the given connection. +// The connection gets released by the releaseConn function. +// The ctx context is from a query method and the txctx context is from an +// optional transaction context. +func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []any) (*Rows, error) { + queryerCtx, ok := dc.ci.(driver.QueryerContext) + var queryer driver.Queryer + if !ok { + queryer, ok = dc.ci.(driver.Queryer) + } + if ok { + var nvdargs []driver.NamedValue + var rowsi driver.Rows + var err error + withLock(dc, func() { + nvdargs, err = driverArgsConnLocked(dc.ci, nil, args) + if err != nil { + return + } + rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs) + }) + if err != driver.ErrSkip { + if err != nil { + releaseConn(err) + return nil, err + } + // Note: ownership of dc passes to the *Rows, to be freed + // with releaseConn. + rows := &Rows{ + dc: dc, + releaseConn: releaseConn, + rowsi: rowsi, + } + rows.initContextClose(ctx, txctx) + return rows, nil + } + } + + var si driver.Stmt + var err error + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, query) + }) + if err != nil { + releaseConn(err) + return nil, err + } + + ds := &driverStmt{Locker: dc, si: si} + rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...) + if err != nil { + ds.Close() + releaseConn(err) + return nil, err + } + + // Note: ownership of ci passes to the *Rows, to be freed + // with releaseConn. + rows := &Rows{ + dc: dc, + releaseConn: releaseConn, + rowsi: rowsi, + closeStmt: ds, + } + rows.initContextClose(ctx, txctx) + return rows, nil +} + +// QueryRowContext executes a query that is expected to return at most one row. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *Row { + rows, err := db.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err} +} + +// QueryRow executes a query that is expected to return at most one row. +// QueryRow always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +// +// QueryRow uses context.Background internally; to specify the context, use +// QueryRowContext. +func (db *DB) QueryRow(query string, args ...any) *Row { + return db.QueryRowContext(context.Background(), query, args...) +} + +// BeginTx starts a transaction. +// +// The provided context is used until the transaction is committed or rolled back. +// If the context is canceled, the sql package will roll back +// the transaction. Tx.Commit will return an error if the context provided to +// BeginTx is canceled. +// +// The provided TxOptions is optional and may be nil if defaults should be used. +// If a non-default isolation level is used that the driver doesn't support, +// an error will be returned. +func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) { + var tx *Tx + var err error + + err = db.retry(func(strategy connReuseStrategy) error { + tx, err = db.begin(ctx, opts, strategy) + return err + }) + + return tx, err +} + +// Begin starts a transaction. The default isolation level is dependent on +// the driver. +// +// Begin uses context.Background internally; to specify the context, use +// BeginTx. +func (db *DB) Begin() (*Tx, error) { + return db.BeginTx(context.Background(), nil) +} + +func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) { + dc, err := db.conn(ctx, strategy) + if err != nil { + return nil, err + } + return db.beginDC(ctx, dc, dc.releaseConn, opts) +} + +// beginDC starts a transaction. The provided dc must be valid and ready to use. +func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) { + var txi driver.Tx + keepConnOnRollback := false + withLock(dc, func() { + _, hasSessionResetter := dc.ci.(driver.SessionResetter) + _, hasConnectionValidator := dc.ci.(driver.Validator) + keepConnOnRollback = hasSessionResetter && hasConnectionValidator + txi, err = ctxDriverBegin(ctx, opts, dc.ci) + }) + if err != nil { + release(err) + return nil, err + } + + // Schedule the transaction to rollback when the context is canceled. + // The cancel function in Tx will be called after done is set to true. + ctx, cancel := context.WithCancel(ctx) + tx = &Tx{ + db: db, + dc: dc, + releaseConn: release, + txi: txi, + cancel: cancel, + keepConnOnRollback: keepConnOnRollback, + ctx: ctx, + } + go tx.awaitDone() + return tx, nil +} + +// Driver returns the database's underlying driver. +func (db *DB) Driver() driver.Driver { + return db.connector.Driver() +} + +// ErrConnDone is returned by any operation that is performed on a connection +// that has already been returned to the connection pool. +var ErrConnDone = errors.New("sql: connection is already closed") + +// Conn returns a single connection by either opening a new connection +// or returning an existing connection from the connection pool. Conn will +// block until either a connection is returned or ctx is canceled. +// Queries run on the same Conn will be run in the same database session. +// +// Every Conn must be returned to the database pool after use by +// calling Conn.Close. +func (db *DB) Conn(ctx context.Context) (*Conn, error) { + var dc *driverConn + var err error + + err = db.retry(func(strategy connReuseStrategy) error { + dc, err = db.conn(ctx, strategy) + return err + }) + + if err != nil { + return nil, err + } + + conn := &Conn{ + db: db, + dc: dc, + } + return conn, nil +} + +type releaseConn func(error) + +// Conn represents a single database connection rather than a pool of database +// connections. Prefer running queries from DB unless there is a specific +// need for a continuous single database connection. +// +// A Conn must call Close to return the connection to the database pool +// and may do so concurrently with a running query. +// +// After a call to Close, all operations on the +// connection fail with ErrConnDone. +type Conn struct { + db *DB + + // closemu prevents the connection from closing while there + // is an active query. It is held for read during queries + // and exclusively during close. + closemu sync.RWMutex + + // dc is owned until close, at which point + // it's returned to the connection pool. + dc *driverConn + + // done transitions from 0 to 1 exactly once, on close. + // Once done, all operations fail with ErrConnDone. + // Use atomic operations on value when checking value. + done int32 +} + +// grabConn takes a context to implement stmtConnGrabber +// but the context is not used. +func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) { + if atomic.LoadInt32(&c.done) != 0 { + return nil, nil, ErrConnDone + } + c.closemu.RLock() + return c.dc, c.closemuRUnlockCondReleaseConn, nil +} + +// PingContext verifies the connection to the database is still alive. +func (c *Conn) PingContext(ctx context.Context) error { + dc, release, err := c.grabConn(ctx) + if err != nil { + return err + } + return c.db.pingDC(ctx, dc, release) +} + +// ExecContext executes a query without returning any rows. +// The args are for any placeholder parameters in the query. +func (c *Conn) ExecContext(ctx context.Context, query string, args ...any) (Result, error) { + dc, release, err := c.grabConn(ctx) + if err != nil { + return nil, err + } + return c.db.execDC(ctx, dc, release, query, args) +} + +// QueryContext executes a query that returns rows, typically a SELECT. +// The args are for any placeholder parameters in the query. +func (c *Conn) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) { + dc, release, err := c.grabConn(ctx) + if err != nil { + return nil, err + } + return c.db.queryDC(ctx, nil, dc, release, query, args) +} + +// QueryRowContext executes a query that is expected to return at most one row. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...any) *Row { + rows, err := c.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err} +} + +// PrepareContext creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. +// The caller must call the statement's Close method +// when the statement is no longer needed. +// +// The provided context is used for the preparation of the statement, not for the +// execution of the statement. +func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + dc, release, err := c.grabConn(ctx) + if err != nil { + return nil, err + } + return c.db.prepareDC(ctx, dc, release, c, query) +} + +// Raw executes f exposing the underlying driver connection for the +// duration of f. The driverConn must not be used outside of f. +// +// Once f returns and err is not driver.ErrBadConn, the Conn will continue to be usable +// until Conn.Close is called. +func (c *Conn) Raw(f func(driverConn any) error) (err error) { + var dc *driverConn + var release releaseConn + + // grabConn takes a context to implement stmtConnGrabber, but the context is not used. + dc, release, err = c.grabConn(nil) + if err != nil { + return + } + fPanic := true + dc.Mutex.Lock() + defer func() { + dc.Mutex.Unlock() + + // If f panics fPanic will remain true. + // Ensure an error is passed to release so the connection + // may be discarded. + if fPanic { + err = driver.ErrBadConn + } + release(err) + }() + err = f(dc.ci) + fPanic = false + + return +} + +// BeginTx starts a transaction. +// +// The provided context is used until the transaction is committed or rolled back. +// If the context is canceled, the sql package will roll back +// the transaction. Tx.Commit will return an error if the context provided to +// BeginTx is canceled. +// +// The provided TxOptions is optional and may be nil if defaults should be used. +// If a non-default isolation level is used that the driver doesn't support, +// an error will be returned. +func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) { + dc, release, err := c.grabConn(ctx) + if err != nil { + return nil, err + } + return c.db.beginDC(ctx, dc, release, opts) +} + +// closemuRUnlockCondReleaseConn read unlocks closemu +// as the sql operation is done with the dc. +func (c *Conn) closemuRUnlockCondReleaseConn(err error) { + c.closemu.RUnlock() + if errors.Is(err, driver.ErrBadConn) { + c.close(err) + } +} + +func (c *Conn) txCtx() context.Context { + return nil +} + +func (c *Conn) close(err error) error { + if !atomic.CompareAndSwapInt32(&c.done, 0, 1) { + return ErrConnDone + } + + // Lock around releasing the driver connection + // to ensure all queries have been stopped before doing so. + c.closemu.Lock() + defer c.closemu.Unlock() + + c.dc.releaseConn(err) + c.dc = nil + c.db = nil + return err +} + +// Close returns the connection to the connection pool. +// All operations after a Close will return with ErrConnDone. +// Close is safe to call concurrently with other operations and will +// block until all other operations finish. It may be useful to first +// cancel any used context and then call close directly after. +func (c *Conn) Close() error { + return c.close(nil) +} + +// Tx is an in-progress database transaction. +// +// A transaction must end with a call to Commit or Rollback. +// +// After a call to Commit or Rollback, all operations on the +// transaction fail with ErrTxDone. +// +// The statements prepared for a transaction by calling +// the transaction's Prepare or Stmt methods are closed +// by the call to Commit or Rollback. +type Tx struct { + db *DB + + // closemu prevents the transaction from closing while there + // is an active query. It is held for read during queries + // and exclusively during close. + closemu sync.RWMutex + + // dc is owned exclusively until Commit or Rollback, at which point + // it's returned with putConn. + dc *driverConn + txi driver.Tx + + // releaseConn is called once the Tx is closed to release + // any held driverConn back to the pool. + releaseConn func(error) + + // done transitions from false to true exactly once, on Commit + // or Rollback. once done, all operations fail with + // ErrTxDone. + done atomic.Bool + + // keepConnOnRollback is true if the driver knows + // how to reset the connection's session and if need be discard + // the connection. + keepConnOnRollback bool + + // All Stmts prepared for this transaction. These will be closed after the + // transaction has been committed or rolled back. + stmts struct { + sync.Mutex + v []*Stmt + } + + // cancel is called after done transitions from 0 to 1. + cancel func() + + // ctx lives for the life of the transaction. + ctx context.Context +} + +// awaitDone blocks until the context in Tx is canceled and rolls back +// the transaction if it's not already done. +func (tx *Tx) awaitDone() { + // Wait for either the transaction to be committed or rolled + // back, or for the associated context to be closed. + <-tx.ctx.Done() + + // Discard and close the connection used to ensure the + // transaction is closed and the resources are released. This + // rollback does nothing if the transaction has already been + // committed or rolled back. + // Do not discard the connection if the connection knows + // how to reset the session. + discardConnection := !tx.keepConnOnRollback + tx.rollback(discardConnection) +} + +func (tx *Tx) isDone() bool { + return tx.done.Load() +} + +// ErrTxDone is returned by any operation that is performed on a transaction +// that has already been committed or rolled back. +var ErrTxDone = errors.New("sql: transaction has already been committed or rolled back") + +// close returns the connection to the pool and +// must only be called by Tx.rollback or Tx.Commit while +// tx is already canceled and won't be executed concurrently. +func (tx *Tx) close(err error) { + tx.releaseConn(err) + tx.dc = nil + tx.txi = nil +} + +// hookTxGrabConn specifies an optional hook to be called on +// a successful call to (*Tx).grabConn. For tests. +var hookTxGrabConn func() + +func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) { + select { + default: + case <-ctx.Done(): + return nil, nil, ctx.Err() + } + + // closemu.RLock must come before the check for isDone to prevent the Tx from + // closing while a query is executing. + tx.closemu.RLock() + if tx.isDone() { + tx.closemu.RUnlock() + return nil, nil, ErrTxDone + } + if hookTxGrabConn != nil { // test hook + hookTxGrabConn() + } + return tx.dc, tx.closemuRUnlockRelease, nil +} + +func (tx *Tx) txCtx() context.Context { + return tx.ctx +} + +// closemuRUnlockRelease is used as a func(error) method value in +// ExecContext and QueryContext. Unlocking in the releaseConn keeps +// the driver conn from being returned to the connection pool until +// the Rows has been closed. +func (tx *Tx) closemuRUnlockRelease(error) { + tx.closemu.RUnlock() +} + +// Closes all Stmts prepared for this transaction. +func (tx *Tx) closePrepared() { + tx.stmts.Lock() + defer tx.stmts.Unlock() + for _, stmt := range tx.stmts.v { + stmt.Close() + } +} + +// Commit commits the transaction. +func (tx *Tx) Commit() error { + // Check context first to avoid transaction leak. + // If put it behind tx.done CompareAndSwap statement, we can't ensure + // the consistency between tx.done and the real COMMIT operation. + select { + default: + case <-tx.ctx.Done(): + if tx.done.Load() { + return ErrTxDone + } + return tx.ctx.Err() + } + if !tx.done.CompareAndSwap(false, true) { + return ErrTxDone + } + + // Cancel the Tx to release any active R-closemu locks. + // This is safe to do because tx.done has already transitioned + // from 0 to 1. Hold the W-closemu lock prior to rollback + // to ensure no other connection has an active query. + tx.cancel() + tx.closemu.Lock() + tx.closemu.Unlock() + + var err error + withLock(tx.dc, func() { + err = tx.txi.Commit() + }) + if !errors.Is(err, driver.ErrBadConn) { + tx.closePrepared() + } + tx.close(err) + return err +} + +var rollbackHook func() + +// rollback aborts the transaction and optionally forces the pool to discard +// the connection. +func (tx *Tx) rollback(discardConn bool) error { + if !tx.done.CompareAndSwap(false, true) { + return ErrTxDone + } + + if rollbackHook != nil { + rollbackHook() + } + + // Cancel the Tx to release any active R-closemu locks. + // This is safe to do because tx.done has already transitioned + // from 0 to 1. Hold the W-closemu lock prior to rollback + // to ensure no other connection has an active query. + tx.cancel() + tx.closemu.Lock() + tx.closemu.Unlock() + + var err error + withLock(tx.dc, func() { + err = tx.txi.Rollback() + }) + if !errors.Is(err, driver.ErrBadConn) { + tx.closePrepared() + } + if discardConn { + err = driver.ErrBadConn + } + tx.close(err) + return err +} + +// Rollback aborts the transaction. +func (tx *Tx) Rollback() error { + return tx.rollback(false) +} + +// PrepareContext creates a prepared statement for use within a transaction. +// +// The returned statement operates within the transaction and will be closed +// when the transaction has been committed or rolled back. +// +// To use an existing prepared statement on this transaction, see Tx.Stmt. +// +// The provided context will be used for the preparation of the context, not +// for the execution of the returned statement. The returned statement +// will run in the transaction context. +func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + dc, release, err := tx.grabConn(ctx) + if err != nil { + return nil, err + } + + stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query) + if err != nil { + return nil, err + } + tx.stmts.Lock() + tx.stmts.v = append(tx.stmts.v, stmt) + tx.stmts.Unlock() + return stmt, nil +} + +// Prepare creates a prepared statement for use within a transaction. +// +// The returned statement operates within the transaction and will be closed +// when the transaction has been committed or rolled back. +// +// To use an existing prepared statement on this transaction, see Tx.Stmt. +// +// Prepare uses context.Background internally; to specify the context, use +// PrepareContext. +func (tx *Tx) Prepare(query string) (*Stmt, error) { + return tx.PrepareContext(context.Background(), query) +} + +// StmtContext returns a transaction-specific prepared statement from +// an existing statement. +// +// Example: +// +// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?") +// ... +// tx, err := db.Begin() +// ... +// res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203) +// +// The provided context is used for the preparation of the statement, not for the +// execution of the statement. +// +// The returned statement operates within the transaction and will be closed +// when the transaction has been committed or rolled back. +func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { + dc, release, err := tx.grabConn(ctx) + if err != nil { + return &Stmt{stickyErr: err} + } + defer release(nil) + + if tx.db != stmt.db { + return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} + } + var si driver.Stmt + var parentStmt *Stmt + stmt.mu.Lock() + if stmt.closed || stmt.cg != nil { + // If the statement has been closed or already belongs to a + // transaction, we can't reuse it in this connection. + // Since tx.StmtContext should never need to be called with a + // Stmt already belonging to tx, we ignore this edge case and + // re-prepare the statement in this case. No need to add + // code-complexity for this. + stmt.mu.Unlock() + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query) + }) + if err != nil { + return &Stmt{stickyErr: err} + } + } else { + stmt.removeClosedStmtLocked() + // See if the statement has already been prepared on this connection, + // and reuse it if possible. + for _, v := range stmt.css { + if v.dc == dc { + si = v.ds.si + break + } + } + + stmt.mu.Unlock() + + if si == nil { + var ds *driverStmt + withLock(dc, func() { + ds, err = stmt.prepareOnConnLocked(ctx, dc) + }) + if err != nil { + return &Stmt{stickyErr: err} + } + si = ds.si + } + parentStmt = stmt + } + + txs := &Stmt{ + db: tx.db, + cg: tx, + cgds: &driverStmt{ + Locker: dc, + si: si, + }, + parentStmt: parentStmt, + query: stmt.query, + } + if parentStmt != nil { + tx.db.addDep(parentStmt, txs) + } + tx.stmts.Lock() + tx.stmts.v = append(tx.stmts.v, txs) + tx.stmts.Unlock() + return txs +} + +// Stmt returns a transaction-specific prepared statement from +// an existing statement. +// +// Example: +// +// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?") +// ... +// tx, err := db.Begin() +// ... +// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203) +// +// The returned statement operates within the transaction and will be closed +// when the transaction has been committed or rolled back. +// +// Stmt uses context.Background internally; to specify the context, use +// StmtContext. +func (tx *Tx) Stmt(stmt *Stmt) *Stmt { + return tx.StmtContext(context.Background(), stmt) +} + +// ExecContext executes a query that doesn't return rows. +// For example: an INSERT and UPDATE. +func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (Result, error) { + dc, release, err := tx.grabConn(ctx) + if err != nil { + return nil, err + } + return tx.db.execDC(ctx, dc, release, query, args) +} + +// Exec executes a query that doesn't return rows. +// For example: an INSERT and UPDATE. +// +// Exec uses context.Background internally; to specify the context, use +// ExecContext. +func (tx *Tx) Exec(query string, args ...any) (Result, error) { + return tx.ExecContext(context.Background(), query, args...) +} + +// QueryContext executes a query that returns rows, typically a SELECT. +func (tx *Tx) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) { + dc, release, err := tx.grabConn(ctx) + if err != nil { + return nil, err + } + + return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args) +} + +// Query executes a query that returns rows, typically a SELECT. +// +// Query uses context.Background internally; to specify the context, use +// QueryContext. +func (tx *Tx) Query(query string, args ...any) (*Rows, error) { + return tx.QueryContext(context.Background(), query, args...) +} + +// QueryRowContext executes a query that is expected to return at most one row. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *Row { + rows, err := tx.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err} +} + +// QueryRow executes a query that is expected to return at most one row. +// QueryRow always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +// +// QueryRow uses context.Background internally; to specify the context, use +// QueryRowContext. +func (tx *Tx) QueryRow(query string, args ...any) *Row { + return tx.QueryRowContext(context.Background(), query, args...) +} + +// connStmt is a prepared statement on a particular connection. +type connStmt struct { + dc *driverConn + ds *driverStmt +} + +// stmtConnGrabber represents a Tx or Conn that will return the underlying +// driverConn and release function. +type stmtConnGrabber interface { + // grabConn returns the driverConn and the associated release function + // that must be called when the operation completes. + grabConn(context.Context) (*driverConn, releaseConn, error) + + // txCtx returns the transaction context if available. + // The returned context should be selected on along with + // any query context when awaiting a cancel. + txCtx() context.Context +} + +var ( + _ stmtConnGrabber = &Tx{} + _ stmtConnGrabber = &Conn{} +) + +// Stmt is a prepared statement. +// A Stmt is safe for concurrent use by multiple goroutines. +// +// If a Stmt is prepared on a Tx or Conn, it will be bound to a single +// underlying connection forever. If the Tx or Conn closes, the Stmt will +// become unusable and all operations will return an error. +// If a Stmt is prepared on a DB, it will remain usable for the lifetime of the +// DB. When the Stmt needs to execute on a new underlying connection, it will +// prepare itself on the new connection automatically. +type Stmt struct { + // Immutable: + db *DB // where we came from + query string // that created the Stmt + stickyErr error // if non-nil, this error is returned for all operations + + closemu sync.RWMutex // held exclusively during close, for read otherwise. + + // If Stmt is prepared on a Tx or Conn then cg is present and will + // only ever grab a connection from cg. + // If cg is nil then the Stmt must grab an arbitrary connection + // from db and determine if it must prepare the stmt again by + // inspecting css. + cg stmtConnGrabber + cgds *driverStmt + + // parentStmt is set when a transaction-specific statement + // is requested from an identical statement prepared on the same + // conn. parentStmt is used to track the dependency of this statement + // on its originating ("parent") statement so that parentStmt may + // be closed by the user without them having to know whether or not + // any transactions are still using it. + parentStmt *Stmt + + mu sync.Mutex // protects the rest of the fields + closed bool + + // css is a list of underlying driver statement interfaces + // that are valid on particular connections. This is only + // used if cg == nil and one is found that has idle + // connections. If cg != nil, cgds is always used. + css []connStmt + + // lastNumClosed is copied from db.numClosed when Stmt is created + // without tx and closed connections in css are removed. + lastNumClosed uint64 +} + +// ExecContext executes a prepared statement with the given arguments and +// returns a Result summarizing the effect of the statement. +func (s *Stmt) ExecContext(ctx context.Context, args ...any) (Result, error) { + s.closemu.RLock() + defer s.closemu.RUnlock() + + var res Result + err := s.db.retry(func(strategy connReuseStrategy) error { + dc, releaseConn, ds, err := s.connStmt(ctx, strategy) + if err != nil { + return err + } + + res, err = resultFromStatement(ctx, dc.ci, ds, args...) + releaseConn(err) + return err + }) + + return res, err +} + +// Exec executes a prepared statement with the given arguments and +// returns a Result summarizing the effect of the statement. +// +// Exec uses context.Background internally; to specify the context, use +// ExecContext. +func (s *Stmt) Exec(args ...any) (Result, error) { + return s.ExecContext(context.Background(), args...) +} + +func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (Result, error) { + ds.Lock() + defer ds.Unlock() + + dargs, err := driverArgsConnLocked(ci, ds, args) + if err != nil { + return nil, err + } + + resi, err := ctxDriverStmtExec(ctx, ds.si, dargs) + if err != nil { + return nil, err + } + return driverResult{ds.Locker, resi}, nil +} + +// removeClosedStmtLocked removes closed conns in s.css. +// +// To avoid lock contention on DB.mu, we do it only when +// s.db.numClosed - s.lastNum is large enough. +func (s *Stmt) removeClosedStmtLocked() { + t := len(s.css)/2 + 1 + if t > 10 { + t = 10 + } + dbClosed := s.db.numClosed.Load() + if dbClosed-s.lastNumClosed < uint64(t) { + return + } + + s.db.mu.Lock() + for i := 0; i < len(s.css); i++ { + if s.css[i].dc.dbmuClosed { + s.css[i] = s.css[len(s.css)-1] + s.css = s.css[:len(s.css)-1] + i-- + } + } + s.db.mu.Unlock() + s.lastNumClosed = dbClosed +} + +// connStmt returns a free driver connection on which to execute the +// statement, a function to call to release the connection, and a +// statement bound to that connection. +func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) { + if err = s.stickyErr; err != nil { + return + } + s.mu.Lock() + if s.closed { + s.mu.Unlock() + err = errors.New("sql: statement is closed") + return + } + + // In a transaction or connection, we always use the connection that the + // stmt was created on. + if s.cg != nil { + s.mu.Unlock() + dc, releaseConn, err = s.cg.grabConn(ctx) // blocks, waiting for the connection. + if err != nil { + return + } + return dc, releaseConn, s.cgds, nil + } + + s.removeClosedStmtLocked() + s.mu.Unlock() + + dc, err = s.db.conn(ctx, strategy) + if err != nil { + return nil, nil, nil, err + } + + s.mu.Lock() + for _, v := range s.css { + if v.dc == dc { + s.mu.Unlock() + return dc, dc.releaseConn, v.ds, nil + } + } + s.mu.Unlock() + + // No luck; we need to prepare the statement on this connection + withLock(dc, func() { + ds, err = s.prepareOnConnLocked(ctx, dc) + }) + if err != nil { + dc.releaseConn(err) + return nil, nil, nil, err + } + + return dc, dc.releaseConn, ds, nil +} + +// prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of +// open connStmt on the statement. It assumes the caller is holding the lock on dc. +func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) { + si, err := dc.prepareLocked(ctx, s.cg, s.query) + if err != nil { + return nil, err + } + cs := connStmt{dc, si} + s.mu.Lock() + s.css = append(s.css, cs) + s.mu.Unlock() + return cs.ds, nil +} + +// QueryContext executes a prepared query statement with the given arguments +// and returns the query results as a *Rows. +func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*Rows, error) { + s.closemu.RLock() + defer s.closemu.RUnlock() + + var rowsi driver.Rows + var rows *Rows + + err := s.db.retry(func(strategy connReuseStrategy) error { + dc, releaseConn, ds, err := s.connStmt(ctx, strategy) + if err != nil { + return err + } + + rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...) + if err == nil { + // Note: ownership of ci passes to the *Rows, to be freed + // with releaseConn. + rows = &Rows{ + dc: dc, + rowsi: rowsi, + // releaseConn set below + } + // addDep must be added before initContextClose or it could attempt + // to removeDep before it has been added. + s.db.addDep(s, rows) + + // releaseConn must be set before initContextClose or it could + // release the connection before it is set. + rows.releaseConn = func(err error) { + releaseConn(err) + s.db.removeDep(s, rows) + } + var txctx context.Context + if s.cg != nil { + txctx = s.cg.txCtx() + } + rows.initContextClose(ctx, txctx) + return nil + } + + releaseConn(err) + return err + }) + + return rows, err +} + +// Query executes a prepared query statement with the given arguments +// and returns the query results as a *Rows. +// +// Query uses context.Background internally; to specify the context, use +// QueryContext. +func (s *Stmt) Query(args ...any) (*Rows, error) { + return s.QueryContext(context.Background(), args...) +} + +func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (driver.Rows, error) { + ds.Lock() + defer ds.Unlock() + dargs, err := driverArgsConnLocked(ci, ds, args) + if err != nil { + return nil, err + } + return ctxDriverStmtQuery(ctx, ds.si, dargs) +} + +// QueryRowContext executes a prepared query statement with the given arguments. +// If an error occurs during the execution of the statement, that error will +// be returned by a call to Scan on the returned *Row, which is always non-nil. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *Row { + rows, err := s.QueryContext(ctx, args...) + if err != nil { + return &Row{err: err} + } + return &Row{rows: rows} +} + +// QueryRow executes a prepared query statement with the given arguments. +// If an error occurs during the execution of the statement, that error will +// be returned by a call to Scan on the returned *Row, which is always non-nil. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +// +// Example usage: +// +// var name string +// err := nameByUseridStmt.QueryRow(id).Scan(&name) +// +// QueryRow uses context.Background internally; to specify the context, use +// QueryRowContext. +func (s *Stmt) QueryRow(args ...any) *Row { + return s.QueryRowContext(context.Background(), args...) +} + +// Close closes the statement. +func (s *Stmt) Close() error { + s.closemu.Lock() + defer s.closemu.Unlock() + + if s.stickyErr != nil { + return s.stickyErr + } + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return nil + } + s.closed = true + txds := s.cgds + s.cgds = nil + + s.mu.Unlock() + + if s.cg == nil { + return s.db.removeDep(s, s) + } + + if s.parentStmt != nil { + // If parentStmt is set, we must not close s.txds since it's stored + // in the css array of the parentStmt. + return s.db.removeDep(s.parentStmt, s) + } + return txds.Close() +} + +func (s *Stmt) finalClose() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.css != nil { + for _, v := range s.css { + s.db.noteUnusedDriverStatement(v.dc, v.ds) + v.dc.removeOpenStmt(v.ds) + } + s.css = nil + } + return nil +} + +// Rows is the result of a query. Its cursor starts before the first row +// of the result set. Use Next to advance from row to row. +type Rows struct { + dc *driverConn // owned; must call releaseConn when closed to release + releaseConn func(error) + rowsi driver.Rows + cancel func() // called when Rows is closed, may be nil. + closeStmt *driverStmt // if non-nil, statement to Close on close + + // closemu prevents Rows from closing while there + // is an active streaming result. It is held for read during non-close operations + // and exclusively during close. + // + // closemu guards lasterr and closed. + closemu sync.RWMutex + closed bool + lasterr error // non-nil only if closed is true + + // lastcols is only used in Scan, Next, and NextResultSet which are expected + // not to be called concurrently. + lastcols []driver.Value +} + +// lasterrOrErrLocked returns either lasterr or the provided err. +// rs.closemu must be read-locked. +func (rs *Rows) lasterrOrErrLocked(err error) error { + if rs.lasterr != nil && rs.lasterr != io.EOF { + return rs.lasterr + } + return err +} + +// bypassRowsAwaitDone is only used for testing. +// If true, it will not close the Rows automatically from the context. +var bypassRowsAwaitDone = false + +func (rs *Rows) initContextClose(ctx, txctx context.Context) { + if ctx.Done() == nil && (txctx == nil || txctx.Done() == nil) { + return + } + if bypassRowsAwaitDone { + return + } + ctx, rs.cancel = context.WithCancel(ctx) + go rs.awaitDone(ctx, txctx) +} + +// awaitDone blocks until either ctx or txctx is canceled. The ctx is provided +// from the query context and is canceled when the query Rows is closed. +// If the query was issued in a transaction, the transaction's context +// is also provided in txctx to ensure Rows is closed if the Tx is closed. +func (rs *Rows) awaitDone(ctx, txctx context.Context) { + var txctxDone <-chan struct{} + if txctx != nil { + txctxDone = txctx.Done() + } + select { + case <-ctx.Done(): + case <-txctxDone: + } + rs.close(ctx.Err()) +} + +// Next prepares the next result row for reading with the Scan method. It +// returns true on success, or false if there is no next result row or an error +// happened while preparing it. Err should be consulted to distinguish between +// the two cases. +// +// Every call to Scan, even the first one, must be preceded by a call to Next. +func (rs *Rows) Next() bool { + var doClose, ok bool + withLock(rs.closemu.RLocker(), func() { + doClose, ok = rs.nextLocked() + }) + if doClose { + rs.Close() + } + return ok +} + +func (rs *Rows) nextLocked() (doClose, ok bool) { + if rs.closed { + return false, false + } + + // Lock the driver connection before calling the driver interface + // rowsi to prevent a Tx from rolling back the connection at the same time. + rs.dc.Lock() + defer rs.dc.Unlock() + + if rs.lastcols == nil { + rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) + } + + rs.lasterr = rs.rowsi.Next(rs.lastcols) + if rs.lasterr != nil { + // Close the connection if there is a driver error. + if rs.lasterr != io.EOF { + return true, false + } + nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) + if !ok { + return true, false + } + // The driver is at the end of the current result set. + // Test to see if there is another result set after the current one. + // Only close Rows if there is no further result sets to read. + if !nextResultSet.HasNextResultSet() { + doClose = true + } + return doClose, false + } + return false, true +} + +// NextResultSet prepares the next result set for reading. It reports whether +// there is further result sets, or false if there is no further result set +// or if there is an error advancing to it. The Err method should be consulted +// to distinguish between the two cases. +// +// After calling NextResultSet, the Next method should always be called before +// scanning. If there are further result sets they may not have rows in the result +// set. +func (rs *Rows) NextResultSet() bool { + var doClose bool + defer func() { + if doClose { + rs.Close() + } + }() + rs.closemu.RLock() + defer rs.closemu.RUnlock() + + if rs.closed { + return false + } + + rs.lastcols = nil + nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) + if !ok { + doClose = true + return false + } + + // Lock the driver connection before calling the driver interface + // rowsi to prevent a Tx from rolling back the connection at the same time. + rs.dc.Lock() + defer rs.dc.Unlock() + + rs.lasterr = nextResultSet.NextResultSet() + if rs.lasterr != nil { + doClose = true + return false + } + return true +} + +// Err returns the error, if any, that was encountered during iteration. +// Err may be called after an explicit or implicit Close. +func (rs *Rows) Err() error { + rs.closemu.RLock() + defer rs.closemu.RUnlock() + return rs.lasterrOrErrLocked(nil) +} + +var errRowsClosed = errors.New("sql: Rows are closed") +var errNoRows = errors.New("sql: no Rows available") + +// Columns returns the column names. +// Columns returns an error if the rows are closed. +func (rs *Rows) Columns() ([]string, error) { + rs.closemu.RLock() + defer rs.closemu.RUnlock() + if rs.closed { + return nil, rs.lasterrOrErrLocked(errRowsClosed) + } + if rs.rowsi == nil { + return nil, rs.lasterrOrErrLocked(errNoRows) + } + rs.dc.Lock() + defer rs.dc.Unlock() + + return rs.rowsi.Columns(), nil +} + +// ColumnTypes returns column information such as column type, length, +// and nullable. Some information may not be available from some drivers. +func (rs *Rows) ColumnTypes() ([]*ColumnType, error) { + rs.closemu.RLock() + defer rs.closemu.RUnlock() + if rs.closed { + return nil, rs.lasterrOrErrLocked(errRowsClosed) + } + if rs.rowsi == nil { + return nil, rs.lasterrOrErrLocked(errNoRows) + } + rs.dc.Lock() + defer rs.dc.Unlock() + + return rowsColumnInfoSetupConnLocked(rs.rowsi), nil +} + +// ColumnType contains the name and type of a column. +type ColumnType struct { + name string + + hasNullable bool + hasLength bool + hasPrecisionScale bool + + nullable bool + length int64 + databaseType string + precision int64 + scale int64 + scanType reflect.Type +} + +// Name returns the name or alias of the column. +func (ci *ColumnType) Name() string { + return ci.name +} + +// Length returns the column type length for variable length column types such +// as text and binary field types. If the type length is unbounded the value will +// be math.MaxInt64 (any database limits will still apply). +// If the column type is not variable length, such as an int, or if not supported +// by the driver ok is false. +func (ci *ColumnType) Length() (length int64, ok bool) { + return ci.length, ci.hasLength +} + +// DecimalSize returns the scale and precision of a decimal type. +// If not applicable or if not supported ok is false. +func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) { + return ci.precision, ci.scale, ci.hasPrecisionScale +} + +// ScanType returns a Go type suitable for scanning into using Rows.Scan. +// If a driver does not support this property ScanType will return +// the type of an empty interface. +func (ci *ColumnType) ScanType() reflect.Type { + return ci.scanType +} + +// Nullable reports whether the column may be null. +// If a driver does not support this property ok will be false. +func (ci *ColumnType) Nullable() (nullable, ok bool) { + return ci.nullable, ci.hasNullable +} + +// DatabaseTypeName returns the database system name of the column type. If an empty +// string is returned, then the driver type name is not supported. +// Consult your driver documentation for a list of driver data types. Length specifiers +// are not included. +// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", +// "INT", and "BIGINT". +func (ci *ColumnType) DatabaseTypeName() string { + return ci.databaseType +} + +func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType { + names := rowsi.Columns() + + list := make([]*ColumnType, len(names)) + for i := range list { + ci := &ColumnType{ + name: names[i], + } + list[i] = ci + + if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok { + ci.scanType = prop.ColumnTypeScanType(i) + } else { + ci.scanType = reflect.TypeOf(new(any)).Elem() + } + if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok { + ci.databaseType = prop.ColumnTypeDatabaseTypeName(i) + } + if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok { + ci.length, ci.hasLength = prop.ColumnTypeLength(i) + } + if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok { + ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i) + } + if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok { + ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i) + } + } + return list +} + +// Scan copies the columns in the current row into the values pointed +// at by dest. The number of values in dest must be the same as the +// number of columns in Rows. +// +// Scan converts columns read from the database into the following +// common Go types and special types provided by the sql package: +// +// *string +// *[]byte +// *int, *int8, *int16, *int32, *int64 +// *uint, *uint8, *uint16, *uint32, *uint64 +// *bool +// *float32, *float64 +// *interface{} +// *RawBytes +// *Rows (cursor value) +// any type implementing Scanner (see Scanner docs) +// +// In the most simple case, if the type of the value from the source +// column is an integer, bool or string type T and dest is of type *T, +// Scan simply assigns the value through the pointer. +// +// Scan also converts between string and numeric types, as long as no +// information would be lost. While Scan stringifies all numbers +// scanned from numeric database columns into *string, scans into +// numeric types are checked for overflow. For example, a float64 with +// value 300 or a string with value "300" can scan into a uint16, but +// not into a uint8, though float64(255) or "255" can scan into a +// uint8. One exception is that scans of some float64 numbers to +// strings may lose information when stringifying. In general, scan +// floating point columns into *float64. +// +// If a dest argument has type *[]byte, Scan saves in that argument a +// copy of the corresponding data. The copy is owned by the caller and +// can be modified and held indefinitely. The copy can be avoided by +// using an argument of type *RawBytes instead; see the documentation +// for RawBytes for restrictions on its use. +// +// If an argument has type *interface{}, Scan copies the value +// provided by the underlying driver without conversion. When scanning +// from a source value of type []byte to *interface{}, a copy of the +// slice is made and the caller owns the result. +// +// Source values of type time.Time may be scanned into values of type +// *time.Time, *interface{}, *string, or *[]byte. When converting to +// the latter two, time.RFC3339Nano is used. +// +// Source values of type bool may be scanned into types *bool, +// *interface{}, *string, *[]byte, or *RawBytes. +// +// For scanning into *bool, the source may be true, false, 1, 0, or +// string inputs parseable by strconv.ParseBool. +// +// Scan can also convert a cursor returned from a query, such as +// "select cursor(select * from my_table) from dual", into a +// *Rows value that can itself be scanned from. The parent +// select query will close any cursor *Rows if the parent *Rows is closed. +// +// If any of the first arguments implementing Scanner returns an error, +// that error will be wrapped in the returned error. +func (rs *Rows) Scan(dest ...any) error { + rs.closemu.RLock() + + if rs.lasterr != nil && rs.lasterr != io.EOF { + rs.closemu.RUnlock() + return rs.lasterr + } + if rs.closed { + err := rs.lasterrOrErrLocked(errRowsClosed) + rs.closemu.RUnlock() + return err + } + rs.closemu.RUnlock() + + if rs.lastcols == nil { + return errors.New("sql: Scan called without calling Next") + } + if len(dest) != len(rs.lastcols) { + return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest)) + } + for i, sv := range rs.lastcols { + err := convertAssignRows(dest[i], sv, rs) + if err != nil { + return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err) + } + } + return nil +} + +// rowsCloseHook returns a function so tests may install the +// hook through a test only mutex. +var rowsCloseHook = func() func(*Rows, *error) { return nil } + +// Close closes the Rows, preventing further enumeration. If Next is called +// and returns false and there are no further result sets, +// the Rows are closed automatically and it will suffice to check the +// result of Err. Close is idempotent and does not affect the result of Err. +func (rs *Rows) Close() error { + return rs.close(nil) +} + +func (rs *Rows) close(err error) error { + rs.closemu.Lock() + defer rs.closemu.Unlock() + + if rs.closed { + return nil + } + rs.closed = true + + if rs.lasterr == nil { + rs.lasterr = err + } + + withLock(rs.dc, func() { + err = rs.rowsi.Close() + }) + if fn := rowsCloseHook(); fn != nil { + fn(rs, &err) + } + if rs.cancel != nil { + rs.cancel() + } + + if rs.closeStmt != nil { + rs.closeStmt.Close() + } + rs.releaseConn(err) + + rs.lasterr = rs.lasterrOrErrLocked(err) + return err +} + +// Row is the result of calling QueryRow to select a single row. +type Row struct { + // One of these two will be non-nil: + err error // deferred error for easy chaining + rows *Rows +} + +// Scan copies the columns from the matched row into the values +// pointed at by dest. See the documentation on Rows.Scan for details. +// If more than one row matches the query, +// Scan uses the first row and discards the rest. If no row matches +// the query, Scan returns ErrNoRows. +func (r *Row) Scan(dest ...any) error { + if r.err != nil { + return r.err + } + + // TODO(bradfitz): for now we need to defensively clone all + // []byte that the driver returned (not permitting + // *RawBytes in Rows.Scan), since we're about to close + // the Rows in our defer, when we return from this function. + // the contract with the driver.Next(...) interface is that it + // can return slices into read-only temporary memory that's + // only valid until the next Scan/Close. But the TODO is that + // for a lot of drivers, this copy will be unnecessary. We + // should provide an optional interface for drivers to + // implement to say, "don't worry, the []bytes that I return + // from Next will not be modified again." (for instance, if + // they were obtained from the network anyway) But for now we + // don't care. + defer r.rows.Close() + for _, dp := range dest { + if _, ok := dp.(*RawBytes); ok { + return errors.New("sql: RawBytes isn't allowed on Row.Scan") + } + } + + if !r.rows.Next() { + if err := r.rows.Err(); err != nil { + return err + } + return ErrNoRows + } + err := r.rows.Scan(dest...) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + return r.rows.Close() +} + +// Err provides a way for wrapping packages to check for +// query errors without calling Scan. +// Err returns the error, if any, that was encountered while running the query. +// If this error is not nil, this error will also be returned from Scan. +func (r *Row) Err() error { + return r.err +} + +// A Result summarizes an executed SQL command. +type Result interface { + // LastInsertId returns the integer generated by the database + // in response to a command. Typically this will be from an + // "auto increment" column when inserting a new row. Not all + // databases support this feature, and the syntax of such + // statements varies. + LastInsertId() (int64, error) + + // RowsAffected returns the number of rows affected by an + // update, insert, or delete. Not every database or database + // driver may support this. + RowsAffected() (int64, error) +} + +type driverResult struct { + sync.Locker // the *driverConn + resi driver.Result +} + +func (dr driverResult) LastInsertId() (int64, error) { + dr.Lock() + defer dr.Unlock() + return dr.resi.LastInsertId() +} + +func (dr driverResult) RowsAffected() (int64, error) { + dr.Lock() + defer dr.Unlock() + return dr.resi.RowsAffected() +} + +func stack() string { + var buf [2 << 10]byte + return string(buf[:runtime.Stack(buf[:], false)]) +} + +// withLock runs while holding lk. +func withLock(lk sync.Locker, fn func()) { + lk.Lock() + defer lk.Unlock() // in case fn panics + fn() +} diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go new file mode 100644 index 0000000..8c58723 --- /dev/null +++ b/src/database/sql/sql_test.go @@ -0,0 +1,4585 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" + "math/rand" + "reflect" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +func init() { + type dbConn struct { + db *DB + c *driverConn + } + freedFrom := make(map[dbConn]string) + var mu sync.Mutex + getFreedFrom := func(c dbConn) string { + mu.Lock() + defer mu.Unlock() + return freedFrom[c] + } + setFreedFrom := func(c dbConn, s string) { + mu.Lock() + defer mu.Unlock() + freedFrom[c] = s + } + putConnHook = func(db *DB, c *driverConn) { + idx := -1 + for i, v := range db.freeConn { + if v == c { + idx = i + break + } + } + if idx >= 0 { + // print before panic, as panic may get lost due to conflicting panic + // (all goroutines asleep) elsewhere, since we might not unlock + // the mutex in freeConn here. + println("double free of conn. conflicts are:\nA) " + getFreedFrom(dbConn{db, c}) + "\n\nand\nB) " + stack()) + panic("double free of conn.") + } + setFreedFrom(dbConn{db, c}, stack()) + } +} + +// pollDuration is an arbitrary interval to wait between checks when polling for +// a condition to occur. +const pollDuration = 5 * time.Millisecond + +const fakeDBName = "foo" + +var chrisBirthday = time.Unix(123456789, 0) + +func newTestDB(t testing.TB, name string) *DB { + return newTestDBConnector(t, &fakeConnector{name: fakeDBName}, name) +} + +func newTestDBConnector(t testing.TB, fc *fakeConnector, name string) *DB { + fc.name = fakeDBName + db := OpenDB(fc) + if _, err := db.Exec("WIPE"); err != nil { + t.Fatalf("exec wipe: %v", err) + } + if name == "people" { + exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime") + exec(t, db, "INSERT|people|name=Alice,age=?,photo=APHOTO", 1) + exec(t, db, "INSERT|people|name=Bob,age=?,photo=BPHOTO", 2) + exec(t, db, "INSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday) + } + if name == "magicquery" { + // Magic table name and column, known by fakedb_test.go. + exec(t, db, "CREATE|magicquery|op=string,millis=int32") + exec(t, db, "INSERT|magicquery|op=sleep,millis=10") + } + if name == "tx_status" { + // Magic table name and column, known by fakedb_test.go. + exec(t, db, "CREATE|tx_status|tx_status=string") + exec(t, db, "INSERT|tx_status|tx_status=invalid") + } + return db +} + +func TestOpenDB(t *testing.T) { + db := OpenDB(dsnConnector{dsn: fakeDBName, driver: fdriver}) + if db.Driver() != fdriver { + t.Fatalf("OpenDB should return the driver of the Connector") + } +} + +func TestDriverPanic(t *testing.T) { + // Test that if driver panics, database/sql does not deadlock. + db, err := Open("test", fakeDBName) + if err != nil { + t.Fatalf("Open: %v", err) + } + expectPanic := func(name string, f func()) { + defer func() { + err := recover() + if err == nil { + t.Fatalf("%s did not panic", name) + } + }() + f() + } + + expectPanic("Exec Exec", func() { db.Exec("PANIC|Exec|WIPE") }) + exec(t, db, "WIPE") // check not deadlocked + expectPanic("Exec NumInput", func() { db.Exec("PANIC|NumInput|WIPE") }) + exec(t, db, "WIPE") // check not deadlocked + expectPanic("Exec Close", func() { db.Exec("PANIC|Close|WIPE") }) + exec(t, db, "WIPE") // check not deadlocked + exec(t, db, "PANIC|Query|WIPE") // should run successfully: Exec does not call Query + exec(t, db, "WIPE") // check not deadlocked + + exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime") + + expectPanic("Query Query", func() { db.Query("PANIC|Query|SELECT|people|age,name|") }) + expectPanic("Query NumInput", func() { db.Query("PANIC|NumInput|SELECT|people|age,name|") }) + expectPanic("Query Close", func() { + rows, err := db.Query("PANIC|Close|SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + rows.Close() + }) + db.Query("PANIC|Exec|SELECT|people|age,name|") // should run successfully: Query does not call Exec + exec(t, db, "WIPE") // check not deadlocked +} + +func exec(t testing.TB, db *DB, query string, args ...any) { + t.Helper() + _, err := db.Exec(query, args...) + if err != nil { + t.Fatalf("Exec of %q: %v", query, err) + } +} + +func closeDB(t testing.TB, db *DB) { + if e := recover(); e != nil { + fmt.Printf("Panic: %v\n", e) + panic(e) + } + defer setHookpostCloseConn(nil) + setHookpostCloseConn(func(_ *fakeConn, err error) { + if err != nil { + t.Errorf("Error closing fakeConn: %v", err) + } + }) + db.mu.Lock() + for i, dc := range db.freeConn { + if n := len(dc.openStmt); n > 0 { + // Just a sanity check. This is legal in + // general, but if we make the tests clean up + // their statements first, then we can safely + // verify this is always zero here, and any + // other value is a leak. + t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, len(db.freeConn), n) + } + } + db.mu.Unlock() + + err := db.Close() + if err != nil { + t.Fatalf("error closing DB: %v", err) + } + + var numOpen int + if !waitCondition(t, func() bool { + numOpen = db.numOpenConns() + return numOpen == 0 + }) { + t.Fatalf("%d connections still open after closing DB", numOpen) + } +} + +// numPrepares assumes that db has exactly 1 idle conn and returns +// its count of calls to Prepare +func numPrepares(t *testing.T, db *DB) int { + if n := len(db.freeConn); n != 1 { + t.Fatalf("free conns = %d; want 1", n) + } + return db.freeConn[0].ci.(*fakeConn).numPrepare +} + +func (db *DB) numDeps() int { + db.mu.Lock() + defer db.mu.Unlock() + return len(db.dep) +} + +// Dependencies are closed via a goroutine, so this polls waiting for +// numDeps to fall to want, waiting up to nearly the test's deadline. +func (db *DB) numDepsPoll(t *testing.T, want int) int { + var n int + waitCondition(t, func() bool { + n = db.numDeps() + return n <= want + }) + return n +} + +func (db *DB) numFreeConns() int { + db.mu.Lock() + defer db.mu.Unlock() + return len(db.freeConn) +} + +func (db *DB) numOpenConns() int { + db.mu.Lock() + defer db.mu.Unlock() + return db.numOpen +} + +// clearAllConns closes all connections in db. +func (db *DB) clearAllConns(t *testing.T) { + db.SetMaxIdleConns(0) + + if g, w := db.numFreeConns(), 0; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPoll(t, 0); n > 0 { + t.Errorf("number of dependencies = %d; expected 0", n) + db.dumpDeps(t) + } +} + +func (db *DB) dumpDeps(t *testing.T) { + for fc := range db.dep { + db.dumpDep(t, 0, fc, map[finalCloser]bool{}) + } +} + +func (db *DB) dumpDep(t *testing.T, depth int, dep finalCloser, seen map[finalCloser]bool) { + seen[dep] = true + indent := strings.Repeat(" ", depth) + ds := db.dep[dep] + for k := range ds { + t.Logf("%s%T (%p) waiting for -> %T (%p)", indent, dep, dep, k, k) + if fc, ok := k.(finalCloser); ok { + if !seen[fc] { + db.dumpDep(t, depth+1, fc, seen) + } + } + } +} + +func TestQuery(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + age int + name string + } + got := []row{} + for rows.Next() { + var r row + err = rows.Scan(&r.age, &r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got = append(got, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want := []row{ + {age: 1, name: "Alice"}, + {age: 2, name: "Bob"}, + {age: 3, name: "Chris"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + if n := db.numFreeConns(); n != 1 { + t.Fatalf("free conns after query hitting EOF = %d; want 1", n) + } + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +// TestQueryContext tests canceling the context while scanning the rows. +func TestQueryContext(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rows, err := db.QueryContext(ctx, "SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + age int + name string + } + got := []row{} + index := 0 + for rows.Next() { + if index == 2 { + cancel() + waitForRowsClose(t, rows) + } + var r row + err = rows.Scan(&r.age, &r.name) + if err != nil { + if index == 2 { + break + } + t.Fatalf("Scan: %v", err) + } + if index == 2 && err != context.Canceled { + t.Fatalf("Scan: %v; want context.Canceled", err) + } + got = append(got, r) + index++ + } + select { + case <-ctx.Done(): + if err := ctx.Err(); err != context.Canceled { + t.Fatalf("context err = %v; want context.Canceled", err) + } + default: + t.Fatalf("context err = nil; want context.Canceled") + } + want := []row{ + {age: 1, name: "Alice"}, + {age: 2, name: "Bob"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + waitForRowsClose(t, rows) + waitForFree(t, db, 1) + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +func waitCondition(t testing.TB, fn func() bool) bool { + timeout := 5 * time.Second + + type deadliner interface { + Deadline() (time.Time, bool) + } + if td, ok := t.(deadliner); ok { + if deadline, ok := td.Deadline(); ok { + timeout = time.Until(deadline) + timeout = timeout * 19 / 20 // Give 5% headroom for cleanup and error-reporting. + } + } + + deadline := time.Now().Add(timeout) + for { + if fn() { + return true + } + if time.Until(deadline) < pollDuration { + return false + } + time.Sleep(pollDuration) + } +} + +// waitForFree checks db.numFreeConns until either it equals want or +// the maxWait time elapses. +func waitForFree(t *testing.T, db *DB, want int) { + var numFree int + if !waitCondition(t, func() bool { + numFree = db.numFreeConns() + return numFree == want + }) { + t.Fatalf("free conns after hitting EOF = %d; want %d", numFree, want) + } +} + +func waitForRowsClose(t *testing.T, rows *Rows) { + if !waitCondition(t, func() bool { + rows.closemu.RLock() + defer rows.closemu.RUnlock() + return rows.closed + }) { + t.Fatal("failed to close rows") + } +} + +// TestQueryContextWait ensures that rows and all internal statements are closed when +// a query context is closed during execution. +func TestQueryContextWait(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // This will trigger the *fakeConn.Prepare method which will take time + // performing the query. The ctxDriverPrepare func will check the context + // after this and close the rows and return an error. + c, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + c.dc.ci.(*fakeConn).waiter = func(c context.Context) { + cancel() + <-ctx.Done() + } + _, err = c.QueryContext(ctx, "SELECT|people|age,name|") + c.Close() + if err != context.Canceled { + t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err) + } + + // Verify closed rows connection after error condition. + waitForFree(t, db, 1) + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Fatalf("executed %d Prepare statements; want 1", prepares) + } +} + +// TestTxContextWait tests the transaction behavior when the tx context is canceled +// during execution of the query. +func TestTxContextWait(t *testing.T) { + testContextWait(t, false) +} + +// TestTxContextWaitNoDiscard is the same as TestTxContextWait, but should not discard +// the final connection. +func TestTxContextWaitNoDiscard(t *testing.T) { + testContextWait(t, true) +} + +func testContextWait(t *testing.T, keepConnOnRollback bool) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + tx.keepConnOnRollback = keepConnOnRollback + + tx.dc.ci.(*fakeConn).waiter = func(c context.Context) { + cancel() + <-ctx.Done() + } + // This will trigger the *fakeConn.Prepare method which will take time + // performing the query. The ctxDriverPrepare func will check the context + // after this and close the rows and return an error. + _, err = tx.QueryContext(ctx, "SELECT|people|age,name|") + if err != context.Canceled { + t.Fatalf("expected QueryContext to error with context canceled but returned %v", err) + } + + if keepConnOnRollback { + waitForFree(t, db, 1) + } else { + waitForFree(t, db, 0) + } +} + +// TestUnsupportedOptions checks that the database fails when a driver that +// doesn't implement ConnBeginTx is used with non-default options and an +// un-cancellable context. +func TestUnsupportedOptions(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + _, err := db.BeginTx(context.Background(), &TxOptions{ + Isolation: LevelSerializable, ReadOnly: true, + }) + if err == nil { + t.Fatal("expected error when using unsupported options, got nil") + } +} + +func TestMultiResultSetQuery(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + rows, err := db.Query("SELECT|people|age,name|;SELECT|people|name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row1 struct { + age int + name string + } + type row2 struct { + name string + } + got1 := []row1{} + for rows.Next() { + var r row1 + err = rows.Scan(&r.age, &r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got1 = append(got1, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want1 := []row1{ + {age: 1, name: "Alice"}, + {age: 2, name: "Bob"}, + {age: 3, name: "Chris"}, + } + if !reflect.DeepEqual(got1, want1) { + t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1) + } + + if !rows.NextResultSet() { + t.Errorf("expected another result set") + } + + got2 := []row2{} + for rows.Next() { + var r row2 + err = rows.Scan(&r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got2 = append(got2, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want2 := []row2{ + {name: "Alice"}, + {name: "Bob"}, + {name: "Chris"}, + } + if !reflect.DeepEqual(got2, want2) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2) + } + if rows.NextResultSet() { + t.Errorf("expected no more result sets") + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + waitForFree(t, db, 1) + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +func TestQueryNamedArg(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + rows, err := db.Query( + // Ensure the name and age parameters only match on placeholder name, not position. + "SELECT|people|age,name|name=?name,age=?age", + Named("age", 2), + Named("name", "Bob"), + ) + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + age int + name string + } + got := []row{} + for rows.Next() { + var r row + err = rows.Scan(&r.age, &r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got = append(got, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want := []row{ + {age: 2, name: "Bob"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + if n := db.numFreeConns(); n != 1 { + t.Fatalf("free conns after query hitting EOF = %d; want 1", n) + } + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +func TestPoolExhaustOnCancel(t *testing.T) { + if testing.Short() { + t.Skip("long test") + } + + max := 3 + var saturate, saturateDone sync.WaitGroup + saturate.Add(max) + saturateDone.Add(max) + + donePing := make(chan bool) + state := 0 + + // waiter will be called for all queries, including + // initial setup queries. The state is only assigned when + // no queries are made. + // + // Only allow the first batch of queries to finish once the + // second batch of Ping queries have finished. + waiter := func(ctx context.Context) { + switch state { + case 0: + // Nothing. Initial database setup. + case 1: + saturate.Done() + select { + case <-ctx.Done(): + case <-donePing: + } + case 2: + } + } + db := newTestDBConnector(t, &fakeConnector{waiter: waiter}, "people") + defer closeDB(t, db) + + db.SetMaxOpenConns(max) + + // First saturate the connection pool. + // Then start new requests for a connection that is canceled after it is requested. + + state = 1 + for i := 0; i < max; i++ { + go func() { + rows, err := db.Query("SELECT|people|name,photo|") + if err != nil { + t.Errorf("Query: %v", err) + return + } + rows.Close() + saturateDone.Done() + }() + } + + saturate.Wait() + if t.Failed() { + t.FailNow() + } + state = 2 + + // Now cancel the request while it is waiting. + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + for i := 0; i < max; i++ { + ctxReq, cancelReq := context.WithCancel(ctx) + go func() { + time.Sleep(100 * time.Millisecond) + cancelReq() + }() + err := db.PingContext(ctxReq) + if err != context.Canceled { + t.Fatalf("PingContext (Exhaust): %v", err) + } + } + close(donePing) + saturateDone.Wait() + + // Now try to open a normal connection. + err := db.PingContext(ctx) + if err != nil { + t.Fatalf("PingContext (Normal): %v", err) + } +} + +func TestRowsColumns(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + cols, err := rows.Columns() + if err != nil { + t.Fatalf("Columns: %v", err) + } + want := []string{"age", "name"} + if !reflect.DeepEqual(cols, want) { + t.Errorf("got %#v; want %#v", cols, want) + } + if err := rows.Close(); err != nil { + t.Errorf("error closing rows: %s", err) + } +} + +func TestRowsColumnTypes(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + tt, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("ColumnTypes: %v", err) + } + + types := make([]reflect.Type, len(tt)) + for i, tp := range tt { + st := tp.ScanType() + if st == nil { + t.Errorf("scantype is null for column %q", tp.Name()) + continue + } + types[i] = st + } + values := make([]any, len(tt)) + for i := range values { + values[i] = reflect.New(types[i]).Interface() + } + ct := 0 + for rows.Next() { + err = rows.Scan(values...) + if err != nil { + t.Fatalf("failed to scan values in %v", err) + } + if ct == 1 { + if age := *values[0].(*int32); age != 2 { + t.Errorf("Expected 2, got %v", age) + } + if name := *values[1].(*string); name != "Bob" { + t.Errorf("Expected Bob, got %v", name) + } + } + ct++ + } + if ct != 3 { + t.Errorf("expected 3 rows, got %d", ct) + } + + if err := rows.Close(); err != nil { + t.Errorf("error closing rows: %s", err) + } +} + +func TestQueryRow(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + var name string + var age int + var birthday time.Time + + err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age) + if err == nil || !strings.Contains(err.Error(), "expected 2 destination arguments") { + t.Errorf("expected error from wrong number of arguments; actually got: %v", err) + } + + err = db.QueryRow("SELECT|people|bdate|age=?", 3).Scan(&birthday) + if err != nil || !birthday.Equal(chrisBirthday) { + t.Errorf("chris birthday = %v, err = %v; want %v", birthday, err, chrisBirthday) + } + + err = db.QueryRow("SELECT|people|age,name|age=?", 2).Scan(&age, &name) + if err != nil { + t.Fatalf("age QueryRow+Scan: %v", err) + } + if name != "Bob" { + t.Errorf("expected name Bob, got %q", name) + } + if age != 2 { + t.Errorf("expected age 2, got %d", age) + } + + err = db.QueryRow("SELECT|people|age,name|name=?", "Alice").Scan(&age, &name) + if err != nil { + t.Fatalf("name QueryRow+Scan: %v", err) + } + if name != "Alice" { + t.Errorf("expected name Alice, got %q", name) + } + if age != 1 { + t.Errorf("expected age 1, got %d", age) + } + + var photo []byte + err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo) + if err != nil { + t.Fatalf("photo QueryRow+Scan: %v", err) + } + want := []byte("APHOTO") + if !reflect.DeepEqual(photo, want) { + t.Errorf("photo = %q; want %q", photo, want) + } +} + +func TestRowErr(t *testing.T) { + db := newTestDB(t, "people") + + err := db.QueryRowContext(context.Background(), "SELECT|people|bdate|age=?", 3).Err() + if err != nil { + t.Errorf("Unexpected err = %v; want %v", err, nil) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + err = db.QueryRowContext(ctx, "SELECT|people|bdate|age=?", 3).Err() + exp := "context canceled" + if err == nil || !strings.Contains(err.Error(), exp) { + t.Errorf("Expected err = %v; got %v", exp, err) + } +} + +func TestTxRollbackCommitErr(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + err = tx.Rollback() + if err != nil { + t.Errorf("expected nil error from Rollback; got %v", err) + } + err = tx.Commit() + if err != ErrTxDone { + t.Errorf("expected %q from Commit; got %q", ErrTxDone, err) + } + + tx, err = db.Begin() + if err != nil { + t.Fatal(err) + } + err = tx.Commit() + if err != nil { + t.Errorf("expected nil error from Commit; got %v", err) + } + err = tx.Rollback() + if err != ErrTxDone { + t.Errorf("expected %q from Rollback; got %q", ErrTxDone, err) + } +} + +func TestStatementErrorAfterClose(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + stmt, err := db.Prepare("SELECT|people|age|name=?") + if err != nil { + t.Fatalf("Prepare: %v", err) + } + err = stmt.Close() + if err != nil { + t.Fatalf("Close: %v", err) + } + var name string + err = stmt.QueryRow("foo").Scan(&name) + if err == nil { + t.Errorf("expected error from QueryRow.Scan after Stmt.Close") + } +} + +func TestStatementQueryRow(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + stmt, err := db.Prepare("SELECT|people|age|name=?") + if err != nil { + t.Fatalf("Prepare: %v", err) + } + defer stmt.Close() + var age int + for n, tt := range []struct { + name string + want int + }{ + {"Alice", 1}, + {"Bob", 2}, + {"Chris", 3}, + } { + if err := stmt.QueryRow(tt.name).Scan(&age); err != nil { + t.Errorf("%d: on %q, QueryRow/Scan: %v", n, tt.name, err) + } else if age != tt.want { + t.Errorf("%d: age=%d, want %d", n, age, tt.want) + } + } +} + +type stubDriverStmt struct { + err error +} + +func (s stubDriverStmt) Close() error { + return s.err +} + +func (s stubDriverStmt) NumInput() int { + return -1 +} + +func (s stubDriverStmt) Exec(args []driver.Value) (driver.Result, error) { + return nil, nil +} + +func (s stubDriverStmt) Query(args []driver.Value) (driver.Rows, error) { + return nil, nil +} + +// golang.org/issue/12798 +func TestStatementClose(t *testing.T) { + want := errors.New("STMT ERROR") + + tests := []struct { + stmt *Stmt + msg string + }{ + {&Stmt{stickyErr: want}, "stickyErr not propagated"}, + {&Stmt{cg: &Tx{}, cgds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"}, + } + for _, test := range tests { + if err := test.stmt.Close(); err != want { + t.Errorf("%s. Got stmt.Close() = %v, want = %v", test.msg, err, want) + } + } +} + +// golang.org/issue/3734 +func TestStatementQueryRowConcurrent(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + stmt, err := db.Prepare("SELECT|people|age|name=?") + if err != nil { + t.Fatalf("Prepare: %v", err) + } + defer stmt.Close() + + const n = 10 + ch := make(chan error, n) + for i := 0; i < n; i++ { + go func() { + var age int + err := stmt.QueryRow("Alice").Scan(&age) + if err == nil && age != 1 { + err = fmt.Errorf("unexpected age %d", age) + } + ch <- err + }() + } + for i := 0; i < n; i++ { + if err := <-ch; err != nil { + t.Error(err) + } + } +} + +// just a test of fakedb itself +func TestBogusPreboundParameters(t *testing.T) { + db := newTestDB(t, "foo") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + _, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion") + if err == nil { + t.Fatalf("expected error") + } + if err.Error() != `fakedb: invalid conversion to int32 from "bogusconversion"` { + t.Errorf("unexpected error: %v", err) + } +} + +func TestExec(t *testing.T) { + db := newTestDB(t, "foo") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Errorf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + + type execTest struct { + args []any + wantErr string + } + execTests := []execTest{ + // Okay: + {[]any{"Brad", 31}, ""}, + {[]any{"Brad", int64(31)}, ""}, + {[]any{"Bob", "32"}, ""}, + {[]any{7, 9}, ""}, + + // Invalid conversions: + {[]any{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument $2 type: sql/driver: value 4294967295 overflows int32"}, + {[]any{"Brad", "strconv fail"}, `sql: converting argument $2 type: sql/driver: value "strconv fail" can't be converted to int32`}, + + // Wrong number of args: + {[]any{}, "sql: expected 2 arguments, got 0"}, + {[]any{1, 2, 3}, "sql: expected 2 arguments, got 3"}, + } + for n, et := range execTests { + _, err := stmt.Exec(et.args...) + errStr := "" + if err != nil { + errStr = err.Error() + } + if errStr != et.wantErr { + t.Errorf("stmt.Execute #%d: for %v, got error %q, want error %q", + n, et.args, errStr, et.wantErr) + } + } +} + +func TestTxPrepare(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + stmt, err := tx.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + _, err = stmt.Exec("Bobby", 7) + if err != nil { + t.Fatalf("Exec = %v", err) + } + err = tx.Commit() + if err != nil { + t.Fatalf("Commit = %v", err) + } + // Commit() should have closed the statement + if !stmt.closed { + t.Fatal("Stmt not closed after Commit") + } +} + +func TestTxStmt(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + txs := tx.Stmt(stmt) + defer txs.Close() + _, err = txs.Exec("Bobby", 7) + if err != nil { + t.Fatalf("Exec = %v", err) + } + err = tx.Commit() + if err != nil { + t.Fatalf("Commit = %v", err) + } + // Commit() should have closed the statement + if !txs.closed { + t.Fatal("Stmt not closed after Commit") + } +} + +func TestTxStmtPreparedOnce(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32") + + prepares0 := numPrepares(t, db) + + // db.Prepare increments numPrepares. + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + + txs1 := tx.Stmt(stmt) + txs2 := tx.Stmt(stmt) + + _, err = txs1.Exec("Go", 7) + if err != nil { + t.Fatalf("Exec = %v", err) + } + txs1.Close() + + _, err = txs2.Exec("Gopher", 8) + if err != nil { + t.Fatalf("Exec = %v", err) + } + txs2.Close() + + err = tx.Commit() + if err != nil { + t.Fatalf("Commit = %v", err) + } + + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +func TestTxStmtClosedRePrepares(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32") + + prepares0 := numPrepares(t, db) + + // db.Prepare increments numPrepares. + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + err = stmt.Close() + if err != nil { + t.Fatalf("stmt.Close() = %v", err) + } + // tx.Stmt increments numPrepares because stmt is closed. + txs := tx.Stmt(stmt) + if txs.stickyErr != nil { + t.Fatal(txs.stickyErr) + } + if txs.parentStmt != nil { + t.Fatal("expected nil parentStmt") + } + _, err = txs.Exec(`Eric`, 82) + if err != nil { + t.Fatalf("txs.Exec = %v", err) + } + + err = txs.Close() + if err != nil { + t.Fatalf("txs.Close = %v", err) + } + + tx.Rollback() + + if prepares := numPrepares(t, db) - prepares0; prepares != 2 { + t.Errorf("executed %d Prepare statements; want 2", prepares) + } +} + +func TestParentStmtOutlivesTxStmt(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32") + + // Make sure everything happens on the same connection. + db.SetMaxOpenConns(1) + + prepares0 := numPrepares(t, db) + + // db.Prepare increments numPrepares. + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + txs := tx.Stmt(stmt) + if len(stmt.css) != 1 { + t.Fatalf("len(stmt.css) = %v; want 1", len(stmt.css)) + } + err = txs.Close() + if err != nil { + t.Fatalf("txs.Close() = %v", err) + } + err = tx.Rollback() + if err != nil { + t.Fatalf("tx.Rollback() = %v", err) + } + // txs must not be valid. + _, err = txs.Exec("Suzan", 30) + if err == nil { + t.Fatalf("txs.Exec(), expected err") + } + // Stmt must still be valid. + _, err = stmt.Exec("Janina", 25) + if err != nil { + t.Fatalf("stmt.Exec() = %v", err) + } + + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + +// Test that tx.Stmt called with a statement already +// associated with tx as argument re-prepares the same +// statement again. +func TestTxStmtFromTxStmtRePrepares(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32") + prepares0 := numPrepares(t, db) + // db.Prepare increments numPrepares. + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + txs1 := tx.Stmt(stmt) + + // tx.Stmt(txs1) increments numPrepares because txs1 already + // belongs to a transaction (albeit the same transaction). + txs2 := tx.Stmt(txs1) + if txs2.stickyErr != nil { + t.Fatal(txs2.stickyErr) + } + if txs2.parentStmt != nil { + t.Fatal("expected nil parentStmt") + } + _, err = txs2.Exec(`Eric`, 82) + if err != nil { + t.Fatal(err) + } + + err = txs1.Close() + if err != nil { + t.Fatalf("txs1.Close = %v", err) + } + err = txs2.Close() + if err != nil { + t.Fatalf("txs1.Close = %v", err) + } + err = tx.Rollback() + if err != nil { + t.Fatalf("tx.Rollback = %v", err) + } + + if prepares := numPrepares(t, db) - prepares0; prepares != 2 { + t.Errorf("executed %d Prepare statements; want 2", prepares) + } +} + +// Issue: https://golang.org/issue/2784 +// This test didn't fail before because we got lucky with the fakedb driver. +// It was failing, and now not, in github.com/bradfitz/go-sql-test +func TestTxQuery(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + exec(t, db, "INSERT|t1|name=Alice") + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + r, err := tx.Query("SELECT|t1|name|") + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if !r.Next() { + if r.Err() != nil { + t.Fatal(r.Err()) + } + t.Fatal("expected one row") + } + + var x string + err = r.Scan(&x) + if err != nil { + t.Fatal(err) + } +} + +func TestTxQueryInvalid(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + _, err = tx.Query("SELECT|t1|name|") + if err == nil { + t.Fatal("Error expected") + } +} + +// Tests fix for issue 4433, that retries in Begin happen when +// conn.Begin() returns ErrBadConn +func TestTxErrBadConn(t *testing.T) { + db, err := Open("test", fakeDBName+";badConn") + if err != nil { + t.Fatalf("Open: %v", err) + } + if _, err := db.Exec("WIPE"); err != nil { + t.Fatalf("exec wipe: %v", err) + } + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + stmt, err := db.Prepare("INSERT|t1|name=?,age=?") + if err != nil { + t.Fatalf("Stmt, err = %v, %v", stmt, err) + } + defer stmt.Close() + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin = %v", err) + } + txs := tx.Stmt(stmt) + defer txs.Close() + _, err = txs.Exec("Bobby", 7) + if err != nil { + t.Fatalf("Exec = %v", err) + } + err = tx.Commit() + if err != nil { + t.Fatalf("Commit = %v", err) + } +} + +func TestConnQuery(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + conn.dc.ci.(*fakeConn).skipDirtySession = true + defer conn.Close() + + var name string + err = conn.QueryRowContext(ctx, "SELECT|people|name|age=?", 3).Scan(&name) + if err != nil { + t.Fatal(err) + } + if name != "Chris" { + t.Fatalf("unexpected result, got %q want Chris", name) + } + + err = conn.PingContext(ctx) + if err != nil { + t.Fatal(err) + } +} + +func TestConnRaw(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + conn.dc.ci.(*fakeConn).skipDirtySession = true + defer conn.Close() + + sawFunc := false + err = conn.Raw(func(dc any) error { + sawFunc = true + if _, ok := dc.(*fakeConn); !ok { + return fmt.Errorf("got %T want *fakeConn", dc) + } + return nil + }) + if err != nil { + t.Fatal(err) + } + if !sawFunc { + t.Fatal("Raw func not called") + } + + func() { + defer func() { + x := recover() + if x == nil { + t.Fatal("expected panic") + } + conn.closemu.Lock() + closed := conn.dc == nil + conn.closemu.Unlock() + if !closed { + t.Fatal("expected connection to be closed after panic") + } + }() + err = conn.Raw(func(dc any) error { + panic("Conn.Raw panic should return an error") + }) + t.Fatal("expected panic from Raw func") + }() +} + +func TestCursorFake(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + exec(t, db, "CREATE|peoplecursor|list=table") + exec(t, db, "INSERT|peoplecursor|list=people!name!age") + + rows, err := db.QueryContext(ctx, `SELECT|peoplecursor|list|`) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + if !rows.Next() { + t.Fatal("no rows") + } + var cursor = &Rows{} + err = rows.Scan(cursor) + if err != nil { + t.Fatal(err) + } + defer cursor.Close() + + const expectedRows = 3 + var currentRow int64 + + var n int64 + var s string + for cursor.Next() { + currentRow++ + err = cursor.Scan(&s, &n) + if err != nil { + t.Fatal(err) + } + if n != currentRow { + t.Errorf("expected number(Age)=%d, got %d", currentRow, n) + } + } + if currentRow != expectedRows { + t.Errorf("expected %d rows, got %d rows", expectedRows, currentRow) + } +} + +func TestInvalidNilValues(t *testing.T) { + var date1 time.Time + var date2 int + + tests := []struct { + name string + input any + expectedError string + }{ + { + name: "time.Time", + input: &date1, + expectedError: `sql: Scan error on column index 0, name "bdate": unsupported Scan, storing driver.Value type <nil> into type *time.Time`, + }, + { + name: "int", + input: &date2, + expectedError: `sql: Scan error on column index 0, name "bdate": converting NULL to int is unsupported`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + conn.dc.ci.(*fakeConn).skipDirtySession = true + defer conn.Close() + + err = conn.QueryRowContext(ctx, "SELECT|people|bdate|age=?", 1).Scan(tt.input) + if err == nil { + t.Fatal("expected error when querying nil column, but succeeded") + } + if err.Error() != tt.expectedError { + t.Fatalf("Expected error: %s\nReceived: %s", tt.expectedError, err.Error()) + } + + err = conn.PingContext(ctx) + if err != nil { + t.Fatal(err) + } + }) + } +} + +func TestConnTx(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + conn.dc.ci.(*fakeConn).skipDirtySession = true + defer conn.Close() + + tx, err := conn.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + insertName, insertAge := "Nancy", 33 + _, err = tx.ExecContext(ctx, "INSERT|people|name=?,age=?,photo=APHOTO", insertName, insertAge) + if err != nil { + t.Fatal(err) + } + err = tx.Commit() + if err != nil { + t.Fatal(err) + } + + var selectName string + err = conn.QueryRowContext(ctx, "SELECT|people|name|age=?", insertAge).Scan(&selectName) + if err != nil { + t.Fatal(err) + } + if selectName != insertName { + t.Fatalf("got %q want %q", selectName, insertName) + } +} + +// TestConnIsValid verifies that a database connection that should be discarded, +// is actually discarded and does not re-enter the connection pool. +// If the IsValid method from *fakeConn is removed, this test will fail. +func TestConnIsValid(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + db.SetMaxOpenConns(1) + + ctx := context.Background() + + c, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + err = c.Raw(func(raw any) error { + dc := raw.(*fakeConn) + dc.stickyBad = true + return nil + }) + if err != nil { + t.Fatal(err) + } + c.Close() + + if len(db.freeConn) > 0 && db.freeConn[0].ci.(*fakeConn).stickyBad { + t.Fatal("bad connection returned to pool; expected bad connection to be discarded") + } +} + +// Tests fix for issue 2542, that we release a lock when querying on +// a closed connection. +func TestIssue2542Deadlock(t *testing.T) { + db := newTestDB(t, "people") + closeDB(t, db) + for i := 0; i < 2; i++ { + _, err := db.Query("SELECT|people|age,name|") + if err == nil { + t.Fatalf("expected error") + } + } +} + +// From golang.org/issue/3865 +func TestCloseStmtBeforeRows(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + s, err := db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + + r, err := s.Query() + if err != nil { + s.Close() + t.Fatal(err) + } + + err = s.Close() + if err != nil { + t.Fatal(err) + } + + r.Close() +} + +// Tests fix for issue 2788, that we bind nil to a []byte if the +// value in the column is sql null +func TestNullByteSlice(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t|id=int32,name=nullstring") + exec(t, db, "INSERT|t|id=10,name=?", nil) + + var name []byte + + err := db.QueryRow("SELECT|t|name|id=?", 10).Scan(&name) + if err != nil { + t.Fatal(err) + } + if name != nil { + t.Fatalf("name []byte should be nil for null column value, got: %#v", name) + } + + exec(t, db, "INSERT|t|id=11,name=?", "bob") + err = db.QueryRow("SELECT|t|name|id=?", 11).Scan(&name) + if err != nil { + t.Fatal(err) + } + if string(name) != "bob" { + t.Fatalf("name []byte should be bob, got: %q", string(name)) + } +} + +func TestPointerParamsAndScans(t *testing.T) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, "CREATE|t|id=int32,name=nullstring") + + bob := "bob" + var name *string + + name = &bob + exec(t, db, "INSERT|t|id=10,name=?", name) + name = nil + exec(t, db, "INSERT|t|id=20,name=?", name) + + err := db.QueryRow("SELECT|t|name|id=?", 10).Scan(&name) + if err != nil { + t.Fatalf("querying id 10: %v", err) + } + if name == nil { + t.Errorf("id 10's name = nil; want bob") + } else if *name != "bob" { + t.Errorf("id 10's name = %q; want bob", *name) + } + + err = db.QueryRow("SELECT|t|name|id=?", 20).Scan(&name) + if err != nil { + t.Fatalf("querying id 20: %v", err) + } + if name != nil { + t.Errorf("id 20 = %q; want nil", *name) + } +} + +func TestQueryRowClosingStmt(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + var name string + var age int + err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age, &name) + if err != nil { + t.Fatal(err) + } + if len(db.freeConn) != 1 { + t.Fatalf("expected 1 free conn") + } + fakeConn := db.freeConn[0].ci.(*fakeConn) + if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed { + t.Errorf("statement close mismatch: made %d, closed %d", made, closed) + } +} + +var atomicRowsCloseHook atomic.Value // of func(*Rows, *error) + +func init() { + rowsCloseHook = func() func(*Rows, *error) { + fn, _ := atomicRowsCloseHook.Load().(func(*Rows, *error)) + return fn + } +} + +func setRowsCloseHook(fn func(*Rows, *error)) { + if fn == nil { + // Can't change an atomic.Value back to nil, so set it to this + // no-op func instead. + fn = func(*Rows, *error) {} + } + atomicRowsCloseHook.Store(fn) +} + +// Test issue 6651 +func TestIssue6651(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + var v string + + want := "error in rows.Next" + rowsCursorNextHook = func(dest []driver.Value) error { + return fmt.Errorf(want) + } + defer func() { rowsCursorNextHook = nil }() + + err := db.QueryRow("SELECT|people|name|").Scan(&v) + if err == nil || err.Error() != want { + t.Errorf("error = %q; want %q", err, want) + } + rowsCursorNextHook = nil + + want = "error in rows.Close" + setRowsCloseHook(func(rows *Rows, err *error) { + *err = fmt.Errorf(want) + }) + defer setRowsCloseHook(nil) + err = db.QueryRow("SELECT|people|name|").Scan(&v) + if err == nil || err.Error() != want { + t.Errorf("error = %q; want %q", err, want) + } +} + +type nullTestRow struct { + nullParam any + notNullParam any + scanNullVal any +} + +type nullTestSpec struct { + nullType string + notNullType string + rows [6]nullTestRow +} + +func TestNullStringParam(t *testing.T) { + spec := nullTestSpec{"nullstring", "string", [6]nullTestRow{ + {NullString{"aqua", true}, "", NullString{"aqua", true}}, + {NullString{"brown", false}, "", NullString{"", false}}, + {"chartreuse", "", NullString{"chartreuse", true}}, + {NullString{"darkred", true}, "", NullString{"darkred", true}}, + {NullString{"eel", false}, "", NullString{"", false}}, + {"foo", NullString{"black", false}, nil}, + }} + nullTestRun(t, spec) +} + +func TestNullInt64Param(t *testing.T) { + spec := nullTestSpec{"nullint64", "int64", [6]nullTestRow{ + {NullInt64{31, true}, 1, NullInt64{31, true}}, + {NullInt64{-22, false}, 1, NullInt64{0, false}}, + {22, 1, NullInt64{22, true}}, + {NullInt64{33, true}, 1, NullInt64{33, true}}, + {NullInt64{222, false}, 1, NullInt64{0, false}}, + {0, NullInt64{31, false}, nil}, + }} + nullTestRun(t, spec) +} + +func TestNullInt32Param(t *testing.T) { + spec := nullTestSpec{"nullint32", "int32", [6]nullTestRow{ + {NullInt32{31, true}, 1, NullInt32{31, true}}, + {NullInt32{-22, false}, 1, NullInt32{0, false}}, + {22, 1, NullInt32{22, true}}, + {NullInt32{33, true}, 1, NullInt32{33, true}}, + {NullInt32{222, false}, 1, NullInt32{0, false}}, + {0, NullInt32{31, false}, nil}, + }} + nullTestRun(t, spec) +} + +func TestNullInt16Param(t *testing.T) { + spec := nullTestSpec{"nullint16", "int16", [6]nullTestRow{ + {NullInt16{31, true}, 1, NullInt16{31, true}}, + {NullInt16{-22, false}, 1, NullInt16{0, false}}, + {22, 1, NullInt16{22, true}}, + {NullInt16{33, true}, 1, NullInt16{33, true}}, + {NullInt16{222, false}, 1, NullInt16{0, false}}, + {0, NullInt16{31, false}, nil}, + }} + nullTestRun(t, spec) +} + +func TestNullByteParam(t *testing.T) { + spec := nullTestSpec{"nullbyte", "byte", [6]nullTestRow{ + {NullByte{31, true}, 1, NullByte{31, true}}, + {NullByte{0, false}, 1, NullByte{0, false}}, + {22, 1, NullByte{22, true}}, + {NullByte{33, true}, 1, NullByte{33, true}}, + {NullByte{222, false}, 1, NullByte{0, false}}, + {0, NullByte{31, false}, nil}, + }} + nullTestRun(t, spec) +} + +func TestNullFloat64Param(t *testing.T) { + spec := nullTestSpec{"nullfloat64", "float64", [6]nullTestRow{ + {NullFloat64{31.2, true}, 1, NullFloat64{31.2, true}}, + {NullFloat64{13.1, false}, 1, NullFloat64{0, false}}, + {-22.9, 1, NullFloat64{-22.9, true}}, + {NullFloat64{33.81, true}, 1, NullFloat64{33.81, true}}, + {NullFloat64{222, false}, 1, NullFloat64{0, false}}, + {10, NullFloat64{31.2, false}, nil}, + }} + nullTestRun(t, spec) +} + +func TestNullBoolParam(t *testing.T) { + spec := nullTestSpec{"nullbool", "bool", [6]nullTestRow{ + {NullBool{false, true}, true, NullBool{false, true}}, + {NullBool{true, false}, false, NullBool{false, false}}, + {true, true, NullBool{true, true}}, + {NullBool{true, true}, false, NullBool{true, true}}, + {NullBool{true, false}, true, NullBool{false, false}}, + {true, NullBool{true, false}, nil}, + }} + nullTestRun(t, spec) +} + +func TestNullTimeParam(t *testing.T) { + t0 := time.Time{} + t1 := time.Date(2000, 1, 1, 8, 9, 10, 11, time.UTC) + t2 := time.Date(2010, 1, 1, 8, 9, 10, 11, time.UTC) + spec := nullTestSpec{"nulldatetime", "datetime", [6]nullTestRow{ + {NullTime{t1, true}, t2, NullTime{t1, true}}, + {NullTime{t1, false}, t2, NullTime{t0, false}}, + {t1, t2, NullTime{t1, true}}, + {NullTime{t1, true}, t2, NullTime{t1, true}}, + {NullTime{t1, false}, t2, NullTime{t0, false}}, + {t2, NullTime{t1, false}, nil}, + }} + nullTestRun(t, spec) +} + +func nullTestRun(t *testing.T, spec nullTestSpec) { + db := newTestDB(t, "") + defer closeDB(t, db) + exec(t, db, fmt.Sprintf("CREATE|t|id=int32,name=string,nullf=%s,notnullf=%s", spec.nullType, spec.notNullType)) + + // Inserts with db.Exec: + exec(t, db, "INSERT|t|id=?,name=?,nullf=?,notnullf=?", 1, "alice", spec.rows[0].nullParam, spec.rows[0].notNullParam) + exec(t, db, "INSERT|t|id=?,name=?,nullf=?,notnullf=?", 2, "bob", spec.rows[1].nullParam, spec.rows[1].notNullParam) + + // Inserts with a prepared statement: + stmt, err := db.Prepare("INSERT|t|id=?,name=?,nullf=?,notnullf=?") + if err != nil { + t.Fatalf("prepare: %v", err) + } + defer stmt.Close() + if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil { + t.Errorf("exec insert chris: %v", err) + } + if _, err := stmt.Exec(4, "dave", spec.rows[3].nullParam, spec.rows[3].notNullParam); err != nil { + t.Errorf("exec insert dave: %v", err) + } + if _, err := stmt.Exec(5, "eleanor", spec.rows[4].nullParam, spec.rows[4].notNullParam); err != nil { + t.Errorf("exec insert eleanor: %v", err) + } + + // Can't put null val into non-null col + if _, err := stmt.Exec(6, "bob", spec.rows[5].nullParam, spec.rows[5].notNullParam); err == nil { + t.Errorf("expected error inserting nil val with prepared statement Exec") + } + + _, err = db.Exec("INSERT|t|id=?,name=?,nullf=?", 999, nil, nil) + if err == nil { + // TODO: this test fails, but it's just because + // fakeConn implements the optional Execer interface, + // so arguably this is the correct behavior. But + // maybe I should flesh out the fakeConn.Exec + // implementation so this properly fails. + // t.Errorf("expected error inserting nil name with Exec") + } + + paramtype := reflect.TypeOf(spec.rows[0].nullParam) + bindVal := reflect.New(paramtype).Interface() + + for i := 0; i < 5; i++ { + id := i + 1 + if err := db.QueryRow("SELECT|t|nullf|id=?", id).Scan(bindVal); err != nil { + t.Errorf("id=%d Scan: %v", id, err) + } + bindValDeref := reflect.ValueOf(bindVal).Elem().Interface() + if !reflect.DeepEqual(bindValDeref, spec.rows[i].scanNullVal) { + t.Errorf("id=%d got %#v, want %#v", id, bindValDeref, spec.rows[i].scanNullVal) + } + } +} + +// golang.org/issue/4859 +func TestQueryRowNilScanDest(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + var name *string // nil pointer + err := db.QueryRow("SELECT|people|name|").Scan(name) + want := `sql: Scan error on column index 0, name "name": destination pointer is nil` + if err == nil || err.Error() != want { + t.Errorf("error = %q; want %q", err.Error(), want) + } +} + +func TestIssue4902(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + driver := db.Driver().(*fakeDriver) + opens0 := driver.openCount + + var stmt *Stmt + var err error + for i := 0; i < 10; i++ { + stmt, err = db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + } + + opens := driver.openCount - opens0 + if opens > 1 { + t.Errorf("opens = %d; want <= 1", opens) + t.Logf("db = %#v", db) + t.Logf("driver = %#v", driver) + t.Logf("stmt = %#v", stmt) + } +} + +// Issue 3857 +// This used to deadlock. +func TestSimultaneousQueries(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + r1, err := tx.Query("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + defer r1.Close() + + r2, err := tx.Query("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + defer r2.Close() +} + +func TestMaxIdleConns(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + tx.Commit() + if got := len(db.freeConn); got != 1 { + t.Errorf("freeConns = %d; want 1", got) + } + + db.SetMaxIdleConns(0) + + if got := len(db.freeConn); got != 0 { + t.Errorf("freeConns after set to zero = %d; want 0", got) + } + + tx, err = db.Begin() + if err != nil { + t.Fatal(err) + } + tx.Commit() + if got := len(db.freeConn); got != 0 { + t.Errorf("freeConns = %d; want 0", got) + } +} + +func TestMaxOpenConns(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + defer setHookpostCloseConn(nil) + setHookpostCloseConn(func(_ *fakeConn, err error) { + if err != nil { + t.Errorf("Error closing fakeConn: %v", err) + } + }) + + db := newTestDB(t, "magicquery") + defer closeDB(t, db) + + driver := db.Driver().(*fakeDriver) + + // Force the number of open connections to 0 so we can get an accurate + // count for the test + db.clearAllConns(t) + + driver.mu.Lock() + opens0 := driver.openCount + closes0 := driver.closeCount + driver.mu.Unlock() + + db.SetMaxIdleConns(10) + db.SetMaxOpenConns(10) + + stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?") + if err != nil { + t.Fatal(err) + } + + // Start 50 parallel slow queries. + const ( + nquery = 50 + sleepMillis = 25 + nbatch = 2 + ) + var wg sync.WaitGroup + for batch := 0; batch < nbatch; batch++ { + for i := 0; i < nquery; i++ { + wg.Add(1) + go func() { + defer wg.Done() + var op string + if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows { + t.Error(err) + } + }() + } + // Wait for the batch of queries above to finish before starting the next round. + wg.Wait() + } + + if g, w := db.numFreeConns(), 10; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPoll(t, 20); n > 20 { + t.Errorf("number of dependencies = %d; expected <= 20", n) + db.dumpDeps(t) + } + + driver.mu.Lock() + opens := driver.openCount - opens0 + closes := driver.closeCount - closes0 + driver.mu.Unlock() + + if opens > 10 { + t.Logf("open calls = %d", opens) + t.Logf("close calls = %d", closes) + t.Errorf("db connections opened = %d; want <= 10", opens) + db.dumpDeps(t) + } + + if err := stmt.Close(); err != nil { + t.Fatal(err) + } + + if g, w := db.numFreeConns(), 10; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPoll(t, 10); n > 10 { + t.Errorf("number of dependencies = %d; expected <= 10", n) + db.dumpDeps(t) + } + + db.SetMaxOpenConns(5) + + if g, w := db.numFreeConns(), 5; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPoll(t, 5); n > 5 { + t.Errorf("number of dependencies = %d; expected 0", n) + db.dumpDeps(t) + } + + db.SetMaxOpenConns(0) + + if g, w := db.numFreeConns(), 5; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPoll(t, 5); n > 5 { + t.Errorf("number of dependencies = %d; expected 0", n) + db.dumpDeps(t) + } + + db.clearAllConns(t) +} + +// Issue 9453: tests that SetMaxOpenConns can be lowered at runtime +// and affects the subsequent release of connections. +func TestMaxOpenConnsOnBusy(t *testing.T) { + defer setHookpostCloseConn(nil) + setHookpostCloseConn(func(_ *fakeConn, err error) { + if err != nil { + t.Errorf("Error closing fakeConn: %v", err) + } + }) + + db := newTestDB(t, "magicquery") + defer closeDB(t, db) + + db.SetMaxOpenConns(3) + + ctx := context.Background() + + conn0, err := db.conn(ctx, cachedOrNewConn) + if err != nil { + t.Fatalf("db open conn fail: %v", err) + } + + conn1, err := db.conn(ctx, cachedOrNewConn) + if err != nil { + t.Fatalf("db open conn fail: %v", err) + } + + conn2, err := db.conn(ctx, cachedOrNewConn) + if err != nil { + t.Fatalf("db open conn fail: %v", err) + } + + if g, w := db.numOpen, 3; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + db.SetMaxOpenConns(2) + if g, w := db.numOpen, 3; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + conn0.releaseConn(nil) + conn1.releaseConn(nil) + if g, w := db.numOpen, 2; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + conn2.releaseConn(nil) + if g, w := db.numOpen, 2; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } +} + +// Issue 10886: tests that all connection attempts return when more than +// DB.maxOpen connections are in flight and the first DB.maxOpen fail. +func TestPendingConnsAfterErr(t *testing.T) { + const ( + maxOpen = 2 + tryOpen = maxOpen*2 + 2 + ) + + // No queries will be run. + db, err := Open("test", fakeDBName) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer closeDB(t, db) + defer func() { + for k, v := range db.lastPut { + t.Logf("%p: %v", k, v) + } + }() + + db.SetMaxOpenConns(maxOpen) + db.SetMaxIdleConns(0) + + errOffline := errors.New("db offline") + + defer func() { setHookOpenErr(nil) }() + + errs := make(chan error, tryOpen) + + var opening sync.WaitGroup + opening.Add(tryOpen) + + setHookOpenErr(func() error { + // Wait for all connections to enqueue. + opening.Wait() + return errOffline + }) + + for i := 0; i < tryOpen; i++ { + go func() { + opening.Done() // signal one connection is in flight + _, err := db.Exec("will never run") + errs <- err + }() + } + + opening.Wait() // wait for all workers to begin running + + const timeout = 5 * time.Second + to := time.NewTimer(timeout) + defer to.Stop() + + // check that all connections fail without deadlock + for i := 0; i < tryOpen; i++ { + select { + case err := <-errs: + if got, want := err, errOffline; got != want { + t.Errorf("unexpected err: got %v, want %v", got, want) + } + case <-to.C: + t.Fatalf("orphaned connection request(s), still waiting after %v", timeout) + } + } + + // Wait a reasonable time for the database to close all connections. + tick := time.NewTicker(3 * time.Millisecond) + defer tick.Stop() + for { + select { + case <-tick.C: + db.mu.Lock() + if db.numOpen == 0 { + db.mu.Unlock() + return + } + db.mu.Unlock() + case <-to.C: + // Closing the database will check for numOpen and fail the test. + return + } + } +} + +func TestSingleOpenConn(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + db.SetMaxOpenConns(1) + + rows, err := db.Query("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + if err = rows.Close(); err != nil { + t.Fatal(err) + } + // shouldn't deadlock + rows, err = db.Query("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + if err = rows.Close(); err != nil { + t.Fatal(err) + } +} + +func TestStats(t *testing.T) { + db := newTestDB(t, "people") + stats := db.Stats() + if got := stats.OpenConnections; got != 1 { + t.Errorf("stats.OpenConnections = %d; want 1", got) + } + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + tx.Commit() + + closeDB(t, db) + stats = db.Stats() + if got := stats.OpenConnections; got != 0 { + t.Errorf("stats.OpenConnections = %d; want 0", got) + } +} + +func TestConnMaxLifetime(t *testing.T) { + t0 := time.Unix(1000000, 0) + offset := time.Duration(0) + + nowFunc = func() time.Time { return t0.Add(offset) } + defer func() { nowFunc = time.Now }() + + db := newTestDB(t, "magicquery") + defer closeDB(t, db) + + driver := db.Driver().(*fakeDriver) + + // Force the number of open connections to 0 so we can get an accurate + // count for the test + db.clearAllConns(t) + + driver.mu.Lock() + opens0 := driver.openCount + closes0 := driver.closeCount + driver.mu.Unlock() + + db.SetMaxIdleConns(10) + db.SetMaxOpenConns(10) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + + offset = time.Second + tx2, err := db.Begin() + if err != nil { + t.Fatal(err) + } + + tx.Commit() + tx2.Commit() + + driver.mu.Lock() + opens := driver.openCount - opens0 + closes := driver.closeCount - closes0 + driver.mu.Unlock() + + if opens != 2 { + t.Errorf("opens = %d; want 2", opens) + } + if closes != 0 { + t.Errorf("closes = %d; want 0", closes) + } + if g, w := db.numFreeConns(), 2; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + // Expire first conn + offset = 11 * time.Second + db.SetConnMaxLifetime(10 * time.Second) + if err != nil { + t.Fatal(err) + } + + tx, err = db.Begin() + if err != nil { + t.Fatal(err) + } + tx2, err = db.Begin() + if err != nil { + t.Fatal(err) + } + tx.Commit() + tx2.Commit() + + // Give connectionCleaner chance to run. + waitCondition(t, func() bool { + driver.mu.Lock() + opens = driver.openCount - opens0 + closes = driver.closeCount - closes0 + driver.mu.Unlock() + + return closes == 1 + }) + + if opens != 3 { + t.Errorf("opens = %d; want 3", opens) + } + if closes != 1 { + t.Errorf("closes = %d; want 1", closes) + } + + if s := db.Stats(); s.MaxLifetimeClosed != 1 { + t.Errorf("MaxLifetimeClosed = %d; want 1 %#v", s.MaxLifetimeClosed, s) + } +} + +// golang.org/issue/5323 +func TestStmtCloseDeps(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + defer setHookpostCloseConn(nil) + setHookpostCloseConn(func(_ *fakeConn, err error) { + if err != nil { + t.Errorf("Error closing fakeConn: %v", err) + } + }) + + db := newTestDB(t, "magicquery") + defer closeDB(t, db) + + driver := db.Driver().(*fakeDriver) + + driver.mu.Lock() + opens0 := driver.openCount + closes0 := driver.closeCount + driver.mu.Unlock() + openDelta0 := opens0 - closes0 + + stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?") + if err != nil { + t.Fatal(err) + } + + // Start 50 parallel slow queries. + const ( + nquery = 50 + sleepMillis = 25 + nbatch = 2 + ) + var wg sync.WaitGroup + for batch := 0; batch < nbatch; batch++ { + for i := 0; i < nquery; i++ { + wg.Add(1) + go func() { + defer wg.Done() + var op string + if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows { + t.Error(err) + } + }() + } + // Wait for the batch of queries above to finish before starting the next round. + wg.Wait() + } + + if g, w := db.numFreeConns(), 2; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPoll(t, 4); n > 4 { + t.Errorf("number of dependencies = %d; expected <= 4", n) + db.dumpDeps(t) + } + + driver.mu.Lock() + opens := driver.openCount - opens0 + closes := driver.closeCount - closes0 + openDelta := (driver.openCount - driver.closeCount) - openDelta0 + driver.mu.Unlock() + + if openDelta > 2 { + t.Logf("open calls = %d", opens) + t.Logf("close calls = %d", closes) + t.Logf("open delta = %d", openDelta) + t.Errorf("db connections opened = %d; want <= 2", openDelta) + db.dumpDeps(t) + } + + if !waitCondition(t, func() bool { + return len(stmt.css) <= nquery + }) { + t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery) + } + + if err := stmt.Close(); err != nil { + t.Fatal(err) + } + + if g, w := db.numFreeConns(), 2; g != w { + t.Errorf("free conns = %d; want %d", g, w) + } + + if n := db.numDepsPoll(t, 2); n > 2 { + t.Errorf("number of dependencies = %d; expected <= 2", n) + db.dumpDeps(t) + } + + db.clearAllConns(t) +} + +// golang.org/issue/5046 +func TestCloseConnBeforeStmts(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + defer setHookpostCloseConn(nil) + setHookpostCloseConn(func(_ *fakeConn, err error) { + if err != nil { + t.Errorf("Error closing fakeConn: %v; from %s", err, stack()) + db.dumpDeps(t) + t.Errorf("DB = %#v", db) + } + }) + + stmt, err := db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + + if len(db.freeConn) != 1 { + t.Fatalf("expected 1 freeConn; got %d", len(db.freeConn)) + } + dc := db.freeConn[0] + if dc.closed { + t.Errorf("conn shouldn't be closed") + } + + if n := len(dc.openStmt); n != 1 { + t.Errorf("driverConn num openStmt = %d; want 1", n) + } + err = db.Close() + if err != nil { + t.Errorf("db Close = %v", err) + } + if !dc.closed { + t.Errorf("after db.Close, driverConn should be closed") + } + if n := len(dc.openStmt); n != 0 { + t.Errorf("driverConn num openStmt = %d; want 0", n) + } + + err = stmt.Close() + if err != nil { + t.Errorf("Stmt close = %v", err) + } + + if !dc.closed { + t.Errorf("conn should be closed") + } + if dc.ci != nil { + t.Errorf("after Stmt Close, driverConn's Conn interface should be nil") + } +} + +// golang.org/issue/5283: don't release the Rows' connection in Close +// before calling Stmt.Close. +func TestRowsCloseOrder(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + db.SetMaxIdleConns(0) + setStrictFakeConnClose(t) + defer setStrictFakeConnClose(nil) + + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + err = rows.Close() + if err != nil { + t.Fatal(err) + } +} + +func TestRowsImplicitClose(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + + want, fail := 2, errors.New("fail") + r := rows.rowsi.(*rowsCursor) + r.errPos, r.err = want, fail + + got := 0 + for rows.Next() { + got++ + } + if got != want { + t.Errorf("got %d rows, want %d", got, want) + } + if err := rows.Err(); err != fail { + t.Errorf("got error %v, want %v", err, fail) + } + if !r.closed { + t.Errorf("r.closed is false, want true") + } +} + +func TestRowsCloseError(t *testing.T) { + db := newTestDB(t, "people") + defer db.Close() + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + age int + name string + } + got := []row{} + + rc, ok := rows.rowsi.(*rowsCursor) + if !ok { + t.Fatal("not using *rowsCursor") + } + rc.closeErr = errors.New("rowsCursor: failed to close") + + for rows.Next() { + var r row + err = rows.Scan(&r.age, &r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got = append(got, r) + } + err = rows.Err() + if err != rc.closeErr { + t.Fatalf("unexpected err: got %v, want %v", err, rc.closeErr) + } +} + +func TestStmtCloseOrder(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + db.SetMaxIdleConns(0) + setStrictFakeConnClose(t) + defer setStrictFakeConnClose(nil) + + _, err := db.Query("SELECT|non_existent|name|") + if err == nil { + t.Fatal("Querying non-existent table should fail") + } +} + +// Test cases where there's more than maxBadConnRetries bad connections in the +// pool (issue 8834) +func TestManyErrBadConn(t *testing.T) { + manyErrBadConnSetup := func(first ...func(db *DB)) *DB { + db := newTestDB(t, "people") + + for _, f := range first { + f(db) + } + + nconn := maxBadConnRetries + 1 + db.SetMaxIdleConns(nconn) + db.SetMaxOpenConns(nconn) + // open enough connections + func() { + for i := 0; i < nconn; i++ { + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + } + }() + + db.mu.Lock() + defer db.mu.Unlock() + if db.numOpen != nconn { + t.Fatalf("unexpected numOpen %d (was expecting %d)", db.numOpen, nconn) + } else if len(db.freeConn) != nconn { + t.Fatalf("unexpected len(db.freeConn) %d (was expecting %d)", len(db.freeConn), nconn) + } + for _, conn := range db.freeConn { + conn.Lock() + conn.ci.(*fakeConn).stickyBad = true + conn.Unlock() + } + return db + } + + // Query + db := manyErrBadConnSetup() + defer closeDB(t, db) + rows, err := db.Query("SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + if err = rows.Close(); err != nil { + t.Fatal(err) + } + + // Exec + db = manyErrBadConnSetup() + defer closeDB(t, db) + _, err = db.Exec("INSERT|people|name=Julia,age=19") + if err != nil { + t.Fatal(err) + } + + // Begin + db = manyErrBadConnSetup() + defer closeDB(t, db) + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + if err = tx.Rollback(); err != nil { + t.Fatal(err) + } + + // Prepare + db = manyErrBadConnSetup() + defer closeDB(t, db) + stmt, err := db.Prepare("SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + if err = stmt.Close(); err != nil { + t.Fatal(err) + } + + // Stmt.Exec + db = manyErrBadConnSetup(func(db *DB) { + stmt, err = db.Prepare("INSERT|people|name=Julia,age=19") + if err != nil { + t.Fatal(err) + } + }) + defer closeDB(t, db) + _, err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + if err = stmt.Close(); err != nil { + t.Fatal(err) + } + + // Stmt.Query + db = manyErrBadConnSetup(func(db *DB) { + stmt, err = db.Prepare("SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + }) + defer closeDB(t, db) + rows, err = stmt.Query() + if err != nil { + t.Fatal(err) + } + if err = rows.Close(); err != nil { + t.Fatal(err) + } + if err = stmt.Close(); err != nil { + t.Fatal(err) + } + + // Conn + db = manyErrBadConnSetup() + defer closeDB(t, db) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + conn.dc.ci.(*fakeConn).skipDirtySession = true + err = conn.Close() + if err != nil { + t.Fatal(err) + } + + // Ping + db = manyErrBadConnSetup() + defer closeDB(t, db) + err = db.PingContext(ctx) + if err != nil { + t.Fatal(err) + } +} + +// Issue 34775: Ensure that a Tx cannot commit after a rollback. +func TestTxCannotCommitAfterRollback(t *testing.T) { + db := newTestDB(t, "tx_status") + defer closeDB(t, db) + + // First check query reporting is correct. + var txStatus string + err := db.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus) + if err != nil { + t.Fatal(err) + } + if g, w := txStatus, "autocommit"; g != w { + t.Fatalf("tx_status=%q, wanted %q", g, w) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + + // Ignore dirty session for this test. + // A failing test should trigger the dirty session flag as well, + // but that isn't exactly what this should test for. + tx.txi.(*fakeTx).c.skipDirtySession = true + + defer tx.Rollback() + + err = tx.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus) + if err != nil { + t.Fatal(err) + } + if g, w := txStatus, "transaction"; g != w { + t.Fatalf("tx_status=%q, wanted %q", g, w) + } + + // 1. Begin a transaction. + // 2. (A) Start a query, (B) begin Tx rollback through a ctx cancel. + // 3. Check if 2.A has committed in Tx (pass) or outside of Tx (fail). + sendQuery := make(chan struct{}) + // The Tx status is returned through the row results, ensure + // that the rows results are not canceled. + bypassRowsAwaitDone = true + hookTxGrabConn = func() { + cancel() + <-sendQuery + } + rollbackHook = func() { + close(sendQuery) + } + defer func() { + hookTxGrabConn = nil + rollbackHook = nil + bypassRowsAwaitDone = false + }() + + err = tx.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus) + if err != nil { + // A failure here would be expected if skipDirtySession was not set to true above. + t.Fatal(err) + } + if g, w := txStatus, "transaction"; g != w { + t.Fatalf("tx_status=%q, wanted %q", g, w) + } +} + +// Issue 40985 transaction statement deadlock while context cancel. +func TestTxStmtDeadlock(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + + stmt, err := tx.Prepare("SELECT|people|name,age|age=?") + if err != nil { + t.Fatal(err) + } + cancel() + // Run number of stmt queries to reproduce deadlock from context cancel + for i := 0; i < 1e3; i++ { + // Encounter any close related errors (e.g. ErrTxDone, stmt is closed) + // is expected due to context cancel. + _, err = stmt.Query(1) + if err != nil { + break + } + } + _ = tx.Rollback() +} + +// Issue32530 encounters an issue where a connection may +// expire right after it comes out of a used connection pool +// even when a new connection is requested. +func TestConnExpiresFreshOutOfPool(t *testing.T) { + execCases := []struct { + expired bool + badReset bool + }{ + {false, false}, + {true, false}, + {false, true}, + } + + t0 := time.Unix(1000000, 0) + offset := time.Duration(0) + offsetMu := sync.RWMutex{} + + nowFunc = func() time.Time { + offsetMu.RLock() + defer offsetMu.RUnlock() + return t0.Add(offset) + } + defer func() { nowFunc = time.Now }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + db := newTestDB(t, "magicquery") + defer closeDB(t, db) + + db.SetMaxOpenConns(1) + + for _, ec := range execCases { + ec := ec + name := fmt.Sprintf("expired=%t,badReset=%t", ec.expired, ec.badReset) + t.Run(name, func(t *testing.T) { + db.clearAllConns(t) + + db.SetMaxIdleConns(1) + db.SetConnMaxLifetime(10 * time.Second) + + conn, err := db.conn(ctx, alwaysNewConn) + if err != nil { + t.Fatal(err) + } + + afterPutConn := make(chan struct{}) + waitingForConn := make(chan struct{}) + + go func() { + defer close(afterPutConn) + + conn, err := db.conn(ctx, alwaysNewConn) + if err == nil { + db.putConn(conn, err, false) + } else { + t.Errorf("db.conn: %v", err) + } + }() + go func() { + defer close(waitingForConn) + + for { + if t.Failed() { + return + } + db.mu.Lock() + ct := len(db.connRequests) + db.mu.Unlock() + if ct > 0 { + return + } + time.Sleep(pollDuration) + } + }() + + <-waitingForConn + + if t.Failed() { + return + } + + offsetMu.Lock() + if ec.expired { + offset = 11 * time.Second + } else { + offset = time.Duration(0) + } + offsetMu.Unlock() + + conn.ci.(*fakeConn).stickyBad = ec.badReset + + db.putConn(conn, err, true) + + <-afterPutConn + }) + } +} + +// TestIssue20575 ensures the Rows from query does not block +// closing a transaction. Ensure Rows is closed while closing a transaction. +func TestIssue20575(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + _, err = tx.QueryContext(ctx, "SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + // Do not close Rows from QueryContext. + err = tx.Rollback() + if err != nil { + t.Fatal(err) + } + select { + default: + case <-ctx.Done(): + t.Fatal("timeout: failed to rollback query without closing rows:", ctx.Err()) + } +} + +// TestIssue20622 tests closing the transaction before rows is closed, requires +// the race detector to fail. +func TestIssue20622(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + + rows, err := tx.Query("SELECT|people|age,name|") + if err != nil { + t.Fatal(err) + } + + count := 0 + for rows.Next() { + count++ + var age int + var name string + if err := rows.Scan(&age, &name); err != nil { + t.Fatal("scan failed", err) + } + + if count == 1 { + cancel() + } + time.Sleep(100 * time.Millisecond) + } + rows.Close() + tx.Commit() +} + +// golang.org/issue/5718 +func TestErrBadConnReconnect(t *testing.T) { + db := newTestDB(t, "foo") + defer closeDB(t, db) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + + simulateBadConn := func(name string, hook *func() bool, op func() error) { + broken, retried := false, false + numOpen := db.numOpen + + // simulate a broken connection on the first try + *hook = func() bool { + if !broken { + broken = true + return true + } + retried = true + return false + } + + if err := op(); err != nil { + t.Errorf(name+": %v", err) + return + } + + if !broken || !retried { + t.Error(name + ": Failed to simulate broken connection") + } + *hook = nil + + if numOpen != db.numOpen { + t.Errorf(name+": leaked %d connection(s)!", db.numOpen-numOpen) + numOpen = db.numOpen + } + } + + // db.Exec + dbExec := func() error { + _, err := db.Exec("INSERT|t1|name=?,age=?,dead=?", "Gordon", 3, true) + return err + } + simulateBadConn("db.Exec prepare", &hookPrepareBadConn, dbExec) + simulateBadConn("db.Exec exec", &hookExecBadConn, dbExec) + + // db.Query + dbQuery := func() error { + rows, err := db.Query("SELECT|t1|age,name|") + if err == nil { + err = rows.Close() + } + return err + } + simulateBadConn("db.Query prepare", &hookPrepareBadConn, dbQuery) + simulateBadConn("db.Query query", &hookQueryBadConn, dbQuery) + + // db.Prepare + simulateBadConn("db.Prepare", &hookPrepareBadConn, func() error { + stmt, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?") + if err != nil { + return err + } + stmt.Close() + return nil + }) + + // Provide a way to force a re-prepare of a statement on next execution + forcePrepare := func(stmt *Stmt) { + stmt.css = nil + } + + // stmt.Exec + stmt1, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?") + if err != nil { + t.Fatalf("prepare: %v", err) + } + defer stmt1.Close() + // make sure we must prepare the stmt first + forcePrepare(stmt1) + + stmtExec := func() error { + _, err := stmt1.Exec("Gopher", 3, false) + return err + } + simulateBadConn("stmt.Exec prepare", &hookPrepareBadConn, stmtExec) + simulateBadConn("stmt.Exec exec", &hookExecBadConn, stmtExec) + + // stmt.Query + stmt2, err := db.Prepare("SELECT|t1|age,name|") + if err != nil { + t.Fatalf("prepare: %v", err) + } + defer stmt2.Close() + // make sure we must prepare the stmt first + forcePrepare(stmt2) + + stmtQuery := func() error { + rows, err := stmt2.Query() + if err == nil { + err = rows.Close() + } + return err + } + simulateBadConn("stmt.Query prepare", &hookPrepareBadConn, stmtQuery) + simulateBadConn("stmt.Query exec", &hookQueryBadConn, stmtQuery) +} + +// golang.org/issue/11264 +func TestTxEndBadConn(t *testing.T) { + db := newTestDB(t, "foo") + defer closeDB(t, db) + db.SetMaxIdleConns(0) + exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") + db.SetMaxIdleConns(1) + + simulateBadConn := func(name string, hook *func() bool, op func() error) { + broken := false + numOpen := db.numOpen + + *hook = func() bool { + if !broken { + broken = true + } + return broken + } + + if err := op(); !errors.Is(err, driver.ErrBadConn) { + t.Errorf(name+": %v", err) + return + } + + if !broken { + t.Error(name + ": Failed to simulate broken connection") + } + *hook = nil + + if numOpen != db.numOpen { + t.Errorf(name+": leaked %d connection(s)!", db.numOpen-numOpen) + } + } + + // db.Exec + dbExec := func(endTx func(tx *Tx) error) func() error { + return func() error { + tx, err := db.Begin() + if err != nil { + return err + } + _, err = tx.Exec("INSERT|t1|name=?,age=?,dead=?", "Gordon", 3, true) + if err != nil { + return err + } + return endTx(tx) + } + } + simulateBadConn("db.Tx.Exec commit", &hookCommitBadConn, dbExec((*Tx).Commit)) + simulateBadConn("db.Tx.Exec rollback", &hookRollbackBadConn, dbExec((*Tx).Rollback)) + + // db.Query + dbQuery := func(endTx func(tx *Tx) error) func() error { + return func() error { + tx, err := db.Begin() + if err != nil { + return err + } + rows, err := tx.Query("SELECT|t1|age,name|") + if err == nil { + err = rows.Close() + } else { + return err + } + return endTx(tx) + } + } + simulateBadConn("db.Tx.Query commit", &hookCommitBadConn, dbQuery((*Tx).Commit)) + simulateBadConn("db.Tx.Query rollback", &hookRollbackBadConn, dbQuery((*Tx).Rollback)) +} + +type concurrentTest interface { + init(t testing.TB, db *DB) + finish(t testing.TB) + test(t testing.TB) error +} + +type concurrentDBQueryTest struct { + db *DB +} + +func (c *concurrentDBQueryTest) init(t testing.TB, db *DB) { + c.db = db +} + +func (c *concurrentDBQueryTest) finish(t testing.TB) { + c.db = nil +} + +func (c *concurrentDBQueryTest) test(t testing.TB) error { + rows, err := c.db.Query("SELECT|people|name|") + if err != nil { + t.Error(err) + return err + } + var name string + for rows.Next() { + rows.Scan(&name) + } + rows.Close() + return nil +} + +type concurrentDBExecTest struct { + db *DB +} + +func (c *concurrentDBExecTest) init(t testing.TB, db *DB) { + c.db = db +} + +func (c *concurrentDBExecTest) finish(t testing.TB) { + c.db = nil +} + +func (c *concurrentDBExecTest) test(t testing.TB) error { + _, err := c.db.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday) + if err != nil { + t.Error(err) + return err + } + return nil +} + +type concurrentStmtQueryTest struct { + db *DB + stmt *Stmt +} + +func (c *concurrentStmtQueryTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.stmt, err = db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentStmtQueryTest) finish(t testing.TB) { + if c.stmt != nil { + c.stmt.Close() + c.stmt = nil + } + c.db = nil +} + +func (c *concurrentStmtQueryTest) test(t testing.TB) error { + rows, err := c.stmt.Query() + if err != nil { + t.Errorf("error on query: %v", err) + return err + } + + var name string + for rows.Next() { + rows.Scan(&name) + } + rows.Close() + return nil +} + +type concurrentStmtExecTest struct { + db *DB + stmt *Stmt +} + +func (c *concurrentStmtExecTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.stmt, err = db.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?") + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentStmtExecTest) finish(t testing.TB) { + if c.stmt != nil { + c.stmt.Close() + c.stmt = nil + } + c.db = nil +} + +func (c *concurrentStmtExecTest) test(t testing.TB) error { + _, err := c.stmt.Exec(3, chrisBirthday) + if err != nil { + t.Errorf("error on exec: %v", err) + return err + } + return nil +} + +type concurrentTxQueryTest struct { + db *DB + tx *Tx +} + +func (c *concurrentTxQueryTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.tx, err = c.db.Begin() + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentTxQueryTest) finish(t testing.TB) { + if c.tx != nil { + c.tx.Rollback() + c.tx = nil + } + c.db = nil +} + +func (c *concurrentTxQueryTest) test(t testing.TB) error { + rows, err := c.db.Query("SELECT|people|name|") + if err != nil { + t.Error(err) + return err + } + var name string + for rows.Next() { + rows.Scan(&name) + } + rows.Close() + return nil +} + +type concurrentTxExecTest struct { + db *DB + tx *Tx +} + +func (c *concurrentTxExecTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.tx, err = c.db.Begin() + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentTxExecTest) finish(t testing.TB) { + if c.tx != nil { + c.tx.Rollback() + c.tx = nil + } + c.db = nil +} + +func (c *concurrentTxExecTest) test(t testing.TB) error { + _, err := c.tx.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday) + if err != nil { + t.Error(err) + return err + } + return nil +} + +type concurrentTxStmtQueryTest struct { + db *DB + tx *Tx + stmt *Stmt +} + +func (c *concurrentTxStmtQueryTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.tx, err = c.db.Begin() + if err != nil { + t.Fatal(err) + } + c.stmt, err = c.tx.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentTxStmtQueryTest) finish(t testing.TB) { + if c.stmt != nil { + c.stmt.Close() + c.stmt = nil + } + if c.tx != nil { + c.tx.Rollback() + c.tx = nil + } + c.db = nil +} + +func (c *concurrentTxStmtQueryTest) test(t testing.TB) error { + rows, err := c.stmt.Query() + if err != nil { + t.Errorf("error on query: %v", err) + return err + } + + var name string + for rows.Next() { + rows.Scan(&name) + } + rows.Close() + return nil +} + +type concurrentTxStmtExecTest struct { + db *DB + tx *Tx + stmt *Stmt +} + +func (c *concurrentTxStmtExecTest) init(t testing.TB, db *DB) { + c.db = db + var err error + c.tx, err = c.db.Begin() + if err != nil { + t.Fatal(err) + } + c.stmt, err = c.tx.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?") + if err != nil { + t.Fatal(err) + } +} + +func (c *concurrentTxStmtExecTest) finish(t testing.TB) { + if c.stmt != nil { + c.stmt.Close() + c.stmt = nil + } + if c.tx != nil { + c.tx.Rollback() + c.tx = nil + } + c.db = nil +} + +func (c *concurrentTxStmtExecTest) test(t testing.TB) error { + _, err := c.stmt.Exec(3, chrisBirthday) + if err != nil { + t.Errorf("error on exec: %v", err) + return err + } + return nil +} + +type concurrentRandomTest struct { + tests []concurrentTest +} + +func (c *concurrentRandomTest) init(t testing.TB, db *DB) { + c.tests = []concurrentTest{ + new(concurrentDBQueryTest), + new(concurrentDBExecTest), + new(concurrentStmtQueryTest), + new(concurrentStmtExecTest), + new(concurrentTxQueryTest), + new(concurrentTxExecTest), + new(concurrentTxStmtQueryTest), + new(concurrentTxStmtExecTest), + } + for _, ct := range c.tests { + ct.init(t, db) + } +} + +func (c *concurrentRandomTest) finish(t testing.TB) { + for _, ct := range c.tests { + ct.finish(t) + } +} + +func (c *concurrentRandomTest) test(t testing.TB) error { + ct := c.tests[rand.Intn(len(c.tests))] + return ct.test(t) +} + +func doConcurrentTest(t testing.TB, ct concurrentTest) { + maxProcs, numReqs := 1, 500 + if testing.Short() { + maxProcs, numReqs = 4, 50 + } + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) + + db := newTestDB(t, "people") + defer closeDB(t, db) + + ct.init(t, db) + defer ct.finish(t) + + var wg sync.WaitGroup + wg.Add(numReqs) + + reqs := make(chan bool) + defer close(reqs) + + for i := 0; i < maxProcs*2; i++ { + go func() { + for range reqs { + err := ct.test(t) + if err != nil { + wg.Done() + continue + } + wg.Done() + } + }() + } + + for i := 0; i < numReqs; i++ { + reqs <- true + } + + wg.Wait() +} + +func TestIssue6081(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + drv := db.Driver().(*fakeDriver) + drv.mu.Lock() + opens0 := drv.openCount + closes0 := drv.closeCount + drv.mu.Unlock() + + stmt, err := db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + setRowsCloseHook(func(rows *Rows, err *error) { + *err = driver.ErrBadConn + }) + defer setRowsCloseHook(nil) + for i := 0; i < 10; i++ { + rows, err := stmt.Query() + if err != nil { + t.Fatal(err) + } + rows.Close() + } + if n := len(stmt.css); n > 1 { + t.Errorf("len(css slice) = %d; want <= 1", n) + } + stmt.Close() + if n := len(stmt.css); n != 0 { + t.Errorf("len(css slice) after Close = %d; want 0", n) + } + + drv.mu.Lock() + opens := drv.openCount - opens0 + closes := drv.closeCount - closes0 + drv.mu.Unlock() + if opens < 9 { + t.Errorf("opens = %d; want >= 9", opens) + } + if closes < 9 { + t.Errorf("closes = %d; want >= 9", closes) + } +} + +// TestIssue18429 attempts to stress rolling back the transaction from a +// context cancel while simultaneously calling Tx.Rollback. Rolling back from a +// context happens concurrently so tx.rollback and tx.Commit must guard against +// double entry. +// +// In the test, a context is canceled while the query is in process so +// the internal rollback will run concurrently with the explicitly called +// Tx.Rollback. +// +// The addition of calling rows.Next also tests +// Issue 21117. +func TestIssue18429(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx := context.Background() + sem := make(chan bool, 20) + var wg sync.WaitGroup + + const milliWait = 30 + + for i := 0; i < 100; i++ { + sem <- true + wg.Add(1) + go func() { + defer func() { + <-sem + wg.Done() + }() + qwait := (time.Duration(rand.Intn(milliWait)) * time.Millisecond).String() + + ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond) + defer cancel() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return + } + // This is expected to give a cancel error most, but not all the time. + // Test failure will happen with a panic or other race condition being + // reported. + rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|") + if rows != nil { + var name string + // Call Next to test Issue 21117 and check for races. + for rows.Next() { + // Scan the buffer so it is read and checked for races. + rows.Scan(&name) + } + rows.Close() + } + // This call will race with the context cancel rollback to complete + // if the rollback itself isn't guarded. + tx.Rollback() + }() + } + wg.Wait() +} + +// TestIssue20160 attempts to test a short context life on a stmt Query. +func TestIssue20160(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx := context.Background() + sem := make(chan bool, 20) + var wg sync.WaitGroup + + const milliWait = 30 + + stmt, err := db.PrepareContext(ctx, "SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + for i := 0; i < 100; i++ { + sem <- true + wg.Add(1) + go func() { + defer func() { + <-sem + wg.Done() + }() + ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond) + defer cancel() + + // This is expected to give a cancel error most, but not all the time. + // Test failure will happen with a panic or other race condition being + // reported. + rows, _ := stmt.QueryContext(ctx) + if rows != nil { + rows.Close() + } + }() + } + wg.Wait() +} + +// TestIssue18719 closes the context right before use. The sql.driverConn +// will nil out the ci on close in a lock, but if another process uses it right after +// it will panic with on the nil ref. +// +// See https://golang.org/cl/35550 . +func TestIssue18719(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + + hookTxGrabConn = func() { + cancel() + + // Wait for the context to cancel and tx to rollback. + for tx.isDone() == false { + time.Sleep(pollDuration) + } + } + defer func() { hookTxGrabConn = nil }() + + // This call will grab the connection and cancel the context + // after it has done so. Code after must deal with the canceled state. + _, err = tx.QueryContext(ctx, "SELECT|people|name|") + if err != nil { + t.Fatalf("expected error %v but got %v", nil, err) + } + + // Rows may be ignored because it will be closed when the context is canceled. + + // Do not explicitly rollback. The rollback will happen from the + // canceled context. + + cancel() +} + +func TestIssue20647(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + conn.dc.ci.(*fakeConn).skipDirtySession = true + defer conn.Close() + + stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + rows1, err := stmt.QueryContext(ctx) + if err != nil { + t.Fatal("rows1", err) + } + defer rows1.Close() + + rows2, err := stmt.QueryContext(ctx) + if err != nil { + t.Fatal("rows2", err) + } + defer rows2.Close() + + if rows1.dc != rows2.dc { + t.Fatal("stmt prepared on Conn does not use same connection") + } +} + +func TestConcurrency(t *testing.T) { + list := []struct { + name string + ct concurrentTest + }{ + {"Query", new(concurrentDBQueryTest)}, + {"Exec", new(concurrentDBExecTest)}, + {"StmtQuery", new(concurrentStmtQueryTest)}, + {"StmtExec", new(concurrentStmtExecTest)}, + {"TxQuery", new(concurrentTxQueryTest)}, + {"TxExec", new(concurrentTxExecTest)}, + {"TxStmtQuery", new(concurrentTxStmtQueryTest)}, + {"TxStmtExec", new(concurrentTxStmtExecTest)}, + {"Random", new(concurrentRandomTest)}, + } + for _, item := range list { + t.Run(item.name, func(t *testing.T) { + doConcurrentTest(t, item.ct) + }) + } +} + +func TestConnectionLeak(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + // Start by opening defaultMaxIdleConns + rows := make([]*Rows, defaultMaxIdleConns) + // We need to SetMaxOpenConns > MaxIdleConns, so the DB can open + // a new connection and we can fill the idle queue with the released + // connections. + db.SetMaxOpenConns(len(rows) + 1) + for ii := range rows { + r, err := db.Query("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + r.Next() + if err := r.Err(); err != nil { + t.Fatal(err) + } + rows[ii] = r + } + // Now we have defaultMaxIdleConns busy connections. Open + // a new one, but wait until the busy connections are released + // before returning control to DB. + drv := db.Driver().(*fakeDriver) + drv.waitCh = make(chan struct{}, 1) + drv.waitingCh = make(chan struct{}, 1) + var wg sync.WaitGroup + wg.Add(1) + go func() { + r, err := db.Query("SELECT|people|name|") + if err != nil { + t.Error(err) + return + } + r.Close() + wg.Done() + }() + // Wait until the goroutine we've just created has started waiting. + <-drv.waitingCh + // Now close the busy connections. This provides a connection for + // the blocked goroutine and then fills up the idle queue. + for _, v := range rows { + v.Close() + } + // At this point we give the new connection to DB. This connection is + // now useless, since the idle queue is full and there are no pending + // requests. DB should deal with this situation without leaking the + // connection. + drv.waitCh <- struct{}{} + wg.Wait() +} + +func TestStatsMaxIdleClosedZero(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + db.SetConnMaxLifetime(0) + + preMaxIdleClosed := db.Stats().MaxIdleClosed + + for i := 0; i < 10; i++ { + rows, err := db.Query("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + rows.Close() + } + + st := db.Stats() + maxIdleClosed := st.MaxIdleClosed - preMaxIdleClosed + t.Logf("MaxIdleClosed: %d", maxIdleClosed) + if maxIdleClosed != 0 { + t.Fatal("expected 0 max idle closed conns, got: ", maxIdleClosed) + } +} + +func TestStatsMaxIdleClosedTen(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(0) + db.SetConnMaxLifetime(0) + + preMaxIdleClosed := db.Stats().MaxIdleClosed + + for i := 0; i < 10; i++ { + rows, err := db.Query("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + rows.Close() + } + + st := db.Stats() + maxIdleClosed := st.MaxIdleClosed - preMaxIdleClosed + t.Logf("MaxIdleClosed: %d", maxIdleClosed) + if maxIdleClosed != 10 { + t.Fatal("expected 0 max idle closed conns, got: ", maxIdleClosed) + } +} + +// testUseConns uses count concurrent connections with 1 nanosecond apart. +// Returns the returnedAt time of the final connection. +func testUseConns(t *testing.T, count int, tm time.Time, db *DB) time.Time { + conns := make([]*Conn, count) + ctx := context.Background() + for i := range conns { + tm = tm.Add(time.Nanosecond) + nowFunc = func() time.Time { + return tm + } + c, err := db.Conn(ctx) + if err != nil { + t.Error(err) + } + conns[i] = c + } + + for i := len(conns) - 1; i >= 0; i-- { + tm = tm.Add(time.Nanosecond) + nowFunc = func() time.Time { + return tm + } + if err := conns[i].Close(); err != nil { + t.Error(err) + } + } + + return tm +} + +func TestMaxIdleTime(t *testing.T) { + usedConns := 5 + reusedConns := 2 + list := []struct { + wantMaxIdleTime time.Duration + wantMaxLifetime time.Duration + wantNextCheck time.Duration + wantIdleClosed int64 + wantMaxIdleClosed int64 + timeOffset time.Duration + secondTimeOffset time.Duration + }{ + { + time.Millisecond, + 0, + time.Millisecond - time.Nanosecond, + int64(usedConns - reusedConns), + int64(usedConns - reusedConns), + 10 * time.Millisecond, + 0, + }, + { + // Want to close some connections via max idle time and one by max lifetime. + time.Millisecond, + // nowFunc() - MaxLifetime should be 1 * time.Nanosecond in connectionCleanerRunLocked. + // This guarantees that first opened connection is to be closed. + // Thus it is timeOffset + secondTimeOffset + 3 (+2 for Close while reusing conns and +1 for Conn). + 10*time.Millisecond + 100*time.Nanosecond + 3*time.Nanosecond, + time.Nanosecond, + // Closed all not reused connections and extra one by max lifetime. + int64(usedConns - reusedConns + 1), + int64(usedConns - reusedConns), + 10 * time.Millisecond, + // Add second offset because otherwise connections are expired via max lifetime in Close. + 100 * time.Nanosecond, + }, + { + time.Hour, + 0, + time.Second, + 0, + 0, + 10 * time.Millisecond, + 0}, + } + baseTime := time.Unix(0, 0) + defer func() { + nowFunc = time.Now + }() + for _, item := range list { + nowFunc = func() time.Time { + return baseTime + } + t.Run(fmt.Sprintf("%v", item.wantMaxIdleTime), func(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + db.SetMaxOpenConns(usedConns) + db.SetMaxIdleConns(usedConns) + db.SetConnMaxIdleTime(item.wantMaxIdleTime) + db.SetConnMaxLifetime(item.wantMaxLifetime) + + preMaxIdleClosed := db.Stats().MaxIdleTimeClosed + + // Busy usedConns. + testUseConns(t, usedConns, baseTime, db) + + tm := baseTime.Add(item.timeOffset) + + // Reuse connections which should never be considered idle + // and exercises the sorting for issue 39471. + tm = testUseConns(t, reusedConns, tm, db) + + tm = tm.Add(item.secondTimeOffset) + nowFunc = func() time.Time { + return tm + } + + db.mu.Lock() + nc, closing := db.connectionCleanerRunLocked(time.Second) + if nc != item.wantNextCheck { + t.Errorf("got %v; want %v next check duration", nc, item.wantNextCheck) + } + + // Validate freeConn order. + var last time.Time + for _, c := range db.freeConn { + if last.After(c.returnedAt) { + t.Error("freeConn is not ordered by returnedAt") + break + } + last = c.returnedAt + } + + db.mu.Unlock() + for _, c := range closing { + c.Close() + } + if g, w := int64(len(closing)), item.wantIdleClosed; g != w { + t.Errorf("got: %d; want %d closed conns", g, w) + } + + st := db.Stats() + maxIdleClosed := st.MaxIdleTimeClosed - preMaxIdleClosed + if g, w := maxIdleClosed, item.wantMaxIdleClosed; g != w { + t.Errorf("got: %d; want %d max idle closed conns", g, w) + } + }) + } +} + +type nvcDriver struct { + fakeDriver + skipNamedValueCheck bool +} + +func (d *nvcDriver) Open(dsn string) (driver.Conn, error) { + c, err := d.fakeDriver.Open(dsn) + fc := c.(*fakeConn) + fc.db.allowAny = true + return &nvcConn{fc, d.skipNamedValueCheck}, err +} + +type nvcConn struct { + *fakeConn + skipNamedValueCheck bool +} + +type decimalInt struct { + value int +} + +type doNotInclude struct{} + +var _ driver.NamedValueChecker = &nvcConn{} + +func (c *nvcConn) CheckNamedValue(nv *driver.NamedValue) error { + if c.skipNamedValueCheck { + return driver.ErrSkip + } + switch v := nv.Value.(type) { + default: + return driver.ErrSkip + case Out: + switch ov := v.Dest.(type) { + default: + return errors.New("unknown NameValueCheck OUTPUT type") + case *string: + *ov = "from-server" + nv.Value = "OUT:*string" + } + return nil + case decimalInt, []int64: + return nil + case doNotInclude: + return driver.ErrRemoveArgument + } +} + +func TestNamedValueChecker(t *testing.T) { + Register("NamedValueCheck", &nvcDriver{}) + db, err := Open("NamedValueCheck", "") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err = db.ExecContext(ctx, "WIPE") + if err != nil { + t.Fatal("exec wipe", err) + } + + _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any,str1=string,out1=string,array1=any") + if err != nil { + t.Fatal("exec create", err) + } + + o1 := "" + _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A,str1=?,out1=?O1,array1=?", Named("A", decimalInt{123}), "hello", Named("O1", Out{Dest: &o1}), []int64{42, 128, 707}, doNotInclude{}) + if err != nil { + t.Fatal("exec insert", err) + } + var ( + str1 string + dec1 decimalInt + arr1 []int64 + ) + err = db.QueryRowContext(ctx, "SELECT|keys|dec1,str1,array1|").Scan(&dec1, &str1, &arr1) + if err != nil { + t.Fatal("select", err) + } + + list := []struct{ got, want any }{ + {o1, "from-server"}, + {dec1, decimalInt{123}}, + {str1, "hello"}, + {arr1, []int64{42, 128, 707}}, + } + + for index, item := range list { + if !reflect.DeepEqual(item.got, item.want) { + t.Errorf("got %#v wanted %#v for index %d", item.got, item.want, index) + } + } +} + +func TestNamedValueCheckerSkip(t *testing.T) { + Register("NamedValueCheckSkip", &nvcDriver{skipNamedValueCheck: true}) + db, err := Open("NamedValueCheckSkip", "") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + _, err = db.ExecContext(ctx, "WIPE") + if err != nil { + t.Fatal("exec wipe", err) + } + + _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any") + if err != nil { + t.Fatal("exec create", err) + } + + _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A", Named("A", decimalInt{123})) + if err == nil { + t.Fatalf("expected error with bad argument, got %v", err) + } +} + +func TestOpenConnector(t *testing.T) { + Register("testctx", &fakeDriverCtx{}) + db, err := Open("testctx", "people") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + c, ok := db.connector.(*fakeConnector) + if !ok { + t.Fatal("not using *fakeConnector") + } + + if err := db.Close(); err != nil { + t.Fatal(err) + } + + if !c.closed { + t.Fatal("connector is not closed") + } +} + +type ctxOnlyDriver struct { + fakeDriver +} + +func (d *ctxOnlyDriver) Open(dsn string) (driver.Conn, error) { + conn, err := d.fakeDriver.Open(dsn) + if err != nil { + return nil, err + } + return &ctxOnlyConn{fc: conn.(*fakeConn)}, nil +} + +var ( + _ driver.Conn = &ctxOnlyConn{} + _ driver.QueryerContext = &ctxOnlyConn{} + _ driver.ExecerContext = &ctxOnlyConn{} +) + +type ctxOnlyConn struct { + fc *fakeConn + + queryCtxCalled bool + execCtxCalled bool +} + +func (c *ctxOnlyConn) Begin() (driver.Tx, error) { + return c.fc.Begin() +} + +func (c *ctxOnlyConn) Close() error { + return c.fc.Close() +} + +// Prepare is still part of the Conn interface, so while it isn't used +// must be defined for compatibility. +func (c *ctxOnlyConn) Prepare(q string) (driver.Stmt, error) { + panic("not used") +} + +func (c *ctxOnlyConn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) { + return c.fc.PrepareContext(ctx, q) +} + +func (c *ctxOnlyConn) QueryContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Rows, error) { + c.queryCtxCalled = true + return c.fc.QueryContext(ctx, q, args) +} + +func (c *ctxOnlyConn) ExecContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Result, error) { + c.execCtxCalled = true + return c.fc.ExecContext(ctx, q, args) +} + +// TestQueryExecContextOnly ensures drivers only need to implement QueryContext +// and ExecContext methods. +func TestQueryExecContextOnly(t *testing.T) { + // Ensure connection does not implement non-context interfaces. + var connType driver.Conn = &ctxOnlyConn{} + if _, ok := connType.(driver.Execer); ok { + t.Fatalf("%T must not implement driver.Execer", connType) + } + if _, ok := connType.(driver.Queryer); ok { + t.Fatalf("%T must not implement driver.Queryer", connType) + } + + Register("ContextOnly", &ctxOnlyDriver{}) + db, err := Open("ContextOnly", "") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal("db.Conn", err) + } + defer conn.Close() + coc := conn.dc.ci.(*ctxOnlyConn) + coc.fc.skipDirtySession = true + + _, err = conn.ExecContext(ctx, "WIPE") + if err != nil { + t.Fatal("exec wipe", err) + } + + _, err = conn.ExecContext(ctx, "CREATE|keys|v1=string") + if err != nil { + t.Fatal("exec create", err) + } + expectedValue := "value1" + _, err = conn.ExecContext(ctx, "INSERT|keys|v1=?", expectedValue) + if err != nil { + t.Fatal("exec insert", err) + } + rows, err := conn.QueryContext(ctx, "SELECT|keys|v1|") + if err != nil { + t.Fatal("query select", err) + } + v1 := "" + for rows.Next() { + err = rows.Scan(&v1) + if err != nil { + t.Fatal("rows scan", err) + } + } + rows.Close() + + if v1 != expectedValue { + t.Fatalf("expected %q, got %q", expectedValue, v1) + } + + if !coc.execCtxCalled { + t.Error("ExecContext not called") + } + if !coc.queryCtxCalled { + t.Error("QueryContext not called") + } +} + +type alwaysErrScanner struct{} + +var errTestScanWrap = errors.New("errTestScanWrap") + +func (alwaysErrScanner) Scan(any) error { + return errTestScanWrap +} + +// Issue 38099: Ensure that Rows.Scan properly wraps underlying errors. +func TestRowsScanProperlyWrapsErrors(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + rows, err := db.Query("SELECT|people|age|") + if err != nil { + t.Fatalf("Query: %v", err) + } + + var res alwaysErrScanner + + for rows.Next() { + err = rows.Scan(&res) + if err == nil { + t.Fatal("expecting back an error") + } + if !errors.Is(err, errTestScanWrap) { + t.Fatalf("errors.Is mismatch\n%v\nWant: %v", err, errTestScanWrap) + } + // Ensure that error substring matching still correctly works. + if !strings.Contains(err.Error(), errTestScanWrap.Error()) { + t.Fatalf("Error %v does not contain %v", err, errTestScanWrap) + } + } +} + +// badConn implements a bad driver.Conn, for TestBadDriver. +// The Exec method panics. +type badConn struct{} + +func (bc badConn) Prepare(query string) (driver.Stmt, error) { + return nil, errors.New("badConn Prepare") +} + +func (bc badConn) Close() error { + return nil +} + +func (bc badConn) Begin() (driver.Tx, error) { + return nil, errors.New("badConn Begin") +} + +func (bc badConn) Exec(query string, args []driver.Value) (driver.Result, error) { + panic("badConn.Exec") +} + +// badDriver is a driver.Driver that uses badConn. +type badDriver struct{} + +func (bd badDriver) Open(name string) (driver.Conn, error) { + return badConn{}, nil +} + +// Issue 15901. +func TestBadDriver(t *testing.T) { + Register("bad", badDriver{}) + db, err := Open("bad", "ignored") + if err != nil { + t.Fatal(err) + } + defer func() { + if r := recover(); r == nil { + t.Error("expected panic") + } else { + if want := "badConn.Exec"; r.(string) != want { + t.Errorf("panic was %v, expected %v", r, want) + } + } + }() + defer db.Close() + db.Exec("ignored") +} + +type pingDriver struct { + fails bool +} + +type pingConn struct { + badConn + driver *pingDriver +} + +var pingError = errors.New("Ping failed") + +func (pc pingConn) Ping(ctx context.Context) error { + if pc.driver.fails { + return pingError + } + return nil +} + +var _ driver.Pinger = pingConn{} + +func (pd *pingDriver) Open(name string) (driver.Conn, error) { + return pingConn{driver: pd}, nil +} + +func TestPing(t *testing.T) { + driver := &pingDriver{} + Register("ping", driver) + + db, err := Open("ping", "ignored") + if err != nil { + t.Fatal(err) + } + + if err := db.Ping(); err != nil { + t.Errorf("err was %#v, expected nil", err) + return + } + + driver.fails = true + if err := db.Ping(); err != pingError { + t.Errorf("err was %#v, expected pingError", err) + } +} + +// Issue 18101. +func TestTypedString(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + type Str string + var scanned Str + + err := db.QueryRow("SELECT|people|name|name=?", "Alice").Scan(&scanned) + if err != nil { + t.Fatal(err) + } + expected := Str("Alice") + if scanned != expected { + t.Errorf("expected %+v, got %+v", expected, scanned) + } +} + +func BenchmarkConcurrentDBExec(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentDBExecTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentStmtQuery(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentStmtQueryTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentStmtExec(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentStmtExecTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentTxQuery(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentTxQueryTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentTxExec(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentTxExecTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentTxStmtQuery(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentTxStmtQueryTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentTxStmtExec(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentTxStmtExecTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkConcurrentRandom(b *testing.B) { + b.ReportAllocs() + ct := new(concurrentRandomTest) + for i := 0; i < b.N; i++ { + doConcurrentTest(b, ct) + } +} + +func BenchmarkManyConcurrentQueries(b *testing.B) { + b.ReportAllocs() + // To see lock contention in Go 1.4, 16~ cores and 128~ goroutines are required. + const parallelism = 16 + + db := newTestDB(b, "magicquery") + defer closeDB(b, db) + db.SetMaxIdleConns(runtime.GOMAXPROCS(0) * parallelism) + + stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?") + if err != nil { + b.Fatal(err) + } + defer stmt.Close() + + b.SetParallelism(parallelism) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + rows, err := stmt.Query("sleep", 1) + if err != nil { + b.Error(err) + return + } + rows.Close() + } + }) +} |