summaryrefslogtreecommitdiffstats
path: root/src/database/sql
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 19:19:13 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-04-16 19:19:13 +0000
commitccd992355df7192993c666236047820244914598 (patch)
treef00fea65147227b7743083c6148396f74cd66935 /src/database/sql
parentInitial commit. (diff)
downloadgolang-1.21-ccd992355df7192993c666236047820244914598.tar.xz
golang-1.21-ccd992355df7192993c666236047820244914598.zip
Adding upstream version 1.21.8.upstream/1.21.8
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/database/sql')
-rw-r--r--src/database/sql/convert.go591
-rw-r--r--src/database/sql/convert_test.go597
-rw-r--r--src/database/sql/ctxutil.go146
-rw-r--r--src/database/sql/doc.txt46
-rw-r--r--src/database/sql/driver/driver.go552
-rw-r--r--src/database/sql/driver/types.go297
-rw-r--r--src/database/sql/driver/types_test.go95
-rw-r--r--src/database/sql/example_cli_test.go84
-rw-r--r--src/database/sql/example_service_test.go158
-rw-r--r--src/database/sql/example_test.go369
-rw-r--r--src/database/sql/fakedb_test.go1283
-rw-r--r--src/database/sql/sql.go3503
-rw-r--r--src/database/sql/sql_test.go4752
13 files changed, 12473 insertions, 0 deletions
diff --git a/src/database/sql/convert.go b/src/database/sql/convert.go
new file mode 100644
index 0000000..ffc4e49
--- /dev/null
+++ b/src/database/sql/convert.go
@@ -0,0 +1,591 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Type conversions for Scan.
+
+package sql
+
+import (
+ "bytes"
+ "database/sql/driver"
+ "errors"
+ "fmt"
+ "reflect"
+ "strconv"
+ "time"
+ "unicode"
+ "unicode/utf8"
+)
+
+var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
+
+func describeNamedValue(nv *driver.NamedValue) string {
+ if len(nv.Name) == 0 {
+ return fmt.Sprintf("$%d", nv.Ordinal)
+ }
+ return fmt.Sprintf("with name %q", nv.Name)
+}
+
+func validateNamedValueName(name string) error {
+ if len(name) == 0 {
+ return nil
+ }
+ r, _ := utf8.DecodeRuneInString(name)
+ if unicode.IsLetter(r) {
+ return nil
+ }
+ return fmt.Errorf("name %q does not begin with a letter", name)
+}
+
+// ccChecker wraps the driver.ColumnConverter and allows it to be used
+// as if it were a NamedValueChecker. If the driver ColumnConverter
+// is not present then the NamedValueChecker will return driver.ErrSkip.
+type ccChecker struct {
+ cci driver.ColumnConverter
+ want int
+}
+
+func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
+ if c.cci == nil {
+ return driver.ErrSkip
+ }
+ // The column converter shouldn't be called on any index
+ // it isn't expecting. The final error will be thrown
+ // in the argument converter loop.
+ index := nv.Ordinal - 1
+ if c.want <= index {
+ return nil
+ }
+
+ // First, see if the value itself knows how to convert
+ // itself to a driver type. For example, a NullString
+ // struct changing into a string or nil.
+ if vr, ok := nv.Value.(driver.Valuer); ok {
+ sv, err := callValuerValue(vr)
+ if err != nil {
+ return err
+ }
+ if !driver.IsValue(sv) {
+ return fmt.Errorf("non-subset type %T returned from Value", sv)
+ }
+ nv.Value = sv
+ }
+
+ // Second, ask the column to sanity check itself. For
+ // example, drivers might use this to make sure that
+ // an int64 values being inserted into a 16-bit
+ // integer field is in range (before getting
+ // truncated), or that a nil can't go into a NOT NULL
+ // column before going across the network to get the
+ // same error.
+ var err error
+ arg := nv.Value
+ nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
+ if err != nil {
+ return err
+ }
+ if !driver.IsValue(nv.Value) {
+ return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value)
+ }
+ return nil
+}
+
+// defaultCheckNamedValue wraps the default ColumnConverter to have the same
+// function signature as the CheckNamedValue in the driver.NamedValueChecker
+// interface.
+func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
+ nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
+ return err
+}
+
+// driverArgsConnLocked converts arguments from callers of Stmt.Exec and
+// Stmt.Query into driver Values.
+//
+// The statement ds may be nil, if no statement is available.
+//
+// ci must be locked.
+func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []any) ([]driver.NamedValue, error) {
+ nvargs := make([]driver.NamedValue, len(args))
+
+ // -1 means the driver doesn't know how to count the number of
+ // placeholders, so we won't sanity check input here and instead let the
+ // driver deal with errors.
+ want := -1
+
+ var si driver.Stmt
+ var cc ccChecker
+ if ds != nil {
+ si = ds.si
+ want = ds.si.NumInput()
+ cc.want = want
+ }
+
+ // Check all types of interfaces from the start.
+ // Drivers may opt to use the NamedValueChecker for special
+ // argument types, then return driver.ErrSkip to pass it along
+ // to the column converter.
+ nvc, ok := si.(driver.NamedValueChecker)
+ if !ok {
+ nvc, ok = ci.(driver.NamedValueChecker)
+ }
+ cci, ok := si.(driver.ColumnConverter)
+ if ok {
+ cc.cci = cci
+ }
+
+ // Loop through all the arguments, checking each one.
+ // If no error is returned simply increment the index
+ // and continue. However if driver.ErrRemoveArgument
+ // is returned the argument is not included in the query
+ // argument list.
+ var err error
+ var n int
+ for _, arg := range args {
+ nv := &nvargs[n]
+ if np, ok := arg.(NamedArg); ok {
+ if err = validateNamedValueName(np.Name); err != nil {
+ return nil, err
+ }
+ arg = np.Value
+ nv.Name = np.Name
+ }
+ nv.Ordinal = n + 1
+ nv.Value = arg
+
+ // Checking sequence has four routes:
+ // A: 1. Default
+ // B: 1. NamedValueChecker 2. Column Converter 3. Default
+ // C: 1. NamedValueChecker 3. Default
+ // D: 1. Column Converter 2. Default
+ //
+ // The only time a Column Converter is called is first
+ // or after NamedValueConverter. If first it is handled before
+ // the nextCheck label. Thus for repeats tries only when the
+ // NamedValueConverter is selected should the Column Converter
+ // be used in the retry.
+ checker := defaultCheckNamedValue
+ nextCC := false
+ switch {
+ case nvc != nil:
+ nextCC = cci != nil
+ checker = nvc.CheckNamedValue
+ case cci != nil:
+ checker = cc.CheckNamedValue
+ }
+
+ nextCheck:
+ err = checker(nv)
+ switch err {
+ case nil:
+ n++
+ continue
+ case driver.ErrRemoveArgument:
+ nvargs = nvargs[:len(nvargs)-1]
+ continue
+ case driver.ErrSkip:
+ if nextCC {
+ nextCC = false
+ checker = cc.CheckNamedValue
+ } else {
+ checker = defaultCheckNamedValue
+ }
+ goto nextCheck
+ default:
+ return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err)
+ }
+ }
+
+ // Check the length of arguments after conversion to allow for omitted
+ // arguments.
+ if want != -1 && len(nvargs) != want {
+ return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
+ }
+
+ return nvargs, nil
+
+}
+
+// convertAssign is the same as convertAssignRows, but without the optional
+// rows argument.
+func convertAssign(dest, src any) error {
+ return convertAssignRows(dest, src, nil)
+}
+
+// convertAssignRows copies to dest the value in src, converting it if possible.
+// An error is returned if the copy would result in loss of information.
+// dest should be a pointer type. If rows is passed in, the rows will
+// be used as the parent for any cursor values converted from a
+// driver.Rows to a *Rows.
+func convertAssignRows(dest, src any, rows *Rows) error {
+ // Common cases, without reflect.
+ switch s := src.(type) {
+ case string:
+ switch d := dest.(type) {
+ case *string:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = s
+ return nil
+ case *[]byte:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = []byte(s)
+ return nil
+ case *RawBytes:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = append((*d)[:0], s...)
+ return nil
+ }
+ case []byte:
+ switch d := dest.(type) {
+ case *string:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = string(s)
+ return nil
+ case *any:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = bytes.Clone(s)
+ return nil
+ case *[]byte:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = bytes.Clone(s)
+ return nil
+ case *RawBytes:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = s
+ return nil
+ }
+ case time.Time:
+ switch d := dest.(type) {
+ case *time.Time:
+ *d = s
+ return nil
+ case *string:
+ *d = s.Format(time.RFC3339Nano)
+ return nil
+ case *[]byte:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = []byte(s.Format(time.RFC3339Nano))
+ return nil
+ case *RawBytes:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = s.AppendFormat((*d)[:0], time.RFC3339Nano)
+ return nil
+ }
+ case decimalDecompose:
+ switch d := dest.(type) {
+ case decimalCompose:
+ return d.Compose(s.Decompose(nil))
+ }
+ case nil:
+ switch d := dest.(type) {
+ case *any:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = nil
+ return nil
+ case *[]byte:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = nil
+ return nil
+ case *RawBytes:
+ if d == nil {
+ return errNilPtr
+ }
+ *d = nil
+ return nil
+ }
+ // The driver is returning a cursor the client may iterate over.
+ case driver.Rows:
+ switch d := dest.(type) {
+ case *Rows:
+ if d == nil {
+ return errNilPtr
+ }
+ if rows == nil {
+ return errors.New("invalid context to convert cursor rows, missing parent *Rows")
+ }
+ rows.closemu.Lock()
+ *d = Rows{
+ dc: rows.dc,
+ releaseConn: func(error) {},
+ rowsi: s,
+ }
+ // Chain the cancel function.
+ parentCancel := rows.cancel
+ rows.cancel = func() {
+ // When Rows.cancel is called, the closemu will be locked as well.
+ // So we can access rs.lasterr.
+ d.close(rows.lasterr)
+ if parentCancel != nil {
+ parentCancel()
+ }
+ }
+ rows.closemu.Unlock()
+ return nil
+ }
+ }
+
+ var sv reflect.Value
+
+ switch d := dest.(type) {
+ case *string:
+ sv = reflect.ValueOf(src)
+ switch sv.Kind() {
+ case reflect.Bool,
+ reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+ reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
+ reflect.Float32, reflect.Float64:
+ *d = asString(src)
+ return nil
+ }
+ case *[]byte:
+ sv = reflect.ValueOf(src)
+ if b, ok := asBytes(nil, sv); ok {
+ *d = b
+ return nil
+ }
+ case *RawBytes:
+ sv = reflect.ValueOf(src)
+ if b, ok := asBytes([]byte(*d)[:0], sv); ok {
+ *d = RawBytes(b)
+ return nil
+ }
+ case *bool:
+ bv, err := driver.Bool.ConvertValue(src)
+ if err == nil {
+ *d = bv.(bool)
+ }
+ return err
+ case *any:
+ *d = src
+ return nil
+ }
+
+ if scanner, ok := dest.(Scanner); ok {
+ return scanner.Scan(src)
+ }
+
+ dpv := reflect.ValueOf(dest)
+ if dpv.Kind() != reflect.Pointer {
+ return errors.New("destination not a pointer")
+ }
+ if dpv.IsNil() {
+ return errNilPtr
+ }
+
+ if !sv.IsValid() {
+ sv = reflect.ValueOf(src)
+ }
+
+ dv := reflect.Indirect(dpv)
+ if sv.IsValid() && sv.Type().AssignableTo(dv.Type()) {
+ switch b := src.(type) {
+ case []byte:
+ dv.Set(reflect.ValueOf(bytes.Clone(b)))
+ default:
+ dv.Set(sv)
+ }
+ return nil
+ }
+
+ if dv.Kind() == sv.Kind() && sv.Type().ConvertibleTo(dv.Type()) {
+ dv.Set(sv.Convert(dv.Type()))
+ return nil
+ }
+
+ // The following conversions use a string value as an intermediate representation
+ // to convert between various numeric types.
+ //
+ // This also allows scanning into user defined types such as "type Int int64".
+ // For symmetry, also check for string destination types.
+ switch dv.Kind() {
+ case reflect.Pointer:
+ if src == nil {
+ dv.SetZero()
+ return nil
+ }
+ dv.Set(reflect.New(dv.Type().Elem()))
+ return convertAssignRows(dv.Interface(), src, rows)
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ if src == nil {
+ return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
+ }
+ s := asString(src)
+ i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
+ if err != nil {
+ err = strconvErr(err)
+ return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
+ }
+ dv.SetInt(i64)
+ return nil
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ if src == nil {
+ return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
+ }
+ s := asString(src)
+ u64, err := strconv.ParseUint(s, 10, dv.Type().Bits())
+ if err != nil {
+ err = strconvErr(err)
+ return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
+ }
+ dv.SetUint(u64)
+ return nil
+ case reflect.Float32, reflect.Float64:
+ if src == nil {
+ return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
+ }
+ s := asString(src)
+ f64, err := strconv.ParseFloat(s, dv.Type().Bits())
+ if err != nil {
+ err = strconvErr(err)
+ return fmt.Errorf("converting driver.Value type %T (%q) to a %s: %v", src, s, dv.Kind(), err)
+ }
+ dv.SetFloat(f64)
+ return nil
+ case reflect.String:
+ if src == nil {
+ return fmt.Errorf("converting NULL to %s is unsupported", dv.Kind())
+ }
+ switch v := src.(type) {
+ case string:
+ dv.SetString(v)
+ return nil
+ case []byte:
+ dv.SetString(string(v))
+ return nil
+ }
+ }
+
+ return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
+}
+
+func strconvErr(err error) error {
+ if ne, ok := err.(*strconv.NumError); ok {
+ return ne.Err
+ }
+ return err
+}
+
+func asString(src any) string {
+ switch v := src.(type) {
+ case string:
+ return v
+ case []byte:
+ return string(v)
+ }
+ rv := reflect.ValueOf(src)
+ switch rv.Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return strconv.FormatInt(rv.Int(), 10)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return strconv.FormatUint(rv.Uint(), 10)
+ case reflect.Float64:
+ return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
+ case reflect.Float32:
+ return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
+ case reflect.Bool:
+ return strconv.FormatBool(rv.Bool())
+ }
+ return fmt.Sprintf("%v", src)
+}
+
+func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
+ switch rv.Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return strconv.AppendInt(buf, rv.Int(), 10), true
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ return strconv.AppendUint(buf, rv.Uint(), 10), true
+ case reflect.Float32:
+ return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
+ case reflect.Float64:
+ return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
+ case reflect.Bool:
+ return strconv.AppendBool(buf, rv.Bool()), true
+ case reflect.String:
+ s := rv.String()
+ return append(buf, s...), true
+ }
+ return
+}
+
+var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
+
+// callValuerValue returns vr.Value(), with one exception:
+// If vr.Value is an auto-generated method on a pointer type and the
+// pointer is nil, it would panic at runtime in the panicwrap
+// method. Treat it like nil instead.
+// Issue 8415.
+//
+// This is so people can implement driver.Value on value types and
+// still use nil pointers to those types to mean nil/NULL, just like
+// string/*string.
+//
+// This function is mirrored in the database/sql/driver package.
+func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
+ if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer &&
+ rv.IsNil() &&
+ rv.Type().Elem().Implements(valuerReflectType) {
+ return nil, nil
+ }
+ return vr.Value()
+}
+
+// decimal composes or decomposes a decimal value to and from individual parts.
+// There are four parts: a boolean negative flag, a form byte with three possible states
+// (finite=0, infinite=1, NaN=2), a base-2 big-endian integer
+// coefficient (also known as a significand) as a []byte, and an int32 exponent.
+// These are composed into a final value as "decimal = (neg) (form=finite) coefficient * 10 ^ exponent".
+// A zero length coefficient is a zero value.
+// The big-endian integer coefficient stores the most significant byte first (at coefficient[0]).
+// If the form is not finite the coefficient and exponent should be ignored.
+// The negative parameter may be set to true for any form, although implementations are not required
+// to respect the negative parameter in the non-finite form.
+//
+// Implementations may choose to set the negative parameter to true on a zero or NaN value,
+// but implementations that do not differentiate between negative and positive
+// zero or NaN values should ignore the negative parameter without error.
+// If an implementation does not support Infinity it may be converted into a NaN without error.
+// If a value is set that is larger than what is supported by an implementation,
+// an error must be returned.
+// Implementations must return an error if a NaN or Infinity is attempted to be set while neither
+// are supported.
+//
+// NOTE(kardianos): This is an experimental interface. See https://golang.org/issue/30870
+type decimal interface {
+ decimalDecompose
+ decimalCompose
+}
+
+type decimalDecompose interface {
+ // Decompose returns the internal decimal state in parts.
+ // If the provided buf has sufficient capacity, buf may be returned as the coefficient with
+ // the value set and length set as appropriate.
+ Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
+}
+
+type decimalCompose interface {
+ // Compose sets the internal decimal value from parts. If the value cannot be
+ // represented then an error should be returned.
+ Compose(form byte, negative bool, coefficient []byte, exponent int32) error
+}
diff --git a/src/database/sql/convert_test.go b/src/database/sql/convert_test.go
new file mode 100644
index 0000000..6d09fa1
--- /dev/null
+++ b/src/database/sql/convert_test.go
@@ -0,0 +1,597 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sql
+
+import (
+ "database/sql/driver"
+ "fmt"
+ "reflect"
+ "runtime"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+)
+
+var someTime = time.Unix(123, 0)
+var answer int64 = 42
+
+type (
+ userDefined float64
+ userDefinedSlice []int
+ userDefinedString string
+)
+
+type conversionTest struct {
+ s, d any // source and destination
+
+ // following are used if they're non-zero
+ wantint int64
+ wantuint uint64
+ wantstr string
+ wantbytes []byte
+ wantraw RawBytes
+ wantf32 float32
+ wantf64 float64
+ wanttime time.Time
+ wantbool bool // used if d is of type *bool
+ wanterr string
+ wantiface any
+ wantptr *int64 // if non-nil, *d's pointed value must be equal to *wantptr
+ wantnil bool // if true, *d must be *int64(nil)
+ wantusrdef userDefined
+ wantusrstr userDefinedString
+}
+
+// Target variables for scanning into.
+var (
+ scanstr string
+ scanbytes []byte
+ scanraw RawBytes
+ scanint int
+ scanuint8 uint8
+ scanuint16 uint16
+ scanbool bool
+ scanf32 float32
+ scanf64 float64
+ scantime time.Time
+ scanptr *int64
+ scaniface any
+)
+
+func conversionTests() []conversionTest {
+ // Return a fresh instance to test so "go test -count 2" works correctly.
+ return []conversionTest{
+ // Exact conversions (destination pointer type matches source type)
+ {s: "foo", d: &scanstr, wantstr: "foo"},
+ {s: 123, d: &scanint, wantint: 123},
+ {s: someTime, d: &scantime, wanttime: someTime},
+
+ // To strings
+ {s: "string", d: &scanstr, wantstr: "string"},
+ {s: []byte("byteslice"), d: &scanstr, wantstr: "byteslice"},
+ {s: 123, d: &scanstr, wantstr: "123"},
+ {s: int8(123), d: &scanstr, wantstr: "123"},
+ {s: int64(123), d: &scanstr, wantstr: "123"},
+ {s: uint8(123), d: &scanstr, wantstr: "123"},
+ {s: uint16(123), d: &scanstr, wantstr: "123"},
+ {s: uint32(123), d: &scanstr, wantstr: "123"},
+ {s: uint64(123), d: &scanstr, wantstr: "123"},
+ {s: 1.5, d: &scanstr, wantstr: "1.5"},
+
+ // From time.Time:
+ {s: time.Unix(1, 0).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01Z"},
+ {s: time.Unix(1453874597, 0).In(time.FixedZone("here", -3600*8)), d: &scanstr, wantstr: "2016-01-26T22:03:17-08:00"},
+ {s: time.Unix(1, 2).UTC(), d: &scanstr, wantstr: "1970-01-01T00:00:01.000000002Z"},
+ {s: time.Time{}, d: &scanstr, wantstr: "0001-01-01T00:00:00Z"},
+ {s: time.Unix(1, 2).UTC(), d: &scanbytes, wantbytes: []byte("1970-01-01T00:00:01.000000002Z")},
+ {s: time.Unix(1, 2).UTC(), d: &scaniface, wantiface: time.Unix(1, 2).UTC()},
+
+ // To []byte
+ {s: nil, d: &scanbytes, wantbytes: nil},
+ {s: "string", d: &scanbytes, wantbytes: []byte("string")},
+ {s: []byte("byteslice"), d: &scanbytes, wantbytes: []byte("byteslice")},
+ {s: 123, d: &scanbytes, wantbytes: []byte("123")},
+ {s: int8(123), d: &scanbytes, wantbytes: []byte("123")},
+ {s: int64(123), d: &scanbytes, wantbytes: []byte("123")},
+ {s: uint8(123), d: &scanbytes, wantbytes: []byte("123")},
+ {s: uint16(123), d: &scanbytes, wantbytes: []byte("123")},
+ {s: uint32(123), d: &scanbytes, wantbytes: []byte("123")},
+ {s: uint64(123), d: &scanbytes, wantbytes: []byte("123")},
+ {s: 1.5, d: &scanbytes, wantbytes: []byte("1.5")},
+
+ // To RawBytes
+ {s: nil, d: &scanraw, wantraw: nil},
+ {s: []byte("byteslice"), d: &scanraw, wantraw: RawBytes("byteslice")},
+ {s: "string", d: &scanraw, wantraw: RawBytes("string")},
+ {s: 123, d: &scanraw, wantraw: RawBytes("123")},
+ {s: int8(123), d: &scanraw, wantraw: RawBytes("123")},
+ {s: int64(123), d: &scanraw, wantraw: RawBytes("123")},
+ {s: uint8(123), d: &scanraw, wantraw: RawBytes("123")},
+ {s: uint16(123), d: &scanraw, wantraw: RawBytes("123")},
+ {s: uint32(123), d: &scanraw, wantraw: RawBytes("123")},
+ {s: uint64(123), d: &scanraw, wantraw: RawBytes("123")},
+ {s: 1.5, d: &scanraw, wantraw: RawBytes("1.5")},
+ // time.Time has been placed here to check that the RawBytes slice gets
+ // correctly reset when calling time.Time.AppendFormat.
+ {s: time.Unix(2, 5).UTC(), d: &scanraw, wantraw: RawBytes("1970-01-01T00:00:02.000000005Z")},
+
+ // Strings to integers
+ {s: "255", d: &scanuint8, wantuint: 255},
+ {s: "256", d: &scanuint8, wanterr: "converting driver.Value type string (\"256\") to a uint8: value out of range"},
+ {s: "256", d: &scanuint16, wantuint: 256},
+ {s: "-1", d: &scanint, wantint: -1},
+ {s: "foo", d: &scanint, wanterr: "converting driver.Value type string (\"foo\") to a int: invalid syntax"},
+
+ // int64 to smaller integers
+ {s: int64(5), d: &scanuint8, wantuint: 5},
+ {s: int64(256), d: &scanuint8, wanterr: "converting driver.Value type int64 (\"256\") to a uint8: value out of range"},
+ {s: int64(256), d: &scanuint16, wantuint: 256},
+ {s: int64(65536), d: &scanuint16, wanterr: "converting driver.Value type int64 (\"65536\") to a uint16: value out of range"},
+
+ // True bools
+ {s: true, d: &scanbool, wantbool: true},
+ {s: "True", d: &scanbool, wantbool: true},
+ {s: "TRUE", d: &scanbool, wantbool: true},
+ {s: "1", d: &scanbool, wantbool: true},
+ {s: 1, d: &scanbool, wantbool: true},
+ {s: int64(1), d: &scanbool, wantbool: true},
+ {s: uint16(1), d: &scanbool, wantbool: true},
+
+ // False bools
+ {s: false, d: &scanbool, wantbool: false},
+ {s: "false", d: &scanbool, wantbool: false},
+ {s: "FALSE", d: &scanbool, wantbool: false},
+ {s: "0", d: &scanbool, wantbool: false},
+ {s: 0, d: &scanbool, wantbool: false},
+ {s: int64(0), d: &scanbool, wantbool: false},
+ {s: uint16(0), d: &scanbool, wantbool: false},
+
+ // Not bools
+ {s: "yup", d: &scanbool, wanterr: `sql/driver: couldn't convert "yup" into type bool`},
+ {s: 2, d: &scanbool, wanterr: `sql/driver: couldn't convert 2 into type bool`},
+
+ // Floats
+ {s: float64(1.5), d: &scanf64, wantf64: float64(1.5)},
+ {s: int64(1), d: &scanf64, wantf64: float64(1)},
+ {s: float64(1.5), d: &scanf32, wantf32: float32(1.5)},
+ {s: "1.5", d: &scanf32, wantf32: float32(1.5)},
+ {s: "1.5", d: &scanf64, wantf64: float64(1.5)},
+
+ // Pointers
+ {s: any(nil), d: &scanptr, wantnil: true},
+ {s: int64(42), d: &scanptr, wantptr: &answer},
+
+ // To interface{}
+ {s: float64(1.5), d: &scaniface, wantiface: float64(1.5)},
+ {s: int64(1), d: &scaniface, wantiface: int64(1)},
+ {s: "str", d: &scaniface, wantiface: "str"},
+ {s: []byte("byteslice"), d: &scaniface, wantiface: []byte("byteslice")},
+ {s: true, d: &scaniface, wantiface: true},
+ {s: nil, d: &scaniface},
+ {s: []byte(nil), d: &scaniface, wantiface: []byte(nil)},
+
+ // To a user-defined type
+ {s: 1.5, d: new(userDefined), wantusrdef: 1.5},
+ {s: int64(123), d: new(userDefined), wantusrdef: 123},
+ {s: "1.5", d: new(userDefined), wantusrdef: 1.5},
+ {s: []byte{1, 2, 3}, d: new(userDefinedSlice), wanterr: `unsupported Scan, storing driver.Value type []uint8 into type *sql.userDefinedSlice`},
+ {s: "str", d: new(userDefinedString), wantusrstr: "str"},
+
+ // Other errors
+ {s: complex(1, 2), d: &scanstr, wanterr: `unsupported Scan, storing driver.Value type complex128 into type *string`},
+ }
+}
+
+func intPtrValue(intptr any) any {
+ return reflect.Indirect(reflect.Indirect(reflect.ValueOf(intptr))).Int()
+}
+
+func intValue(intptr any) int64 {
+ return reflect.Indirect(reflect.ValueOf(intptr)).Int()
+}
+
+func uintValue(intptr any) uint64 {
+ return reflect.Indirect(reflect.ValueOf(intptr)).Uint()
+}
+
+func float64Value(ptr any) float64 {
+ return *(ptr.(*float64))
+}
+
+func float32Value(ptr any) float32 {
+ return *(ptr.(*float32))
+}
+
+func timeValue(ptr any) time.Time {
+ return *(ptr.(*time.Time))
+}
+
+func TestConversions(t *testing.T) {
+ for n, ct := range conversionTests() {
+ err := convertAssign(ct.d, ct.s)
+ errstr := ""
+ if err != nil {
+ errstr = err.Error()
+ }
+ errf := func(format string, args ...any) {
+ base := fmt.Sprintf("convertAssign #%d: for %v (%T) -> %T, ", n, ct.s, ct.s, ct.d)
+ t.Errorf(base+format, args...)
+ }
+ if errstr != ct.wanterr {
+ errf("got error %q, want error %q", errstr, ct.wanterr)
+ }
+ if ct.wantstr != "" && ct.wantstr != scanstr {
+ errf("want string %q, got %q", ct.wantstr, scanstr)
+ }
+ if ct.wantbytes != nil && string(ct.wantbytes) != string(scanbytes) {
+ errf("want byte %q, got %q", ct.wantbytes, scanbytes)
+ }
+ if ct.wantraw != nil && string(ct.wantraw) != string(scanraw) {
+ errf("want RawBytes %q, got %q", ct.wantraw, scanraw)
+ }
+ if ct.wantint != 0 && ct.wantint != intValue(ct.d) {
+ errf("want int %d, got %d", ct.wantint, intValue(ct.d))
+ }
+ if ct.wantuint != 0 && ct.wantuint != uintValue(ct.d) {
+ errf("want uint %d, got %d", ct.wantuint, uintValue(ct.d))
+ }
+ if ct.wantf32 != 0 && ct.wantf32 != float32Value(ct.d) {
+ errf("want float32 %v, got %v", ct.wantf32, float32Value(ct.d))
+ }
+ if ct.wantf64 != 0 && ct.wantf64 != float64Value(ct.d) {
+ errf("want float32 %v, got %v", ct.wantf64, float64Value(ct.d))
+ }
+ if bp, boolTest := ct.d.(*bool); boolTest && *bp != ct.wantbool && ct.wanterr == "" {
+ errf("want bool %v, got %v", ct.wantbool, *bp)
+ }
+ if !ct.wanttime.IsZero() && !ct.wanttime.Equal(timeValue(ct.d)) {
+ errf("want time %v, got %v", ct.wanttime, timeValue(ct.d))
+ }
+ if ct.wantnil && *ct.d.(**int64) != nil {
+ errf("want nil, got %v", intPtrValue(ct.d))
+ }
+ if ct.wantptr != nil {
+ if *ct.d.(**int64) == nil {
+ errf("want pointer to %v, got nil", *ct.wantptr)
+ } else if *ct.wantptr != intPtrValue(ct.d) {
+ errf("want pointer to %v, got %v", *ct.wantptr, intPtrValue(ct.d))
+ }
+ }
+ if ifptr, ok := ct.d.(*any); ok {
+ if !reflect.DeepEqual(ct.wantiface, scaniface) {
+ errf("want interface %#v, got %#v", ct.wantiface, scaniface)
+ continue
+ }
+ if srcBytes, ok := ct.s.([]byte); ok {
+ dstBytes := (*ifptr).([]byte)
+ if len(srcBytes) > 0 && &dstBytes[0] == &srcBytes[0] {
+ errf("copy into interface{} didn't copy []byte data")
+ }
+ }
+ }
+ if ct.wantusrdef != 0 && ct.wantusrdef != *ct.d.(*userDefined) {
+ errf("want userDefined %f, got %f", ct.wantusrdef, *ct.d.(*userDefined))
+ }
+ if len(ct.wantusrstr) != 0 && ct.wantusrstr != *ct.d.(*userDefinedString) {
+ errf("want userDefined %q, got %q", ct.wantusrstr, *ct.d.(*userDefinedString))
+ }
+ }
+}
+
+func TestNullString(t *testing.T) {
+ var ns NullString
+ convertAssign(&ns, []byte("foo"))
+ if !ns.Valid {
+ t.Errorf("expecting not null")
+ }
+ if ns.String != "foo" {
+ t.Errorf("expecting foo; got %q", ns.String)
+ }
+ convertAssign(&ns, nil)
+ if ns.Valid {
+ t.Errorf("expecting null on nil")
+ }
+ if ns.String != "" {
+ t.Errorf("expecting blank on nil; got %q", ns.String)
+ }
+}
+
+type valueConverterTest struct {
+ c driver.ValueConverter
+ in, out any
+ err string
+}
+
+var valueConverterTests = []valueConverterTest{
+ {driver.DefaultParameterConverter, NullString{"hi", true}, "hi", ""},
+ {driver.DefaultParameterConverter, NullString{"", false}, nil, ""},
+}
+
+func TestValueConverters(t *testing.T) {
+ for i, tt := range valueConverterTests {
+ out, err := tt.c.ConvertValue(tt.in)
+ goterr := ""
+ if err != nil {
+ goterr = err.Error()
+ }
+ if goterr != tt.err {
+ t.Errorf("test %d: %T(%T(%v)) error = %q; want error = %q",
+ i, tt.c, tt.in, tt.in, goterr, tt.err)
+ }
+ if tt.err != "" {
+ continue
+ }
+ if !reflect.DeepEqual(out, tt.out) {
+ t.Errorf("test %d: %T(%T(%v)) = %v (%T); want %v (%T)",
+ i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out)
+ }
+ }
+}
+
+// Tests that assigning to RawBytes doesn't allocate (and also works).
+func TestRawBytesAllocs(t *testing.T) {
+ var tests = []struct {
+ name string
+ in any
+ want string
+ }{
+ {"uint64", uint64(12345678), "12345678"},
+ {"uint32", uint32(1234), "1234"},
+ {"uint16", uint16(12), "12"},
+ {"uint8", uint8(1), "1"},
+ {"uint", uint(123), "123"},
+ {"int", int(123), "123"},
+ {"int8", int8(1), "1"},
+ {"int16", int16(12), "12"},
+ {"int32", int32(1234), "1234"},
+ {"int64", int64(12345678), "12345678"},
+ {"float32", float32(1.5), "1.5"},
+ {"float64", float64(64), "64"},
+ {"bool", false, "false"},
+ {"time", time.Unix(2, 5).UTC(), "1970-01-01T00:00:02.000000005Z"},
+ }
+
+ buf := make(RawBytes, 10)
+ test := func(name string, in any, want string) {
+ if err := convertAssign(&buf, in); err != nil {
+ t.Fatalf("%s: convertAssign = %v", name, err)
+ }
+ match := len(buf) == len(want)
+ if match {
+ for i, b := range buf {
+ if want[i] != b {
+ match = false
+ break
+ }
+ }
+ }
+ if !match {
+ t.Fatalf("%s: got %q (len %d); want %q (len %d)", name, buf, len(buf), want, len(want))
+ }
+ }
+
+ n := testing.AllocsPerRun(100, func() {
+ for _, tt := range tests {
+ test(tt.name, tt.in, tt.want)
+ }
+ })
+
+ // The numbers below are only valid for 64-bit interface word sizes,
+ // and gc. With 32-bit words there are more convT2E allocs, and
+ // with gccgo, only pointers currently go in interface data.
+ // So only care on amd64 gc for now.
+ measureAllocs := runtime.GOARCH == "amd64" && runtime.Compiler == "gc"
+
+ if n > 0.5 && measureAllocs {
+ t.Fatalf("allocs = %v; want 0", n)
+ }
+
+ // This one involves a convT2E allocation, string -> interface{}
+ n = testing.AllocsPerRun(100, func() {
+ test("string", "foo", "foo")
+ })
+ if n > 1.5 && measureAllocs {
+ t.Fatalf("allocs = %v; want max 1", n)
+ }
+}
+
+// https://golang.org/issues/13905
+func TestUserDefinedBytes(t *testing.T) {
+ type userDefinedBytes []byte
+ var u userDefinedBytes
+ v := []byte("foo")
+
+ convertAssign(&u, v)
+ if &u[0] == &v[0] {
+ t.Fatal("userDefinedBytes got potentially dirty driver memory")
+ }
+}
+
+type Valuer_V string
+
+func (v Valuer_V) Value() (driver.Value, error) {
+ return strings.ToUpper(string(v)), nil
+}
+
+type Valuer_P string
+
+func (p *Valuer_P) Value() (driver.Value, error) {
+ if p == nil {
+ return "nil-to-str", nil
+ }
+ return strings.ToUpper(string(*p)), nil
+}
+
+func TestDriverArgs(t *testing.T) {
+ var nilValuerVPtr *Valuer_V
+ var nilValuerPPtr *Valuer_P
+ var nilStrPtr *string
+ tests := []struct {
+ args []any
+ want []driver.NamedValue
+ }{
+ 0: {
+ args: []any{Valuer_V("foo")},
+ want: []driver.NamedValue{
+ {
+ Ordinal: 1,
+ Value: "FOO",
+ },
+ },
+ },
+ 1: {
+ args: []any{nilValuerVPtr},
+ want: []driver.NamedValue{
+ {
+ Ordinal: 1,
+ Value: nil,
+ },
+ },
+ },
+ 2: {
+ args: []any{nilValuerPPtr},
+ want: []driver.NamedValue{
+ {
+ Ordinal: 1,
+ Value: "nil-to-str",
+ },
+ },
+ },
+ 3: {
+ args: []any{"plain-str"},
+ want: []driver.NamedValue{
+ {
+ Ordinal: 1,
+ Value: "plain-str",
+ },
+ },
+ },
+ 4: {
+ args: []any{nilStrPtr},
+ want: []driver.NamedValue{
+ {
+ Ordinal: 1,
+ Value: nil,
+ },
+ },
+ },
+ }
+ for i, tt := range tests {
+ ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}}
+ got, err := driverArgsConnLocked(nil, ds, tt.args)
+ if err != nil {
+ t.Errorf("test[%d]: %v", i, err)
+ continue
+ }
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("test[%d]: got %v, want %v", i, got, tt.want)
+ }
+ }
+}
+
+type dec struct {
+ form byte
+ neg bool
+ coefficient [16]byte
+ exponent int32
+}
+
+func (d dec) Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) {
+ coef := make([]byte, 16)
+ copy(coef, d.coefficient[:])
+ return d.form, d.neg, coef, d.exponent
+}
+
+func (d *dec) Compose(form byte, negative bool, coefficient []byte, exponent int32) error {
+ switch form {
+ default:
+ return fmt.Errorf("unknown form %d", form)
+ case 1, 2:
+ d.form = form
+ d.neg = negative
+ return nil
+ case 0:
+ }
+ d.form = form
+ d.neg = negative
+ d.exponent = exponent
+
+ // This isn't strictly correct, as the extra bytes could be all zero,
+ // ignore this for this test.
+ if len(coefficient) > 16 {
+ return fmt.Errorf("coefficient too large")
+ }
+ copy(d.coefficient[:], coefficient)
+
+ return nil
+}
+
+type decFinite struct {
+ neg bool
+ coefficient [16]byte
+ exponent int32
+}
+
+func (d decFinite) Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) {
+ coef := make([]byte, 16)
+ copy(coef, d.coefficient[:])
+ return 0, d.neg, coef, d.exponent
+}
+
+func (d *decFinite) Compose(form byte, negative bool, coefficient []byte, exponent int32) error {
+ switch form {
+ default:
+ return fmt.Errorf("unknown form %d", form)
+ case 1, 2:
+ return fmt.Errorf("unsupported form %d", form)
+ case 0:
+ }
+ d.neg = negative
+ d.exponent = exponent
+
+ // This isn't strictly correct, as the extra bytes could be all zero,
+ // ignore this for this test.
+ if len(coefficient) > 16 {
+ return fmt.Errorf("coefficient too large")
+ }
+ copy(d.coefficient[:], coefficient)
+
+ return nil
+}
+
+func TestDecimal(t *testing.T) {
+ list := []struct {
+ name string
+ in decimalDecompose
+ out dec
+ err bool
+ }{
+ {name: "same", in: dec{exponent: -6}, out: dec{exponent: -6}},
+
+ // Ensure reflection is not used to assign the value by using different types.
+ {name: "diff", in: decFinite{exponent: -6}, out: dec{exponent: -6}},
+
+ {name: "bad-form", in: dec{form: 200}, err: true},
+ }
+ for _, item := range list {
+ t.Run(item.name, func(t *testing.T) {
+ out := dec{}
+ err := convertAssign(&out, item.in)
+ if item.err {
+ if err == nil {
+ t.Fatalf("unexpected nil error")
+ }
+ return
+ }
+ if err != nil {
+ t.Fatalf("unexpected error: %v", err)
+ }
+ if !reflect.DeepEqual(out, item.out) {
+ t.Fatalf("got %#v want %#v", out, item.out)
+ }
+ })
+ }
+}
diff --git a/src/database/sql/ctxutil.go b/src/database/sql/ctxutil.go
new file mode 100644
index 0000000..4dbe6af
--- /dev/null
+++ b/src/database/sql/ctxutil.go
@@ -0,0 +1,146 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sql
+
+import (
+ "context"
+ "database/sql/driver"
+ "errors"
+)
+
+func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) {
+ if ciCtx, is := ci.(driver.ConnPrepareContext); is {
+ return ciCtx.PrepareContext(ctx, query)
+ }
+ si, err := ci.Prepare(query)
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ si.Close()
+ return nil, ctx.Err()
+ }
+ }
+ return si, err
+}
+
+func ctxDriverExec(ctx context.Context, execerCtx driver.ExecerContext, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
+ if execerCtx != nil {
+ return execerCtx.ExecContext(ctx, query, nvdargs)
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
+ }
+
+ select {
+ default:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ return execer.Exec(query, dargs)
+}
+
+func ctxDriverQuery(ctx context.Context, queryerCtx driver.QueryerContext, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
+ if queryerCtx != nil {
+ return queryerCtx.QueryContext(ctx, query, nvdargs)
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
+ }
+
+ select {
+ default:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ return queryer.Query(query, dargs)
+}
+
+func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
+ if siCtx, is := si.(driver.StmtExecContext); is {
+ return siCtx.ExecContext(ctx, nvdargs)
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
+ }
+
+ select {
+ default:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ return si.Exec(dargs)
+}
+
+func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
+ if siCtx, is := si.(driver.StmtQueryContext); is {
+ return siCtx.QueryContext(ctx, nvdargs)
+ }
+ dargs, err := namedValueToValue(nvdargs)
+ if err != nil {
+ return nil, err
+ }
+
+ select {
+ default:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+ return si.Query(dargs)
+}
+
+func ctxDriverBegin(ctx context.Context, opts *TxOptions, ci driver.Conn) (driver.Tx, error) {
+ if ciCtx, is := ci.(driver.ConnBeginTx); is {
+ dopts := driver.TxOptions{}
+ if opts != nil {
+ dopts.Isolation = driver.IsolationLevel(opts.Isolation)
+ dopts.ReadOnly = opts.ReadOnly
+ }
+ return ciCtx.BeginTx(ctx, dopts)
+ }
+
+ if opts != nil {
+ // Check the transaction level. If the transaction level is non-default
+ // then return an error here as the BeginTx driver value is not supported.
+ if opts.Isolation != LevelDefault {
+ return nil, errors.New("sql: driver does not support non-default isolation level")
+ }
+
+ // If a read-only transaction is requested return an error as the
+ // BeginTx driver value is not supported.
+ if opts.ReadOnly {
+ return nil, errors.New("sql: driver does not support read-only transactions")
+ }
+ }
+
+ if ctx.Done() == nil {
+ return ci.Begin()
+ }
+
+ txi, err := ci.Begin()
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ txi.Rollback()
+ return nil, ctx.Err()
+ }
+ }
+ return txi, err
+}
+
+func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
+ dargs := make([]driver.Value, len(named))
+ for n, param := range named {
+ if len(param.Name) > 0 {
+ return nil, errors.New("sql: driver does not support the use of Named Parameters")
+ }
+ dargs[n] = param.Value
+ }
+ return dargs, nil
+}
diff --git a/src/database/sql/doc.txt b/src/database/sql/doc.txt
new file mode 100644
index 0000000..1341b57
--- /dev/null
+++ b/src/database/sql/doc.txt
@@ -0,0 +1,46 @@
+Goals of the sql and sql/driver packages:
+
+* Provide a generic database API for a variety of SQL or SQL-like
+ databases. There currently exist Go libraries for SQLite, MySQL,
+ and Postgres, but all with a very different feel, and often
+ a non-Go-like feel.
+
+* Feel like Go.
+
+* Care mostly about the common cases. Common SQL should be portable.
+ SQL edge cases or db-specific extensions can be detected and
+ conditionally used by the application. It is a non-goal to care
+ about every particular db's extension or quirk.
+
+* Separate out the basic implementation of a database driver
+ (implementing the sql/driver interfaces) vs the implementation
+ of all the user-level types and convenience methods.
+ In a nutshell:
+
+ User Code ---> sql package (concrete types) ---> sql/driver (interfaces)
+ Database Driver -> sql (to register) + sql/driver (implement interfaces)
+
+* Make type casting/conversions consistent between all drivers. To
+ achieve this, most of the conversions are done in the sql package,
+ not in each driver. The drivers then only have to deal with a
+ smaller set of types.
+
+* Be flexible with type conversions, but be paranoid about silent
+ truncation or other loss of precision.
+
+* Handle concurrency well. Users shouldn't need to care about the
+ database's per-connection thread safety issues (or lack thereof),
+ and shouldn't have to maintain their own free pools of connections.
+ The 'sql' package should deal with that bookkeeping as needed. Given
+ an *sql.DB, it should be possible to share that instance between
+ multiple goroutines, without any extra synchronization.
+
+* Push complexity, where necessary, down into the sql+driver packages,
+ rather than exposing it to users. Said otherwise, the sql package
+ should expose an ideal database that's not finnicky about how it's
+ accessed, even if that's not true.
+
+* Provide optional interfaces in sql/driver for drivers to implement
+ for special cases or fastpaths. But the only party that knows about
+ those is the sql package. To user code, some stuff just might start
+ working or start working slightly faster.
diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go
new file mode 100644
index 0000000..daf282b
--- /dev/null
+++ b/src/database/sql/driver/driver.go
@@ -0,0 +1,552 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package driver defines interfaces to be implemented by database
+// drivers as used by package sql.
+//
+// Most code should use package sql.
+//
+// The driver interface has evolved over time. Drivers should implement
+// Connector and DriverContext interfaces.
+// The Connector.Connect and Driver.Open methods should never return ErrBadConn.
+// ErrBadConn should only be returned from Validator, SessionResetter, or
+// a query method if the connection is already in an invalid (e.g. closed) state.
+//
+// All Conn implementations should implement the following interfaces:
+// Pinger, SessionResetter, and Validator.
+//
+// If named parameters or context are supported, the driver's Conn should implement:
+// ExecerContext, QueryerContext, ConnPrepareContext, and ConnBeginTx.
+//
+// To support custom data types, implement NamedValueChecker. NamedValueChecker
+// also allows queries to accept per-query options as a parameter by returning
+// ErrRemoveArgument from CheckNamedValue.
+//
+// If multiple result sets are supported, Rows should implement RowsNextResultSet.
+// If the driver knows how to describe the types present in the returned result
+// it should implement the following interfaces: RowsColumnTypeScanType,
+// RowsColumnTypeDatabaseTypeName, RowsColumnTypeLength, RowsColumnTypeNullable,
+// and RowsColumnTypePrecisionScale. A given row value may also return a Rows
+// type, which may represent a database cursor value.
+//
+// Before a connection is returned to the connection pool after use, IsValid is
+// called if implemented. Before a connection is reused for another query,
+// ResetSession is called if implemented. If a connection is never returned to the
+// connection pool but immediately reused, then ResetSession is called prior to
+// reuse but IsValid is not called.
+package driver
+
+import (
+ "context"
+ "errors"
+ "reflect"
+)
+
+// Value is a value that drivers must be able to handle.
+// It is either nil, a type handled by a database driver's NamedValueChecker
+// interface, or an instance of one of these types:
+//
+// int64
+// float64
+// bool
+// []byte
+// string
+// time.Time
+//
+// If the driver supports cursors, a returned Value may also implement the Rows interface
+// in this package. This is used, for example, when a user selects a cursor
+// such as "select cursor(select * from my_table) from dual". If the Rows
+// from the select is closed, the cursor Rows will also be closed.
+type Value any
+
+// NamedValue holds both the value name and value.
+type NamedValue struct {
+ // If the Name is not empty it should be used for the parameter identifier and
+ // not the ordinal position.
+ //
+ // Name will not have a symbol prefix.
+ Name string
+
+ // Ordinal position of the parameter starting from one and is always set.
+ Ordinal int
+
+ // Value is the parameter value.
+ Value Value
+}
+
+// Driver is the interface that must be implemented by a database
+// driver.
+//
+// Database drivers may implement DriverContext for access
+// to contexts and to parse the name only once for a pool of connections,
+// instead of once per connection.
+type Driver interface {
+ // Open returns a new connection to the database.
+ // The name is a string in a driver-specific format.
+ //
+ // Open may return a cached connection (one previously
+ // closed), but doing so is unnecessary; the sql package
+ // maintains a pool of idle connections for efficient re-use.
+ //
+ // The returned connection is only used by one goroutine at a
+ // time.
+ Open(name string) (Conn, error)
+}
+
+// If a Driver implements DriverContext, then sql.DB will call
+// OpenConnector to obtain a Connector and then invoke
+// that Connector's Connect method to obtain each needed connection,
+// instead of invoking the Driver's Open method for each connection.
+// The two-step sequence allows drivers to parse the name just once
+// and also provides access to per-Conn contexts.
+type DriverContext interface {
+ // OpenConnector must parse the name in the same format that Driver.Open
+ // parses the name parameter.
+ OpenConnector(name string) (Connector, error)
+}
+
+// A Connector represents a driver in a fixed configuration
+// and can create any number of equivalent Conns for use
+// by multiple goroutines.
+//
+// A Connector can be passed to sql.OpenDB, to allow drivers
+// to implement their own sql.DB constructors, or returned by
+// DriverContext's OpenConnector method, to allow drivers
+// access to context and to avoid repeated parsing of driver
+// configuration.
+//
+// If a Connector implements io.Closer, the sql package's DB.Close
+// method will call Close and return error (if any).
+type Connector interface {
+ // Connect returns a connection to the database.
+ // Connect may return a cached connection (one previously
+ // closed), but doing so is unnecessary; the sql package
+ // maintains a pool of idle connections for efficient re-use.
+ //
+ // The provided context.Context is for dialing purposes only
+ // (see net.DialContext) and should not be stored or used for
+ // other purposes. A default timeout should still be used
+ // when dialing as a connection pool may call Connect
+ // asynchronously to any query.
+ //
+ // The returned connection is only used by one goroutine at a
+ // time.
+ Connect(context.Context) (Conn, error)
+
+ // Driver returns the underlying Driver of the Connector,
+ // mainly to maintain compatibility with the Driver method
+ // on sql.DB.
+ Driver() Driver
+}
+
+// ErrSkip may be returned by some optional interfaces' methods to
+// indicate at runtime that the fast path is unavailable and the sql
+// package should continue as if the optional interface was not
+// implemented. ErrSkip is only supported where explicitly
+// documented.
+var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented")
+
+// ErrBadConn should be returned by a driver to signal to the sql
+// package that a driver.Conn is in a bad state (such as the server
+// having earlier closed the connection) and the sql package should
+// retry on a new connection.
+//
+// To prevent duplicate operations, ErrBadConn should NOT be returned
+// if there's a possibility that the database server might have
+// performed the operation. Even if the server sends back an error,
+// you shouldn't return ErrBadConn.
+//
+// Errors will be checked using errors.Is. An error may
+// wrap ErrBadConn or implement the Is(error) bool method.
+var ErrBadConn = errors.New("driver: bad connection")
+
+// Pinger is an optional interface that may be implemented by a Conn.
+//
+// If a Conn does not implement Pinger, the sql package's DB.Ping and
+// DB.PingContext will check if there is at least one Conn available.
+//
+// If Conn.Ping returns ErrBadConn, DB.Ping and DB.PingContext will remove
+// the Conn from pool.
+type Pinger interface {
+ Ping(ctx context.Context) error
+}
+
+// Execer is an optional interface that may be implemented by a Conn.
+//
+// If a Conn implements neither ExecerContext nor Execer,
+// the sql package's DB.Exec will first prepare a query, execute the statement,
+// and then close the statement.
+//
+// Exec may return ErrSkip.
+//
+// Deprecated: Drivers should implement ExecerContext instead.
+type Execer interface {
+ Exec(query string, args []Value) (Result, error)
+}
+
+// ExecerContext is an optional interface that may be implemented by a Conn.
+//
+// If a Conn does not implement ExecerContext, the sql package's DB.Exec
+// will fall back to Execer; if the Conn does not implement Execer either,
+// DB.Exec will first prepare a query, execute the statement, and then
+// close the statement.
+//
+// ExecContext may return ErrSkip.
+//
+// ExecContext must honor the context timeout and return when the context is canceled.
+type ExecerContext interface {
+ ExecContext(ctx context.Context, query string, args []NamedValue) (Result, error)
+}
+
+// Queryer is an optional interface that may be implemented by a Conn.
+//
+// If a Conn implements neither QueryerContext nor Queryer,
+// the sql package's DB.Query will first prepare a query, execute the statement,
+// and then close the statement.
+//
+// Query may return ErrSkip.
+//
+// Deprecated: Drivers should implement QueryerContext instead.
+type Queryer interface {
+ Query(query string, args []Value) (Rows, error)
+}
+
+// QueryerContext is an optional interface that may be implemented by a Conn.
+//
+// If a Conn does not implement QueryerContext, the sql package's DB.Query
+// will fall back to Queryer; if the Conn does not implement Queryer either,
+// DB.Query will first prepare a query, execute the statement, and then
+// close the statement.
+//
+// QueryContext may return ErrSkip.
+//
+// QueryContext must honor the context timeout and return when the context is canceled.
+type QueryerContext interface {
+ QueryContext(ctx context.Context, query string, args []NamedValue) (Rows, error)
+}
+
+// Conn is a connection to a database. It is not used concurrently
+// by multiple goroutines.
+//
+// Conn is assumed to be stateful.
+type Conn interface {
+ // Prepare returns a prepared statement, bound to this connection.
+ Prepare(query string) (Stmt, error)
+
+ // Close invalidates and potentially stops any current
+ // prepared statements and transactions, marking this
+ // connection as no longer in use.
+ //
+ // Because the sql package maintains a free pool of
+ // connections and only calls Close when there's a surplus of
+ // idle connections, it shouldn't be necessary for drivers to
+ // do their own connection caching.
+ //
+ // Drivers must ensure all network calls made by Close
+ // do not block indefinitely (e.g. apply a timeout).
+ Close() error
+
+ // Begin starts and returns a new transaction.
+ //
+ // Deprecated: Drivers should implement ConnBeginTx instead (or additionally).
+ Begin() (Tx, error)
+}
+
+// ConnPrepareContext enhances the Conn interface with context.
+type ConnPrepareContext interface {
+ // PrepareContext returns a prepared statement, bound to this connection.
+ // context is for the preparation of the statement,
+ // it must not store the context within the statement itself.
+ PrepareContext(ctx context.Context, query string) (Stmt, error)
+}
+
+// IsolationLevel is the transaction isolation level stored in TxOptions.
+//
+// This type should be considered identical to sql.IsolationLevel along
+// with any values defined on it.
+type IsolationLevel int
+
+// TxOptions holds the transaction options.
+//
+// This type should be considered identical to sql.TxOptions.
+type TxOptions struct {
+ Isolation IsolationLevel
+ ReadOnly bool
+}
+
+// ConnBeginTx enhances the Conn interface with context and TxOptions.
+type ConnBeginTx interface {
+ // BeginTx starts and returns a new transaction.
+ // If the context is canceled by the user the sql package will
+ // call Tx.Rollback before discarding and closing the connection.
+ //
+ // This must check opts.Isolation to determine if there is a set
+ // isolation level. If the driver does not support a non-default
+ // level and one is set or if there is a non-default isolation level
+ // that is not supported, an error must be returned.
+ //
+ // This must also check opts.ReadOnly to determine if the read-only
+ // value is true to either set the read-only transaction property if supported
+ // or return an error if it is not supported.
+ BeginTx(ctx context.Context, opts TxOptions) (Tx, error)
+}
+
+// SessionResetter may be implemented by Conn to allow drivers to reset the
+// session state associated with the connection and to signal a bad connection.
+type SessionResetter interface {
+ // ResetSession is called prior to executing a query on the connection
+ // if the connection has been used before. If the driver returns ErrBadConn
+ // the connection is discarded.
+ ResetSession(ctx context.Context) error
+}
+
+// Validator may be implemented by Conn to allow drivers to
+// signal if a connection is valid or if it should be discarded.
+//
+// If implemented, drivers may return the underlying error from queries,
+// even if the connection should be discarded by the connection pool.
+type Validator interface {
+ // IsValid is called prior to placing the connection into the
+ // connection pool. The connection will be discarded if false is returned.
+ IsValid() bool
+}
+
+// Result is the result of a query execution.
+type Result interface {
+ // LastInsertId returns the database's auto-generated ID
+ // after, for example, an INSERT into a table with primary
+ // key.
+ LastInsertId() (int64, error)
+
+ // RowsAffected returns the number of rows affected by the
+ // query.
+ RowsAffected() (int64, error)
+}
+
+// Stmt is a prepared statement. It is bound to a Conn and not
+// used by multiple goroutines concurrently.
+type Stmt interface {
+ // Close closes the statement.
+ //
+ // As of Go 1.1, a Stmt will not be closed if it's in use
+ // by any queries.
+ //
+ // Drivers must ensure all network calls made by Close
+ // do not block indefinitely (e.g. apply a timeout).
+ Close() error
+
+ // NumInput returns the number of placeholder parameters.
+ //
+ // If NumInput returns >= 0, the sql package will sanity check
+ // argument counts from callers and return errors to the caller
+ // before the statement's Exec or Query methods are called.
+ //
+ // NumInput may also return -1, if the driver doesn't know
+ // its number of placeholders. In that case, the sql package
+ // will not sanity check Exec or Query argument counts.
+ NumInput() int
+
+ // Exec executes a query that doesn't return rows, such
+ // as an INSERT or UPDATE.
+ //
+ // Deprecated: Drivers should implement StmtExecContext instead (or additionally).
+ Exec(args []Value) (Result, error)
+
+ // Query executes a query that may return rows, such as a
+ // SELECT.
+ //
+ // Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
+ Query(args []Value) (Rows, error)
+}
+
+// StmtExecContext enhances the Stmt interface by providing Exec with context.
+type StmtExecContext interface {
+ // ExecContext executes a query that doesn't return rows, such
+ // as an INSERT or UPDATE.
+ //
+ // ExecContext must honor the context timeout and return when it is canceled.
+ ExecContext(ctx context.Context, args []NamedValue) (Result, error)
+}
+
+// StmtQueryContext enhances the Stmt interface by providing Query with context.
+type StmtQueryContext interface {
+ // QueryContext executes a query that may return rows, such as a
+ // SELECT.
+ //
+ // QueryContext must honor the context timeout and return when it is canceled.
+ QueryContext(ctx context.Context, args []NamedValue) (Rows, error)
+}
+
+// ErrRemoveArgument may be returned from NamedValueChecker to instruct the
+// sql package to not pass the argument to the driver query interface.
+// Return when accepting query specific options or structures that aren't
+// SQL query arguments.
+var ErrRemoveArgument = errors.New("driver: remove argument from query")
+
+// NamedValueChecker may be optionally implemented by Conn or Stmt. It provides
+// the driver more control to handle Go and database types beyond the default
+// Values types allowed.
+//
+// The sql package checks for value checkers in the following order,
+// stopping at the first found match: Stmt.NamedValueChecker, Conn.NamedValueChecker,
+// Stmt.ColumnConverter, DefaultParameterConverter.
+//
+// If CheckNamedValue returns ErrRemoveArgument, the NamedValue will not be included in
+// the final query arguments. This may be used to pass special options to
+// the query itself.
+//
+// If ErrSkip is returned the column converter error checking
+// path is used for the argument. Drivers may wish to return ErrSkip after
+// they have exhausted their own special cases.
+type NamedValueChecker interface {
+ // CheckNamedValue is called before passing arguments to the driver
+ // and is called in place of any ColumnConverter. CheckNamedValue must do type
+ // validation and conversion as appropriate for the driver.
+ CheckNamedValue(*NamedValue) error
+}
+
+// ColumnConverter may be optionally implemented by Stmt if the
+// statement is aware of its own columns' types and can convert from
+// any type to a driver Value.
+//
+// Deprecated: Drivers should implement NamedValueChecker.
+type ColumnConverter interface {
+ // ColumnConverter returns a ValueConverter for the provided
+ // column index. If the type of a specific column isn't known
+ // or shouldn't be handled specially, DefaultValueConverter
+ // can be returned.
+ ColumnConverter(idx int) ValueConverter
+}
+
+// Rows is an iterator over an executed query's results.
+type Rows interface {
+ // Columns returns the names of the columns. The number of
+ // columns of the result is inferred from the length of the
+ // slice. If a particular column name isn't known, an empty
+ // string should be returned for that entry.
+ Columns() []string
+
+ // Close closes the rows iterator.
+ Close() error
+
+ // Next is called to populate the next row of data into
+ // the provided slice. The provided slice will be the same
+ // size as the Columns() are wide.
+ //
+ // Next should return io.EOF when there are no more rows.
+ //
+ // The dest should not be written to outside of Next. Care
+ // should be taken when closing Rows not to modify
+ // a buffer held in dest.
+ Next(dest []Value) error
+}
+
+// RowsNextResultSet extends the Rows interface by providing a way to signal
+// the driver to advance to the next result set.
+type RowsNextResultSet interface {
+ Rows
+
+ // HasNextResultSet is called at the end of the current result set and
+ // reports whether there is another result set after the current one.
+ HasNextResultSet() bool
+
+ // NextResultSet advances the driver to the next result set even
+ // if there are remaining rows in the current result set.
+ //
+ // NextResultSet should return io.EOF when there are no more result sets.
+ NextResultSet() error
+}
+
+// RowsColumnTypeScanType may be implemented by Rows. It should return
+// the value type that can be used to scan types into. For example, the database
+// column type "bigint" this should return "reflect.TypeOf(int64(0))".
+type RowsColumnTypeScanType interface {
+ Rows
+ ColumnTypeScanType(index int) reflect.Type
+}
+
+// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
+// database system type name without the length. Type names should be uppercase.
+// Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
+// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
+// "TIMESTAMP".
+type RowsColumnTypeDatabaseTypeName interface {
+ Rows
+ ColumnTypeDatabaseTypeName(index int) string
+}
+
+// RowsColumnTypeLength may be implemented by Rows. It should return the length
+// of the column type if the column is a variable length type. If the column is
+// not a variable length type ok should return false.
+// If length is not limited other than system limits, it should return math.MaxInt64.
+// The following are examples of returned values for various types:
+//
+// TEXT (math.MaxInt64, true)
+// varchar(10) (10, true)
+// nvarchar(10) (10, true)
+// decimal (0, false)
+// int (0, false)
+// bytea(30) (30, true)
+type RowsColumnTypeLength interface {
+ Rows
+ ColumnTypeLength(index int) (length int64, ok bool)
+}
+
+// RowsColumnTypeNullable may be implemented by Rows. The nullable value should
+// be true if it is known the column may be null, or false if the column is known
+// to be not nullable.
+// If the column nullability is unknown, ok should be false.
+type RowsColumnTypeNullable interface {
+ Rows
+ ColumnTypeNullable(index int) (nullable, ok bool)
+}
+
+// RowsColumnTypePrecisionScale may be implemented by Rows. It should return
+// the precision and scale for decimal types. If not applicable, ok should be false.
+// The following are examples of returned values for various types:
+//
+// decimal(38, 4) (38, 4, true)
+// int (0, 0, false)
+// decimal (math.MaxInt64, math.MaxInt64, true)
+type RowsColumnTypePrecisionScale interface {
+ Rows
+ ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool)
+}
+
+// Tx is a transaction.
+type Tx interface {
+ Commit() error
+ Rollback() error
+}
+
+// RowsAffected implements Result for an INSERT or UPDATE operation
+// which mutates a number of rows.
+type RowsAffected int64
+
+var _ Result = RowsAffected(0)
+
+func (RowsAffected) LastInsertId() (int64, error) {
+ return 0, errors.New("LastInsertId is not supported by this driver")
+}
+
+func (v RowsAffected) RowsAffected() (int64, error) {
+ return int64(v), nil
+}
+
+// ResultNoRows is a pre-defined Result for drivers to return when a DDL
+// command (such as a CREATE TABLE) succeeds. It returns an error for both
+// LastInsertId and RowsAffected.
+var ResultNoRows noRows
+
+type noRows struct{}
+
+var _ Result = noRows{}
+
+func (noRows) LastInsertId() (int64, error) {
+ return 0, errors.New("no LastInsertId available after DDL statement")
+}
+
+func (noRows) RowsAffected() (int64, error) {
+ return 0, errors.New("no RowsAffected available after DDL statement")
+}
diff --git a/src/database/sql/driver/types.go b/src/database/sql/driver/types.go
new file mode 100644
index 0000000..fa98df7
--- /dev/null
+++ b/src/database/sql/driver/types.go
@@ -0,0 +1,297 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package driver
+
+import (
+ "fmt"
+ "reflect"
+ "strconv"
+ "time"
+)
+
+// ValueConverter is the interface providing the ConvertValue method.
+//
+// Various implementations of ValueConverter are provided by the
+// driver package to provide consistent implementations of conversions
+// between drivers. The ValueConverters have several uses:
+//
+// - converting from the Value types as provided by the sql package
+// into a database table's specific column type and making sure it
+// fits, such as making sure a particular int64 fits in a
+// table's uint16 column.
+//
+// - converting a value as given from the database into one of the
+// driver Value types.
+//
+// - by the sql package, for converting from a driver's Value type
+// to a user's type in a scan.
+type ValueConverter interface {
+ // ConvertValue converts a value to a driver Value.
+ ConvertValue(v any) (Value, error)
+}
+
+// Valuer is the interface providing the Value method.
+//
+// Types implementing Valuer interface are able to convert
+// themselves to a driver Value.
+type Valuer interface {
+ // Value returns a driver Value.
+ // Value must not panic.
+ Value() (Value, error)
+}
+
+// Bool is a ValueConverter that converts input values to bools.
+//
+// The conversion rules are:
+// - booleans are returned unchanged
+// - for integer types,
+// 1 is true
+// 0 is false,
+// other integers are an error
+// - for strings and []byte, same rules as strconv.ParseBool
+// - all other types are an error
+var Bool boolType
+
+type boolType struct{}
+
+var _ ValueConverter = boolType{}
+
+func (boolType) String() string { return "Bool" }
+
+func (boolType) ConvertValue(src any) (Value, error) {
+ switch s := src.(type) {
+ case bool:
+ return s, nil
+ case string:
+ b, err := strconv.ParseBool(s)
+ if err != nil {
+ return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s)
+ }
+ return b, nil
+ case []byte:
+ b, err := strconv.ParseBool(string(s))
+ if err != nil {
+ return nil, fmt.Errorf("sql/driver: couldn't convert %q into type bool", s)
+ }
+ return b, nil
+ }
+
+ sv := reflect.ValueOf(src)
+ switch sv.Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ iv := sv.Int()
+ if iv == 1 || iv == 0 {
+ return iv == 1, nil
+ }
+ return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", iv)
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ uv := sv.Uint()
+ if uv == 1 || uv == 0 {
+ return uv == 1, nil
+ }
+ return nil, fmt.Errorf("sql/driver: couldn't convert %d into type bool", uv)
+ }
+
+ return nil, fmt.Errorf("sql/driver: couldn't convert %v (%T) into type bool", src, src)
+}
+
+// Int32 is a ValueConverter that converts input values to int64,
+// respecting the limits of an int32 value.
+var Int32 int32Type
+
+type int32Type struct{}
+
+var _ ValueConverter = int32Type{}
+
+func (int32Type) ConvertValue(v any) (Value, error) {
+ rv := reflect.ValueOf(v)
+ switch rv.Kind() {
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ i64 := rv.Int()
+ if i64 > (1<<31)-1 || i64 < -(1<<31) {
+ return nil, fmt.Errorf("sql/driver: value %d overflows int32", v)
+ }
+ return i64, nil
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ u64 := rv.Uint()
+ if u64 > (1<<31)-1 {
+ return nil, fmt.Errorf("sql/driver: value %d overflows int32", v)
+ }
+ return int64(u64), nil
+ case reflect.String:
+ i, err := strconv.Atoi(rv.String())
+ if err != nil {
+ return nil, fmt.Errorf("sql/driver: value %q can't be converted to int32", v)
+ }
+ return int64(i), nil
+ }
+ return nil, fmt.Errorf("sql/driver: unsupported value %v (type %T) converting to int32", v, v)
+}
+
+// String is a ValueConverter that converts its input to a string.
+// If the value is already a string or []byte, it's unchanged.
+// If the value is of another type, conversion to string is done
+// with fmt.Sprintf("%v", v).
+var String stringType
+
+type stringType struct{}
+
+func (stringType) ConvertValue(v any) (Value, error) {
+ switch v.(type) {
+ case string, []byte:
+ return v, nil
+ }
+ return fmt.Sprintf("%v", v), nil
+}
+
+// Null is a type that implements ValueConverter by allowing nil
+// values but otherwise delegating to another ValueConverter.
+type Null struct {
+ Converter ValueConverter
+}
+
+func (n Null) ConvertValue(v any) (Value, error) {
+ if v == nil {
+ return nil, nil
+ }
+ return n.Converter.ConvertValue(v)
+}
+
+// NotNull is a type that implements ValueConverter by disallowing nil
+// values but otherwise delegating to another ValueConverter.
+type NotNull struct {
+ Converter ValueConverter
+}
+
+func (n NotNull) ConvertValue(v any) (Value, error) {
+ if v == nil {
+ return nil, fmt.Errorf("nil value not allowed")
+ }
+ return n.Converter.ConvertValue(v)
+}
+
+// IsValue reports whether v is a valid Value parameter type.
+func IsValue(v any) bool {
+ if v == nil {
+ return true
+ }
+ switch v.(type) {
+ case []byte, bool, float64, int64, string, time.Time:
+ return true
+ case decimalDecompose:
+ return true
+ }
+ return false
+}
+
+// IsScanValue is equivalent to IsValue.
+// It exists for compatibility.
+func IsScanValue(v any) bool {
+ return IsValue(v)
+}
+
+// DefaultParameterConverter is the default implementation of
+// ValueConverter that's used when a Stmt doesn't implement
+// ColumnConverter.
+//
+// DefaultParameterConverter returns its argument directly if
+// IsValue(arg). Otherwise, if the argument implements Valuer, its
+// Value method is used to return a Value. As a fallback, the provided
+// argument's underlying type is used to convert it to a Value:
+// underlying integer types are converted to int64, floats to float64,
+// bool, string, and []byte to themselves. If the argument is a nil
+// pointer, ConvertValue returns a nil Value. If the argument is a
+// non-nil pointer, it is dereferenced and ConvertValue is called
+// recursively. Other types are an error.
+var DefaultParameterConverter defaultConverter
+
+type defaultConverter struct{}
+
+var _ ValueConverter = defaultConverter{}
+
+var valuerReflectType = reflect.TypeOf((*Valuer)(nil)).Elem()
+
+// callValuerValue returns vr.Value(), with one exception:
+// If vr.Value is an auto-generated method on a pointer type and the
+// pointer is nil, it would panic at runtime in the panicwrap
+// method. Treat it like nil instead.
+// Issue 8415.
+//
+// This is so people can implement driver.Value on value types and
+// still use nil pointers to those types to mean nil/NULL, just like
+// string/*string.
+//
+// This function is mirrored in the database/sql package.
+func callValuerValue(vr Valuer) (v Value, err error) {
+ if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Pointer &&
+ rv.IsNil() &&
+ rv.Type().Elem().Implements(valuerReflectType) {
+ return nil, nil
+ }
+ return vr.Value()
+}
+
+func (defaultConverter) ConvertValue(v any) (Value, error) {
+ if IsValue(v) {
+ return v, nil
+ }
+
+ switch vr := v.(type) {
+ case Valuer:
+ sv, err := callValuerValue(vr)
+ if err != nil {
+ return nil, err
+ }
+ if !IsValue(sv) {
+ return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
+ }
+ return sv, nil
+
+ // For now, continue to prefer the Valuer interface over the decimal decompose interface.
+ case decimalDecompose:
+ return vr, nil
+ }
+
+ rv := reflect.ValueOf(v)
+ switch rv.Kind() {
+ case reflect.Pointer:
+ // indirect pointers
+ if rv.IsNil() {
+ return nil, nil
+ } else {
+ return defaultConverter{}.ConvertValue(rv.Elem().Interface())
+ }
+ case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+ return rv.Int(), nil
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
+ return int64(rv.Uint()), nil
+ case reflect.Uint64:
+ u64 := rv.Uint()
+ if u64 >= 1<<63 {
+ return nil, fmt.Errorf("uint64 values with high bit set are not supported")
+ }
+ return int64(u64), nil
+ case reflect.Float32, reflect.Float64:
+ return rv.Float(), nil
+ case reflect.Bool:
+ return rv.Bool(), nil
+ case reflect.Slice:
+ ek := rv.Type().Elem().Kind()
+ if ek == reflect.Uint8 {
+ return rv.Bytes(), nil
+ }
+ return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
+ case reflect.String:
+ return rv.String(), nil
+ }
+ return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
+}
+
+type decimalDecompose interface {
+ // Decompose returns the internal decimal state into parts.
+ // If the provided buf has sufficient capacity, buf may be returned as the coefficient with
+ // the value set and length set as appropriate.
+ Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32)
+}
diff --git a/src/database/sql/driver/types_test.go b/src/database/sql/driver/types_test.go
new file mode 100644
index 0000000..80e5e05
--- /dev/null
+++ b/src/database/sql/driver/types_test.go
@@ -0,0 +1,95 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package driver
+
+import (
+ "reflect"
+ "testing"
+ "time"
+)
+
+type valueConverterTest struct {
+ c ValueConverter
+ in any
+ out any
+ err string
+}
+
+var now = time.Now()
+var answer int64 = 42
+
+type (
+ i int64
+ f float64
+ b bool
+ bs []byte
+ s string
+ t time.Time
+ is []int
+)
+
+var valueConverterTests = []valueConverterTest{
+ {Bool, "true", true, ""},
+ {Bool, "True", true, ""},
+ {Bool, []byte("t"), true, ""},
+ {Bool, true, true, ""},
+ {Bool, "1", true, ""},
+ {Bool, 1, true, ""},
+ {Bool, int64(1), true, ""},
+ {Bool, uint16(1), true, ""},
+ {Bool, "false", false, ""},
+ {Bool, false, false, ""},
+ {Bool, "0", false, ""},
+ {Bool, 0, false, ""},
+ {Bool, int64(0), false, ""},
+ {Bool, uint16(0), false, ""},
+ {c: Bool, in: "foo", err: "sql/driver: couldn't convert \"foo\" into type bool"},
+ {c: Bool, in: 2, err: "sql/driver: couldn't convert 2 into type bool"},
+ {DefaultParameterConverter, now, now, ""},
+ {DefaultParameterConverter, (*int64)(nil), nil, ""},
+ {DefaultParameterConverter, &answer, answer, ""},
+ {DefaultParameterConverter, &now, now, ""},
+ {DefaultParameterConverter, i(9), int64(9), ""},
+ {DefaultParameterConverter, f(0.1), float64(0.1), ""},
+ {DefaultParameterConverter, b(true), true, ""},
+ {DefaultParameterConverter, bs{1}, []byte{1}, ""},
+ {DefaultParameterConverter, s("a"), "a", ""},
+ {DefaultParameterConverter, is{1}, nil, "unsupported type driver.is, a slice of int"},
+ {DefaultParameterConverter, dec{exponent: -6}, dec{exponent: -6}, ""},
+}
+
+func TestValueConverters(t *testing.T) {
+ for i, tt := range valueConverterTests {
+ out, err := tt.c.ConvertValue(tt.in)
+ goterr := ""
+ if err != nil {
+ goterr = err.Error()
+ }
+ if goterr != tt.err {
+ t.Errorf("test %d: %T(%T(%v)) error = %q; want error = %q",
+ i, tt.c, tt.in, tt.in, goterr, tt.err)
+ }
+ if tt.err != "" {
+ continue
+ }
+ if !reflect.DeepEqual(out, tt.out) {
+ t.Errorf("test %d: %T(%T(%v)) = %v (%T); want %v (%T)",
+ i, tt.c, tt.in, tt.in, out, out, tt.out, tt.out)
+ }
+ }
+}
+
+type dec struct {
+ form byte
+ neg bool
+ coefficient [16]byte
+ exponent int32
+}
+
+func (d dec) Decompose(buf []byte) (form byte, negative bool, coefficient []byte, exponent int32) {
+ coef := make([]byte, 16)
+ copy(coef, d.coefficient[:])
+ return d.form, d.neg, coef, d.exponent
+}
diff --git a/src/database/sql/example_cli_test.go b/src/database/sql/example_cli_test.go
new file mode 100644
index 0000000..1e297af
--- /dev/null
+++ b/src/database/sql/example_cli_test.go
@@ -0,0 +1,84 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sql_test
+
+import (
+ "context"
+ "database/sql"
+ "flag"
+ "log"
+ "os"
+ "os/signal"
+ "time"
+)
+
+var pool *sql.DB // Database connection pool.
+
+func Example_openDBCLI() {
+ id := flag.Int64("id", 0, "person ID to find")
+ dsn := flag.String("dsn", os.Getenv("DSN"), "connection data source name")
+ flag.Parse()
+
+ if len(*dsn) == 0 {
+ log.Fatal("missing dsn flag")
+ }
+ if *id == 0 {
+ log.Fatal("missing person ID")
+ }
+ var err error
+
+ // Opening a driver typically will not attempt to connect to the database.
+ pool, err = sql.Open("driver-name", *dsn)
+ if err != nil {
+ // This will not be a connection error, but a DSN parse error or
+ // another initialization error.
+ log.Fatal("unable to use data source name", err)
+ }
+ defer pool.Close()
+
+ pool.SetConnMaxLifetime(0)
+ pool.SetMaxIdleConns(3)
+ pool.SetMaxOpenConns(3)
+
+ ctx, stop := context.WithCancel(context.Background())
+ defer stop()
+
+ appSignal := make(chan os.Signal, 3)
+ signal.Notify(appSignal, os.Interrupt)
+
+ go func() {
+ <-appSignal
+ stop()
+ }()
+
+ Ping(ctx)
+
+ Query(ctx, *id)
+}
+
+// Ping the database to verify DSN provided by the user is valid and the
+// server accessible. If the ping fails exit the program with an error.
+func Ping(ctx context.Context) {
+ ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
+ defer cancel()
+
+ if err := pool.PingContext(ctx); err != nil {
+ log.Fatalf("unable to connect to database: %v", err)
+ }
+}
+
+// Query the database for the information requested and prints the results.
+// If the query fails exit the program with an error.
+func Query(ctx context.Context, id int64) {
+ ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
+ defer cancel()
+
+ var name string
+ err := pool.QueryRowContext(ctx, "select p.name from people as p where p.id = :id;", sql.Named("id", id)).Scan(&name)
+ if err != nil {
+ log.Fatal("unable to execute search query", err)
+ }
+ log.Println("name=", name)
+}
diff --git a/src/database/sql/example_service_test.go b/src/database/sql/example_service_test.go
new file mode 100644
index 0000000..768307c
--- /dev/null
+++ b/src/database/sql/example_service_test.go
@@ -0,0 +1,158 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sql_test
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "fmt"
+ "io"
+ "log"
+ "net/http"
+ "time"
+)
+
+func Example_openDBService() {
+ // Opening a driver typically will not attempt to connect to the database.
+ db, err := sql.Open("driver-name", "database=test1")
+ if err != nil {
+ // This will not be a connection error, but a DSN parse error or
+ // another initialization error.
+ log.Fatal(err)
+ }
+ db.SetConnMaxLifetime(0)
+ db.SetMaxIdleConns(50)
+ db.SetMaxOpenConns(50)
+
+ s := &Service{db: db}
+
+ http.ListenAndServe(":8080", s)
+}
+
+type Service struct {
+ db *sql.DB
+}
+
+func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ db := s.db
+ switch r.URL.Path {
+ default:
+ http.Error(w, "not found", http.StatusNotFound)
+ return
+ case "/healthz":
+ ctx, cancel := context.WithTimeout(r.Context(), 1*time.Second)
+ defer cancel()
+
+ err := s.db.PingContext(ctx)
+ if err != nil {
+ http.Error(w, fmt.Sprintf("db down: %v", err), http.StatusFailedDependency)
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ return
+ case "/quick-action":
+ // This is a short SELECT. Use the request context as the base of
+ // the context timeout.
+ ctx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ defer cancel()
+
+ id := 5
+ org := 10
+ var name string
+ err := db.QueryRowContext(ctx, `
+select
+ p.name
+from
+ people as p
+ join organization as o on p.organization = o.id
+where
+ p.id = :id
+ and o.id = :org
+;`,
+ sql.Named("id", id),
+ sql.Named("org", org),
+ ).Scan(&name)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ http.Error(w, "not found", http.StatusNotFound)
+ return
+ }
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ io.WriteString(w, name)
+ return
+ case "/long-action":
+ // This is a long SELECT. Use the request context as the base of
+ // the context timeout, but give it some time to finish. If
+ // the client cancels before the query is done the query will also
+ // be canceled.
+ ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second)
+ defer cancel()
+
+ var names []string
+ rows, err := db.QueryContext(ctx, "select p.name from people as p where p.active = true;")
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ for rows.Next() {
+ var name string
+ err = rows.Scan(&name)
+ if err != nil {
+ break
+ }
+ names = append(names, name)
+ }
+ // Check for errors during rows "Close".
+ // This may be more important if multiple statements are executed
+ // in a single batch and rows were written as well as read.
+ if closeErr := rows.Close(); closeErr != nil {
+ http.Error(w, closeErr.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ // Check for row scan error.
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ // Check for errors during row iteration.
+ if err = rows.Err(); err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ json.NewEncoder(w).Encode(names)
+ return
+ case "/async-action":
+ // This action has side effects that we want to preserve
+ // even if the client cancels the HTTP request part way through.
+ // For this we do not use the http request context as a base for
+ // the timeout.
+ ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ var orderRef = "ABC123"
+ tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
+ _, err = tx.ExecContext(ctx, "stored_proc_name", orderRef)
+
+ if err != nil {
+ tx.Rollback()
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+ err = tx.Commit()
+ if err != nil {
+ http.Error(w, "action in unknown state, check state before attempting again", http.StatusInternalServerError)
+ return
+ }
+ w.WriteHeader(http.StatusOK)
+ return
+ }
+}
diff --git a/src/database/sql/example_test.go b/src/database/sql/example_test.go
new file mode 100644
index 0000000..aafb0e3
--- /dev/null
+++ b/src/database/sql/example_test.go
@@ -0,0 +1,369 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sql_test
+
+import (
+ "context"
+ "database/sql"
+ "fmt"
+ "log"
+ "strings"
+ "time"
+)
+
+var (
+ ctx context.Context
+ db *sql.DB
+)
+
+func ExampleDB_QueryContext() {
+ age := 27
+ rows, err := db.QueryContext(ctx, "SELECT name FROM users WHERE age=?", age)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer rows.Close()
+ names := make([]string, 0)
+
+ for rows.Next() {
+ var name string
+ if err := rows.Scan(&name); err != nil {
+ // Check for a scan error.
+ // Query rows will be closed with defer.
+ log.Fatal(err)
+ }
+ names = append(names, name)
+ }
+ // If the database is being written to ensure to check for Close
+ // errors that may be returned from the driver. The query may
+ // encounter an auto-commit error and be forced to rollback changes.
+ rerr := rows.Close()
+ if rerr != nil {
+ log.Fatal(rerr)
+ }
+
+ // Rows.Err will report the last error encountered by Rows.Scan.
+ if err := rows.Err(); err != nil {
+ log.Fatal(err)
+ }
+ fmt.Printf("%s are %d years old", strings.Join(names, ", "), age)
+}
+
+func ExampleDB_QueryRowContext() {
+ id := 123
+ var username string
+ var created time.Time
+ err := db.QueryRowContext(ctx, "SELECT username, created_at FROM users WHERE id=?", id).Scan(&username, &created)
+ switch {
+ case err == sql.ErrNoRows:
+ log.Printf("no user with id %d\n", id)
+ case err != nil:
+ log.Fatalf("query error: %v\n", err)
+ default:
+ log.Printf("username is %q, account created on %s\n", username, created)
+ }
+}
+
+func ExampleDB_ExecContext() {
+ id := 47
+ result, err := db.ExecContext(ctx, "UPDATE balances SET balance = balance + 10 WHERE user_id = ?", id)
+ if err != nil {
+ log.Fatal(err)
+ }
+ rows, err := result.RowsAffected()
+ if err != nil {
+ log.Fatal(err)
+ }
+ if rows != 1 {
+ log.Fatalf("expected to affect 1 row, affected %d", rows)
+ }
+}
+
+func ExampleDB_Query_multipleResultSets() {
+ age := 27
+ q := `
+create temp table uid (id bigint); -- Create temp table for queries.
+insert into uid
+select id from users where age < ?; -- Populate temp table.
+
+-- First result set.
+select
+ users.id, name
+from
+ users
+ join uid on users.id = uid.id
+;
+
+-- Second result set.
+select
+ ur.user, ur.role
+from
+ user_roles as ur
+ join uid on uid.id = ur.user
+;
+ `
+ rows, err := db.Query(q, age)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer rows.Close()
+
+ for rows.Next() {
+ var (
+ id int64
+ name string
+ )
+ if err := rows.Scan(&id, &name); err != nil {
+ log.Fatal(err)
+ }
+ log.Printf("id %d name is %s\n", id, name)
+ }
+ if !rows.NextResultSet() {
+ log.Fatalf("expected more result sets: %v", rows.Err())
+ }
+ var roleMap = map[int64]string{
+ 1: "user",
+ 2: "admin",
+ 3: "gopher",
+ }
+ for rows.Next() {
+ var (
+ id int64
+ role int64
+ )
+ if err := rows.Scan(&id, &role); err != nil {
+ log.Fatal(err)
+ }
+ log.Printf("id %d has role %s\n", id, roleMap[role])
+ }
+ if err := rows.Err(); err != nil {
+ log.Fatal(err)
+ }
+}
+
+func ExampleDB_PingContext() {
+ // Ping and PingContext may be used to determine if communication with
+ // the database server is still possible.
+ //
+ // When used in a command line application Ping may be used to establish
+ // that further queries are possible; that the provided DSN is valid.
+ //
+ // When used in long running service Ping may be part of the health
+ // checking system.
+
+ ctx, cancel := context.WithTimeout(ctx, 1*time.Second)
+ defer cancel()
+
+ status := "up"
+ if err := db.PingContext(ctx); err != nil {
+ status = "down"
+ }
+ log.Println(status)
+}
+
+func ExampleDB_Prepare() {
+ projects := []struct {
+ mascot string
+ release int
+ }{
+ {"tux", 1991},
+ {"duke", 1996},
+ {"gopher", 2009},
+ {"moby dock", 2013},
+ }
+
+ stmt, err := db.Prepare("INSERT INTO projects(id, mascot, release, category) VALUES( ?, ?, ?, ? )")
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer stmt.Close() // Prepared statements take up server resources and should be closed after use.
+
+ for id, project := range projects {
+ if _, err := stmt.Exec(id+1, project.mascot, project.release, "open source"); err != nil {
+ log.Fatal(err)
+ }
+ }
+}
+
+func ExampleTx_Prepare() {
+ projects := []struct {
+ mascot string
+ release int
+ }{
+ {"tux", 1991},
+ {"duke", 1996},
+ {"gopher", 2009},
+ {"moby dock", 2013},
+ }
+
+ tx, err := db.Begin()
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer tx.Rollback() // The rollback will be ignored if the tx has been committed later in the function.
+
+ stmt, err := tx.Prepare("INSERT INTO projects(id, mascot, release, category) VALUES( ?, ?, ?, ? )")
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer stmt.Close() // Prepared statements take up server resources and should be closed after use.
+
+ for id, project := range projects {
+ if _, err := stmt.Exec(id+1, project.mascot, project.release, "open source"); err != nil {
+ log.Fatal(err)
+ }
+ }
+ if err := tx.Commit(); err != nil {
+ log.Fatal(err)
+ }
+}
+
+func ExampleDB_BeginTx() {
+ tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
+ if err != nil {
+ log.Fatal(err)
+ }
+ id := 37
+ _, execErr := tx.Exec(`UPDATE users SET status = ? WHERE id = ?`, "paid", id)
+ if execErr != nil {
+ _ = tx.Rollback()
+ log.Fatal(execErr)
+ }
+ if err := tx.Commit(); err != nil {
+ log.Fatal(err)
+ }
+}
+
+func ExampleConn_ExecContext() {
+ // A *DB is a pool of connections. Call Conn to reserve a connection for
+ // exclusive use.
+ conn, err := db.Conn(ctx)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer conn.Close() // Return the connection to the pool.
+ id := 41
+ result, err := conn.ExecContext(ctx, `UPDATE balances SET balance = balance + 10 WHERE user_id = ?;`, id)
+ if err != nil {
+ log.Fatal(err)
+ }
+ rows, err := result.RowsAffected()
+ if err != nil {
+ log.Fatal(err)
+ }
+ if rows != 1 {
+ log.Fatalf("expected single row affected, got %d rows affected", rows)
+ }
+}
+
+func ExampleTx_ExecContext() {
+ tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
+ if err != nil {
+ log.Fatal(err)
+ }
+ id := 37
+ _, execErr := tx.ExecContext(ctx, "UPDATE users SET status = ? WHERE id = ?", "paid", id)
+ if execErr != nil {
+ if rollbackErr := tx.Rollback(); rollbackErr != nil {
+ log.Fatalf("update failed: %v, unable to rollback: %v\n", execErr, rollbackErr)
+ }
+ log.Fatalf("update failed: %v", execErr)
+ }
+ if err := tx.Commit(); err != nil {
+ log.Fatal(err)
+ }
+}
+
+func ExampleTx_Rollback() {
+ tx, err := db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelSerializable})
+ if err != nil {
+ log.Fatal(err)
+ }
+ id := 53
+ _, err = tx.ExecContext(ctx, "UPDATE drivers SET status = ? WHERE id = ?;", "assigned", id)
+ if err != nil {
+ if rollbackErr := tx.Rollback(); rollbackErr != nil {
+ log.Fatalf("update drivers: unable to rollback: %v", rollbackErr)
+ }
+ log.Fatal(err)
+ }
+ _, err = tx.ExecContext(ctx, "UPDATE pickups SET driver_id = $1;", id)
+ if err != nil {
+ if rollbackErr := tx.Rollback(); rollbackErr != nil {
+ log.Fatalf("update failed: %v, unable to back: %v", err, rollbackErr)
+ }
+ log.Fatal(err)
+ }
+ if err := tx.Commit(); err != nil {
+ log.Fatal(err)
+ }
+}
+
+func ExampleStmt() {
+ // In normal use, create one Stmt when your process starts.
+ stmt, err := db.PrepareContext(ctx, "SELECT username FROM users WHERE id = ?")
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer stmt.Close()
+
+ // Then reuse it each time you need to issue the query.
+ id := 43
+ var username string
+ err = stmt.QueryRowContext(ctx, id).Scan(&username)
+ switch {
+ case err == sql.ErrNoRows:
+ log.Fatalf("no user with id %d", id)
+ case err != nil:
+ log.Fatal(err)
+ default:
+ log.Printf("username is %s\n", username)
+ }
+}
+
+func ExampleStmt_QueryRowContext() {
+ // In normal use, create one Stmt when your process starts.
+ stmt, err := db.PrepareContext(ctx, "SELECT username FROM users WHERE id = ?")
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer stmt.Close()
+
+ // Then reuse it each time you need to issue the query.
+ id := 43
+ var username string
+ err = stmt.QueryRowContext(ctx, id).Scan(&username)
+ switch {
+ case err == sql.ErrNoRows:
+ log.Fatalf("no user with id %d", id)
+ case err != nil:
+ log.Fatal(err)
+ default:
+ log.Printf("username is %s\n", username)
+ }
+}
+
+func ExampleRows() {
+ age := 27
+ rows, err := db.QueryContext(ctx, "SELECT name FROM users WHERE age=?", age)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer rows.Close()
+
+ names := make([]string, 0)
+ for rows.Next() {
+ var name string
+ if err := rows.Scan(&name); err != nil {
+ log.Fatal(err)
+ }
+ names = append(names, name)
+ }
+ // Check for errors from iterating over rows.
+ if err := rows.Err(); err != nil {
+ log.Fatal(err)
+ }
+ log.Printf("%s are %d years old", strings.Join(names, ", "), age)
+}
diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go
new file mode 100644
index 0000000..cfeb3b3
--- /dev/null
+++ b/src/database/sql/fakedb_test.go
@@ -0,0 +1,1283 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sql
+
+import (
+ "context"
+ "database/sql/driver"
+ "errors"
+ "fmt"
+ "io"
+ "reflect"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+// fakeDriver is a fake database that implements Go's driver.Driver
+// interface, just for testing.
+//
+// It speaks a query language that's semantically similar to but
+// syntactically different and simpler than SQL. The syntax is as
+// follows:
+//
+// WIPE
+// CREATE|<tablename>|<col>=<type>,<col>=<type>,...
+// where types are: "string", [u]int{8,16,32,64}, "bool"
+// INSERT|<tablename>|col=val,col2=val2,col3=?
+// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
+// SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
+//
+// Any of these can be preceded by PANIC|<method>|, to cause the
+// named method on fakeStmt to panic.
+//
+// Any of these can be proceeded by WAIT|<duration>|, to cause the
+// named method on fakeStmt to sleep for the specified duration.
+//
+// Multiple of these can be combined when separated with a semicolon.
+//
+// When opening a fakeDriver's database, it starts empty with no
+// tables. All tables and data are stored in memory only.
+type fakeDriver struct {
+ mu sync.Mutex // guards 3 following fields
+ openCount int // conn opens
+ closeCount int // conn closes
+ waitCh chan struct{}
+ waitingCh chan struct{}
+ dbs map[string]*fakeDB
+}
+
+type fakeConnector struct {
+ name string
+
+ waiter func(context.Context)
+ closed bool
+}
+
+func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
+ conn, err := fdriver.Open(c.name)
+ conn.(*fakeConn).waiter = c.waiter
+ return conn, err
+}
+
+func (c *fakeConnector) Driver() driver.Driver {
+ return fdriver
+}
+
+func (c *fakeConnector) Close() error {
+ if c.closed {
+ return errors.New("fakedb: connector is closed")
+ }
+ c.closed = true
+ return nil
+}
+
+type fakeDriverCtx struct {
+ fakeDriver
+}
+
+var _ driver.DriverContext = &fakeDriverCtx{}
+
+func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
+ return &fakeConnector{name: name}, nil
+}
+
+type fakeDB struct {
+ name string
+
+ useRawBytes atomic.Bool
+
+ mu sync.Mutex
+ tables map[string]*table
+ badConn bool
+ allowAny bool
+}
+
+type fakeError struct {
+ Message string
+ Wrapped error
+}
+
+func (err fakeError) Error() string {
+ return err.Message
+}
+
+func (err fakeError) Unwrap() error {
+ return err.Wrapped
+}
+
+type table struct {
+ mu sync.Mutex
+ colname []string
+ coltype []string
+ rows []*row
+}
+
+func (t *table) columnIndex(name string) int {
+ for n, nname := range t.colname {
+ if name == nname {
+ return n
+ }
+ }
+ return -1
+}
+
+type row struct {
+ cols []any // must be same size as its table colname + coltype
+}
+
+type memToucher interface {
+ // touchMem reads & writes some memory, to help find data races.
+ touchMem()
+}
+
+type fakeConn struct {
+ db *fakeDB // where to return ourselves to
+
+ currTx *fakeTx
+
+ // Every operation writes to line to enable the race detector
+ // check for data races.
+ line int64
+
+ // Stats for tests:
+ mu sync.Mutex
+ stmtsMade int
+ stmtsClosed int
+ numPrepare int
+
+ // bad connection tests; see isBad()
+ bad bool
+ stickyBad bool
+
+ skipDirtySession bool // tests that use Conn should set this to true.
+
+ // dirtySession tests ResetSession, true if a query has executed
+ // until ResetSession is called.
+ dirtySession bool
+
+ // The waiter is called before each query. May be used in place of the "WAIT"
+ // directive.
+ waiter func(context.Context)
+}
+
+func (c *fakeConn) touchMem() {
+ c.line++
+}
+
+func (c *fakeConn) incrStat(v *int) {
+ c.mu.Lock()
+ *v++
+ c.mu.Unlock()
+}
+
+type fakeTx struct {
+ c *fakeConn
+}
+
+type boundCol struct {
+ Column string
+ Placeholder string
+ Ordinal int
+}
+
+type fakeStmt struct {
+ memToucher
+ c *fakeConn
+ q string // just for debugging
+
+ cmd string
+ table string
+ panic string
+ wait time.Duration
+
+ next *fakeStmt // used for returning multiple results.
+
+ closed bool
+
+ colName []string // used by CREATE, INSERT, SELECT (selected columns)
+ colType []string // used by CREATE
+ colValue []any // used by INSERT (mix of strings and "?" for bound params)
+ placeholders int // used by INSERT/SELECT: number of ? params
+
+ whereCol []boundCol // used by SELECT (all placeholders)
+
+ placeholderConverter []driver.ValueConverter // used by INSERT
+}
+
+var fdriver driver.Driver = &fakeDriver{}
+
+func init() {
+ Register("test", fdriver)
+}
+
+func contains(list []string, y string) bool {
+ for _, x := range list {
+ if x == y {
+ return true
+ }
+ }
+ return false
+}
+
+type Dummy struct {
+ driver.Driver
+}
+
+func TestDrivers(t *testing.T) {
+ unregisterAllDrivers()
+ Register("test", fdriver)
+ Register("invalid", Dummy{})
+ all := Drivers()
+ if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
+ t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
+ }
+}
+
+// hook to simulate connection failures
+var hookOpenErr struct {
+ sync.Mutex
+ fn func() error
+}
+
+func setHookOpenErr(fn func() error) {
+ hookOpenErr.Lock()
+ defer hookOpenErr.Unlock()
+ hookOpenErr.fn = fn
+}
+
+// Supports dsn forms:
+//
+// <dbname>
+// <dbname>;<opts> (only currently supported option is `badConn`,
+// which causes driver.ErrBadConn to be returned on
+// every other conn.Begin())
+func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
+ hookOpenErr.Lock()
+ fn := hookOpenErr.fn
+ hookOpenErr.Unlock()
+ if fn != nil {
+ if err := fn(); err != nil {
+ return nil, err
+ }
+ }
+ parts := strings.Split(dsn, ";")
+ if len(parts) < 1 {
+ return nil, errors.New("fakedb: no database name")
+ }
+ name := parts[0]
+
+ db := d.getDB(name)
+
+ d.mu.Lock()
+ d.openCount++
+ d.mu.Unlock()
+ conn := &fakeConn{db: db}
+
+ if len(parts) >= 2 && parts[1] == "badConn" {
+ conn.bad = true
+ }
+ if d.waitCh != nil {
+ d.waitingCh <- struct{}{}
+ <-d.waitCh
+ d.waitCh = nil
+ d.waitingCh = nil
+ }
+ return conn, nil
+}
+
+func (d *fakeDriver) getDB(name string) *fakeDB {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ if d.dbs == nil {
+ d.dbs = make(map[string]*fakeDB)
+ }
+ db, ok := d.dbs[name]
+ if !ok {
+ db = &fakeDB{name: name}
+ d.dbs[name] = db
+ }
+ return db
+}
+
+func (db *fakeDB) wipe() {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ db.tables = nil
+}
+
+func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ if db.tables == nil {
+ db.tables = make(map[string]*table)
+ }
+ if _, exist := db.tables[name]; exist {
+ return fmt.Errorf("fakedb: table %q already exists", name)
+ }
+ if len(columnNames) != len(columnTypes) {
+ return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d",
+ name, len(columnNames), len(columnTypes))
+ }
+ db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
+ return nil
+}
+
+// must be called with db.mu lock held
+func (db *fakeDB) table(table string) (*table, bool) {
+ if db.tables == nil {
+ return nil, false
+ }
+ t, ok := db.tables[table]
+ return t, ok
+}
+
+func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ t, ok := db.table(table)
+ if !ok {
+ return
+ }
+ for n, cname := range t.colname {
+ if cname == column {
+ return t.coltype[n], true
+ }
+ }
+ return "", false
+}
+
+func (c *fakeConn) isBad() bool {
+ if c.stickyBad {
+ return true
+ } else if c.bad {
+ if c.db == nil {
+ return false
+ }
+ // alternate between bad conn and not bad conn
+ c.db.badConn = !c.db.badConn
+ return c.db.badConn
+ } else {
+ return false
+ }
+}
+
+func (c *fakeConn) isDirtyAndMark() bool {
+ if c.skipDirtySession {
+ return false
+ }
+ if c.currTx != nil {
+ c.dirtySession = true
+ return false
+ }
+ if c.dirtySession {
+ return true
+ }
+ c.dirtySession = true
+ return false
+}
+
+func (c *fakeConn) Begin() (driver.Tx, error) {
+ if c.isBad() {
+ return nil, fakeError{Wrapped: driver.ErrBadConn}
+ }
+ if c.currTx != nil {
+ return nil, errors.New("fakedb: already in a transaction")
+ }
+ c.touchMem()
+ c.currTx = &fakeTx{c: c}
+ return c.currTx, nil
+}
+
+var hookPostCloseConn struct {
+ sync.Mutex
+ fn func(*fakeConn, error)
+}
+
+func setHookpostCloseConn(fn func(*fakeConn, error)) {
+ hookPostCloseConn.Lock()
+ defer hookPostCloseConn.Unlock()
+ hookPostCloseConn.fn = fn
+}
+
+var testStrictClose *testing.T
+
+// setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
+// fails to close. If nil, the check is disabled.
+func setStrictFakeConnClose(t *testing.T) {
+ testStrictClose = t
+}
+
+func (c *fakeConn) ResetSession(ctx context.Context) error {
+ c.dirtySession = false
+ c.currTx = nil
+ if c.isBad() {
+ return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn}
+ }
+ return nil
+}
+
+var _ driver.Validator = (*fakeConn)(nil)
+
+func (c *fakeConn) IsValid() bool {
+ return !c.isBad()
+}
+
+func (c *fakeConn) Close() (err error) {
+ drv := fdriver.(*fakeDriver)
+ defer func() {
+ if err != nil && testStrictClose != nil {
+ testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
+ }
+ hookPostCloseConn.Lock()
+ fn := hookPostCloseConn.fn
+ hookPostCloseConn.Unlock()
+ if fn != nil {
+ fn(c, err)
+ }
+ if err == nil {
+ drv.mu.Lock()
+ drv.closeCount++
+ drv.mu.Unlock()
+ }
+ }()
+ c.touchMem()
+ if c.currTx != nil {
+ return errors.New("fakedb: can't close fakeConn; in a Transaction")
+ }
+ if c.db == nil {
+ return errors.New("fakedb: can't close fakeConn; already closed")
+ }
+ if c.stmtsMade > c.stmtsClosed {
+ return errors.New("fakedb: can't close; dangling statement(s)")
+ }
+ c.db = nil
+ return nil
+}
+
+func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
+ for _, arg := range args {
+ switch arg.Value.(type) {
+ case int64, float64, bool, nil, []byte, string, time.Time:
+ default:
+ if !allowAny {
+ return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
+ }
+ }
+ }
+ return nil
+}
+
+func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+ // Ensure that ExecContext is called if available.
+ panic("ExecContext was not called.")
+}
+
+func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
+ // This is an optional interface, but it's implemented here
+ // just to check that all the args are of the proper types.
+ // ErrSkip is returned so the caller acts as if we didn't
+ // implement this at all.
+ err := checkSubsetTypes(c.db.allowAny, args)
+ if err != nil {
+ return nil, err
+ }
+ return nil, driver.ErrSkip
+}
+
+func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
+ // Ensure that ExecContext is called if available.
+ panic("QueryContext was not called.")
+}
+
+func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
+ // This is an optional interface, but it's implemented here
+ // just to check that all the args are of the proper types.
+ // ErrSkip is returned so the caller acts as if we didn't
+ // implement this at all.
+ err := checkSubsetTypes(c.db.allowAny, args)
+ if err != nil {
+ return nil, err
+ }
+ return nil, driver.ErrSkip
+}
+
+func errf(msg string, args ...any) error {
+ return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
+}
+
+// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
+// (note that where columns must always contain ? marks,
+// just a limitation for fakedb)
+func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
+ if len(parts) != 3 {
+ stmt.Close()
+ return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
+ }
+ stmt.table = parts[0]
+
+ stmt.colName = strings.Split(parts[1], ",")
+ for n, colspec := range strings.Split(parts[2], ",") {
+ if colspec == "" {
+ continue
+ }
+ nameVal := strings.Split(colspec, "=")
+ if len(nameVal) != 2 {
+ stmt.Close()
+ return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
+ }
+ column, value := nameVal[0], nameVal[1]
+ _, ok := c.db.columnType(stmt.table, column)
+ if !ok {
+ stmt.Close()
+ return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
+ }
+ if !strings.HasPrefix(value, "?") {
+ stmt.Close()
+ return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
+ stmt.table, column)
+ }
+ stmt.placeholders++
+ stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
+ }
+ return stmt, nil
+}
+
+// parts are table|col=type,col2=type2
+func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
+ if len(parts) != 2 {
+ stmt.Close()
+ return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
+ }
+ stmt.table = parts[0]
+ for n, colspec := range strings.Split(parts[1], ",") {
+ nameType := strings.Split(colspec, "=")
+ if len(nameType) != 2 {
+ stmt.Close()
+ return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
+ }
+ stmt.colName = append(stmt.colName, nameType[0])
+ stmt.colType = append(stmt.colType, nameType[1])
+ }
+ return stmt, nil
+}
+
+// parts are table|col=?,col2=val
+func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
+ if len(parts) != 2 {
+ stmt.Close()
+ return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
+ }
+ stmt.table = parts[0]
+ for n, colspec := range strings.Split(parts[1], ",") {
+ nameVal := strings.Split(colspec, "=")
+ if len(nameVal) != 2 {
+ stmt.Close()
+ return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
+ }
+ column, value := nameVal[0], nameVal[1]
+ ctype, ok := c.db.columnType(stmt.table, column)
+ if !ok {
+ stmt.Close()
+ return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
+ }
+ stmt.colName = append(stmt.colName, column)
+
+ if !strings.HasPrefix(value, "?") {
+ var subsetVal any
+ // Convert to driver subset type
+ switch ctype {
+ case "string":
+ subsetVal = []byte(value)
+ case "blob":
+ subsetVal = []byte(value)
+ case "int32":
+ i, err := strconv.Atoi(value)
+ if err != nil {
+ stmt.Close()
+ return nil, errf("invalid conversion to int32 from %q", value)
+ }
+ subsetVal = int64(i) // int64 is a subset type, but not int32
+ case "table": // For testing cursor reads.
+ c.skipDirtySession = true
+ vparts := strings.Split(value, "!")
+
+ substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
+ if err != nil {
+ return nil, err
+ }
+ cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
+ substmt.Close()
+ if err != nil {
+ return nil, err
+ }
+ subsetVal = cursor
+ default:
+ stmt.Close()
+ return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
+ }
+ stmt.colValue = append(stmt.colValue, subsetVal)
+ } else {
+ stmt.placeholders++
+ stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
+ stmt.colValue = append(stmt.colValue, value)
+ }
+ }
+ return stmt, nil
+}
+
+// hook to simulate broken connections
+var hookPrepareBadConn func() bool
+
+func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
+ panic("use PrepareContext")
+}
+
+func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
+ c.numPrepare++
+ if c.db == nil {
+ panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
+ }
+
+ if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
+ return nil, fakeError{Message: "Prepare: Sticky Bad", Wrapped: driver.ErrBadConn}
+ }
+
+ c.touchMem()
+ var firstStmt, prev *fakeStmt
+ for _, query := range strings.Split(query, ";") {
+ parts := strings.Split(query, "|")
+ if len(parts) < 1 {
+ return nil, errf("empty query")
+ }
+ stmt := &fakeStmt{q: query, c: c, memToucher: c}
+ if firstStmt == nil {
+ firstStmt = stmt
+ }
+ if len(parts) >= 3 {
+ switch parts[0] {
+ case "PANIC":
+ stmt.panic = parts[1]
+ parts = parts[2:]
+ case "WAIT":
+ wait, err := time.ParseDuration(parts[1])
+ if err != nil {
+ return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
+ }
+ parts = parts[2:]
+ stmt.wait = wait
+ }
+ }
+ cmd := parts[0]
+ stmt.cmd = cmd
+ parts = parts[1:]
+
+ if c.waiter != nil {
+ c.waiter(ctx)
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
+ }
+
+ if stmt.wait > 0 {
+ wait := time.NewTimer(stmt.wait)
+ select {
+ case <-wait.C:
+ case <-ctx.Done():
+ wait.Stop()
+ return nil, ctx.Err()
+ }
+ }
+
+ c.incrStat(&c.stmtsMade)
+ var err error
+ switch cmd {
+ case "WIPE":
+ // Nothing
+ case "USE_RAWBYTES":
+ c.db.useRawBytes.Store(true)
+ case "SELECT":
+ stmt, err = c.prepareSelect(stmt, parts)
+ case "CREATE":
+ stmt, err = c.prepareCreate(stmt, parts)
+ case "INSERT":
+ stmt, err = c.prepareInsert(ctx, stmt, parts)
+ case "NOSERT":
+ // Do all the prep-work like for an INSERT but don't actually insert the row.
+ // Used for some of the concurrent tests.
+ stmt, err = c.prepareInsert(ctx, stmt, parts)
+ default:
+ stmt.Close()
+ return nil, errf("unsupported command type %q", cmd)
+ }
+ if err != nil {
+ return nil, err
+ }
+ if prev != nil {
+ prev.next = stmt
+ }
+ prev = stmt
+ }
+ return firstStmt, nil
+}
+
+func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
+ if s.panic == "ColumnConverter" {
+ panic(s.panic)
+ }
+ if len(s.placeholderConverter) == 0 {
+ return driver.DefaultParameterConverter
+ }
+ return s.placeholderConverter[idx]
+}
+
+func (s *fakeStmt) Close() error {
+ if s.panic == "Close" {
+ panic(s.panic)
+ }
+ if s.c == nil {
+ panic("nil conn in fakeStmt.Close")
+ }
+ if s.c.db == nil {
+ panic("in fakeStmt.Close, conn's db is nil (already closed)")
+ }
+ s.touchMem()
+ if !s.closed {
+ s.c.incrStat(&s.c.stmtsClosed)
+ s.closed = true
+ }
+ if s.next != nil {
+ s.next.Close()
+ }
+ return nil
+}
+
+var errClosed = errors.New("fakedb: statement has been closed")
+
+// hook to simulate broken connections
+var hookExecBadConn func() bool
+
+func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
+ panic("Using ExecContext")
+}
+
+var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
+
+func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
+ if s.panic == "Exec" {
+ panic(s.panic)
+ }
+ if s.closed {
+ return nil, errClosed
+ }
+
+ if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
+ return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn}
+ }
+ if s.c.isDirtyAndMark() {
+ return nil, errFakeConnSessionDirty
+ }
+
+ err := checkSubsetTypes(s.c.db.allowAny, args)
+ if err != nil {
+ return nil, err
+ }
+ s.touchMem()
+
+ if s.wait > 0 {
+ time.Sleep(s.wait)
+ }
+
+ select {
+ default:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+
+ db := s.c.db
+ switch s.cmd {
+ case "WIPE":
+ db.wipe()
+ return driver.ResultNoRows, nil
+ case "USE_RAWBYTES":
+ s.c.db.useRawBytes.Store(true)
+ return driver.ResultNoRows, nil
+ case "CREATE":
+ if err := db.createTable(s.table, s.colName, s.colType); err != nil {
+ return nil, err
+ }
+ return driver.ResultNoRows, nil
+ case "INSERT":
+ return s.execInsert(args, true)
+ case "NOSERT":
+ // Do all the prep-work like for an INSERT but don't actually insert the row.
+ // Used for some of the concurrent tests.
+ return s.execInsert(args, false)
+ }
+ return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
+}
+
+// When doInsert is true, add the row to the table.
+// When doInsert is false do prep-work and error checking, but don't
+// actually add the row to the table.
+func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
+ db := s.c.db
+ if len(args) != s.placeholders {
+ panic("error in pkg db; should only get here if size is correct")
+ }
+ db.mu.Lock()
+ t, ok := db.table(s.table)
+ db.mu.Unlock()
+ if !ok {
+ return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
+ }
+
+ t.mu.Lock()
+ defer t.mu.Unlock()
+
+ var cols []any
+ if doInsert {
+ cols = make([]any, len(t.colname))
+ }
+ argPos := 0
+ for n, colname := range s.colName {
+ colidx := t.columnIndex(colname)
+ if colidx == -1 {
+ return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
+ }
+ var val any
+ if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
+ if strvalue == "?" {
+ val = args[argPos].Value
+ } else {
+ // Assign value from argument placeholder name.
+ for _, a := range args {
+ if a.Name == strvalue[1:] {
+ val = a.Value
+ break
+ }
+ }
+ }
+ argPos++
+ } else {
+ val = s.colValue[n]
+ }
+ if doInsert {
+ cols[colidx] = val
+ }
+ }
+
+ if doInsert {
+ t.rows = append(t.rows, &row{cols: cols})
+ }
+ return driver.RowsAffected(1), nil
+}
+
+// hook to simulate broken connections
+var hookQueryBadConn func() bool
+
+func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
+ panic("Use QueryContext")
+}
+
+func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
+ if s.panic == "Query" {
+ panic(s.panic)
+ }
+ if s.closed {
+ return nil, errClosed
+ }
+
+ if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
+ return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn}
+ }
+ if s.c.isDirtyAndMark() {
+ return nil, errFakeConnSessionDirty
+ }
+
+ err := checkSubsetTypes(s.c.db.allowAny, args)
+ if err != nil {
+ return nil, err
+ }
+
+ s.touchMem()
+ db := s.c.db
+ if len(args) != s.placeholders {
+ panic("error in pkg db; should only get here if size is correct")
+ }
+
+ setMRows := make([][]*row, 0, 1)
+ setColumns := make([][]string, 0, 1)
+ setColType := make([][]string, 0, 1)
+
+ for {
+ db.mu.Lock()
+ t, ok := db.table(s.table)
+ db.mu.Unlock()
+ if !ok {
+ return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
+ }
+
+ if s.table == "magicquery" {
+ if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
+ if args[0].Value == "sleep" {
+ time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
+ }
+ }
+ }
+ if s.table == "tx_status" && s.colName[0] == "tx_status" {
+ txStatus := "autocommit"
+ if s.c.currTx != nil {
+ txStatus = "transaction"
+ }
+ cursor := &rowsCursor{
+ db: s.c.db,
+ parentMem: s.c,
+ posRow: -1,
+ rows: [][]*row{
+ {
+ {
+ cols: []any{
+ txStatus,
+ },
+ },
+ },
+ },
+ cols: [][]string{
+ {
+ "tx_status",
+ },
+ },
+ colType: [][]string{
+ {
+ "string",
+ },
+ },
+ errPos: -1,
+ }
+ return cursor, nil
+ }
+
+ t.mu.Lock()
+
+ colIdx := make(map[string]int) // select column name -> column index in table
+ for _, name := range s.colName {
+ idx := t.columnIndex(name)
+ if idx == -1 {
+ t.mu.Unlock()
+ return nil, fmt.Errorf("fakedb: unknown column name %q", name)
+ }
+ colIdx[name] = idx
+ }
+
+ mrows := []*row{}
+ rows:
+ for _, trow := range t.rows {
+ // Process the where clause, skipping non-match rows. This is lazy
+ // and just uses fmt.Sprintf("%v") to test equality. Good enough
+ // for test code.
+ for _, wcol := range s.whereCol {
+ idx := t.columnIndex(wcol.Column)
+ if idx == -1 {
+ t.mu.Unlock()
+ return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol)
+ }
+ tcol := trow.cols[idx]
+ if bs, ok := tcol.([]byte); ok {
+ // lazy hack to avoid sprintf %v on a []byte
+ tcol = string(bs)
+ }
+ var argValue any
+ if wcol.Placeholder == "?" {
+ argValue = args[wcol.Ordinal-1].Value
+ } else {
+ // Assign arg value from placeholder name.
+ for _, a := range args {
+ if a.Name == wcol.Placeholder[1:] {
+ argValue = a.Value
+ break
+ }
+ }
+ }
+ if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
+ continue rows
+ }
+ }
+ mrow := &row{cols: make([]any, len(s.colName))}
+ for seli, name := range s.colName {
+ mrow.cols[seli] = trow.cols[colIdx[name]]
+ }
+ mrows = append(mrows, mrow)
+ }
+
+ var colType []string
+ for _, column := range s.colName {
+ colType = append(colType, t.coltype[t.columnIndex(column)])
+ }
+
+ t.mu.Unlock()
+
+ setMRows = append(setMRows, mrows)
+ setColumns = append(setColumns, s.colName)
+ setColType = append(setColType, colType)
+
+ if s.next == nil {
+ break
+ }
+ s = s.next
+ }
+
+ cursor := &rowsCursor{
+ db: s.c.db,
+ parentMem: s.c,
+ posRow: -1,
+ rows: setMRows,
+ cols: setColumns,
+ colType: setColType,
+ errPos: -1,
+ }
+ return cursor, nil
+}
+
+func (s *fakeStmt) NumInput() int {
+ if s.panic == "NumInput" {
+ panic(s.panic)
+ }
+ return s.placeholders
+}
+
+// hook to simulate broken connections
+var hookCommitBadConn func() bool
+
+func (tx *fakeTx) Commit() error {
+ tx.c.currTx = nil
+ if hookCommitBadConn != nil && hookCommitBadConn() {
+ return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn}
+ }
+ tx.c.touchMem()
+ return nil
+}
+
+// hook to simulate broken connections
+var hookRollbackBadConn func() bool
+
+func (tx *fakeTx) Rollback() error {
+ tx.c.currTx = nil
+ if hookRollbackBadConn != nil && hookRollbackBadConn() {
+ return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn}
+ }
+ tx.c.touchMem()
+ return nil
+}
+
+type rowsCursor struct {
+ db *fakeDB
+ parentMem memToucher
+ cols [][]string
+ colType [][]string
+ posSet int
+ posRow int
+ rows [][]*row
+ closed bool
+
+ // errPos and err are for making Next return early with error.
+ errPos int
+ err error
+
+ // a clone of slices to give out to clients, indexed by the
+ // original slice's first byte address. we clone them
+ // just so we're able to corrupt them on close.
+ bytesClone map[*byte][]byte
+
+ // Every operation writes to line to enable the race detector
+ // check for data races.
+ // This is separate from the fakeConn.line to allow for drivers that
+ // can start multiple queries on the same transaction at the same time.
+ line int64
+
+ // closeErr is returned when rowsCursor.Close
+ closeErr error
+}
+
+func (rc *rowsCursor) touchMem() {
+ rc.parentMem.touchMem()
+ rc.line++
+}
+
+func (rc *rowsCursor) Close() error {
+ rc.touchMem()
+ rc.parentMem.touchMem()
+ rc.closed = true
+ return rc.closeErr
+}
+
+func (rc *rowsCursor) Columns() []string {
+ return rc.cols[rc.posSet]
+}
+
+func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
+ return colTypeToReflectType(rc.colType[rc.posSet][index])
+}
+
+var rowsCursorNextHook func(dest []driver.Value) error
+
+func (rc *rowsCursor) Next(dest []driver.Value) error {
+ if rowsCursorNextHook != nil {
+ return rowsCursorNextHook(dest)
+ }
+
+ if rc.closed {
+ return errors.New("fakedb: cursor is closed")
+ }
+ rc.touchMem()
+ rc.posRow++
+ if rc.posRow == rc.errPos {
+ return rc.err
+ }
+ if rc.posRow >= len(rc.rows[rc.posSet]) {
+ return io.EOF // per interface spec
+ }
+ for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
+ // TODO(bradfitz): convert to subset types? naah, I
+ // think the subset types should only be input to
+ // driver, but the sql package should be able to handle
+ // a wider range of types coming out of drivers. all
+ // for ease of drivers, and to prevent drivers from
+ // messing up conversions or doing them differently.
+ dest[i] = v
+
+ if bs, ok := v.([]byte); ok && !rc.db.useRawBytes.Load() {
+ if rc.bytesClone == nil {
+ rc.bytesClone = make(map[*byte][]byte)
+ }
+ clone, ok := rc.bytesClone[&bs[0]]
+ if !ok {
+ clone = make([]byte, len(bs))
+ copy(clone, bs)
+ rc.bytesClone[&bs[0]] = clone
+ }
+ dest[i] = clone
+ }
+ }
+ return nil
+}
+
+func (rc *rowsCursor) HasNextResultSet() bool {
+ rc.touchMem()
+ return rc.posSet < len(rc.rows)-1
+}
+
+func (rc *rowsCursor) NextResultSet() error {
+ rc.touchMem()
+ if rc.HasNextResultSet() {
+ rc.posSet++
+ rc.posRow = -1
+ return nil
+ }
+ return io.EOF // Per interface spec.
+}
+
+// fakeDriverString is like driver.String, but indirects pointers like
+// DefaultValueConverter.
+//
+// This could be surprising behavior to retroactively apply to
+// driver.String now that Go1 is out, but this is convenient for
+// our TestPointerParamsAndScans.
+type fakeDriverString struct{}
+
+func (fakeDriverString) ConvertValue(v any) (driver.Value, error) {
+ switch c := v.(type) {
+ case string, []byte:
+ return v, nil
+ case *string:
+ if c == nil {
+ return nil, nil
+ }
+ return *c, nil
+ }
+ return fmt.Sprintf("%v", v), nil
+}
+
+type anyTypeConverter struct{}
+
+func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) {
+ return v, nil
+}
+
+func converterForType(typ string) driver.ValueConverter {
+ switch typ {
+ case "bool":
+ return driver.Bool
+ case "nullbool":
+ return driver.Null{Converter: driver.Bool}
+ case "byte", "int16":
+ return driver.NotNull{Converter: driver.DefaultParameterConverter}
+ case "int32":
+ return driver.Int32
+ case "nullbyte", "nullint32", "nullint16":
+ return driver.Null{Converter: driver.DefaultParameterConverter}
+ case "string":
+ return driver.NotNull{Converter: fakeDriverString{}}
+ case "nullstring":
+ return driver.Null{Converter: fakeDriverString{}}
+ case "int64":
+ // TODO(coopernurse): add type-specific converter
+ return driver.NotNull{Converter: driver.DefaultParameterConverter}
+ case "nullint64":
+ // TODO(coopernurse): add type-specific converter
+ return driver.Null{Converter: driver.DefaultParameterConverter}
+ case "float64":
+ // TODO(coopernurse): add type-specific converter
+ return driver.NotNull{Converter: driver.DefaultParameterConverter}
+ case "nullfloat64":
+ // TODO(coopernurse): add type-specific converter
+ return driver.Null{Converter: driver.DefaultParameterConverter}
+ case "datetime":
+ return driver.NotNull{Converter: driver.DefaultParameterConverter}
+ case "nulldatetime":
+ return driver.Null{Converter: driver.DefaultParameterConverter}
+ case "any":
+ return anyTypeConverter{}
+ }
+ panic("invalid fakedb column type of " + typ)
+}
+
+func colTypeToReflectType(typ string) reflect.Type {
+ switch typ {
+ case "bool":
+ return reflect.TypeOf(false)
+ case "nullbool":
+ return reflect.TypeOf(NullBool{})
+ case "int16":
+ return reflect.TypeOf(int16(0))
+ case "nullint16":
+ return reflect.TypeOf(NullInt16{})
+ case "int32":
+ return reflect.TypeOf(int32(0))
+ case "nullint32":
+ return reflect.TypeOf(NullInt32{})
+ case "string":
+ return reflect.TypeOf("")
+ case "nullstring":
+ return reflect.TypeOf(NullString{})
+ case "int64":
+ return reflect.TypeOf(int64(0))
+ case "nullint64":
+ return reflect.TypeOf(NullInt64{})
+ case "float64":
+ return reflect.TypeOf(float64(0))
+ case "nullfloat64":
+ return reflect.TypeOf(NullFloat64{})
+ case "datetime":
+ return reflect.TypeOf(time.Time{})
+ case "any":
+ return reflect.TypeOf(new(any)).Elem()
+ }
+ panic("invalid fakedb column type of " + typ)
+}
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go
new file mode 100644
index 0000000..836fe83
--- /dev/null
+++ b/src/database/sql/sql.go
@@ -0,0 +1,3503 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package sql provides a generic interface around SQL (or SQL-like)
+// databases.
+//
+// The sql package must be used in conjunction with a database driver.
+// See https://golang.org/s/sqldrivers for a list of drivers.
+//
+// Drivers that do not support context cancellation will not return until
+// after the query is completed.
+//
+// For usage examples, see the wiki page at
+// https://golang.org/s/sqlwiki.
+package sql
+
+import (
+ "context"
+ "database/sql/driver"
+ "errors"
+ "fmt"
+ "io"
+ "reflect"
+ "runtime"
+ "sort"
+ "strconv"
+ "sync"
+ "sync/atomic"
+ "time"
+)
+
+var (
+ driversMu sync.RWMutex
+ drivers = make(map[string]driver.Driver)
+)
+
+// nowFunc returns the current time; it's overridden in tests.
+var nowFunc = time.Now
+
+// Register makes a database driver available by the provided name.
+// If Register is called twice with the same name or if driver is nil,
+// it panics.
+func Register(name string, driver driver.Driver) {
+ driversMu.Lock()
+ defer driversMu.Unlock()
+ if driver == nil {
+ panic("sql: Register driver is nil")
+ }
+ if _, dup := drivers[name]; dup {
+ panic("sql: Register called twice for driver " + name)
+ }
+ drivers[name] = driver
+}
+
+func unregisterAllDrivers() {
+ driversMu.Lock()
+ defer driversMu.Unlock()
+ // For tests.
+ drivers = make(map[string]driver.Driver)
+}
+
+// Drivers returns a sorted list of the names of the registered drivers.
+func Drivers() []string {
+ driversMu.RLock()
+ defer driversMu.RUnlock()
+ list := make([]string, 0, len(drivers))
+ for name := range drivers {
+ list = append(list, name)
+ }
+ sort.Strings(list)
+ return list
+}
+
+// A NamedArg is a named argument. NamedArg values may be used as
+// arguments to Query or Exec and bind to the corresponding named
+// parameter in the SQL statement.
+//
+// For a more concise way to create NamedArg values, see
+// the Named function.
+type NamedArg struct {
+ _NamedFieldsRequired struct{}
+
+ // Name is the name of the parameter placeholder.
+ //
+ // If empty, the ordinal position in the argument list will be
+ // used.
+ //
+ // Name must omit any symbol prefix.
+ Name string
+
+ // Value is the value of the parameter.
+ // It may be assigned the same value types as the query
+ // arguments.
+ Value any
+}
+
+// Named provides a more concise way to create NamedArg values.
+//
+// Example usage:
+//
+// db.ExecContext(ctx, `
+// delete from Invoice
+// where
+// TimeCreated < @end
+// and TimeCreated >= @start;`,
+// sql.Named("start", startTime),
+// sql.Named("end", endTime),
+// )
+func Named(name string, value any) NamedArg {
+ // This method exists because the go1compat promise
+ // doesn't guarantee that structs don't grow more fields,
+ // so unkeyed struct literals are a vet error. Thus, we don't
+ // want to allow sql.NamedArg{name, value}.
+ return NamedArg{Name: name, Value: value}
+}
+
+// IsolationLevel is the transaction isolation level used in TxOptions.
+type IsolationLevel int
+
+// Various isolation levels that drivers may support in BeginTx.
+// If a driver does not support a given isolation level an error may be returned.
+//
+// See https://en.wikipedia.org/wiki/Isolation_(database_systems)#Isolation_levels.
+const (
+ LevelDefault IsolationLevel = iota
+ LevelReadUncommitted
+ LevelReadCommitted
+ LevelWriteCommitted
+ LevelRepeatableRead
+ LevelSnapshot
+ LevelSerializable
+ LevelLinearizable
+)
+
+// String returns the name of the transaction isolation level.
+func (i IsolationLevel) String() string {
+ switch i {
+ case LevelDefault:
+ return "Default"
+ case LevelReadUncommitted:
+ return "Read Uncommitted"
+ case LevelReadCommitted:
+ return "Read Committed"
+ case LevelWriteCommitted:
+ return "Write Committed"
+ case LevelRepeatableRead:
+ return "Repeatable Read"
+ case LevelSnapshot:
+ return "Snapshot"
+ case LevelSerializable:
+ return "Serializable"
+ case LevelLinearizable:
+ return "Linearizable"
+ default:
+ return "IsolationLevel(" + strconv.Itoa(int(i)) + ")"
+ }
+}
+
+var _ fmt.Stringer = LevelDefault
+
+// TxOptions holds the transaction options to be used in DB.BeginTx.
+type TxOptions struct {
+ // Isolation is the transaction isolation level.
+ // If zero, the driver or database's default level is used.
+ Isolation IsolationLevel
+ ReadOnly bool
+}
+
+// RawBytes is a byte slice that holds a reference to memory owned by
+// the database itself. After a Scan into a RawBytes, the slice is only
+// valid until the next call to Next, Scan, or Close.
+type RawBytes []byte
+
+// NullString represents a string that may be null.
+// NullString implements the Scanner interface so
+// it can be used as a scan destination:
+//
+// var s NullString
+// err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&s)
+// ...
+// if s.Valid {
+// // use s.String
+// } else {
+// // NULL value
+// }
+type NullString struct {
+ String string
+ Valid bool // Valid is true if String is not NULL
+}
+
+// Scan implements the Scanner interface.
+func (ns *NullString) Scan(value any) error {
+ if value == nil {
+ ns.String, ns.Valid = "", false
+ return nil
+ }
+ ns.Valid = true
+ return convertAssign(&ns.String, value)
+}
+
+// Value implements the driver Valuer interface.
+func (ns NullString) Value() (driver.Value, error) {
+ if !ns.Valid {
+ return nil, nil
+ }
+ return ns.String, nil
+}
+
+// NullInt64 represents an int64 that may be null.
+// NullInt64 implements the Scanner interface so
+// it can be used as a scan destination, similar to NullString.
+type NullInt64 struct {
+ Int64 int64
+ Valid bool // Valid is true if Int64 is not NULL
+}
+
+// Scan implements the Scanner interface.
+func (n *NullInt64) Scan(value any) error {
+ if value == nil {
+ n.Int64, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ return convertAssign(&n.Int64, value)
+}
+
+// Value implements the driver Valuer interface.
+func (n NullInt64) Value() (driver.Value, error) {
+ if !n.Valid {
+ return nil, nil
+ }
+ return n.Int64, nil
+}
+
+// NullInt32 represents an int32 that may be null.
+// NullInt32 implements the Scanner interface so
+// it can be used as a scan destination, similar to NullString.
+type NullInt32 struct {
+ Int32 int32
+ Valid bool // Valid is true if Int32 is not NULL
+}
+
+// Scan implements the Scanner interface.
+func (n *NullInt32) Scan(value any) error {
+ if value == nil {
+ n.Int32, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ return convertAssign(&n.Int32, value)
+}
+
+// Value implements the driver Valuer interface.
+func (n NullInt32) Value() (driver.Value, error) {
+ if !n.Valid {
+ return nil, nil
+ }
+ return int64(n.Int32), nil
+}
+
+// NullInt16 represents an int16 that may be null.
+// NullInt16 implements the Scanner interface so
+// it can be used as a scan destination, similar to NullString.
+type NullInt16 struct {
+ Int16 int16
+ Valid bool // Valid is true if Int16 is not NULL
+}
+
+// Scan implements the Scanner interface.
+func (n *NullInt16) Scan(value any) error {
+ if value == nil {
+ n.Int16, n.Valid = 0, false
+ return nil
+ }
+ err := convertAssign(&n.Int16, value)
+ n.Valid = err == nil
+ return err
+}
+
+// Value implements the driver Valuer interface.
+func (n NullInt16) Value() (driver.Value, error) {
+ if !n.Valid {
+ return nil, nil
+ }
+ return int64(n.Int16), nil
+}
+
+// NullByte represents a byte that may be null.
+// NullByte implements the Scanner interface so
+// it can be used as a scan destination, similar to NullString.
+type NullByte struct {
+ Byte byte
+ Valid bool // Valid is true if Byte is not NULL
+}
+
+// Scan implements the Scanner interface.
+func (n *NullByte) Scan(value any) error {
+ if value == nil {
+ n.Byte, n.Valid = 0, false
+ return nil
+ }
+ err := convertAssign(&n.Byte, value)
+ n.Valid = err == nil
+ return err
+}
+
+// Value implements the driver Valuer interface.
+func (n NullByte) Value() (driver.Value, error) {
+ if !n.Valid {
+ return nil, nil
+ }
+ return int64(n.Byte), nil
+}
+
+// NullFloat64 represents a float64 that may be null.
+// NullFloat64 implements the Scanner interface so
+// it can be used as a scan destination, similar to NullString.
+type NullFloat64 struct {
+ Float64 float64
+ Valid bool // Valid is true if Float64 is not NULL
+}
+
+// Scan implements the Scanner interface.
+func (n *NullFloat64) Scan(value any) error {
+ if value == nil {
+ n.Float64, n.Valid = 0, false
+ return nil
+ }
+ n.Valid = true
+ return convertAssign(&n.Float64, value)
+}
+
+// Value implements the driver Valuer interface.
+func (n NullFloat64) Value() (driver.Value, error) {
+ if !n.Valid {
+ return nil, nil
+ }
+ return n.Float64, nil
+}
+
+// NullBool represents a bool that may be null.
+// NullBool implements the Scanner interface so
+// it can be used as a scan destination, similar to NullString.
+type NullBool struct {
+ Bool bool
+ Valid bool // Valid is true if Bool is not NULL
+}
+
+// Scan implements the Scanner interface.
+func (n *NullBool) Scan(value any) error {
+ if value == nil {
+ n.Bool, n.Valid = false, false
+ return nil
+ }
+ n.Valid = true
+ return convertAssign(&n.Bool, value)
+}
+
+// Value implements the driver Valuer interface.
+func (n NullBool) Value() (driver.Value, error) {
+ if !n.Valid {
+ return nil, nil
+ }
+ return n.Bool, nil
+}
+
+// NullTime represents a time.Time that may be null.
+// NullTime implements the Scanner interface so
+// it can be used as a scan destination, similar to NullString.
+type NullTime struct {
+ Time time.Time
+ Valid bool // Valid is true if Time is not NULL
+}
+
+// Scan implements the Scanner interface.
+func (n *NullTime) Scan(value any) error {
+ if value == nil {
+ n.Time, n.Valid = time.Time{}, false
+ return nil
+ }
+ n.Valid = true
+ return convertAssign(&n.Time, value)
+}
+
+// Value implements the driver Valuer interface.
+func (n NullTime) Value() (driver.Value, error) {
+ if !n.Valid {
+ return nil, nil
+ }
+ return n.Time, nil
+}
+
+// Scanner is an interface used by Scan.
+type Scanner interface {
+ // Scan assigns a value from a database driver.
+ //
+ // The src value will be of one of the following types:
+ //
+ // int64
+ // float64
+ // bool
+ // []byte
+ // string
+ // time.Time
+ // nil - for NULL values
+ //
+ // An error should be returned if the value cannot be stored
+ // without loss of information.
+ //
+ // Reference types such as []byte are only valid until the next call to Scan
+ // and should not be retained. Their underlying memory is owned by the driver.
+ // If retention is necessary, copy their values before the next call to Scan.
+ Scan(src any) error
+}
+
+// Out may be used to retrieve OUTPUT value parameters from stored procedures.
+//
+// Not all drivers and databases support OUTPUT value parameters.
+//
+// Example usage:
+//
+// var outArg string
+// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", sql.Out{Dest: &outArg}))
+type Out struct {
+ _NamedFieldsRequired struct{}
+
+ // Dest is a pointer to the value that will be set to the result of the
+ // stored procedure's OUTPUT parameter.
+ Dest any
+
+ // In is whether the parameter is an INOUT parameter. If so, the input value to the stored
+ // procedure is the dereferenced value of Dest's pointer, which is then replaced with
+ // the output value.
+ In bool
+}
+
+// ErrNoRows is returned by Scan when QueryRow doesn't return a
+// row. In such a case, QueryRow returns a placeholder *Row value that
+// defers this error until a Scan.
+var ErrNoRows = errors.New("sql: no rows in result set")
+
+// DB is a database handle representing a pool of zero or more
+// underlying connections. It's safe for concurrent use by multiple
+// goroutines.
+//
+// The sql package creates and frees connections automatically; it
+// also maintains a free pool of idle connections. If the database has
+// a concept of per-connection state, such state can be reliably observed
+// within a transaction (Tx) or connection (Conn). Once DB.Begin is called, the
+// returned Tx is bound to a single connection. Once Commit or
+// Rollback is called on the transaction, that transaction's
+// connection is returned to DB's idle connection pool. The pool size
+// can be controlled with SetMaxIdleConns.
+type DB struct {
+ // Total time waited for new connections.
+ waitDuration atomic.Int64
+
+ connector driver.Connector
+ // numClosed is an atomic counter which represents a total number of
+ // closed connections. Stmt.openStmt checks it before cleaning closed
+ // connections in Stmt.css.
+ numClosed atomic.Uint64
+
+ mu sync.Mutex // protects following fields
+ freeConn []*driverConn // free connections ordered by returnedAt oldest to newest
+ connRequests map[uint64]chan connRequest
+ nextRequest uint64 // Next key to use in connRequests.
+ numOpen int // number of opened and pending open connections
+ // Used to signal the need for new connections
+ // a goroutine running connectionOpener() reads on this chan and
+ // maybeOpenNewConnections sends on the chan (one send per needed connection)
+ // It is closed during db.Close(). The close tells the connectionOpener
+ // goroutine to exit.
+ openerCh chan struct{}
+ closed bool
+ dep map[finalCloser]depSet
+ lastPut map[*driverConn]string // stacktrace of last conn's put; debug only
+ maxIdleCount int // zero means defaultMaxIdleConns; negative means 0
+ maxOpen int // <= 0 means unlimited
+ maxLifetime time.Duration // maximum amount of time a connection may be reused
+ maxIdleTime time.Duration // maximum amount of time a connection may be idle before being closed
+ cleanerCh chan struct{}
+ waitCount int64 // Total number of connections waited for.
+ maxIdleClosed int64 // Total number of connections closed due to idle count.
+ maxIdleTimeClosed int64 // Total number of connections closed due to idle time.
+ maxLifetimeClosed int64 // Total number of connections closed due to max connection lifetime limit.
+
+ stop func() // stop cancels the connection opener.
+}
+
+// connReuseStrategy determines how (*DB).conn returns database connections.
+type connReuseStrategy uint8
+
+const (
+ // alwaysNewConn forces a new connection to the database.
+ alwaysNewConn connReuseStrategy = iota
+ // cachedOrNewConn returns a cached connection, if available, else waits
+ // for one to become available (if MaxOpenConns has been reached) or
+ // creates a new database connection.
+ cachedOrNewConn
+)
+
+// driverConn wraps a driver.Conn with a mutex, to
+// be held during all calls into the Conn. (including any calls onto
+// interfaces returned via that Conn, such as calls on Tx, Stmt,
+// Result, Rows)
+type driverConn struct {
+ db *DB
+ createdAt time.Time
+
+ sync.Mutex // guards following
+ ci driver.Conn
+ needReset bool // The connection session should be reset before use if true.
+ closed bool
+ finalClosed bool // ci.Close has been called
+ openStmt map[*driverStmt]bool
+
+ // guarded by db.mu
+ inUse bool
+ returnedAt time.Time // Time the connection was created or returned.
+ onPut []func() // code (with db.mu held) run when conn is next returned
+ dbmuClosed bool // same as closed, but guarded by db.mu, for removeClosedStmtLocked
+}
+
+func (dc *driverConn) releaseConn(err error) {
+ dc.db.putConn(dc, err, true)
+}
+
+func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
+ dc.Lock()
+ defer dc.Unlock()
+ delete(dc.openStmt, ds)
+}
+
+func (dc *driverConn) expired(timeout time.Duration) bool {
+ if timeout <= 0 {
+ return false
+ }
+ return dc.createdAt.Add(timeout).Before(nowFunc())
+}
+
+// resetSession checks if the driver connection needs the
+// session to be reset and if required, resets it.
+func (dc *driverConn) resetSession(ctx context.Context) error {
+ dc.Lock()
+ defer dc.Unlock()
+
+ if !dc.needReset {
+ return nil
+ }
+ if cr, ok := dc.ci.(driver.SessionResetter); ok {
+ return cr.ResetSession(ctx)
+ }
+ return nil
+}
+
+// validateConnection checks if the connection is valid and can
+// still be used. It also marks the session for reset if required.
+func (dc *driverConn) validateConnection(needsReset bool) bool {
+ dc.Lock()
+ defer dc.Unlock()
+
+ if needsReset {
+ dc.needReset = true
+ }
+ if cv, ok := dc.ci.(driver.Validator); ok {
+ return cv.IsValid()
+ }
+ return true
+}
+
+// prepareLocked prepares the query on dc. When cg == nil the dc must keep track of
+// the prepared statements in a pool.
+func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) {
+ si, err := ctxDriverPrepare(ctx, dc.ci, query)
+ if err != nil {
+ return nil, err
+ }
+ ds := &driverStmt{Locker: dc, si: si}
+
+ // No need to manage open statements if there is a single connection grabber.
+ if cg != nil {
+ return ds, nil
+ }
+
+ // Track each driverConn's open statements, so we can close them
+ // before closing the conn.
+ //
+ // Wrap all driver.Stmt is *driverStmt to ensure they are only closed once.
+ if dc.openStmt == nil {
+ dc.openStmt = make(map[*driverStmt]bool)
+ }
+ dc.openStmt[ds] = true
+ return ds, nil
+}
+
+// the dc.db's Mutex is held.
+func (dc *driverConn) closeDBLocked() func() error {
+ dc.Lock()
+ defer dc.Unlock()
+ if dc.closed {
+ return func() error { return errors.New("sql: duplicate driverConn close") }
+ }
+ dc.closed = true
+ return dc.db.removeDepLocked(dc, dc)
+}
+
+func (dc *driverConn) Close() error {
+ dc.Lock()
+ if dc.closed {
+ dc.Unlock()
+ return errors.New("sql: duplicate driverConn close")
+ }
+ dc.closed = true
+ dc.Unlock() // not defer; removeDep finalClose calls may need to lock
+
+ // And now updates that require holding dc.mu.Lock.
+ dc.db.mu.Lock()
+ dc.dbmuClosed = true
+ fn := dc.db.removeDepLocked(dc, dc)
+ dc.db.mu.Unlock()
+ return fn()
+}
+
+func (dc *driverConn) finalClose() error {
+ var err error
+
+ // Each *driverStmt has a lock to the dc. Copy the list out of the dc
+ // before calling close on each stmt.
+ var openStmt []*driverStmt
+ withLock(dc, func() {
+ openStmt = make([]*driverStmt, 0, len(dc.openStmt))
+ for ds := range dc.openStmt {
+ openStmt = append(openStmt, ds)
+ }
+ dc.openStmt = nil
+ })
+ for _, ds := range openStmt {
+ ds.Close()
+ }
+ withLock(dc, func() {
+ dc.finalClosed = true
+ err = dc.ci.Close()
+ dc.ci = nil
+ })
+
+ dc.db.mu.Lock()
+ dc.db.numOpen--
+ dc.db.maybeOpenNewConnections()
+ dc.db.mu.Unlock()
+
+ dc.db.numClosed.Add(1)
+ return err
+}
+
+// driverStmt associates a driver.Stmt with the
+// *driverConn from which it came, so the driverConn's lock can be
+// held during calls.
+type driverStmt struct {
+ sync.Locker // the *driverConn
+ si driver.Stmt
+ closed bool
+ closeErr error // return value of previous Close call
+}
+
+// Close ensures driver.Stmt is only closed once and always returns the same
+// result.
+func (ds *driverStmt) Close() error {
+ ds.Lock()
+ defer ds.Unlock()
+ if ds.closed {
+ return ds.closeErr
+ }
+ ds.closed = true
+ ds.closeErr = ds.si.Close()
+ return ds.closeErr
+}
+
+// depSet is a finalCloser's outstanding dependencies
+type depSet map[any]bool // set of true bools
+
+// The finalCloser interface is used by (*DB).addDep and related
+// dependency reference counting.
+type finalCloser interface {
+ // finalClose is called when the reference count of an object
+ // goes to zero. (*DB).mu is not held while calling it.
+ finalClose() error
+}
+
+// addDep notes that x now depends on dep, and x's finalClose won't be
+// called until all of x's dependencies are removed with removeDep.
+func (db *DB) addDep(x finalCloser, dep any) {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ db.addDepLocked(x, dep)
+}
+
+func (db *DB) addDepLocked(x finalCloser, dep any) {
+ if db.dep == nil {
+ db.dep = make(map[finalCloser]depSet)
+ }
+ xdep := db.dep[x]
+ if xdep == nil {
+ xdep = make(depSet)
+ db.dep[x] = xdep
+ }
+ xdep[dep] = true
+}
+
+// removeDep notes that x no longer depends on dep.
+// If x still has dependencies, nil is returned.
+// If x no longer has any dependencies, its finalClose method will be
+// called and its error value will be returned.
+func (db *DB) removeDep(x finalCloser, dep any) error {
+ db.mu.Lock()
+ fn := db.removeDepLocked(x, dep)
+ db.mu.Unlock()
+ return fn()
+}
+
+func (db *DB) removeDepLocked(x finalCloser, dep any) func() error {
+ xdep, ok := db.dep[x]
+ if !ok {
+ panic(fmt.Sprintf("unpaired removeDep: no deps for %T", x))
+ }
+
+ l0 := len(xdep)
+ delete(xdep, dep)
+
+ switch len(xdep) {
+ case l0:
+ // Nothing removed. Shouldn't happen.
+ panic(fmt.Sprintf("unpaired removeDep: no %T dep on %T", dep, x))
+ case 0:
+ // No more dependencies.
+ delete(db.dep, x)
+ return x.finalClose
+ default:
+ // Dependencies remain.
+ return func() error { return nil }
+ }
+}
+
+// This is the size of the connectionOpener request chan (DB.openerCh).
+// This value should be larger than the maximum typical value
+// used for db.maxOpen. If maxOpen is significantly larger than
+// connectionRequestQueueSize then it is possible for ALL calls into the *DB
+// to block until the connectionOpener can satisfy the backlog of requests.
+var connectionRequestQueueSize = 1000000
+
+type dsnConnector struct {
+ dsn string
+ driver driver.Driver
+}
+
+func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
+ return t.driver.Open(t.dsn)
+}
+
+func (t dsnConnector) Driver() driver.Driver {
+ return t.driver
+}
+
+// OpenDB opens a database using a Connector, allowing drivers to
+// bypass a string based data source name.
+//
+// Most users will open a database via a driver-specific connection
+// helper function that returns a *DB. No database drivers are included
+// in the Go standard library. See https://golang.org/s/sqldrivers for
+// a list of third-party drivers.
+//
+// OpenDB may just validate its arguments without creating a connection
+// to the database. To verify that the data source name is valid, call
+// Ping.
+//
+// The returned DB is safe for concurrent use by multiple goroutines
+// and maintains its own pool of idle connections. Thus, the OpenDB
+// function should be called just once. It is rarely necessary to
+// close a DB.
+func OpenDB(c driver.Connector) *DB {
+ ctx, cancel := context.WithCancel(context.Background())
+ db := &DB{
+ connector: c,
+ openerCh: make(chan struct{}, connectionRequestQueueSize),
+ lastPut: make(map[*driverConn]string),
+ connRequests: make(map[uint64]chan connRequest),
+ stop: cancel,
+ }
+
+ go db.connectionOpener(ctx)
+
+ return db
+}
+
+// Open opens a database specified by its database driver name and a
+// driver-specific data source name, usually consisting of at least a
+// database name and connection information.
+//
+// Most users will open a database via a driver-specific connection
+// helper function that returns a *DB. No database drivers are included
+// in the Go standard library. See https://golang.org/s/sqldrivers for
+// a list of third-party drivers.
+//
+// Open may just validate its arguments without creating a connection
+// to the database. To verify that the data source name is valid, call
+// Ping.
+//
+// The returned DB is safe for concurrent use by multiple goroutines
+// and maintains its own pool of idle connections. Thus, the Open
+// function should be called just once. It is rarely necessary to
+// close a DB.
+func Open(driverName, dataSourceName string) (*DB, error) {
+ driversMu.RLock()
+ driveri, ok := drivers[driverName]
+ driversMu.RUnlock()
+ if !ok {
+ return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
+ }
+
+ if driverCtx, ok := driveri.(driver.DriverContext); ok {
+ connector, err := driverCtx.OpenConnector(dataSourceName)
+ if err != nil {
+ return nil, err
+ }
+ return OpenDB(connector), nil
+ }
+
+ return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
+}
+
+func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error {
+ var err error
+ if pinger, ok := dc.ci.(driver.Pinger); ok {
+ withLock(dc, func() {
+ err = pinger.Ping(ctx)
+ })
+ }
+ release(err)
+ return err
+}
+
+// PingContext verifies a connection to the database is still alive,
+// establishing a connection if necessary.
+func (db *DB) PingContext(ctx context.Context) error {
+ var dc *driverConn
+ var err error
+
+ err = db.retry(func(strategy connReuseStrategy) error {
+ dc, err = db.conn(ctx, strategy)
+ return err
+ })
+
+ if err != nil {
+ return err
+ }
+
+ return db.pingDC(ctx, dc, dc.releaseConn)
+}
+
+// Ping verifies a connection to the database is still alive,
+// establishing a connection if necessary.
+//
+// Ping uses context.Background internally; to specify the context, use
+// PingContext.
+func (db *DB) Ping() error {
+ return db.PingContext(context.Background())
+}
+
+// Close closes the database and prevents new queries from starting.
+// Close then waits for all queries that have started processing on the server
+// to finish.
+//
+// It is rare to Close a DB, as the DB handle is meant to be
+// long-lived and shared between many goroutines.
+func (db *DB) Close() error {
+ db.mu.Lock()
+ if db.closed { // Make DB.Close idempotent
+ db.mu.Unlock()
+ return nil
+ }
+ if db.cleanerCh != nil {
+ close(db.cleanerCh)
+ }
+ var err error
+ fns := make([]func() error, 0, len(db.freeConn))
+ for _, dc := range db.freeConn {
+ fns = append(fns, dc.closeDBLocked())
+ }
+ db.freeConn = nil
+ db.closed = true
+ for _, req := range db.connRequests {
+ close(req)
+ }
+ db.mu.Unlock()
+ for _, fn := range fns {
+ err1 := fn()
+ if err1 != nil {
+ err = err1
+ }
+ }
+ db.stop()
+ if c, ok := db.connector.(io.Closer); ok {
+ err1 := c.Close()
+ if err1 != nil {
+ err = err1
+ }
+ }
+ return err
+}
+
+const defaultMaxIdleConns = 2
+
+func (db *DB) maxIdleConnsLocked() int {
+ n := db.maxIdleCount
+ switch {
+ case n == 0:
+ // TODO(bradfitz): ask driver, if supported, for its default preference
+ return defaultMaxIdleConns
+ case n < 0:
+ return 0
+ default:
+ return n
+ }
+}
+
+func (db *DB) shortestIdleTimeLocked() time.Duration {
+ if db.maxIdleTime <= 0 {
+ return db.maxLifetime
+ }
+ if db.maxLifetime <= 0 {
+ return db.maxIdleTime
+ }
+
+ min := db.maxIdleTime
+ if min > db.maxLifetime {
+ min = db.maxLifetime
+ }
+ return min
+}
+
+// SetMaxIdleConns sets the maximum number of connections in the idle
+// connection pool.
+//
+// If MaxOpenConns is greater than 0 but less than the new MaxIdleConns,
+// then the new MaxIdleConns will be reduced to match the MaxOpenConns limit.
+//
+// If n <= 0, no idle connections are retained.
+//
+// The default max idle connections is currently 2. This may change in
+// a future release.
+func (db *DB) SetMaxIdleConns(n int) {
+ db.mu.Lock()
+ if n > 0 {
+ db.maxIdleCount = n
+ } else {
+ // No idle connections.
+ db.maxIdleCount = -1
+ }
+ // Make sure maxIdle doesn't exceed maxOpen
+ if db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen {
+ db.maxIdleCount = db.maxOpen
+ }
+ var closing []*driverConn
+ idleCount := len(db.freeConn)
+ maxIdle := db.maxIdleConnsLocked()
+ if idleCount > maxIdle {
+ closing = db.freeConn[maxIdle:]
+ db.freeConn = db.freeConn[:maxIdle]
+ }
+ db.maxIdleClosed += int64(len(closing))
+ db.mu.Unlock()
+ for _, c := range closing {
+ c.Close()
+ }
+}
+
+// SetMaxOpenConns sets the maximum number of open connections to the database.
+//
+// If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than
+// MaxIdleConns, then MaxIdleConns will be reduced to match the new
+// MaxOpenConns limit.
+//
+// If n <= 0, then there is no limit on the number of open connections.
+// The default is 0 (unlimited).
+func (db *DB) SetMaxOpenConns(n int) {
+ db.mu.Lock()
+ db.maxOpen = n
+ if n < 0 {
+ db.maxOpen = 0
+ }
+ syncMaxIdle := db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen
+ db.mu.Unlock()
+ if syncMaxIdle {
+ db.SetMaxIdleConns(n)
+ }
+}
+
+// SetConnMaxLifetime sets the maximum amount of time a connection may be reused.
+//
+// Expired connections may be closed lazily before reuse.
+//
+// If d <= 0, connections are not closed due to a connection's age.
+func (db *DB) SetConnMaxLifetime(d time.Duration) {
+ if d < 0 {
+ d = 0
+ }
+ db.mu.Lock()
+ // Wake cleaner up when lifetime is shortened.
+ if d > 0 && d < db.maxLifetime && db.cleanerCh != nil {
+ select {
+ case db.cleanerCh <- struct{}{}:
+ default:
+ }
+ }
+ db.maxLifetime = d
+ db.startCleanerLocked()
+ db.mu.Unlock()
+}
+
+// SetConnMaxIdleTime sets the maximum amount of time a connection may be idle.
+//
+// Expired connections may be closed lazily before reuse.
+//
+// If d <= 0, connections are not closed due to a connection's idle time.
+func (db *DB) SetConnMaxIdleTime(d time.Duration) {
+ if d < 0 {
+ d = 0
+ }
+ db.mu.Lock()
+ defer db.mu.Unlock()
+
+ // Wake cleaner up when idle time is shortened.
+ if d > 0 && d < db.maxIdleTime && db.cleanerCh != nil {
+ select {
+ case db.cleanerCh <- struct{}{}:
+ default:
+ }
+ }
+ db.maxIdleTime = d
+ db.startCleanerLocked()
+}
+
+// startCleanerLocked starts connectionCleaner if needed.
+func (db *DB) startCleanerLocked() {
+ if (db.maxLifetime > 0 || db.maxIdleTime > 0) && db.numOpen > 0 && db.cleanerCh == nil {
+ db.cleanerCh = make(chan struct{}, 1)
+ go db.connectionCleaner(db.shortestIdleTimeLocked())
+ }
+}
+
+func (db *DB) connectionCleaner(d time.Duration) {
+ const minInterval = time.Second
+
+ if d < minInterval {
+ d = minInterval
+ }
+ t := time.NewTimer(d)
+
+ for {
+ select {
+ case <-t.C:
+ case <-db.cleanerCh: // maxLifetime was changed or db was closed.
+ }
+
+ db.mu.Lock()
+
+ d = db.shortestIdleTimeLocked()
+ if db.closed || db.numOpen == 0 || d <= 0 {
+ db.cleanerCh = nil
+ db.mu.Unlock()
+ return
+ }
+
+ d, closing := db.connectionCleanerRunLocked(d)
+ db.mu.Unlock()
+ for _, c := range closing {
+ c.Close()
+ }
+
+ if d < minInterval {
+ d = minInterval
+ }
+
+ if !t.Stop() {
+ select {
+ case <-t.C:
+ default:
+ }
+ }
+ t.Reset(d)
+ }
+}
+
+// connectionCleanerRunLocked removes connections that should be closed from
+// freeConn and returns them along side an updated duration to the next check
+// if a quicker check is required to ensure connections are checked appropriately.
+func (db *DB) connectionCleanerRunLocked(d time.Duration) (time.Duration, []*driverConn) {
+ var idleClosing int64
+ var closing []*driverConn
+ if db.maxIdleTime > 0 {
+ // As freeConn is ordered by returnedAt process
+ // in reverse order to minimise the work needed.
+ idleSince := nowFunc().Add(-db.maxIdleTime)
+ last := len(db.freeConn) - 1
+ for i := last; i >= 0; i-- {
+ c := db.freeConn[i]
+ if c.returnedAt.Before(idleSince) {
+ i++
+ closing = db.freeConn[:i:i]
+ db.freeConn = db.freeConn[i:]
+ idleClosing = int64(len(closing))
+ db.maxIdleTimeClosed += idleClosing
+ break
+ }
+ }
+
+ if len(db.freeConn) > 0 {
+ c := db.freeConn[0]
+ if d2 := c.returnedAt.Sub(idleSince); d2 < d {
+ // Ensure idle connections are cleaned up as soon as
+ // possible.
+ d = d2
+ }
+ }
+ }
+
+ if db.maxLifetime > 0 {
+ expiredSince := nowFunc().Add(-db.maxLifetime)
+ for i := 0; i < len(db.freeConn); i++ {
+ c := db.freeConn[i]
+ if c.createdAt.Before(expiredSince) {
+ closing = append(closing, c)
+
+ last := len(db.freeConn) - 1
+ // Use slow delete as order is required to ensure
+ // connections are reused least idle time first.
+ copy(db.freeConn[i:], db.freeConn[i+1:])
+ db.freeConn[last] = nil
+ db.freeConn = db.freeConn[:last]
+ i--
+ } else if d2 := c.createdAt.Sub(expiredSince); d2 < d {
+ // Prevent connections sitting the freeConn when they
+ // have expired by updating our next deadline d.
+ d = d2
+ }
+ }
+ db.maxLifetimeClosed += int64(len(closing)) - idleClosing
+ }
+
+ return d, closing
+}
+
+// DBStats contains database statistics.
+type DBStats struct {
+ MaxOpenConnections int // Maximum number of open connections to the database.
+
+ // Pool Status
+ OpenConnections int // The number of established connections both in use and idle.
+ InUse int // The number of connections currently in use.
+ Idle int // The number of idle connections.
+
+ // Counters
+ WaitCount int64 // The total number of connections waited for.
+ WaitDuration time.Duration // The total time blocked waiting for a new connection.
+ MaxIdleClosed int64 // The total number of connections closed due to SetMaxIdleConns.
+ MaxIdleTimeClosed int64 // The total number of connections closed due to SetConnMaxIdleTime.
+ MaxLifetimeClosed int64 // The total number of connections closed due to SetConnMaxLifetime.
+}
+
+// Stats returns database statistics.
+func (db *DB) Stats() DBStats {
+ wait := db.waitDuration.Load()
+
+ db.mu.Lock()
+ defer db.mu.Unlock()
+
+ stats := DBStats{
+ MaxOpenConnections: db.maxOpen,
+
+ Idle: len(db.freeConn),
+ OpenConnections: db.numOpen,
+ InUse: db.numOpen - len(db.freeConn),
+
+ WaitCount: db.waitCount,
+ WaitDuration: time.Duration(wait),
+ MaxIdleClosed: db.maxIdleClosed,
+ MaxIdleTimeClosed: db.maxIdleTimeClosed,
+ MaxLifetimeClosed: db.maxLifetimeClosed,
+ }
+ return stats
+}
+
+// Assumes db.mu is locked.
+// If there are connRequests and the connection limit hasn't been reached,
+// then tell the connectionOpener to open new connections.
+func (db *DB) maybeOpenNewConnections() {
+ numRequests := len(db.connRequests)
+ if db.maxOpen > 0 {
+ numCanOpen := db.maxOpen - db.numOpen
+ if numRequests > numCanOpen {
+ numRequests = numCanOpen
+ }
+ }
+ for numRequests > 0 {
+ db.numOpen++ // optimistically
+ numRequests--
+ if db.closed {
+ return
+ }
+ db.openerCh <- struct{}{}
+ }
+}
+
+// Runs in a separate goroutine, opens new connections when requested.
+func (db *DB) connectionOpener(ctx context.Context) {
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-db.openerCh:
+ db.openNewConnection(ctx)
+ }
+ }
+}
+
+// Open one new connection
+func (db *DB) openNewConnection(ctx context.Context) {
+ // maybeOpenNewConnections has already executed db.numOpen++ before it sent
+ // on db.openerCh. This function must execute db.numOpen-- if the
+ // connection fails or is closed before returning.
+ ci, err := db.connector.Connect(ctx)
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ if db.closed {
+ if err == nil {
+ ci.Close()
+ }
+ db.numOpen--
+ return
+ }
+ if err != nil {
+ db.numOpen--
+ db.putConnDBLocked(nil, err)
+ db.maybeOpenNewConnections()
+ return
+ }
+ dc := &driverConn{
+ db: db,
+ createdAt: nowFunc(),
+ returnedAt: nowFunc(),
+ ci: ci,
+ }
+ if db.putConnDBLocked(dc, err) {
+ db.addDepLocked(dc, dc)
+ } else {
+ db.numOpen--
+ ci.Close()
+ }
+}
+
+// connRequest represents one request for a new connection
+// When there are no idle connections available, DB.conn will create
+// a new connRequest and put it on the db.connRequests list.
+type connRequest struct {
+ conn *driverConn
+ err error
+}
+
+var errDBClosed = errors.New("sql: database is closed")
+
+// nextRequestKeyLocked returns the next connection request key.
+// It is assumed that nextRequest will not overflow.
+func (db *DB) nextRequestKeyLocked() uint64 {
+ next := db.nextRequest
+ db.nextRequest++
+ return next
+}
+
+// conn returns a newly-opened or cached *driverConn.
+func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
+ db.mu.Lock()
+ if db.closed {
+ db.mu.Unlock()
+ return nil, errDBClosed
+ }
+ // Check if the context is expired.
+ select {
+ default:
+ case <-ctx.Done():
+ db.mu.Unlock()
+ return nil, ctx.Err()
+ }
+ lifetime := db.maxLifetime
+
+ // Prefer a free connection, if possible.
+ last := len(db.freeConn) - 1
+ if strategy == cachedOrNewConn && last >= 0 {
+ // Reuse the lowest idle time connection so we can close
+ // connections which remain idle as soon as possible.
+ conn := db.freeConn[last]
+ db.freeConn = db.freeConn[:last]
+ conn.inUse = true
+ if conn.expired(lifetime) {
+ db.maxLifetimeClosed++
+ db.mu.Unlock()
+ conn.Close()
+ return nil, driver.ErrBadConn
+ }
+ db.mu.Unlock()
+
+ // Reset the session if required.
+ if err := conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
+ conn.Close()
+ return nil, err
+ }
+
+ return conn, nil
+ }
+
+ // Out of free connections or we were asked not to use one. If we're not
+ // allowed to open any more connections, make a request and wait.
+ if db.maxOpen > 0 && db.numOpen >= db.maxOpen {
+ // Make the connRequest channel. It's buffered so that the
+ // connectionOpener doesn't block while waiting for the req to be read.
+ req := make(chan connRequest, 1)
+ reqKey := db.nextRequestKeyLocked()
+ db.connRequests[reqKey] = req
+ db.waitCount++
+ db.mu.Unlock()
+
+ waitStart := nowFunc()
+
+ // Timeout the connection request with the context.
+ select {
+ case <-ctx.Done():
+ // Remove the connection request and ensure no value has been sent
+ // on it after removing.
+ db.mu.Lock()
+ delete(db.connRequests, reqKey)
+ db.mu.Unlock()
+
+ db.waitDuration.Add(int64(time.Since(waitStart)))
+
+ select {
+ default:
+ case ret, ok := <-req:
+ if ok && ret.conn != nil {
+ db.putConn(ret.conn, ret.err, false)
+ }
+ }
+ return nil, ctx.Err()
+ case ret, ok := <-req:
+ db.waitDuration.Add(int64(time.Since(waitStart)))
+
+ if !ok {
+ return nil, errDBClosed
+ }
+ // Only check if the connection is expired if the strategy is cachedOrNewConns.
+ // If we require a new connection, just re-use the connection without looking
+ // at the expiry time. If it is expired, it will be checked when it is placed
+ // back into the connection pool.
+ // This prioritizes giving a valid connection to a client over the exact connection
+ // lifetime, which could expire exactly after this point anyway.
+ if strategy == cachedOrNewConn && ret.err == nil && ret.conn.expired(lifetime) {
+ db.mu.Lock()
+ db.maxLifetimeClosed++
+ db.mu.Unlock()
+ ret.conn.Close()
+ return nil, driver.ErrBadConn
+ }
+ if ret.conn == nil {
+ return nil, ret.err
+ }
+
+ // Reset the session if required.
+ if err := ret.conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
+ ret.conn.Close()
+ return nil, err
+ }
+ return ret.conn, ret.err
+ }
+ }
+
+ db.numOpen++ // optimistically
+ db.mu.Unlock()
+ ci, err := db.connector.Connect(ctx)
+ if err != nil {
+ db.mu.Lock()
+ db.numOpen-- // correct for earlier optimism
+ db.maybeOpenNewConnections()
+ db.mu.Unlock()
+ return nil, err
+ }
+ db.mu.Lock()
+ dc := &driverConn{
+ db: db,
+ createdAt: nowFunc(),
+ returnedAt: nowFunc(),
+ ci: ci,
+ inUse: true,
+ }
+ db.addDepLocked(dc, dc)
+ db.mu.Unlock()
+ return dc, nil
+}
+
+// putConnHook is a hook for testing.
+var putConnHook func(*DB, *driverConn)
+
+// noteUnusedDriverStatement notes that ds is no longer used and should
+// be closed whenever possible (when c is next not in use), unless c is
+// already closed.
+func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ if c.inUse {
+ c.onPut = append(c.onPut, func() {
+ ds.Close()
+ })
+ } else {
+ c.Lock()
+ fc := c.finalClosed
+ c.Unlock()
+ if !fc {
+ ds.Close()
+ }
+ }
+}
+
+// debugGetPut determines whether getConn & putConn calls' stack traces
+// are returned for more verbose crashes.
+const debugGetPut = false
+
+// putConn adds a connection to the db's free pool.
+// err is optionally the last error that occurred on this connection.
+func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
+ if !errors.Is(err, driver.ErrBadConn) {
+ if !dc.validateConnection(resetSession) {
+ err = driver.ErrBadConn
+ }
+ }
+ db.mu.Lock()
+ if !dc.inUse {
+ db.mu.Unlock()
+ if debugGetPut {
+ fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc])
+ }
+ panic("sql: connection returned that was never out")
+ }
+
+ if !errors.Is(err, driver.ErrBadConn) && dc.expired(db.maxLifetime) {
+ db.maxLifetimeClosed++
+ err = driver.ErrBadConn
+ }
+ if debugGetPut {
+ db.lastPut[dc] = stack()
+ }
+ dc.inUse = false
+ dc.returnedAt = nowFunc()
+
+ for _, fn := range dc.onPut {
+ fn()
+ }
+ dc.onPut = nil
+
+ if errors.Is(err, driver.ErrBadConn) {
+ // Don't reuse bad connections.
+ // Since the conn is considered bad and is being discarded, treat it
+ // as closed. Don't decrement the open count here, finalClose will
+ // take care of that.
+ db.maybeOpenNewConnections()
+ db.mu.Unlock()
+ dc.Close()
+ return
+ }
+ if putConnHook != nil {
+ putConnHook(db, dc)
+ }
+ added := db.putConnDBLocked(dc, nil)
+ db.mu.Unlock()
+
+ if !added {
+ dc.Close()
+ return
+ }
+}
+
+// Satisfy a connRequest or put the driverConn in the idle pool and return true
+// or return false.
+// putConnDBLocked will satisfy a connRequest if there is one, or it will
+// return the *driverConn to the freeConn list if err == nil and the idle
+// connection limit will not be exceeded.
+// If err != nil, the value of dc is ignored.
+// If err == nil, then dc must not equal nil.
+// If a connRequest was fulfilled or the *driverConn was placed in the
+// freeConn list, then true is returned, otherwise false is returned.
+func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
+ if db.closed {
+ return false
+ }
+ if db.maxOpen > 0 && db.numOpen > db.maxOpen {
+ return false
+ }
+ if c := len(db.connRequests); c > 0 {
+ var req chan connRequest
+ var reqKey uint64
+ for reqKey, req = range db.connRequests {
+ break
+ }
+ delete(db.connRequests, reqKey) // Remove from pending requests.
+ if err == nil {
+ dc.inUse = true
+ }
+ req <- connRequest{
+ conn: dc,
+ err: err,
+ }
+ return true
+ } else if err == nil && !db.closed {
+ if db.maxIdleConnsLocked() > len(db.freeConn) {
+ db.freeConn = append(db.freeConn, dc)
+ db.startCleanerLocked()
+ return true
+ }
+ db.maxIdleClosed++
+ }
+ return false
+}
+
+// maxBadConnRetries is the number of maximum retries if the driver returns
+// driver.ErrBadConn to signal a broken connection before forcing a new
+// connection to be opened.
+const maxBadConnRetries = 2
+
+func (db *DB) retry(fn func(strategy connReuseStrategy) error) error {
+ for i := int64(0); i < maxBadConnRetries; i++ {
+ err := fn(cachedOrNewConn)
+ // retry if err is driver.ErrBadConn
+ if err == nil || !errors.Is(err, driver.ErrBadConn) {
+ return err
+ }
+ }
+
+ return fn(alwaysNewConn)
+}
+
+// PrepareContext creates a prepared statement for later queries or executions.
+// Multiple queries or executions may be run concurrently from the
+// returned statement.
+// The caller must call the statement's Close method
+// when the statement is no longer needed.
+//
+// The provided context is used for the preparation of the statement, not for the
+// execution of the statement.
+func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
+ var stmt *Stmt
+ var err error
+
+ err = db.retry(func(strategy connReuseStrategy) error {
+ stmt, err = db.prepare(ctx, query, strategy)
+ return err
+ })
+
+ return stmt, err
+}
+
+// Prepare creates a prepared statement for later queries or executions.
+// Multiple queries or executions may be run concurrently from the
+// returned statement.
+// The caller must call the statement's Close method
+// when the statement is no longer needed.
+//
+// Prepare uses context.Background internally; to specify the context, use
+// PrepareContext.
+func (db *DB) Prepare(query string) (*Stmt, error) {
+ return db.PrepareContext(context.Background(), query)
+}
+
+func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
+ // TODO: check if db.driver supports an optional
+ // driver.Preparer interface and call that instead, if so,
+ // otherwise we make a prepared statement that's bound
+ // to a connection, and to execute this prepared statement
+ // we either need to use this connection (if it's free), else
+ // get a new connection + re-prepare + execute on that one.
+ dc, err := db.conn(ctx, strategy)
+ if err != nil {
+ return nil, err
+ }
+ return db.prepareDC(ctx, dc, dc.releaseConn, nil, query)
+}
+
+// prepareDC prepares a query on the driverConn and calls release before
+// returning. When cg == nil it implies that a connection pool is used, and
+// when cg != nil only a single driver connection is used.
+func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) {
+ var ds *driverStmt
+ var err error
+ defer func() {
+ release(err)
+ }()
+ withLock(dc, func() {
+ ds, err = dc.prepareLocked(ctx, cg, query)
+ })
+ if err != nil {
+ return nil, err
+ }
+ stmt := &Stmt{
+ db: db,
+ query: query,
+ cg: cg,
+ cgds: ds,
+ }
+
+ // When cg == nil this statement will need to keep track of various
+ // connections they are prepared on and record the stmt dependency on
+ // the DB.
+ if cg == nil {
+ stmt.css = []connStmt{{dc, ds}}
+ stmt.lastNumClosed = db.numClosed.Load()
+ db.addDep(stmt, stmt)
+ }
+ return stmt, nil
+}
+
+// ExecContext executes a query without returning any rows.
+// The args are for any placeholder parameters in the query.
+func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
+ var res Result
+ var err error
+
+ err = db.retry(func(strategy connReuseStrategy) error {
+ res, err = db.exec(ctx, query, args, strategy)
+ return err
+ })
+
+ return res, err
+}
+
+// Exec executes a query without returning any rows.
+// The args are for any placeholder parameters in the query.
+//
+// Exec uses context.Background internally; to specify the context, use
+// ExecContext.
+func (db *DB) Exec(query string, args ...any) (Result, error) {
+ return db.ExecContext(context.Background(), query, args...)
+}
+
+func (db *DB) exec(ctx context.Context, query string, args []any, strategy connReuseStrategy) (Result, error) {
+ dc, err := db.conn(ctx, strategy)
+ if err != nil {
+ return nil, err
+ }
+ return db.execDC(ctx, dc, dc.releaseConn, query, args)
+}
+
+func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), query string, args []any) (res Result, err error) {
+ defer func() {
+ release(err)
+ }()
+ execerCtx, ok := dc.ci.(driver.ExecerContext)
+ var execer driver.Execer
+ if !ok {
+ execer, ok = dc.ci.(driver.Execer)
+ }
+ if ok {
+ var nvdargs []driver.NamedValue
+ var resi driver.Result
+ withLock(dc, func() {
+ nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
+ if err != nil {
+ return
+ }
+ resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
+ })
+ if err != driver.ErrSkip {
+ if err != nil {
+ return nil, err
+ }
+ return driverResult{dc, resi}, nil
+ }
+ }
+
+ var si driver.Stmt
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
+ })
+ if err != nil {
+ return nil, err
+ }
+ ds := &driverStmt{Locker: dc, si: si}
+ defer ds.Close()
+ return resultFromStatement(ctx, dc.ci, ds, args...)
+}
+
+// QueryContext executes a query that returns rows, typically a SELECT.
+// The args are for any placeholder parameters in the query.
+func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
+ var rows *Rows
+ var err error
+
+ err = db.retry(func(strategy connReuseStrategy) error {
+ rows, err = db.query(ctx, query, args, strategy)
+ return err
+ })
+
+ return rows, err
+}
+
+// Query executes a query that returns rows, typically a SELECT.
+// The args are for any placeholder parameters in the query.
+//
+// Query uses context.Background internally; to specify the context, use
+// QueryContext.
+func (db *DB) Query(query string, args ...any) (*Rows, error) {
+ return db.QueryContext(context.Background(), query, args...)
+}
+
+func (db *DB) query(ctx context.Context, query string, args []any, strategy connReuseStrategy) (*Rows, error) {
+ dc, err := db.conn(ctx, strategy)
+ if err != nil {
+ return nil, err
+ }
+
+ return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args)
+}
+
+// queryDC executes a query on the given connection.
+// The connection gets released by the releaseConn function.
+// The ctx context is from a query method and the txctx context is from an
+// optional transaction context.
+func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []any) (*Rows, error) {
+ queryerCtx, ok := dc.ci.(driver.QueryerContext)
+ var queryer driver.Queryer
+ if !ok {
+ queryer, ok = dc.ci.(driver.Queryer)
+ }
+ if ok {
+ var nvdargs []driver.NamedValue
+ var rowsi driver.Rows
+ var err error
+ withLock(dc, func() {
+ nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
+ if err != nil {
+ return
+ }
+ rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
+ })
+ if err != driver.ErrSkip {
+ if err != nil {
+ releaseConn(err)
+ return nil, err
+ }
+ // Note: ownership of dc passes to the *Rows, to be freed
+ // with releaseConn.
+ rows := &Rows{
+ dc: dc,
+ releaseConn: releaseConn,
+ rowsi: rowsi,
+ }
+ rows.initContextClose(ctx, txctx)
+ return rows, nil
+ }
+ }
+
+ var si driver.Stmt
+ var err error
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
+ })
+ if err != nil {
+ releaseConn(err)
+ return nil, err
+ }
+
+ ds := &driverStmt{Locker: dc, si: si}
+ rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...)
+ if err != nil {
+ ds.Close()
+ releaseConn(err)
+ return nil, err
+ }
+
+ // Note: ownership of ci passes to the *Rows, to be freed
+ // with releaseConn.
+ rows := &Rows{
+ dc: dc,
+ releaseConn: releaseConn,
+ rowsi: rowsi,
+ closeStmt: ds,
+ }
+ rows.initContextClose(ctx, txctx)
+ return rows, nil
+}
+
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
+func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
+ rows, err := db.QueryContext(ctx, query, args...)
+ return &Row{rows: rows, err: err}
+}
+
+// QueryRow executes a query that is expected to return at most one row.
+// QueryRow always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
+//
+// QueryRow uses context.Background internally; to specify the context, use
+// QueryRowContext.
+func (db *DB) QueryRow(query string, args ...any) *Row {
+ return db.QueryRowContext(context.Background(), query, args...)
+}
+
+// BeginTx starts a transaction.
+//
+// The provided context is used until the transaction is committed or rolled back.
+// If the context is canceled, the sql package will roll back
+// the transaction. Tx.Commit will return an error if the context provided to
+// BeginTx is canceled.
+//
+// The provided TxOptions is optional and may be nil if defaults should be used.
+// If a non-default isolation level is used that the driver doesn't support,
+// an error will be returned.
+func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
+ var tx *Tx
+ var err error
+
+ err = db.retry(func(strategy connReuseStrategy) error {
+ tx, err = db.begin(ctx, opts, strategy)
+ return err
+ })
+
+ return tx, err
+}
+
+// Begin starts a transaction. The default isolation level is dependent on
+// the driver.
+//
+// Begin uses context.Background internally; to specify the context, use
+// BeginTx.
+func (db *DB) Begin() (*Tx, error) {
+ return db.BeginTx(context.Background(), nil)
+}
+
+func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) {
+ dc, err := db.conn(ctx, strategy)
+ if err != nil {
+ return nil, err
+ }
+ return db.beginDC(ctx, dc, dc.releaseConn, opts)
+}
+
+// beginDC starts a transaction. The provided dc must be valid and ready to use.
+func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) {
+ var txi driver.Tx
+ keepConnOnRollback := false
+ withLock(dc, func() {
+ _, hasSessionResetter := dc.ci.(driver.SessionResetter)
+ _, hasConnectionValidator := dc.ci.(driver.Validator)
+ keepConnOnRollback = hasSessionResetter && hasConnectionValidator
+ txi, err = ctxDriverBegin(ctx, opts, dc.ci)
+ })
+ if err != nil {
+ release(err)
+ return nil, err
+ }
+
+ // Schedule the transaction to rollback when the context is canceled.
+ // The cancel function in Tx will be called after done is set to true.
+ ctx, cancel := context.WithCancel(ctx)
+ tx = &Tx{
+ db: db,
+ dc: dc,
+ releaseConn: release,
+ txi: txi,
+ cancel: cancel,
+ keepConnOnRollback: keepConnOnRollback,
+ ctx: ctx,
+ }
+ go tx.awaitDone()
+ return tx, nil
+}
+
+// Driver returns the database's underlying driver.
+func (db *DB) Driver() driver.Driver {
+ return db.connector.Driver()
+}
+
+// ErrConnDone is returned by any operation that is performed on a connection
+// that has already been returned to the connection pool.
+var ErrConnDone = errors.New("sql: connection is already closed")
+
+// Conn returns a single connection by either opening a new connection
+// or returning an existing connection from the connection pool. Conn will
+// block until either a connection is returned or ctx is canceled.
+// Queries run on the same Conn will be run in the same database session.
+//
+// Every Conn must be returned to the database pool after use by
+// calling Conn.Close.
+func (db *DB) Conn(ctx context.Context) (*Conn, error) {
+ var dc *driverConn
+ var err error
+
+ err = db.retry(func(strategy connReuseStrategy) error {
+ dc, err = db.conn(ctx, strategy)
+ return err
+ })
+
+ if err != nil {
+ return nil, err
+ }
+
+ conn := &Conn{
+ db: db,
+ dc: dc,
+ }
+ return conn, nil
+}
+
+type releaseConn func(error)
+
+// Conn represents a single database connection rather than a pool of database
+// connections. Prefer running queries from DB unless there is a specific
+// need for a continuous single database connection.
+//
+// A Conn must call Close to return the connection to the database pool
+// and may do so concurrently with a running query.
+//
+// After a call to Close, all operations on the
+// connection fail with ErrConnDone.
+type Conn struct {
+ db *DB
+
+ // closemu prevents the connection from closing while there
+ // is an active query. It is held for read during queries
+ // and exclusively during close.
+ closemu sync.RWMutex
+
+ // dc is owned until close, at which point
+ // it's returned to the connection pool.
+ dc *driverConn
+
+ // done transitions from false to true exactly once, on close.
+ // Once done, all operations fail with ErrConnDone.
+ done atomic.Bool
+
+ // releaseConn is a cache of c.closemuRUnlockCondReleaseConn
+ // to save allocations in a call to grabConn.
+ releaseConnOnce sync.Once
+ releaseConnCache releaseConn
+}
+
+// grabConn takes a context to implement stmtConnGrabber
+// but the context is not used.
+func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) {
+ if c.done.Load() {
+ return nil, nil, ErrConnDone
+ }
+ c.releaseConnOnce.Do(func() {
+ c.releaseConnCache = c.closemuRUnlockCondReleaseConn
+ })
+ c.closemu.RLock()
+ return c.dc, c.releaseConnCache, nil
+}
+
+// PingContext verifies the connection to the database is still alive.
+func (c *Conn) PingContext(ctx context.Context) error {
+ dc, release, err := c.grabConn(ctx)
+ if err != nil {
+ return err
+ }
+ return c.db.pingDC(ctx, dc, release)
+}
+
+// ExecContext executes a query without returning any rows.
+// The args are for any placeholder parameters in the query.
+func (c *Conn) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
+ dc, release, err := c.grabConn(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return c.db.execDC(ctx, dc, release, query, args)
+}
+
+// QueryContext executes a query that returns rows, typically a SELECT.
+// The args are for any placeholder parameters in the query.
+func (c *Conn) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
+ dc, release, err := c.grabConn(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return c.db.queryDC(ctx, nil, dc, release, query, args)
+}
+
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
+func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
+ rows, err := c.QueryContext(ctx, query, args...)
+ return &Row{rows: rows, err: err}
+}
+
+// PrepareContext creates a prepared statement for later queries or executions.
+// Multiple queries or executions may be run concurrently from the
+// returned statement.
+// The caller must call the statement's Close method
+// when the statement is no longer needed.
+//
+// The provided context is used for the preparation of the statement, not for the
+// execution of the statement.
+func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
+ dc, release, err := c.grabConn(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return c.db.prepareDC(ctx, dc, release, c, query)
+}
+
+// Raw executes f exposing the underlying driver connection for the
+// duration of f. The driverConn must not be used outside of f.
+//
+// Once f returns and err is not driver.ErrBadConn, the Conn will continue to be usable
+// until Conn.Close is called.
+func (c *Conn) Raw(f func(driverConn any) error) (err error) {
+ var dc *driverConn
+ var release releaseConn
+
+ // grabConn takes a context to implement stmtConnGrabber, but the context is not used.
+ dc, release, err = c.grabConn(nil)
+ if err != nil {
+ return
+ }
+ fPanic := true
+ dc.Mutex.Lock()
+ defer func() {
+ dc.Mutex.Unlock()
+
+ // If f panics fPanic will remain true.
+ // Ensure an error is passed to release so the connection
+ // may be discarded.
+ if fPanic {
+ err = driver.ErrBadConn
+ }
+ release(err)
+ }()
+ err = f(dc.ci)
+ fPanic = false
+
+ return
+}
+
+// BeginTx starts a transaction.
+//
+// The provided context is used until the transaction is committed or rolled back.
+// If the context is canceled, the sql package will roll back
+// the transaction. Tx.Commit will return an error if the context provided to
+// BeginTx is canceled.
+//
+// The provided TxOptions is optional and may be nil if defaults should be used.
+// If a non-default isolation level is used that the driver doesn't support,
+// an error will be returned.
+func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
+ dc, release, err := c.grabConn(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return c.db.beginDC(ctx, dc, release, opts)
+}
+
+// closemuRUnlockCondReleaseConn read unlocks closemu
+// as the sql operation is done with the dc.
+func (c *Conn) closemuRUnlockCondReleaseConn(err error) {
+ c.closemu.RUnlock()
+ if errors.Is(err, driver.ErrBadConn) {
+ c.close(err)
+ }
+}
+
+func (c *Conn) txCtx() context.Context {
+ return nil
+}
+
+func (c *Conn) close(err error) error {
+ if !c.done.CompareAndSwap(false, true) {
+ return ErrConnDone
+ }
+
+ // Lock around releasing the driver connection
+ // to ensure all queries have been stopped before doing so.
+ c.closemu.Lock()
+ defer c.closemu.Unlock()
+
+ c.dc.releaseConn(err)
+ c.dc = nil
+ c.db = nil
+ return err
+}
+
+// Close returns the connection to the connection pool.
+// All operations after a Close will return with ErrConnDone.
+// Close is safe to call concurrently with other operations and will
+// block until all other operations finish. It may be useful to first
+// cancel any used context and then call close directly after.
+func (c *Conn) Close() error {
+ return c.close(nil)
+}
+
+// Tx is an in-progress database transaction.
+//
+// A transaction must end with a call to Commit or Rollback.
+//
+// After a call to Commit or Rollback, all operations on the
+// transaction fail with ErrTxDone.
+//
+// The statements prepared for a transaction by calling
+// the transaction's Prepare or Stmt methods are closed
+// by the call to Commit or Rollback.
+type Tx struct {
+ db *DB
+
+ // closemu prevents the transaction from closing while there
+ // is an active query. It is held for read during queries
+ // and exclusively during close.
+ closemu sync.RWMutex
+
+ // dc is owned exclusively until Commit or Rollback, at which point
+ // it's returned with putConn.
+ dc *driverConn
+ txi driver.Tx
+
+ // releaseConn is called once the Tx is closed to release
+ // any held driverConn back to the pool.
+ releaseConn func(error)
+
+ // done transitions from false to true exactly once, on Commit
+ // or Rollback. once done, all operations fail with
+ // ErrTxDone.
+ done atomic.Bool
+
+ // keepConnOnRollback is true if the driver knows
+ // how to reset the connection's session and if need be discard
+ // the connection.
+ keepConnOnRollback bool
+
+ // All Stmts prepared for this transaction. These will be closed after the
+ // transaction has been committed or rolled back.
+ stmts struct {
+ sync.Mutex
+ v []*Stmt
+ }
+
+ // cancel is called after done transitions from 0 to 1.
+ cancel func()
+
+ // ctx lives for the life of the transaction.
+ ctx context.Context
+}
+
+// awaitDone blocks until the context in Tx is canceled and rolls back
+// the transaction if it's not already done.
+func (tx *Tx) awaitDone() {
+ // Wait for either the transaction to be committed or rolled
+ // back, or for the associated context to be closed.
+ <-tx.ctx.Done()
+
+ // Discard and close the connection used to ensure the
+ // transaction is closed and the resources are released. This
+ // rollback does nothing if the transaction has already been
+ // committed or rolled back.
+ // Do not discard the connection if the connection knows
+ // how to reset the session.
+ discardConnection := !tx.keepConnOnRollback
+ tx.rollback(discardConnection)
+}
+
+func (tx *Tx) isDone() bool {
+ return tx.done.Load()
+}
+
+// ErrTxDone is returned by any operation that is performed on a transaction
+// that has already been committed or rolled back.
+var ErrTxDone = errors.New("sql: transaction has already been committed or rolled back")
+
+// close returns the connection to the pool and
+// must only be called by Tx.rollback or Tx.Commit while
+// tx is already canceled and won't be executed concurrently.
+func (tx *Tx) close(err error) {
+ tx.releaseConn(err)
+ tx.dc = nil
+ tx.txi = nil
+}
+
+// hookTxGrabConn specifies an optional hook to be called on
+// a successful call to (*Tx).grabConn. For tests.
+var hookTxGrabConn func()
+
+func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
+ select {
+ default:
+ case <-ctx.Done():
+ return nil, nil, ctx.Err()
+ }
+
+ // closemu.RLock must come before the check for isDone to prevent the Tx from
+ // closing while a query is executing.
+ tx.closemu.RLock()
+ if tx.isDone() {
+ tx.closemu.RUnlock()
+ return nil, nil, ErrTxDone
+ }
+ if hookTxGrabConn != nil { // test hook
+ hookTxGrabConn()
+ }
+ return tx.dc, tx.closemuRUnlockRelease, nil
+}
+
+func (tx *Tx) txCtx() context.Context {
+ return tx.ctx
+}
+
+// closemuRUnlockRelease is used as a func(error) method value in
+// ExecContext and QueryContext. Unlocking in the releaseConn keeps
+// the driver conn from being returned to the connection pool until
+// the Rows has been closed.
+func (tx *Tx) closemuRUnlockRelease(error) {
+ tx.closemu.RUnlock()
+}
+
+// Closes all Stmts prepared for this transaction.
+func (tx *Tx) closePrepared() {
+ tx.stmts.Lock()
+ defer tx.stmts.Unlock()
+ for _, stmt := range tx.stmts.v {
+ stmt.Close()
+ }
+}
+
+// Commit commits the transaction.
+func (tx *Tx) Commit() error {
+ // Check context first to avoid transaction leak.
+ // If put it behind tx.done CompareAndSwap statement, we can't ensure
+ // the consistency between tx.done and the real COMMIT operation.
+ select {
+ default:
+ case <-tx.ctx.Done():
+ if tx.done.Load() {
+ return ErrTxDone
+ }
+ return tx.ctx.Err()
+ }
+ if !tx.done.CompareAndSwap(false, true) {
+ return ErrTxDone
+ }
+
+ // Cancel the Tx to release any active R-closemu locks.
+ // This is safe to do because tx.done has already transitioned
+ // from 0 to 1. Hold the W-closemu lock prior to rollback
+ // to ensure no other connection has an active query.
+ tx.cancel()
+ tx.closemu.Lock()
+ tx.closemu.Unlock()
+
+ var err error
+ withLock(tx.dc, func() {
+ err = tx.txi.Commit()
+ })
+ if !errors.Is(err, driver.ErrBadConn) {
+ tx.closePrepared()
+ }
+ tx.close(err)
+ return err
+}
+
+var rollbackHook func()
+
+// rollback aborts the transaction and optionally forces the pool to discard
+// the connection.
+func (tx *Tx) rollback(discardConn bool) error {
+ if !tx.done.CompareAndSwap(false, true) {
+ return ErrTxDone
+ }
+
+ if rollbackHook != nil {
+ rollbackHook()
+ }
+
+ // Cancel the Tx to release any active R-closemu locks.
+ // This is safe to do because tx.done has already transitioned
+ // from 0 to 1. Hold the W-closemu lock prior to rollback
+ // to ensure no other connection has an active query.
+ tx.cancel()
+ tx.closemu.Lock()
+ tx.closemu.Unlock()
+
+ var err error
+ withLock(tx.dc, func() {
+ err = tx.txi.Rollback()
+ })
+ if !errors.Is(err, driver.ErrBadConn) {
+ tx.closePrepared()
+ }
+ if discardConn {
+ err = driver.ErrBadConn
+ }
+ tx.close(err)
+ return err
+}
+
+// Rollback aborts the transaction.
+func (tx *Tx) Rollback() error {
+ return tx.rollback(false)
+}
+
+// PrepareContext creates a prepared statement for use within a transaction.
+//
+// The returned statement operates within the transaction and will be closed
+// when the transaction has been committed or rolled back.
+//
+// To use an existing prepared statement on this transaction, see Tx.Stmt.
+//
+// The provided context will be used for the preparation of the context, not
+// for the execution of the returned statement. The returned statement
+// will run in the transaction context.
+func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
+ dc, release, err := tx.grabConn(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query)
+ if err != nil {
+ return nil, err
+ }
+ tx.stmts.Lock()
+ tx.stmts.v = append(tx.stmts.v, stmt)
+ tx.stmts.Unlock()
+ return stmt, nil
+}
+
+// Prepare creates a prepared statement for use within a transaction.
+//
+// The returned statement operates within the transaction and will be closed
+// when the transaction has been committed or rolled back.
+//
+// To use an existing prepared statement on this transaction, see Tx.Stmt.
+//
+// Prepare uses context.Background internally; to specify the context, use
+// PrepareContext.
+func (tx *Tx) Prepare(query string) (*Stmt, error) {
+ return tx.PrepareContext(context.Background(), query)
+}
+
+// StmtContext returns a transaction-specific prepared statement from
+// an existing statement.
+//
+// Example:
+//
+// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
+// ...
+// tx, err := db.Begin()
+// ...
+// res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203)
+//
+// The provided context is used for the preparation of the statement, not for the
+// execution of the statement.
+//
+// The returned statement operates within the transaction and will be closed
+// when the transaction has been committed or rolled back.
+func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
+ dc, release, err := tx.grabConn(ctx)
+ if err != nil {
+ return &Stmt{stickyErr: err}
+ }
+ defer release(nil)
+
+ if tx.db != stmt.db {
+ return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
+ }
+ var si driver.Stmt
+ var parentStmt *Stmt
+ stmt.mu.Lock()
+ if stmt.closed || stmt.cg != nil {
+ // If the statement has been closed or already belongs to a
+ // transaction, we can't reuse it in this connection.
+ // Since tx.StmtContext should never need to be called with a
+ // Stmt already belonging to tx, we ignore this edge case and
+ // re-prepare the statement in this case. No need to add
+ // code-complexity for this.
+ stmt.mu.Unlock()
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
+ })
+ if err != nil {
+ return &Stmt{stickyErr: err}
+ }
+ } else {
+ stmt.removeClosedStmtLocked()
+ // See if the statement has already been prepared on this connection,
+ // and reuse it if possible.
+ for _, v := range stmt.css {
+ if v.dc == dc {
+ si = v.ds.si
+ break
+ }
+ }
+
+ stmt.mu.Unlock()
+
+ if si == nil {
+ var ds *driverStmt
+ withLock(dc, func() {
+ ds, err = stmt.prepareOnConnLocked(ctx, dc)
+ })
+ if err != nil {
+ return &Stmt{stickyErr: err}
+ }
+ si = ds.si
+ }
+ parentStmt = stmt
+ }
+
+ txs := &Stmt{
+ db: tx.db,
+ cg: tx,
+ cgds: &driverStmt{
+ Locker: dc,
+ si: si,
+ },
+ parentStmt: parentStmt,
+ query: stmt.query,
+ }
+ if parentStmt != nil {
+ tx.db.addDep(parentStmt, txs)
+ }
+ tx.stmts.Lock()
+ tx.stmts.v = append(tx.stmts.v, txs)
+ tx.stmts.Unlock()
+ return txs
+}
+
+// Stmt returns a transaction-specific prepared statement from
+// an existing statement.
+//
+// Example:
+//
+// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
+// ...
+// tx, err := db.Begin()
+// ...
+// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+//
+// The returned statement operates within the transaction and will be closed
+// when the transaction has been committed or rolled back.
+//
+// Stmt uses context.Background internally; to specify the context, use
+// StmtContext.
+func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+ return tx.StmtContext(context.Background(), stmt)
+}
+
+// ExecContext executes a query that doesn't return rows.
+// For example: an INSERT and UPDATE.
+func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
+ dc, release, err := tx.grabConn(ctx)
+ if err != nil {
+ return nil, err
+ }
+ return tx.db.execDC(ctx, dc, release, query, args)
+}
+
+// Exec executes a query that doesn't return rows.
+// For example: an INSERT and UPDATE.
+//
+// Exec uses context.Background internally; to specify the context, use
+// ExecContext.
+func (tx *Tx) Exec(query string, args ...any) (Result, error) {
+ return tx.ExecContext(context.Background(), query, args...)
+}
+
+// QueryContext executes a query that returns rows, typically a SELECT.
+func (tx *Tx) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
+ dc, release, err := tx.grabConn(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args)
+}
+
+// Query executes a query that returns rows, typically a SELECT.
+//
+// Query uses context.Background internally; to specify the context, use
+// QueryContext.
+func (tx *Tx) Query(query string, args ...any) (*Rows, error) {
+ return tx.QueryContext(context.Background(), query, args...)
+}
+
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
+func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *Row {
+ rows, err := tx.QueryContext(ctx, query, args...)
+ return &Row{rows: rows, err: err}
+}
+
+// QueryRow executes a query that is expected to return at most one row.
+// QueryRow always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
+//
+// QueryRow uses context.Background internally; to specify the context, use
+// QueryRowContext.
+func (tx *Tx) QueryRow(query string, args ...any) *Row {
+ return tx.QueryRowContext(context.Background(), query, args...)
+}
+
+// connStmt is a prepared statement on a particular connection.
+type connStmt struct {
+ dc *driverConn
+ ds *driverStmt
+}
+
+// stmtConnGrabber represents a Tx or Conn that will return the underlying
+// driverConn and release function.
+type stmtConnGrabber interface {
+ // grabConn returns the driverConn and the associated release function
+ // that must be called when the operation completes.
+ grabConn(context.Context) (*driverConn, releaseConn, error)
+
+ // txCtx returns the transaction context if available.
+ // The returned context should be selected on along with
+ // any query context when awaiting a cancel.
+ txCtx() context.Context
+}
+
+var (
+ _ stmtConnGrabber = &Tx{}
+ _ stmtConnGrabber = &Conn{}
+)
+
+// Stmt is a prepared statement.
+// A Stmt is safe for concurrent use by multiple goroutines.
+//
+// If a Stmt is prepared on a Tx or Conn, it will be bound to a single
+// underlying connection forever. If the Tx or Conn closes, the Stmt will
+// become unusable and all operations will return an error.
+// If a Stmt is prepared on a DB, it will remain usable for the lifetime of the
+// DB. When the Stmt needs to execute on a new underlying connection, it will
+// prepare itself on the new connection automatically.
+type Stmt struct {
+ // Immutable:
+ db *DB // where we came from
+ query string // that created the Stmt
+ stickyErr error // if non-nil, this error is returned for all operations
+
+ closemu sync.RWMutex // held exclusively during close, for read otherwise.
+
+ // If Stmt is prepared on a Tx or Conn then cg is present and will
+ // only ever grab a connection from cg.
+ // If cg is nil then the Stmt must grab an arbitrary connection
+ // from db and determine if it must prepare the stmt again by
+ // inspecting css.
+ cg stmtConnGrabber
+ cgds *driverStmt
+
+ // parentStmt is set when a transaction-specific statement
+ // is requested from an identical statement prepared on the same
+ // conn. parentStmt is used to track the dependency of this statement
+ // on its originating ("parent") statement so that parentStmt may
+ // be closed by the user without them having to know whether or not
+ // any transactions are still using it.
+ parentStmt *Stmt
+
+ mu sync.Mutex // protects the rest of the fields
+ closed bool
+
+ // css is a list of underlying driver statement interfaces
+ // that are valid on particular connections. This is only
+ // used if cg == nil and one is found that has idle
+ // connections. If cg != nil, cgds is always used.
+ css []connStmt
+
+ // lastNumClosed is copied from db.numClosed when Stmt is created
+ // without tx and closed connections in css are removed.
+ lastNumClosed uint64
+}
+
+// ExecContext executes a prepared statement with the given arguments and
+// returns a Result summarizing the effect of the statement.
+func (s *Stmt) ExecContext(ctx context.Context, args ...any) (Result, error) {
+ s.closemu.RLock()
+ defer s.closemu.RUnlock()
+
+ var res Result
+ err := s.db.retry(func(strategy connReuseStrategy) error {
+ dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
+ if err != nil {
+ return err
+ }
+
+ res, err = resultFromStatement(ctx, dc.ci, ds, args...)
+ releaseConn(err)
+ return err
+ })
+
+ return res, err
+}
+
+// Exec executes a prepared statement with the given arguments and
+// returns a Result summarizing the effect of the statement.
+//
+// Exec uses context.Background internally; to specify the context, use
+// ExecContext.
+func (s *Stmt) Exec(args ...any) (Result, error) {
+ return s.ExecContext(context.Background(), args...)
+}
+
+func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (Result, error) {
+ ds.Lock()
+ defer ds.Unlock()
+
+ dargs, err := driverArgsConnLocked(ci, ds, args)
+ if err != nil {
+ return nil, err
+ }
+
+ resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
+ if err != nil {
+ return nil, err
+ }
+ return driverResult{ds.Locker, resi}, nil
+}
+
+// removeClosedStmtLocked removes closed conns in s.css.
+//
+// To avoid lock contention on DB.mu, we do it only when
+// s.db.numClosed - s.lastNum is large enough.
+func (s *Stmt) removeClosedStmtLocked() {
+ t := len(s.css)/2 + 1
+ if t > 10 {
+ t = 10
+ }
+ dbClosed := s.db.numClosed.Load()
+ if dbClosed-s.lastNumClosed < uint64(t) {
+ return
+ }
+
+ s.db.mu.Lock()
+ for i := 0; i < len(s.css); i++ {
+ if s.css[i].dc.dbmuClosed {
+ s.css[i] = s.css[len(s.css)-1]
+ s.css = s.css[:len(s.css)-1]
+ i--
+ }
+ }
+ s.db.mu.Unlock()
+ s.lastNumClosed = dbClosed
+}
+
+// connStmt returns a free driver connection on which to execute the
+// statement, a function to call to release the connection, and a
+// statement bound to that connection.
+func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
+ if err = s.stickyErr; err != nil {
+ return
+ }
+ s.mu.Lock()
+ if s.closed {
+ s.mu.Unlock()
+ err = errors.New("sql: statement is closed")
+ return
+ }
+
+ // In a transaction or connection, we always use the connection that the
+ // stmt was created on.
+ if s.cg != nil {
+ s.mu.Unlock()
+ dc, releaseConn, err = s.cg.grabConn(ctx) // blocks, waiting for the connection.
+ if err != nil {
+ return
+ }
+ return dc, releaseConn, s.cgds, nil
+ }
+
+ s.removeClosedStmtLocked()
+ s.mu.Unlock()
+
+ dc, err = s.db.conn(ctx, strategy)
+ if err != nil {
+ return nil, nil, nil, err
+ }
+
+ s.mu.Lock()
+ for _, v := range s.css {
+ if v.dc == dc {
+ s.mu.Unlock()
+ return dc, dc.releaseConn, v.ds, nil
+ }
+ }
+ s.mu.Unlock()
+
+ // No luck; we need to prepare the statement on this connection
+ withLock(dc, func() {
+ ds, err = s.prepareOnConnLocked(ctx, dc)
+ })
+ if err != nil {
+ dc.releaseConn(err)
+ return nil, nil, nil, err
+ }
+
+ return dc, dc.releaseConn, ds, nil
+}
+
+// prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of
+// open connStmt on the statement. It assumes the caller is holding the lock on dc.
+func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
+ si, err := dc.prepareLocked(ctx, s.cg, s.query)
+ if err != nil {
+ return nil, err
+ }
+ cs := connStmt{dc, si}
+ s.mu.Lock()
+ s.css = append(s.css, cs)
+ s.mu.Unlock()
+ return cs.ds, nil
+}
+
+// QueryContext executes a prepared query statement with the given arguments
+// and returns the query results as a *Rows.
+func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*Rows, error) {
+ s.closemu.RLock()
+ defer s.closemu.RUnlock()
+
+ var rowsi driver.Rows
+ var rows *Rows
+
+ err := s.db.retry(func(strategy connReuseStrategy) error {
+ dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
+ if err != nil {
+ return err
+ }
+
+ rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...)
+ if err == nil {
+ // Note: ownership of ci passes to the *Rows, to be freed
+ // with releaseConn.
+ rows = &Rows{
+ dc: dc,
+ rowsi: rowsi,
+ // releaseConn set below
+ }
+ // addDep must be added before initContextClose or it could attempt
+ // to removeDep before it has been added.
+ s.db.addDep(s, rows)
+
+ // releaseConn must be set before initContextClose or it could
+ // release the connection before it is set.
+ rows.releaseConn = func(err error) {
+ releaseConn(err)
+ s.db.removeDep(s, rows)
+ }
+ var txctx context.Context
+ if s.cg != nil {
+ txctx = s.cg.txCtx()
+ }
+ rows.initContextClose(ctx, txctx)
+ return nil
+ }
+
+ releaseConn(err)
+ return err
+ })
+
+ return rows, err
+}
+
+// Query executes a prepared query statement with the given arguments
+// and returns the query results as a *Rows.
+//
+// Query uses context.Background internally; to specify the context, use
+// QueryContext.
+func (s *Stmt) Query(args ...any) (*Rows, error) {
+ return s.QueryContext(context.Background(), args...)
+}
+
+func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (driver.Rows, error) {
+ ds.Lock()
+ defer ds.Unlock()
+ dargs, err := driverArgsConnLocked(ci, ds, args)
+ if err != nil {
+ return nil, err
+ }
+ return ctxDriverStmtQuery(ctx, ds.si, dargs)
+}
+
+// QueryRowContext executes a prepared query statement with the given arguments.
+// If an error occurs during the execution of the statement, that error will
+// be returned by a call to Scan on the returned *Row, which is always non-nil.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
+func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *Row {
+ rows, err := s.QueryContext(ctx, args...)
+ if err != nil {
+ return &Row{err: err}
+ }
+ return &Row{rows: rows}
+}
+
+// QueryRow executes a prepared query statement with the given arguments.
+// If an error occurs during the execution of the statement, that error will
+// be returned by a call to Scan on the returned *Row, which is always non-nil.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
+//
+// Example usage:
+//
+// var name string
+// err := nameByUseridStmt.QueryRow(id).Scan(&name)
+//
+// QueryRow uses context.Background internally; to specify the context, use
+// QueryRowContext.
+func (s *Stmt) QueryRow(args ...any) *Row {
+ return s.QueryRowContext(context.Background(), args...)
+}
+
+// Close closes the statement.
+func (s *Stmt) Close() error {
+ s.closemu.Lock()
+ defer s.closemu.Unlock()
+
+ if s.stickyErr != nil {
+ return s.stickyErr
+ }
+ s.mu.Lock()
+ if s.closed {
+ s.mu.Unlock()
+ return nil
+ }
+ s.closed = true
+ txds := s.cgds
+ s.cgds = nil
+
+ s.mu.Unlock()
+
+ if s.cg == nil {
+ return s.db.removeDep(s, s)
+ }
+
+ if s.parentStmt != nil {
+ // If parentStmt is set, we must not close s.txds since it's stored
+ // in the css array of the parentStmt.
+ return s.db.removeDep(s.parentStmt, s)
+ }
+ return txds.Close()
+}
+
+func (s *Stmt) finalClose() error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.css != nil {
+ for _, v := range s.css {
+ s.db.noteUnusedDriverStatement(v.dc, v.ds)
+ v.dc.removeOpenStmt(v.ds)
+ }
+ s.css = nil
+ }
+ return nil
+}
+
+// Rows is the result of a query. Its cursor starts before the first row
+// of the result set. Use Next to advance from row to row.
+type Rows struct {
+ dc *driverConn // owned; must call releaseConn when closed to release
+ releaseConn func(error)
+ rowsi driver.Rows
+ cancel func() // called when Rows is closed, may be nil.
+ closeStmt *driverStmt // if non-nil, statement to Close on close
+
+ contextDone atomic.Pointer[error] // error that awaitDone saw; set before close attempt
+
+ // closemu prevents Rows from closing while there
+ // is an active streaming result. It is held for read during non-close operations
+ // and exclusively during close.
+ //
+ // closemu guards lasterr and closed.
+ closemu sync.RWMutex
+ closed bool
+ lasterr error // non-nil only if closed is true
+
+ // lastcols is only used in Scan, Next, and NextResultSet which are expected
+ // not to be called concurrently.
+ lastcols []driver.Value
+
+ // closemuScanHold is whether the previous call to Scan kept closemu RLock'ed
+ // without unlocking it. It does that when the user passes a *RawBytes scan
+ // target. In that case, we need to prevent awaitDone from closing the Rows
+ // while the user's still using the memory. See go.dev/issue/60304.
+ //
+ // It is only used by Scan, Next, and NextResultSet which are expected
+ // not to be called concurrently.
+ closemuScanHold bool
+
+ // hitEOF is whether Next hit the end of the rows without
+ // encountering an error. It's set in Next before
+ // returning. It's only used by Next and Err which are
+ // expected not to be called concurrently.
+ hitEOF bool
+}
+
+// lasterrOrErrLocked returns either lasterr or the provided err.
+// rs.closemu must be read-locked.
+func (rs *Rows) lasterrOrErrLocked(err error) error {
+ if rs.lasterr != nil && rs.lasterr != io.EOF {
+ return rs.lasterr
+ }
+ return err
+}
+
+// bypassRowsAwaitDone is only used for testing.
+// If true, it will not close the Rows automatically from the context.
+var bypassRowsAwaitDone = false
+
+func (rs *Rows) initContextClose(ctx, txctx context.Context) {
+ if ctx.Done() == nil && (txctx == nil || txctx.Done() == nil) {
+ return
+ }
+ if bypassRowsAwaitDone {
+ return
+ }
+ closectx, cancel := context.WithCancel(ctx)
+ rs.cancel = cancel
+ go rs.awaitDone(ctx, txctx, closectx)
+}
+
+// awaitDone blocks until ctx, txctx, or closectx is canceled.
+// The ctx is provided from the query context.
+// If the query was issued in a transaction, the transaction's context
+// is also provided in txctx, to ensure Rows is closed if the Tx is closed.
+// The closectx is closed by an explicit call to rs.Close.
+func (rs *Rows) awaitDone(ctx, txctx, closectx context.Context) {
+ var txctxDone <-chan struct{}
+ if txctx != nil {
+ txctxDone = txctx.Done()
+ }
+ select {
+ case <-ctx.Done():
+ err := ctx.Err()
+ rs.contextDone.Store(&err)
+ case <-txctxDone:
+ err := txctx.Err()
+ rs.contextDone.Store(&err)
+ case <-closectx.Done():
+ // rs.cancel was called via Close(); don't store this into contextDone
+ // to ensure Err() is unaffected.
+ }
+ rs.close(ctx.Err())
+}
+
+// Next prepares the next result row for reading with the Scan method. It
+// returns true on success, or false if there is no next result row or an error
+// happened while preparing it. Err should be consulted to distinguish between
+// the two cases.
+//
+// Every call to Scan, even the first one, must be preceded by a call to Next.
+func (rs *Rows) Next() bool {
+ // If the user's calling Next, they're done with their previous row's Scan
+ // results (any RawBytes memory), so we can release the read lock that would
+ // be preventing awaitDone from calling close.
+ rs.closemuRUnlockIfHeldByScan()
+
+ if rs.contextDone.Load() != nil {
+ return false
+ }
+
+ var doClose, ok bool
+ withLock(rs.closemu.RLocker(), func() {
+ doClose, ok = rs.nextLocked()
+ })
+ if doClose {
+ rs.Close()
+ }
+ if doClose && !ok {
+ rs.hitEOF = true
+ }
+ return ok
+}
+
+func (rs *Rows) nextLocked() (doClose, ok bool) {
+ if rs.closed {
+ return false, false
+ }
+
+ // Lock the driver connection before calling the driver interface
+ // rowsi to prevent a Tx from rolling back the connection at the same time.
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
+ if rs.lastcols == nil {
+ rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
+ }
+
+ rs.lasterr = rs.rowsi.Next(rs.lastcols)
+ if rs.lasterr != nil {
+ // Close the connection if there is a driver error.
+ if rs.lasterr != io.EOF {
+ return true, false
+ }
+ nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
+ if !ok {
+ return true, false
+ }
+ // The driver is at the end of the current result set.
+ // Test to see if there is another result set after the current one.
+ // Only close Rows if there is no further result sets to read.
+ if !nextResultSet.HasNextResultSet() {
+ doClose = true
+ }
+ return doClose, false
+ }
+ return false, true
+}
+
+// NextResultSet prepares the next result set for reading. It reports whether
+// there is further result sets, or false if there is no further result set
+// or if there is an error advancing to it. The Err method should be consulted
+// to distinguish between the two cases.
+//
+// After calling NextResultSet, the Next method should always be called before
+// scanning. If there are further result sets they may not have rows in the result
+// set.
+func (rs *Rows) NextResultSet() bool {
+ // If the user's calling NextResultSet, they're done with their previous
+ // row's Scan results (any RawBytes memory), so we can release the read lock
+ // that would be preventing awaitDone from calling close.
+ rs.closemuRUnlockIfHeldByScan()
+
+ var doClose bool
+ defer func() {
+ if doClose {
+ rs.Close()
+ }
+ }()
+ rs.closemu.RLock()
+ defer rs.closemu.RUnlock()
+
+ if rs.closed {
+ return false
+ }
+
+ rs.lastcols = nil
+ nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
+ if !ok {
+ doClose = true
+ return false
+ }
+
+ // Lock the driver connection before calling the driver interface
+ // rowsi to prevent a Tx from rolling back the connection at the same time.
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
+ rs.lasterr = nextResultSet.NextResultSet()
+ if rs.lasterr != nil {
+ doClose = true
+ return false
+ }
+ return true
+}
+
+// Err returns the error, if any, that was encountered during iteration.
+// Err may be called after an explicit or implicit Close.
+func (rs *Rows) Err() error {
+ // Return any context error that might've happened during row iteration,
+ // but only if we haven't reported the final Next() = false after rows
+ // are done, in which case the user might've canceled their own context
+ // before calling Rows.Err.
+ if !rs.hitEOF {
+ if errp := rs.contextDone.Load(); errp != nil {
+ return *errp
+ }
+ }
+
+ rs.closemu.RLock()
+ defer rs.closemu.RUnlock()
+ return rs.lasterrOrErrLocked(nil)
+}
+
+var errRowsClosed = errors.New("sql: Rows are closed")
+var errNoRows = errors.New("sql: no Rows available")
+
+// Columns returns the column names.
+// Columns returns an error if the rows are closed.
+func (rs *Rows) Columns() ([]string, error) {
+ rs.closemu.RLock()
+ defer rs.closemu.RUnlock()
+ if rs.closed {
+ return nil, rs.lasterrOrErrLocked(errRowsClosed)
+ }
+ if rs.rowsi == nil {
+ return nil, rs.lasterrOrErrLocked(errNoRows)
+ }
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
+ return rs.rowsi.Columns(), nil
+}
+
+// ColumnTypes returns column information such as column type, length,
+// and nullable. Some information may not be available from some drivers.
+func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
+ rs.closemu.RLock()
+ defer rs.closemu.RUnlock()
+ if rs.closed {
+ return nil, rs.lasterrOrErrLocked(errRowsClosed)
+ }
+ if rs.rowsi == nil {
+ return nil, rs.lasterrOrErrLocked(errNoRows)
+ }
+ rs.dc.Lock()
+ defer rs.dc.Unlock()
+
+ return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
+}
+
+// ColumnType contains the name and type of a column.
+type ColumnType struct {
+ name string
+
+ hasNullable bool
+ hasLength bool
+ hasPrecisionScale bool
+
+ nullable bool
+ length int64
+ databaseType string
+ precision int64
+ scale int64
+ scanType reflect.Type
+}
+
+// Name returns the name or alias of the column.
+func (ci *ColumnType) Name() string {
+ return ci.name
+}
+
+// Length returns the column type length for variable length column types such
+// as text and binary field types. If the type length is unbounded the value will
+// be math.MaxInt64 (any database limits will still apply).
+// If the column type is not variable length, such as an int, or if not supported
+// by the driver ok is false.
+func (ci *ColumnType) Length() (length int64, ok bool) {
+ return ci.length, ci.hasLength
+}
+
+// DecimalSize returns the scale and precision of a decimal type.
+// If not applicable or if not supported ok is false.
+func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
+ return ci.precision, ci.scale, ci.hasPrecisionScale
+}
+
+// ScanType returns a Go type suitable for scanning into using Rows.Scan.
+// If a driver does not support this property ScanType will return
+// the type of an empty interface.
+func (ci *ColumnType) ScanType() reflect.Type {
+ return ci.scanType
+}
+
+// Nullable reports whether the column may be null.
+// If a driver does not support this property ok will be false.
+func (ci *ColumnType) Nullable() (nullable, ok bool) {
+ return ci.nullable, ci.hasNullable
+}
+
+// DatabaseTypeName returns the database system name of the column type. If an empty
+// string is returned, then the driver type name is not supported.
+// Consult your driver documentation for a list of driver data types. Length specifiers
+// are not included.
+// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL",
+// "INT", and "BIGINT".
+func (ci *ColumnType) DatabaseTypeName() string {
+ return ci.databaseType
+}
+
+func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
+ names := rowsi.Columns()
+
+ list := make([]*ColumnType, len(names))
+ for i := range list {
+ ci := &ColumnType{
+ name: names[i],
+ }
+ list[i] = ci
+
+ if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok {
+ ci.scanType = prop.ColumnTypeScanType(i)
+ } else {
+ ci.scanType = reflect.TypeOf(new(any)).Elem()
+ }
+ if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok {
+ ci.databaseType = prop.ColumnTypeDatabaseTypeName(i)
+ }
+ if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok {
+ ci.length, ci.hasLength = prop.ColumnTypeLength(i)
+ }
+ if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok {
+ ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i)
+ }
+ if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok {
+ ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i)
+ }
+ }
+ return list
+}
+
+// Scan copies the columns in the current row into the values pointed
+// at by dest. The number of values in dest must be the same as the
+// number of columns in Rows.
+//
+// Scan converts columns read from the database into the following
+// common Go types and special types provided by the sql package:
+//
+// *string
+// *[]byte
+// *int, *int8, *int16, *int32, *int64
+// *uint, *uint8, *uint16, *uint32, *uint64
+// *bool
+// *float32, *float64
+// *interface{}
+// *RawBytes
+// *Rows (cursor value)
+// any type implementing Scanner (see Scanner docs)
+//
+// In the most simple case, if the type of the value from the source
+// column is an integer, bool or string type T and dest is of type *T,
+// Scan simply assigns the value through the pointer.
+//
+// Scan also converts between string and numeric types, as long as no
+// information would be lost. While Scan stringifies all numbers
+// scanned from numeric database columns into *string, scans into
+// numeric types are checked for overflow. For example, a float64 with
+// value 300 or a string with value "300" can scan into a uint16, but
+// not into a uint8, though float64(255) or "255" can scan into a
+// uint8. One exception is that scans of some float64 numbers to
+// strings may lose information when stringifying. In general, scan
+// floating point columns into *float64.
+//
+// If a dest argument has type *[]byte, Scan saves in that argument a
+// copy of the corresponding data. The copy is owned by the caller and
+// can be modified and held indefinitely. The copy can be avoided by
+// using an argument of type *RawBytes instead; see the documentation
+// for RawBytes for restrictions on its use.
+//
+// If an argument has type *interface{}, Scan copies the value
+// provided by the underlying driver without conversion. When scanning
+// from a source value of type []byte to *interface{}, a copy of the
+// slice is made and the caller owns the result.
+//
+// Source values of type time.Time may be scanned into values of type
+// *time.Time, *interface{}, *string, or *[]byte. When converting to
+// the latter two, time.RFC3339Nano is used.
+//
+// Source values of type bool may be scanned into types *bool,
+// *interface{}, *string, *[]byte, or *RawBytes.
+//
+// For scanning into *bool, the source may be true, false, 1, 0, or
+// string inputs parseable by strconv.ParseBool.
+//
+// Scan can also convert a cursor returned from a query, such as
+// "select cursor(select * from my_table) from dual", into a
+// *Rows value that can itself be scanned from. The parent
+// select query will close any cursor *Rows if the parent *Rows is closed.
+//
+// If any of the first arguments implementing Scanner returns an error,
+// that error will be wrapped in the returned error.
+func (rs *Rows) Scan(dest ...any) error {
+ if rs.closemuScanHold {
+ // This should only be possible if the user calls Scan twice in a row
+ // without calling Next.
+ return fmt.Errorf("sql: Scan called without calling Next (closemuScanHold)")
+ }
+ rs.closemu.RLock()
+
+ if rs.lasterr != nil && rs.lasterr != io.EOF {
+ rs.closemu.RUnlock()
+ return rs.lasterr
+ }
+ if rs.closed {
+ err := rs.lasterrOrErrLocked(errRowsClosed)
+ rs.closemu.RUnlock()
+ return err
+ }
+
+ if scanArgsContainRawBytes(dest) {
+ rs.closemuScanHold = true
+ } else {
+ rs.closemu.RUnlock()
+ }
+
+ if rs.lastcols == nil {
+ rs.closemuRUnlockIfHeldByScan()
+ return errors.New("sql: Scan called without calling Next")
+ }
+ if len(dest) != len(rs.lastcols) {
+ rs.closemuRUnlockIfHeldByScan()
+ return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
+ }
+
+ for i, sv := range rs.lastcols {
+ err := convertAssignRows(dest[i], sv, rs)
+ if err != nil {
+ rs.closemuRUnlockIfHeldByScan()
+ return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
+ }
+ }
+ return nil
+}
+
+// closemuRUnlockIfHeldByScan releases any closemu.RLock held open by a previous
+// call to Scan with *RawBytes.
+func (rs *Rows) closemuRUnlockIfHeldByScan() {
+ if rs.closemuScanHold {
+ rs.closemuScanHold = false
+ rs.closemu.RUnlock()
+ }
+}
+
+func scanArgsContainRawBytes(args []any) bool {
+ for _, a := range args {
+ if _, ok := a.(*RawBytes); ok {
+ return true
+ }
+ }
+ return false
+}
+
+// rowsCloseHook returns a function so tests may install the
+// hook through a test only mutex.
+var rowsCloseHook = func() func(*Rows, *error) { return nil }
+
+// Close closes the Rows, preventing further enumeration. If Next is called
+// and returns false and there are no further result sets,
+// the Rows are closed automatically and it will suffice to check the
+// result of Err. Close is idempotent and does not affect the result of Err.
+func (rs *Rows) Close() error {
+ // If the user's calling Close, they're done with their previous row's Scan
+ // results (any RawBytes memory), so we can release the read lock that would
+ // be preventing awaitDone from calling the unexported close before we do so.
+ rs.closemuRUnlockIfHeldByScan()
+
+ return rs.close(nil)
+}
+
+func (rs *Rows) close(err error) error {
+ rs.closemu.Lock()
+ defer rs.closemu.Unlock()
+
+ if rs.closed {
+ return nil
+ }
+ rs.closed = true
+
+ if rs.lasterr == nil {
+ rs.lasterr = err
+ }
+
+ withLock(rs.dc, func() {
+ err = rs.rowsi.Close()
+ })
+ if fn := rowsCloseHook(); fn != nil {
+ fn(rs, &err)
+ }
+ if rs.cancel != nil {
+ rs.cancel()
+ }
+
+ if rs.closeStmt != nil {
+ rs.closeStmt.Close()
+ }
+ rs.releaseConn(err)
+
+ rs.lasterr = rs.lasterrOrErrLocked(err)
+ return err
+}
+
+// Row is the result of calling QueryRow to select a single row.
+type Row struct {
+ // One of these two will be non-nil:
+ err error // deferred error for easy chaining
+ rows *Rows
+}
+
+// Scan copies the columns from the matched row into the values
+// pointed at by dest. See the documentation on Rows.Scan for details.
+// If more than one row matches the query,
+// Scan uses the first row and discards the rest. If no row matches
+// the query, Scan returns ErrNoRows.
+func (r *Row) Scan(dest ...any) error {
+ if r.err != nil {
+ return r.err
+ }
+
+ // TODO(bradfitz): for now we need to defensively clone all
+ // []byte that the driver returned (not permitting
+ // *RawBytes in Rows.Scan), since we're about to close
+ // the Rows in our defer, when we return from this function.
+ // the contract with the driver.Next(...) interface is that it
+ // can return slices into read-only temporary memory that's
+ // only valid until the next Scan/Close. But the TODO is that
+ // for a lot of drivers, this copy will be unnecessary. We
+ // should provide an optional interface for drivers to
+ // implement to say, "don't worry, the []bytes that I return
+ // from Next will not be modified again." (for instance, if
+ // they were obtained from the network anyway) But for now we
+ // don't care.
+ defer r.rows.Close()
+ for _, dp := range dest {
+ if _, ok := dp.(*RawBytes); ok {
+ return errors.New("sql: RawBytes isn't allowed on Row.Scan")
+ }
+ }
+
+ if !r.rows.Next() {
+ if err := r.rows.Err(); err != nil {
+ return err
+ }
+ return ErrNoRows
+ }
+ err := r.rows.Scan(dest...)
+ if err != nil {
+ return err
+ }
+ // Make sure the query can be processed to completion with no errors.
+ return r.rows.Close()
+}
+
+// Err provides a way for wrapping packages to check for
+// query errors without calling Scan.
+// Err returns the error, if any, that was encountered while running the query.
+// If this error is not nil, this error will also be returned from Scan.
+func (r *Row) Err() error {
+ return r.err
+}
+
+// A Result summarizes an executed SQL command.
+type Result interface {
+ // LastInsertId returns the integer generated by the database
+ // in response to a command. Typically this will be from an
+ // "auto increment" column when inserting a new row. Not all
+ // databases support this feature, and the syntax of such
+ // statements varies.
+ LastInsertId() (int64, error)
+
+ // RowsAffected returns the number of rows affected by an
+ // update, insert, or delete. Not every database or database
+ // driver may support this.
+ RowsAffected() (int64, error)
+}
+
+type driverResult struct {
+ sync.Locker // the *driverConn
+ resi driver.Result
+}
+
+func (dr driverResult) LastInsertId() (int64, error) {
+ dr.Lock()
+ defer dr.Unlock()
+ return dr.resi.LastInsertId()
+}
+
+func (dr driverResult) RowsAffected() (int64, error) {
+ dr.Lock()
+ defer dr.Unlock()
+ return dr.resi.RowsAffected()
+}
+
+func stack() string {
+ var buf [2 << 10]byte
+ return string(buf[:runtime.Stack(buf[:], false)])
+}
+
+// withLock runs while holding lk.
+func withLock(lk sync.Locker, fn func()) {
+ lk.Lock()
+ defer lk.Unlock() // in case fn panics
+ fn()
+}
diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go
new file mode 100644
index 0000000..e6a5cd9
--- /dev/null
+++ b/src/database/sql/sql_test.go
@@ -0,0 +1,4752 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sql
+
+import (
+ "context"
+ "database/sql/driver"
+ "errors"
+ "fmt"
+ "internal/race"
+ "internal/testenv"
+ "math/rand"
+ "reflect"
+ "runtime"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+func init() {
+ type dbConn struct {
+ db *DB
+ c *driverConn
+ }
+ freedFrom := make(map[dbConn]string)
+ var mu sync.Mutex
+ getFreedFrom := func(c dbConn) string {
+ mu.Lock()
+ defer mu.Unlock()
+ return freedFrom[c]
+ }
+ setFreedFrom := func(c dbConn, s string) {
+ mu.Lock()
+ defer mu.Unlock()
+ freedFrom[c] = s
+ }
+ putConnHook = func(db *DB, c *driverConn) {
+ idx := -1
+ for i, v := range db.freeConn {
+ if v == c {
+ idx = i
+ break
+ }
+ }
+ if idx >= 0 {
+ // print before panic, as panic may get lost due to conflicting panic
+ // (all goroutines asleep) elsewhere, since we might not unlock
+ // the mutex in freeConn here.
+ println("double free of conn. conflicts are:\nA) " + getFreedFrom(dbConn{db, c}) + "\n\nand\nB) " + stack())
+ panic("double free of conn.")
+ }
+ setFreedFrom(dbConn{db, c}, stack())
+ }
+}
+
+// pollDuration is an arbitrary interval to wait between checks when polling for
+// a condition to occur.
+const pollDuration = 5 * time.Millisecond
+
+const fakeDBName = "foo"
+
+var chrisBirthday = time.Unix(123456789, 0)
+
+func newTestDB(t testing.TB, name string) *DB {
+ return newTestDBConnector(t, &fakeConnector{name: fakeDBName}, name)
+}
+
+func newTestDBConnector(t testing.TB, fc *fakeConnector, name string) *DB {
+ fc.name = fakeDBName
+ db := OpenDB(fc)
+ if _, err := db.Exec("WIPE"); err != nil {
+ t.Fatalf("exec wipe: %v", err)
+ }
+ if name == "people" {
+ exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime")
+ exec(t, db, "INSERT|people|name=Alice,age=?,photo=APHOTO", 1)
+ exec(t, db, "INSERT|people|name=Bob,age=?,photo=BPHOTO", 2)
+ exec(t, db, "INSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
+ }
+ if name == "magicquery" {
+ // Magic table name and column, known by fakedb_test.go.
+ exec(t, db, "CREATE|magicquery|op=string,millis=int32")
+ exec(t, db, "INSERT|magicquery|op=sleep,millis=10")
+ }
+ if name == "tx_status" {
+ // Magic table name and column, known by fakedb_test.go.
+ exec(t, db, "CREATE|tx_status|tx_status=string")
+ exec(t, db, "INSERT|tx_status|tx_status=invalid")
+ }
+ return db
+}
+
+func TestOpenDB(t *testing.T) {
+ db := OpenDB(dsnConnector{dsn: fakeDBName, driver: fdriver})
+ if db.Driver() != fdriver {
+ t.Fatalf("OpenDB should return the driver of the Connector")
+ }
+}
+
+func TestDriverPanic(t *testing.T) {
+ // Test that if driver panics, database/sql does not deadlock.
+ db, err := Open("test", fakeDBName)
+ if err != nil {
+ t.Fatalf("Open: %v", err)
+ }
+ expectPanic := func(name string, f func()) {
+ defer func() {
+ err := recover()
+ if err == nil {
+ t.Fatalf("%s did not panic", name)
+ }
+ }()
+ f()
+ }
+
+ expectPanic("Exec Exec", func() { db.Exec("PANIC|Exec|WIPE") })
+ exec(t, db, "WIPE") // check not deadlocked
+ expectPanic("Exec NumInput", func() { db.Exec("PANIC|NumInput|WIPE") })
+ exec(t, db, "WIPE") // check not deadlocked
+ expectPanic("Exec Close", func() { db.Exec("PANIC|Close|WIPE") })
+ exec(t, db, "WIPE") // check not deadlocked
+ exec(t, db, "PANIC|Query|WIPE") // should run successfully: Exec does not call Query
+ exec(t, db, "WIPE") // check not deadlocked
+
+ exec(t, db, "CREATE|people|name=string,age=int32,photo=blob,dead=bool,bdate=datetime")
+
+ expectPanic("Query Query", func() { db.Query("PANIC|Query|SELECT|people|age,name|") })
+ expectPanic("Query NumInput", func() { db.Query("PANIC|NumInput|SELECT|people|age,name|") })
+ expectPanic("Query Close", func() {
+ rows, err := db.Query("PANIC|Close|SELECT|people|age,name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ rows.Close()
+ })
+ db.Query("PANIC|Exec|SELECT|people|age,name|") // should run successfully: Query does not call Exec
+ exec(t, db, "WIPE") // check not deadlocked
+}
+
+func exec(t testing.TB, db *DB, query string, args ...any) {
+ t.Helper()
+ _, err := db.Exec(query, args...)
+ if err != nil {
+ t.Fatalf("Exec of %q: %v", query, err)
+ }
+}
+
+func closeDB(t testing.TB, db *DB) {
+ if e := recover(); e != nil {
+ fmt.Printf("Panic: %v\n", e)
+ panic(e)
+ }
+ defer setHookpostCloseConn(nil)
+ setHookpostCloseConn(func(_ *fakeConn, err error) {
+ if err != nil {
+ t.Errorf("Error closing fakeConn: %v", err)
+ }
+ })
+ db.mu.Lock()
+ for i, dc := range db.freeConn {
+ if n := len(dc.openStmt); n > 0 {
+ // Just a sanity check. This is legal in
+ // general, but if we make the tests clean up
+ // their statements first, then we can safely
+ // verify this is always zero here, and any
+ // other value is a leak.
+ t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, len(db.freeConn), n)
+ }
+ }
+ db.mu.Unlock()
+
+ err := db.Close()
+ if err != nil {
+ t.Fatalf("error closing DB: %v", err)
+ }
+
+ var numOpen int
+ if !waitCondition(t, func() bool {
+ numOpen = db.numOpenConns()
+ return numOpen == 0
+ }) {
+ t.Fatalf("%d connections still open after closing DB", numOpen)
+ }
+}
+
+// numPrepares assumes that db has exactly 1 idle conn and returns
+// its count of calls to Prepare
+func numPrepares(t *testing.T, db *DB) int {
+ if n := len(db.freeConn); n != 1 {
+ t.Fatalf("free conns = %d; want 1", n)
+ }
+ return db.freeConn[0].ci.(*fakeConn).numPrepare
+}
+
+func (db *DB) numDeps() int {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ return len(db.dep)
+}
+
+// Dependencies are closed via a goroutine, so this polls waiting for
+// numDeps to fall to want, waiting up to nearly the test's deadline.
+func (db *DB) numDepsPoll(t *testing.T, want int) int {
+ var n int
+ waitCondition(t, func() bool {
+ n = db.numDeps()
+ return n <= want
+ })
+ return n
+}
+
+func (db *DB) numFreeConns() int {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ return len(db.freeConn)
+}
+
+func (db *DB) numOpenConns() int {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ return db.numOpen
+}
+
+// clearAllConns closes all connections in db.
+func (db *DB) clearAllConns(t *testing.T) {
+ db.SetMaxIdleConns(0)
+
+ if g, w := db.numFreeConns(), 0; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPoll(t, 0); n > 0 {
+ t.Errorf("number of dependencies = %d; expected 0", n)
+ db.dumpDeps(t)
+ }
+}
+
+func (db *DB) dumpDeps(t *testing.T) {
+ for fc := range db.dep {
+ db.dumpDep(t, 0, fc, map[finalCloser]bool{})
+ }
+}
+
+func (db *DB) dumpDep(t *testing.T, depth int, dep finalCloser, seen map[finalCloser]bool) {
+ seen[dep] = true
+ indent := strings.Repeat(" ", depth)
+ ds := db.dep[dep]
+ for k := range ds {
+ t.Logf("%s%T (%p) waiting for -> %T (%p)", indent, dep, dep, k, k)
+ if fc, ok := k.(finalCloser); ok {
+ if !seen[fc] {
+ db.dumpDep(t, depth+1, fc, seen)
+ }
+ }
+ }
+}
+
+func TestQuery(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+ rows, err := db.Query("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ type row struct {
+ age int
+ name string
+ }
+ got := []row{}
+ for rows.Next() {
+ var r row
+ err = rows.Scan(&r.age, &r.name)
+ if err != nil {
+ t.Fatalf("Scan: %v", err)
+ }
+ got = append(got, r)
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("Err: %v", err)
+ }
+ want := []row{
+ {age: 1, name: "Alice"},
+ {age: 2, name: "Bob"},
+ {age: 3, name: "Chris"},
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
+ }
+
+ // And verify that the final rows.Next() call, which hit EOF,
+ // also closed the rows connection.
+ if n := db.numFreeConns(); n != 1 {
+ t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+ }
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+// TestQueryContext tests canceling the context while scanning the rows.
+func TestQueryContext(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ rows, err := db.QueryContext(ctx, "SELECT|people|age,name|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ type row struct {
+ age int
+ name string
+ }
+ got := []row{}
+ index := 0
+ for rows.Next() {
+ if index == 2 {
+ cancel()
+ waitForRowsClose(t, rows)
+ }
+ var r row
+ err = rows.Scan(&r.age, &r.name)
+ if err != nil {
+ if index == 2 {
+ break
+ }
+ t.Fatalf("Scan: %v", err)
+ }
+ if index == 2 && err != context.Canceled {
+ t.Fatalf("Scan: %v; want context.Canceled", err)
+ }
+ got = append(got, r)
+ index++
+ }
+ select {
+ case <-ctx.Done():
+ if err := ctx.Err(); err != context.Canceled {
+ t.Fatalf("context err = %v; want context.Canceled", err)
+ }
+ default:
+ t.Fatalf("context err = nil; want context.Canceled")
+ }
+ want := []row{
+ {age: 1, name: "Alice"},
+ {age: 2, name: "Bob"},
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
+ }
+
+ // And verify that the final rows.Next() call, which hit EOF,
+ // also closed the rows connection.
+ waitForRowsClose(t, rows)
+ waitForFree(t, db, 1)
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+func waitCondition(t testing.TB, fn func() bool) bool {
+ timeout := 5 * time.Second
+
+ type deadliner interface {
+ Deadline() (time.Time, bool)
+ }
+ if td, ok := t.(deadliner); ok {
+ if deadline, ok := td.Deadline(); ok {
+ timeout = time.Until(deadline)
+ timeout = timeout * 19 / 20 // Give 5% headroom for cleanup and error-reporting.
+ }
+ }
+
+ deadline := time.Now().Add(timeout)
+ for {
+ if fn() {
+ return true
+ }
+ if time.Until(deadline) < pollDuration {
+ return false
+ }
+ time.Sleep(pollDuration)
+ }
+}
+
+// waitForFree checks db.numFreeConns until either it equals want or
+// the maxWait time elapses.
+func waitForFree(t *testing.T, db *DB, want int) {
+ var numFree int
+ if !waitCondition(t, func() bool {
+ numFree = db.numFreeConns()
+ return numFree == want
+ }) {
+ t.Fatalf("free conns after hitting EOF = %d; want %d", numFree, want)
+ }
+}
+
+func waitForRowsClose(t *testing.T, rows *Rows) {
+ if !waitCondition(t, func() bool {
+ rows.closemu.RLock()
+ defer rows.closemu.RUnlock()
+ return rows.closed
+ }) {
+ t.Fatal("failed to close rows")
+ }
+}
+
+// TestQueryContextWait ensures that rows and all internal statements are closed when
+// a query context is closed during execution.
+func TestQueryContextWait(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ // This will trigger the *fakeConn.Prepare method which will take time
+ // performing the query. The ctxDriverPrepare func will check the context
+ // after this and close the rows and return an error.
+ c, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ c.dc.ci.(*fakeConn).waiter = func(c context.Context) {
+ cancel()
+ <-ctx.Done()
+ }
+ _, err = c.QueryContext(ctx, "SELECT|people|age,name|")
+ c.Close()
+ if err != context.Canceled {
+ t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err)
+ }
+
+ // Verify closed rows connection after error condition.
+ waitForFree(t, db, 1)
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Fatalf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+// TestTxContextWait tests the transaction behavior when the tx context is canceled
+// during execution of the query.
+func TestTxContextWait(t *testing.T) {
+ testContextWait(t, false)
+}
+
+// TestTxContextWaitNoDiscard is the same as TestTxContextWait, but should not discard
+// the final connection.
+func TestTxContextWaitNoDiscard(t *testing.T) {
+ testContextWait(t, true)
+}
+
+func testContextWait(t *testing.T, keepConnOnRollback bool) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+
+ tx, err := db.BeginTx(ctx, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ tx.keepConnOnRollback = keepConnOnRollback
+
+ tx.dc.ci.(*fakeConn).waiter = func(c context.Context) {
+ cancel()
+ <-ctx.Done()
+ }
+ // This will trigger the *fakeConn.Prepare method which will take time
+ // performing the query. The ctxDriverPrepare func will check the context
+ // after this and close the rows and return an error.
+ _, err = tx.QueryContext(ctx, "SELECT|people|age,name|")
+ if err != context.Canceled {
+ t.Fatalf("expected QueryContext to error with context canceled but returned %v", err)
+ }
+
+ if keepConnOnRollback {
+ waitForFree(t, db, 1)
+ } else {
+ waitForFree(t, db, 0)
+ }
+}
+
+// TestUnsupportedOptions checks that the database fails when a driver that
+// doesn't implement ConnBeginTx is used with non-default options and an
+// un-cancellable context.
+func TestUnsupportedOptions(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ _, err := db.BeginTx(context.Background(), &TxOptions{
+ Isolation: LevelSerializable, ReadOnly: true,
+ })
+ if err == nil {
+ t.Fatal("expected error when using unsupported options, got nil")
+ }
+}
+
+func TestMultiResultSetQuery(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+ rows, err := db.Query("SELECT|people|age,name|;SELECT|people|name|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ type row1 struct {
+ age int
+ name string
+ }
+ type row2 struct {
+ name string
+ }
+ got1 := []row1{}
+ for rows.Next() {
+ var r row1
+ err = rows.Scan(&r.age, &r.name)
+ if err != nil {
+ t.Fatalf("Scan: %v", err)
+ }
+ got1 = append(got1, r)
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("Err: %v", err)
+ }
+ want1 := []row1{
+ {age: 1, name: "Alice"},
+ {age: 2, name: "Bob"},
+ {age: 3, name: "Chris"},
+ }
+ if !reflect.DeepEqual(got1, want1) {
+ t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1)
+ }
+
+ if !rows.NextResultSet() {
+ t.Errorf("expected another result set")
+ }
+
+ got2 := []row2{}
+ for rows.Next() {
+ var r row2
+ err = rows.Scan(&r.name)
+ if err != nil {
+ t.Fatalf("Scan: %v", err)
+ }
+ got2 = append(got2, r)
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("Err: %v", err)
+ }
+ want2 := []row2{
+ {name: "Alice"},
+ {name: "Bob"},
+ {name: "Chris"},
+ }
+ if !reflect.DeepEqual(got2, want2) {
+ t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2)
+ }
+ if rows.NextResultSet() {
+ t.Errorf("expected no more result sets")
+ }
+
+ // And verify that the final rows.Next() call, which hit EOF,
+ // also closed the rows connection.
+ waitForFree(t, db, 1)
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+func TestQueryNamedArg(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+ rows, err := db.Query(
+ // Ensure the name and age parameters only match on placeholder name, not position.
+ "SELECT|people|age,name|name=?name,age=?age",
+ Named("age", 2),
+ Named("name", "Bob"),
+ )
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ type row struct {
+ age int
+ name string
+ }
+ got := []row{}
+ for rows.Next() {
+ var r row
+ err = rows.Scan(&r.age, &r.name)
+ if err != nil {
+ t.Fatalf("Scan: %v", err)
+ }
+ got = append(got, r)
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("Err: %v", err)
+ }
+ want := []row{
+ {age: 2, name: "Bob"},
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
+ }
+
+ // And verify that the final rows.Next() call, which hit EOF,
+ // also closed the rows connection.
+ if n := db.numFreeConns(); n != 1 {
+ t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+ }
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+func TestPoolExhaustOnCancel(t *testing.T) {
+ if testing.Short() {
+ t.Skip("long test")
+ }
+
+ max := 3
+ var saturate, saturateDone sync.WaitGroup
+ saturate.Add(max)
+ saturateDone.Add(max)
+
+ donePing := make(chan bool)
+ state := 0
+
+ // waiter will be called for all queries, including
+ // initial setup queries. The state is only assigned when
+ // no queries are made.
+ //
+ // Only allow the first batch of queries to finish once the
+ // second batch of Ping queries have finished.
+ waiter := func(ctx context.Context) {
+ switch state {
+ case 0:
+ // Nothing. Initial database setup.
+ case 1:
+ saturate.Done()
+ select {
+ case <-ctx.Done():
+ case <-donePing:
+ }
+ case 2:
+ }
+ }
+ db := newTestDBConnector(t, &fakeConnector{waiter: waiter}, "people")
+ defer closeDB(t, db)
+
+ db.SetMaxOpenConns(max)
+
+ // First saturate the connection pool.
+ // Then start new requests for a connection that is canceled after it is requested.
+
+ state = 1
+ for i := 0; i < max; i++ {
+ go func() {
+ rows, err := db.Query("SELECT|people|name,photo|")
+ if err != nil {
+ t.Errorf("Query: %v", err)
+ return
+ }
+ rows.Close()
+ saturateDone.Done()
+ }()
+ }
+
+ saturate.Wait()
+ if t.Failed() {
+ t.FailNow()
+ }
+ state = 2
+
+ // Now cancel the request while it is waiting.
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ for i := 0; i < max; i++ {
+ ctxReq, cancelReq := context.WithCancel(ctx)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ cancelReq()
+ }()
+ err := db.PingContext(ctxReq)
+ if err != context.Canceled {
+ t.Fatalf("PingContext (Exhaust): %v", err)
+ }
+ }
+ close(donePing)
+ saturateDone.Wait()
+
+ // Now try to open a normal connection.
+ err := db.PingContext(ctx)
+ if err != nil {
+ t.Fatalf("PingContext (Normal): %v", err)
+ }
+}
+
+func TestRowsColumns(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ rows, err := db.Query("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ cols, err := rows.Columns()
+ if err != nil {
+ t.Fatalf("Columns: %v", err)
+ }
+ want := []string{"age", "name"}
+ if !reflect.DeepEqual(cols, want) {
+ t.Errorf("got %#v; want %#v", cols, want)
+ }
+ if err := rows.Close(); err != nil {
+ t.Errorf("error closing rows: %s", err)
+ }
+}
+
+func TestRowsColumnTypes(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ rows, err := db.Query("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ tt, err := rows.ColumnTypes()
+ if err != nil {
+ t.Fatalf("ColumnTypes: %v", err)
+ }
+
+ types := make([]reflect.Type, len(tt))
+ for i, tp := range tt {
+ st := tp.ScanType()
+ if st == nil {
+ t.Errorf("scantype is null for column %q", tp.Name())
+ continue
+ }
+ types[i] = st
+ }
+ values := make([]any, len(tt))
+ for i := range values {
+ values[i] = reflect.New(types[i]).Interface()
+ }
+ ct := 0
+ for rows.Next() {
+ err = rows.Scan(values...)
+ if err != nil {
+ t.Fatalf("failed to scan values in %v", err)
+ }
+ if ct == 1 {
+ if age := *values[0].(*int32); age != 2 {
+ t.Errorf("Expected 2, got %v", age)
+ }
+ if name := *values[1].(*string); name != "Bob" {
+ t.Errorf("Expected Bob, got %v", name)
+ }
+ }
+ ct++
+ }
+ if ct != 3 {
+ t.Errorf("expected 3 rows, got %d", ct)
+ }
+
+ if err := rows.Close(); err != nil {
+ t.Errorf("error closing rows: %s", err)
+ }
+}
+
+func TestQueryRow(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ var name string
+ var age int
+ var birthday time.Time
+
+ err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age)
+ if err == nil || !strings.Contains(err.Error(), "expected 2 destination arguments") {
+ t.Errorf("expected error from wrong number of arguments; actually got: %v", err)
+ }
+
+ err = db.QueryRow("SELECT|people|bdate|age=?", 3).Scan(&birthday)
+ if err != nil || !birthday.Equal(chrisBirthday) {
+ t.Errorf("chris birthday = %v, err = %v; want %v", birthday, err, chrisBirthday)
+ }
+
+ err = db.QueryRow("SELECT|people|age,name|age=?", 2).Scan(&age, &name)
+ if err != nil {
+ t.Fatalf("age QueryRow+Scan: %v", err)
+ }
+ if name != "Bob" {
+ t.Errorf("expected name Bob, got %q", name)
+ }
+ if age != 2 {
+ t.Errorf("expected age 2, got %d", age)
+ }
+
+ err = db.QueryRow("SELECT|people|age,name|name=?", "Alice").Scan(&age, &name)
+ if err != nil {
+ t.Fatalf("name QueryRow+Scan: %v", err)
+ }
+ if name != "Alice" {
+ t.Errorf("expected name Alice, got %q", name)
+ }
+ if age != 1 {
+ t.Errorf("expected age 1, got %d", age)
+ }
+
+ var photo []byte
+ err = db.QueryRow("SELECT|people|photo|name=?", "Alice").Scan(&photo)
+ if err != nil {
+ t.Fatalf("photo QueryRow+Scan: %v", err)
+ }
+ want := []byte("APHOTO")
+ if !reflect.DeepEqual(photo, want) {
+ t.Errorf("photo = %q; want %q", photo, want)
+ }
+}
+
+func TestRowErr(t *testing.T) {
+ db := newTestDB(t, "people")
+
+ err := db.QueryRowContext(context.Background(), "SELECT|people|bdate|age=?", 3).Err()
+ if err != nil {
+ t.Errorf("Unexpected err = %v; want %v", err, nil)
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ err = db.QueryRowContext(ctx, "SELECT|people|bdate|age=?", 3).Err()
+ exp := "context canceled"
+ if err == nil || !strings.Contains(err.Error(), exp) {
+ t.Errorf("Expected err = %v; got %v", exp, err)
+ }
+}
+
+func TestTxRollbackCommitErr(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = tx.Rollback()
+ if err != nil {
+ t.Errorf("expected nil error from Rollback; got %v", err)
+ }
+ err = tx.Commit()
+ if err != ErrTxDone {
+ t.Errorf("expected %q from Commit; got %q", ErrTxDone, err)
+ }
+
+ tx, err = db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = tx.Commit()
+ if err != nil {
+ t.Errorf("expected nil error from Commit; got %v", err)
+ }
+ err = tx.Rollback()
+ if err != ErrTxDone {
+ t.Errorf("expected %q from Rollback; got %q", ErrTxDone, err)
+ }
+}
+
+func TestStatementErrorAfterClose(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ stmt, err := db.Prepare("SELECT|people|age|name=?")
+ if err != nil {
+ t.Fatalf("Prepare: %v", err)
+ }
+ err = stmt.Close()
+ if err != nil {
+ t.Fatalf("Close: %v", err)
+ }
+ var name string
+ err = stmt.QueryRow("foo").Scan(&name)
+ if err == nil {
+ t.Errorf("expected error from QueryRow.Scan after Stmt.Close")
+ }
+}
+
+func TestStatementQueryRow(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ stmt, err := db.Prepare("SELECT|people|age|name=?")
+ if err != nil {
+ t.Fatalf("Prepare: %v", err)
+ }
+ defer stmt.Close()
+ var age int
+ for n, tt := range []struct {
+ name string
+ want int
+ }{
+ {"Alice", 1},
+ {"Bob", 2},
+ {"Chris", 3},
+ } {
+ if err := stmt.QueryRow(tt.name).Scan(&age); err != nil {
+ t.Errorf("%d: on %q, QueryRow/Scan: %v", n, tt.name, err)
+ } else if age != tt.want {
+ t.Errorf("%d: age=%d, want %d", n, age, tt.want)
+ }
+ }
+}
+
+type stubDriverStmt struct {
+ err error
+}
+
+func (s stubDriverStmt) Close() error {
+ return s.err
+}
+
+func (s stubDriverStmt) NumInput() int {
+ return -1
+}
+
+func (s stubDriverStmt) Exec(args []driver.Value) (driver.Result, error) {
+ return nil, nil
+}
+
+func (s stubDriverStmt) Query(args []driver.Value) (driver.Rows, error) {
+ return nil, nil
+}
+
+// golang.org/issue/12798
+func TestStatementClose(t *testing.T) {
+ want := errors.New("STMT ERROR")
+
+ tests := []struct {
+ stmt *Stmt
+ msg string
+ }{
+ {&Stmt{stickyErr: want}, "stickyErr not propagated"},
+ {&Stmt{cg: &Tx{}, cgds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
+ }
+ for _, test := range tests {
+ if err := test.stmt.Close(); err != want {
+ t.Errorf("%s. Got stmt.Close() = %v, want = %v", test.msg, err, want)
+ }
+ }
+}
+
+// golang.org/issue/3734
+func TestStatementQueryRowConcurrent(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ stmt, err := db.Prepare("SELECT|people|age|name=?")
+ if err != nil {
+ t.Fatalf("Prepare: %v", err)
+ }
+ defer stmt.Close()
+
+ const n = 10
+ ch := make(chan error, n)
+ for i := 0; i < n; i++ {
+ go func() {
+ var age int
+ err := stmt.QueryRow("Alice").Scan(&age)
+ if err == nil && age != 1 {
+ err = fmt.Errorf("unexpected age %d", age)
+ }
+ ch <- err
+ }()
+ }
+ for i := 0; i < n; i++ {
+ if err := <-ch; err != nil {
+ t.Error(err)
+ }
+ }
+}
+
+// just a test of fakedb itself
+func TestBogusPreboundParameters(t *testing.T) {
+ db := newTestDB(t, "foo")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
+ _, err := db.Prepare("INSERT|t1|name=?,age=bogusconversion")
+ if err == nil {
+ t.Fatalf("expected error")
+ }
+ if err.Error() != `fakedb: invalid conversion to int32 from "bogusconversion"` {
+ t.Errorf("unexpected error: %v", err)
+ }
+}
+
+func TestExec(t *testing.T) {
+ db := newTestDB(t, "foo")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Errorf("Stmt, err = %v, %v", stmt, err)
+ }
+ defer stmt.Close()
+
+ type execTest struct {
+ args []any
+ wantErr string
+ }
+ execTests := []execTest{
+ // Okay:
+ {[]any{"Brad", 31}, ""},
+ {[]any{"Brad", int64(31)}, ""},
+ {[]any{"Bob", "32"}, ""},
+ {[]any{7, 9}, ""},
+
+ // Invalid conversions:
+ {[]any{"Brad", int64(0xFFFFFFFF)}, "sql: converting argument $2 type: sql/driver: value 4294967295 overflows int32"},
+ {[]any{"Brad", "strconv fail"}, `sql: converting argument $2 type: sql/driver: value "strconv fail" can't be converted to int32`},
+
+ // Wrong number of args:
+ {[]any{}, "sql: expected 2 arguments, got 0"},
+ {[]any{1, 2, 3}, "sql: expected 2 arguments, got 3"},
+ }
+ for n, et := range execTests {
+ _, err := stmt.Exec(et.args...)
+ errStr := ""
+ if err != nil {
+ errStr = err.Error()
+ }
+ if errStr != et.wantErr {
+ t.Errorf("stmt.Execute #%d: for %v, got error %q, want error %q",
+ n, et.args, errStr, et.wantErr)
+ }
+ }
+}
+
+func TestTxPrepare(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+ stmt, err := tx.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ defer stmt.Close()
+ _, err = stmt.Exec("Bobby", 7)
+ if err != nil {
+ t.Fatalf("Exec = %v", err)
+ }
+ err = tx.Commit()
+ if err != nil {
+ t.Fatalf("Commit = %v", err)
+ }
+ // Commit() should have closed the statement
+ if !stmt.closed {
+ t.Fatal("Stmt not closed after Commit")
+ }
+}
+
+func TestTxStmt(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ defer stmt.Close()
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+ txs := tx.Stmt(stmt)
+ defer txs.Close()
+ _, err = txs.Exec("Bobby", 7)
+ if err != nil {
+ t.Fatalf("Exec = %v", err)
+ }
+ err = tx.Commit()
+ if err != nil {
+ t.Fatalf("Commit = %v", err)
+ }
+ // Commit() should have closed the statement
+ if !txs.closed {
+ t.Fatal("Stmt not closed after Commit")
+ }
+}
+
+func TestTxStmtPreparedOnce(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32")
+
+ prepares0 := numPrepares(t, db)
+
+ // db.Prepare increments numPrepares.
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ defer stmt.Close()
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+
+ txs1 := tx.Stmt(stmt)
+ txs2 := tx.Stmt(stmt)
+
+ _, err = txs1.Exec("Go", 7)
+ if err != nil {
+ t.Fatalf("Exec = %v", err)
+ }
+ txs1.Close()
+
+ _, err = txs2.Exec("Gopher", 8)
+ if err != nil {
+ t.Fatalf("Exec = %v", err)
+ }
+ txs2.Close()
+
+ err = tx.Commit()
+ if err != nil {
+ t.Fatalf("Commit = %v", err)
+ }
+
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+func TestTxStmtClosedRePrepares(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32")
+
+ prepares0 := numPrepares(t, db)
+
+ // db.Prepare increments numPrepares.
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+ err = stmt.Close()
+ if err != nil {
+ t.Fatalf("stmt.Close() = %v", err)
+ }
+ // tx.Stmt increments numPrepares because stmt is closed.
+ txs := tx.Stmt(stmt)
+ if txs.stickyErr != nil {
+ t.Fatal(txs.stickyErr)
+ }
+ if txs.parentStmt != nil {
+ t.Fatal("expected nil parentStmt")
+ }
+ _, err = txs.Exec(`Eric`, 82)
+ if err != nil {
+ t.Fatalf("txs.Exec = %v", err)
+ }
+
+ err = txs.Close()
+ if err != nil {
+ t.Fatalf("txs.Close = %v", err)
+ }
+
+ tx.Rollback()
+
+ if prepares := numPrepares(t, db) - prepares0; prepares != 2 {
+ t.Errorf("executed %d Prepare statements; want 2", prepares)
+ }
+}
+
+func TestParentStmtOutlivesTxStmt(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32")
+
+ // Make sure everything happens on the same connection.
+ db.SetMaxOpenConns(1)
+
+ prepares0 := numPrepares(t, db)
+
+ // db.Prepare increments numPrepares.
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ defer stmt.Close()
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+ txs := tx.Stmt(stmt)
+ if len(stmt.css) != 1 {
+ t.Fatalf("len(stmt.css) = %v; want 1", len(stmt.css))
+ }
+ err = txs.Close()
+ if err != nil {
+ t.Fatalf("txs.Close() = %v", err)
+ }
+ err = tx.Rollback()
+ if err != nil {
+ t.Fatalf("tx.Rollback() = %v", err)
+ }
+ // txs must not be valid.
+ _, err = txs.Exec("Suzan", 30)
+ if err == nil {
+ t.Fatalf("txs.Exec(), expected err")
+ }
+ // Stmt must still be valid.
+ _, err = stmt.Exec("Janina", 25)
+ if err != nil {
+ t.Fatalf("stmt.Exec() = %v", err)
+ }
+
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+// Test that tx.Stmt called with a statement already
+// associated with tx as argument re-prepares the same
+// statement again.
+func TestTxStmtFromTxStmtRePrepares(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32")
+ prepares0 := numPrepares(t, db)
+ // db.Prepare increments numPrepares.
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ defer stmt.Close()
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+ txs1 := tx.Stmt(stmt)
+
+ // tx.Stmt(txs1) increments numPrepares because txs1 already
+ // belongs to a transaction (albeit the same transaction).
+ txs2 := tx.Stmt(txs1)
+ if txs2.stickyErr != nil {
+ t.Fatal(txs2.stickyErr)
+ }
+ if txs2.parentStmt != nil {
+ t.Fatal("expected nil parentStmt")
+ }
+ _, err = txs2.Exec(`Eric`, 82)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = txs1.Close()
+ if err != nil {
+ t.Fatalf("txs1.Close = %v", err)
+ }
+ err = txs2.Close()
+ if err != nil {
+ t.Fatalf("txs1.Close = %v", err)
+ }
+ err = tx.Rollback()
+ if err != nil {
+ t.Fatalf("tx.Rollback = %v", err)
+ }
+
+ if prepares := numPrepares(t, db) - prepares0; prepares != 2 {
+ t.Errorf("executed %d Prepare statements; want 2", prepares)
+ }
+}
+
+// Issue: https://golang.org/issue/2784
+// This test didn't fail before because we got lucky with the fakedb driver.
+// It was failing, and now not, in github.com/bradfitz/go-sql-test
+func TestTxQuery(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
+ exec(t, db, "INSERT|t1|name=Alice")
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer tx.Rollback()
+
+ r, err := tx.Query("SELECT|t1|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer r.Close()
+
+ if !r.Next() {
+ if r.Err() != nil {
+ t.Fatal(r.Err())
+ }
+ t.Fatal("expected one row")
+ }
+
+ var x string
+ err = r.Scan(&x)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestTxQueryInvalid(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer tx.Rollback()
+
+ _, err = tx.Query("SELECT|t1|name|")
+ if err == nil {
+ t.Fatal("Error expected")
+ }
+}
+
+// Tests fix for issue 4433, that retries in Begin happen when
+// conn.Begin() returns ErrBadConn
+func TestTxErrBadConn(t *testing.T) {
+ db, err := Open("test", fakeDBName+";badConn")
+ if err != nil {
+ t.Fatalf("Open: %v", err)
+ }
+ if _, err := db.Exec("WIPE"); err != nil {
+ t.Fatalf("exec wipe: %v", err)
+ }
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ defer stmt.Close()
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+ txs := tx.Stmt(stmt)
+ defer txs.Close()
+ _, err = txs.Exec("Bobby", 7)
+ if err != nil {
+ t.Fatalf("Exec = %v", err)
+ }
+ err = tx.Commit()
+ if err != nil {
+ t.Fatalf("Commit = %v", err)
+ }
+}
+
+func TestConnQuery(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ conn, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conn.dc.ci.(*fakeConn).skipDirtySession = true
+ defer conn.Close()
+
+ var name string
+ err = conn.QueryRowContext(ctx, "SELECT|people|name|age=?", 3).Scan(&name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if name != "Chris" {
+ t.Fatalf("unexpected result, got %q want Chris", name)
+ }
+
+ err = conn.PingContext(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestConnRaw(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ conn, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conn.dc.ci.(*fakeConn).skipDirtySession = true
+ defer conn.Close()
+
+ sawFunc := false
+ err = conn.Raw(func(dc any) error {
+ sawFunc = true
+ if _, ok := dc.(*fakeConn); !ok {
+ return fmt.Errorf("got %T want *fakeConn", dc)
+ }
+ return nil
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !sawFunc {
+ t.Fatal("Raw func not called")
+ }
+
+ func() {
+ defer func() {
+ x := recover()
+ if x == nil {
+ t.Fatal("expected panic")
+ }
+ conn.closemu.Lock()
+ closed := conn.dc == nil
+ conn.closemu.Unlock()
+ if !closed {
+ t.Fatal("expected connection to be closed after panic")
+ }
+ }()
+ err = conn.Raw(func(dc any) error {
+ panic("Conn.Raw panic should return an error")
+ })
+ t.Fatal("expected panic from Raw func")
+ }()
+}
+
+func TestCursorFake(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
+ defer cancel()
+
+ exec(t, db, "CREATE|peoplecursor|list=table")
+ exec(t, db, "INSERT|peoplecursor|list=people!name!age")
+
+ rows, err := db.QueryContext(ctx, `SELECT|peoplecursor|list|`)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer rows.Close()
+
+ if !rows.Next() {
+ t.Fatal("no rows")
+ }
+ var cursor = &Rows{}
+ err = rows.Scan(cursor)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cursor.Close()
+
+ const expectedRows = 3
+ var currentRow int64
+
+ var n int64
+ var s string
+ for cursor.Next() {
+ currentRow++
+ err = cursor.Scan(&s, &n)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != currentRow {
+ t.Errorf("expected number(Age)=%d, got %d", currentRow, n)
+ }
+ }
+ if currentRow != expectedRows {
+ t.Errorf("expected %d rows, got %d rows", expectedRows, currentRow)
+ }
+}
+
+func TestInvalidNilValues(t *testing.T) {
+ var date1 time.Time
+ var date2 int
+
+ tests := []struct {
+ name string
+ input any
+ expectedError string
+ }{
+ {
+ name: "time.Time",
+ input: &date1,
+ expectedError: `sql: Scan error on column index 0, name "bdate": unsupported Scan, storing driver.Value type <nil> into type *time.Time`,
+ },
+ {
+ name: "int",
+ input: &date2,
+ expectedError: `sql: Scan error on column index 0, name "bdate": converting NULL to int is unsupported`,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ conn, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conn.dc.ci.(*fakeConn).skipDirtySession = true
+ defer conn.Close()
+
+ err = conn.QueryRowContext(ctx, "SELECT|people|bdate|age=?", 1).Scan(tt.input)
+ if err == nil {
+ t.Fatal("expected error when querying nil column, but succeeded")
+ }
+ if err.Error() != tt.expectedError {
+ t.Fatalf("Expected error: %s\nReceived: %s", tt.expectedError, err.Error())
+ }
+
+ err = conn.PingContext(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ }
+}
+
+func TestConnTx(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ conn, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conn.dc.ci.(*fakeConn).skipDirtySession = true
+ defer conn.Close()
+
+ tx, err := conn.BeginTx(ctx, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ insertName, insertAge := "Nancy", 33
+ _, err = tx.ExecContext(ctx, "INSERT|people|name=?,age=?,photo=APHOTO", insertName, insertAge)
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = tx.Commit()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var selectName string
+ err = conn.QueryRowContext(ctx, "SELECT|people|name|age=?", insertAge).Scan(&selectName)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if selectName != insertName {
+ t.Fatalf("got %q want %q", selectName, insertName)
+ }
+}
+
+// TestConnIsValid verifies that a database connection that should be discarded,
+// is actually discarded and does not re-enter the connection pool.
+// If the IsValid method from *fakeConn is removed, this test will fail.
+func TestConnIsValid(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ db.SetMaxOpenConns(1)
+
+ ctx := context.Background()
+
+ c, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = c.Raw(func(raw any) error {
+ dc := raw.(*fakeConn)
+ dc.stickyBad = true
+ return nil
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.Close()
+
+ if len(db.freeConn) > 0 && db.freeConn[0].ci.(*fakeConn).stickyBad {
+ t.Fatal("bad connection returned to pool; expected bad connection to be discarded")
+ }
+}
+
+// Tests fix for issue 2542, that we release a lock when querying on
+// a closed connection.
+func TestIssue2542Deadlock(t *testing.T) {
+ db := newTestDB(t, "people")
+ closeDB(t, db)
+ for i := 0; i < 2; i++ {
+ _, err := db.Query("SELECT|people|age,name|")
+ if err == nil {
+ t.Fatalf("expected error")
+ }
+ }
+}
+
+// From golang.org/issue/3865
+func TestCloseStmtBeforeRows(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ s, err := db.Prepare("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ r, err := s.Query()
+ if err != nil {
+ s.Close()
+ t.Fatal(err)
+ }
+
+ err = s.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ r.Close()
+}
+
+// Tests fix for issue 2788, that we bind nil to a []byte if the
+// value in the column is sql null
+func TestNullByteSlice(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t|id=int32,name=nullstring")
+ exec(t, db, "INSERT|t|id=10,name=?", nil)
+
+ var name []byte
+
+ err := db.QueryRow("SELECT|t|name|id=?", 10).Scan(&name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if name != nil {
+ t.Fatalf("name []byte should be nil for null column value, got: %#v", name)
+ }
+
+ exec(t, db, "INSERT|t|id=11,name=?", "bob")
+ err = db.QueryRow("SELECT|t|name|id=?", 11).Scan(&name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(name) != "bob" {
+ t.Fatalf("name []byte should be bob, got: %q", string(name))
+ }
+}
+
+func TestPointerParamsAndScans(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t|id=int32,name=nullstring")
+
+ bob := "bob"
+ var name *string
+
+ name = &bob
+ exec(t, db, "INSERT|t|id=10,name=?", name)
+ name = nil
+ exec(t, db, "INSERT|t|id=20,name=?", name)
+
+ err := db.QueryRow("SELECT|t|name|id=?", 10).Scan(&name)
+ if err != nil {
+ t.Fatalf("querying id 10: %v", err)
+ }
+ if name == nil {
+ t.Errorf("id 10's name = nil; want bob")
+ } else if *name != "bob" {
+ t.Errorf("id 10's name = %q; want bob", *name)
+ }
+
+ err = db.QueryRow("SELECT|t|name|id=?", 20).Scan(&name)
+ if err != nil {
+ t.Fatalf("querying id 20: %v", err)
+ }
+ if name != nil {
+ t.Errorf("id 20 = %q; want nil", *name)
+ }
+}
+
+func TestQueryRowClosingStmt(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ var name string
+ var age int
+ err := db.QueryRow("SELECT|people|age,name|age=?", 3).Scan(&age, &name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(db.freeConn) != 1 {
+ t.Fatalf("expected 1 free conn")
+ }
+ fakeConn := db.freeConn[0].ci.(*fakeConn)
+ if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed {
+ t.Errorf("statement close mismatch: made %d, closed %d", made, closed)
+ }
+}
+
+var atomicRowsCloseHook atomic.Value // of func(*Rows, *error)
+
+func init() {
+ rowsCloseHook = func() func(*Rows, *error) {
+ fn, _ := atomicRowsCloseHook.Load().(func(*Rows, *error))
+ return fn
+ }
+}
+
+func setRowsCloseHook(fn func(*Rows, *error)) {
+ if fn == nil {
+ // Can't change an atomic.Value back to nil, so set it to this
+ // no-op func instead.
+ fn = func(*Rows, *error) {}
+ }
+ atomicRowsCloseHook.Store(fn)
+}
+
+// Test issue 6651
+func TestIssue6651(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ var v string
+
+ want := "error in rows.Next"
+ rowsCursorNextHook = func(dest []driver.Value) error {
+ return fmt.Errorf(want)
+ }
+ defer func() { rowsCursorNextHook = nil }()
+
+ err := db.QueryRow("SELECT|people|name|").Scan(&v)
+ if err == nil || err.Error() != want {
+ t.Errorf("error = %q; want %q", err, want)
+ }
+ rowsCursorNextHook = nil
+
+ want = "error in rows.Close"
+ setRowsCloseHook(func(rows *Rows, err *error) {
+ *err = fmt.Errorf(want)
+ })
+ defer setRowsCloseHook(nil)
+ err = db.QueryRow("SELECT|people|name|").Scan(&v)
+ if err == nil || err.Error() != want {
+ t.Errorf("error = %q; want %q", err, want)
+ }
+}
+
+type nullTestRow struct {
+ nullParam any
+ notNullParam any
+ scanNullVal any
+}
+
+type nullTestSpec struct {
+ nullType string
+ notNullType string
+ rows [6]nullTestRow
+}
+
+func TestNullStringParam(t *testing.T) {
+ spec := nullTestSpec{"nullstring", "string", [6]nullTestRow{
+ {NullString{"aqua", true}, "", NullString{"aqua", true}},
+ {NullString{"brown", false}, "", NullString{"", false}},
+ {"chartreuse", "", NullString{"chartreuse", true}},
+ {NullString{"darkred", true}, "", NullString{"darkred", true}},
+ {NullString{"eel", false}, "", NullString{"", false}},
+ {"foo", NullString{"black", false}, nil},
+ }}
+ nullTestRun(t, spec)
+}
+
+func TestNullInt64Param(t *testing.T) {
+ spec := nullTestSpec{"nullint64", "int64", [6]nullTestRow{
+ {NullInt64{31, true}, 1, NullInt64{31, true}},
+ {NullInt64{-22, false}, 1, NullInt64{0, false}},
+ {22, 1, NullInt64{22, true}},
+ {NullInt64{33, true}, 1, NullInt64{33, true}},
+ {NullInt64{222, false}, 1, NullInt64{0, false}},
+ {0, NullInt64{31, false}, nil},
+ }}
+ nullTestRun(t, spec)
+}
+
+func TestNullInt32Param(t *testing.T) {
+ spec := nullTestSpec{"nullint32", "int32", [6]nullTestRow{
+ {NullInt32{31, true}, 1, NullInt32{31, true}},
+ {NullInt32{-22, false}, 1, NullInt32{0, false}},
+ {22, 1, NullInt32{22, true}},
+ {NullInt32{33, true}, 1, NullInt32{33, true}},
+ {NullInt32{222, false}, 1, NullInt32{0, false}},
+ {0, NullInt32{31, false}, nil},
+ }}
+ nullTestRun(t, spec)
+}
+
+func TestNullInt16Param(t *testing.T) {
+ spec := nullTestSpec{"nullint16", "int16", [6]nullTestRow{
+ {NullInt16{31, true}, 1, NullInt16{31, true}},
+ {NullInt16{-22, false}, 1, NullInt16{0, false}},
+ {22, 1, NullInt16{22, true}},
+ {NullInt16{33, true}, 1, NullInt16{33, true}},
+ {NullInt16{222, false}, 1, NullInt16{0, false}},
+ {0, NullInt16{31, false}, nil},
+ }}
+ nullTestRun(t, spec)
+}
+
+func TestNullByteParam(t *testing.T) {
+ spec := nullTestSpec{"nullbyte", "byte", [6]nullTestRow{
+ {NullByte{31, true}, 1, NullByte{31, true}},
+ {NullByte{0, false}, 1, NullByte{0, false}},
+ {22, 1, NullByte{22, true}},
+ {NullByte{33, true}, 1, NullByte{33, true}},
+ {NullByte{222, false}, 1, NullByte{0, false}},
+ {0, NullByte{31, false}, nil},
+ }}
+ nullTestRun(t, spec)
+}
+
+func TestNullFloat64Param(t *testing.T) {
+ spec := nullTestSpec{"nullfloat64", "float64", [6]nullTestRow{
+ {NullFloat64{31.2, true}, 1, NullFloat64{31.2, true}},
+ {NullFloat64{13.1, false}, 1, NullFloat64{0, false}},
+ {-22.9, 1, NullFloat64{-22.9, true}},
+ {NullFloat64{33.81, true}, 1, NullFloat64{33.81, true}},
+ {NullFloat64{222, false}, 1, NullFloat64{0, false}},
+ {10, NullFloat64{31.2, false}, nil},
+ }}
+ nullTestRun(t, spec)
+}
+
+func TestNullBoolParam(t *testing.T) {
+ spec := nullTestSpec{"nullbool", "bool", [6]nullTestRow{
+ {NullBool{false, true}, true, NullBool{false, true}},
+ {NullBool{true, false}, false, NullBool{false, false}},
+ {true, true, NullBool{true, true}},
+ {NullBool{true, true}, false, NullBool{true, true}},
+ {NullBool{true, false}, true, NullBool{false, false}},
+ {true, NullBool{true, false}, nil},
+ }}
+ nullTestRun(t, spec)
+}
+
+func TestNullTimeParam(t *testing.T) {
+ t0 := time.Time{}
+ t1 := time.Date(2000, 1, 1, 8, 9, 10, 11, time.UTC)
+ t2 := time.Date(2010, 1, 1, 8, 9, 10, 11, time.UTC)
+ spec := nullTestSpec{"nulldatetime", "datetime", [6]nullTestRow{
+ {NullTime{t1, true}, t2, NullTime{t1, true}},
+ {NullTime{t1, false}, t2, NullTime{t0, false}},
+ {t1, t2, NullTime{t1, true}},
+ {NullTime{t1, true}, t2, NullTime{t1, true}},
+ {NullTime{t1, false}, t2, NullTime{t0, false}},
+ {t2, NullTime{t1, false}, nil},
+ }}
+ nullTestRun(t, spec)
+}
+
+func nullTestRun(t *testing.T, spec nullTestSpec) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, fmt.Sprintf("CREATE|t|id=int32,name=string,nullf=%s,notnullf=%s", spec.nullType, spec.notNullType))
+
+ // Inserts with db.Exec:
+ exec(t, db, "INSERT|t|id=?,name=?,nullf=?,notnullf=?", 1, "alice", spec.rows[0].nullParam, spec.rows[0].notNullParam)
+ exec(t, db, "INSERT|t|id=?,name=?,nullf=?,notnullf=?", 2, "bob", spec.rows[1].nullParam, spec.rows[1].notNullParam)
+
+ // Inserts with a prepared statement:
+ stmt, err := db.Prepare("INSERT|t|id=?,name=?,nullf=?,notnullf=?")
+ if err != nil {
+ t.Fatalf("prepare: %v", err)
+ }
+ defer stmt.Close()
+ if _, err := stmt.Exec(3, "chris", spec.rows[2].nullParam, spec.rows[2].notNullParam); err != nil {
+ t.Errorf("exec insert chris: %v", err)
+ }
+ if _, err := stmt.Exec(4, "dave", spec.rows[3].nullParam, spec.rows[3].notNullParam); err != nil {
+ t.Errorf("exec insert dave: %v", err)
+ }
+ if _, err := stmt.Exec(5, "eleanor", spec.rows[4].nullParam, spec.rows[4].notNullParam); err != nil {
+ t.Errorf("exec insert eleanor: %v", err)
+ }
+
+ // Can't put null val into non-null col
+ if _, err := stmt.Exec(6, "bob", spec.rows[5].nullParam, spec.rows[5].notNullParam); err == nil {
+ t.Errorf("expected error inserting nil val with prepared statement Exec")
+ }
+
+ _, err = db.Exec("INSERT|t|id=?,name=?,nullf=?", 999, nil, nil)
+ if err == nil {
+ // TODO: this test fails, but it's just because
+ // fakeConn implements the optional Execer interface,
+ // so arguably this is the correct behavior. But
+ // maybe I should flesh out the fakeConn.Exec
+ // implementation so this properly fails.
+ // t.Errorf("expected error inserting nil name with Exec")
+ }
+
+ paramtype := reflect.TypeOf(spec.rows[0].nullParam)
+ bindVal := reflect.New(paramtype).Interface()
+
+ for i := 0; i < 5; i++ {
+ id := i + 1
+ if err := db.QueryRow("SELECT|t|nullf|id=?", id).Scan(bindVal); err != nil {
+ t.Errorf("id=%d Scan: %v", id, err)
+ }
+ bindValDeref := reflect.ValueOf(bindVal).Elem().Interface()
+ if !reflect.DeepEqual(bindValDeref, spec.rows[i].scanNullVal) {
+ t.Errorf("id=%d got %#v, want %#v", id, bindValDeref, spec.rows[i].scanNullVal)
+ }
+ }
+}
+
+// golang.org/issue/4859
+func TestQueryRowNilScanDest(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ var name *string // nil pointer
+ err := db.QueryRow("SELECT|people|name|").Scan(name)
+ want := `sql: Scan error on column index 0, name "name": destination pointer is nil`
+ if err == nil || err.Error() != want {
+ t.Errorf("error = %q; want %q", err.Error(), want)
+ }
+}
+
+func TestIssue4902(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ driver := db.Driver().(*fakeDriver)
+ opens0 := driver.openCount
+
+ var stmt *Stmt
+ var err error
+ for i := 0; i < 10; i++ {
+ stmt, err = db.Prepare("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = stmt.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ opens := driver.openCount - opens0
+ if opens > 1 {
+ t.Errorf("opens = %d; want <= 1", opens)
+ t.Logf("db = %#v", db)
+ t.Logf("driver = %#v", driver)
+ t.Logf("stmt = %#v", stmt)
+ }
+}
+
+// Issue 3857
+// This used to deadlock.
+func TestSimultaneousQueries(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer tx.Rollback()
+
+ r1, err := tx.Query("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer r1.Close()
+
+ r2, err := tx.Query("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer r2.Close()
+}
+
+func TestMaxIdleConns(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ tx.Commit()
+ if got := len(db.freeConn); got != 1 {
+ t.Errorf("freeConns = %d; want 1", got)
+ }
+
+ db.SetMaxIdleConns(0)
+
+ if got := len(db.freeConn); got != 0 {
+ t.Errorf("freeConns after set to zero = %d; want 0", got)
+ }
+
+ tx, err = db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ tx.Commit()
+ if got := len(db.freeConn); got != 0 {
+ t.Errorf("freeConns = %d; want 0", got)
+ }
+}
+
+func TestMaxOpenConns(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ defer setHookpostCloseConn(nil)
+ setHookpostCloseConn(func(_ *fakeConn, err error) {
+ if err != nil {
+ t.Errorf("Error closing fakeConn: %v", err)
+ }
+ })
+
+ db := newTestDB(t, "magicquery")
+ defer closeDB(t, db)
+
+ driver := db.Driver().(*fakeDriver)
+
+ // Force the number of open connections to 0 so we can get an accurate
+ // count for the test
+ db.clearAllConns(t)
+
+ driver.mu.Lock()
+ opens0 := driver.openCount
+ closes0 := driver.closeCount
+ driver.mu.Unlock()
+
+ db.SetMaxIdleConns(10)
+ db.SetMaxOpenConns(10)
+
+ stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Start 50 parallel slow queries.
+ const (
+ nquery = 50
+ sleepMillis = 25
+ nbatch = 2
+ )
+ var wg sync.WaitGroup
+ for batch := 0; batch < nbatch; batch++ {
+ for i := 0; i < nquery; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ var op string
+ if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows {
+ t.Error(err)
+ }
+ }()
+ }
+ // Wait for the batch of queries above to finish before starting the next round.
+ wg.Wait()
+ }
+
+ if g, w := db.numFreeConns(), 10; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPoll(t, 20); n > 20 {
+ t.Errorf("number of dependencies = %d; expected <= 20", n)
+ db.dumpDeps(t)
+ }
+
+ driver.mu.Lock()
+ opens := driver.openCount - opens0
+ closes := driver.closeCount - closes0
+ driver.mu.Unlock()
+
+ if opens > 10 {
+ t.Logf("open calls = %d", opens)
+ t.Logf("close calls = %d", closes)
+ t.Errorf("db connections opened = %d; want <= 10", opens)
+ db.dumpDeps(t)
+ }
+
+ if err := stmt.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ if g, w := db.numFreeConns(), 10; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPoll(t, 10); n > 10 {
+ t.Errorf("number of dependencies = %d; expected <= 10", n)
+ db.dumpDeps(t)
+ }
+
+ db.SetMaxOpenConns(5)
+
+ if g, w := db.numFreeConns(), 5; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPoll(t, 5); n > 5 {
+ t.Errorf("number of dependencies = %d; expected 0", n)
+ db.dumpDeps(t)
+ }
+
+ db.SetMaxOpenConns(0)
+
+ if g, w := db.numFreeConns(), 5; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPoll(t, 5); n > 5 {
+ t.Errorf("number of dependencies = %d; expected 0", n)
+ db.dumpDeps(t)
+ }
+
+ db.clearAllConns(t)
+}
+
+// Issue 9453: tests that SetMaxOpenConns can be lowered at runtime
+// and affects the subsequent release of connections.
+func TestMaxOpenConnsOnBusy(t *testing.T) {
+ defer setHookpostCloseConn(nil)
+ setHookpostCloseConn(func(_ *fakeConn, err error) {
+ if err != nil {
+ t.Errorf("Error closing fakeConn: %v", err)
+ }
+ })
+
+ db := newTestDB(t, "magicquery")
+ defer closeDB(t, db)
+
+ db.SetMaxOpenConns(3)
+
+ ctx := context.Background()
+
+ conn0, err := db.conn(ctx, cachedOrNewConn)
+ if err != nil {
+ t.Fatalf("db open conn fail: %v", err)
+ }
+
+ conn1, err := db.conn(ctx, cachedOrNewConn)
+ if err != nil {
+ t.Fatalf("db open conn fail: %v", err)
+ }
+
+ conn2, err := db.conn(ctx, cachedOrNewConn)
+ if err != nil {
+ t.Fatalf("db open conn fail: %v", err)
+ }
+
+ if g, w := db.numOpen, 3; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ db.SetMaxOpenConns(2)
+ if g, w := db.numOpen, 3; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ conn0.releaseConn(nil)
+ conn1.releaseConn(nil)
+ if g, w := db.numOpen, 2; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ conn2.releaseConn(nil)
+ if g, w := db.numOpen, 2; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+}
+
+// Issue 10886: tests that all connection attempts return when more than
+// DB.maxOpen connections are in flight and the first DB.maxOpen fail.
+func TestPendingConnsAfterErr(t *testing.T) {
+ const (
+ maxOpen = 2
+ tryOpen = maxOpen*2 + 2
+ )
+
+ // No queries will be run.
+ db, err := Open("test", fakeDBName)
+ if err != nil {
+ t.Fatalf("Open: %v", err)
+ }
+ defer closeDB(t, db)
+ defer func() {
+ for k, v := range db.lastPut {
+ t.Logf("%p: %v", k, v)
+ }
+ }()
+
+ db.SetMaxOpenConns(maxOpen)
+ db.SetMaxIdleConns(0)
+
+ errOffline := errors.New("db offline")
+
+ defer func() { setHookOpenErr(nil) }()
+
+ errs := make(chan error, tryOpen)
+
+ var opening sync.WaitGroup
+ opening.Add(tryOpen)
+
+ setHookOpenErr(func() error {
+ // Wait for all connections to enqueue.
+ opening.Wait()
+ return errOffline
+ })
+
+ for i := 0; i < tryOpen; i++ {
+ go func() {
+ opening.Done() // signal one connection is in flight
+ _, err := db.Exec("will never run")
+ errs <- err
+ }()
+ }
+
+ opening.Wait() // wait for all workers to begin running
+
+ const timeout = 5 * time.Second
+ to := time.NewTimer(timeout)
+ defer to.Stop()
+
+ // check that all connections fail without deadlock
+ for i := 0; i < tryOpen; i++ {
+ select {
+ case err := <-errs:
+ if got, want := err, errOffline; got != want {
+ t.Errorf("unexpected err: got %v, want %v", got, want)
+ }
+ case <-to.C:
+ t.Fatalf("orphaned connection request(s), still waiting after %v", timeout)
+ }
+ }
+
+ // Wait a reasonable time for the database to close all connections.
+ tick := time.NewTicker(3 * time.Millisecond)
+ defer tick.Stop()
+ for {
+ select {
+ case <-tick.C:
+ db.mu.Lock()
+ if db.numOpen == 0 {
+ db.mu.Unlock()
+ return
+ }
+ db.mu.Unlock()
+ case <-to.C:
+ // Closing the database will check for numOpen and fail the test.
+ return
+ }
+ }
+}
+
+func TestSingleOpenConn(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ db.SetMaxOpenConns(1)
+
+ rows, err := db.Query("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err = rows.Close(); err != nil {
+ t.Fatal(err)
+ }
+ // shouldn't deadlock
+ rows, err = db.Query("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err = rows.Close(); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestStats(t *testing.T) {
+ db := newTestDB(t, "people")
+ stats := db.Stats()
+ if got := stats.OpenConnections; got != 1 {
+ t.Errorf("stats.OpenConnections = %d; want 1", got)
+ }
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ tx.Commit()
+
+ closeDB(t, db)
+ stats = db.Stats()
+ if got := stats.OpenConnections; got != 0 {
+ t.Errorf("stats.OpenConnections = %d; want 0", got)
+ }
+}
+
+func TestConnMaxLifetime(t *testing.T) {
+ t0 := time.Unix(1000000, 0)
+ offset := time.Duration(0)
+
+ nowFunc = func() time.Time { return t0.Add(offset) }
+ defer func() { nowFunc = time.Now }()
+
+ db := newTestDB(t, "magicquery")
+ defer closeDB(t, db)
+
+ driver := db.Driver().(*fakeDriver)
+
+ // Force the number of open connections to 0 so we can get an accurate
+ // count for the test
+ db.clearAllConns(t)
+
+ driver.mu.Lock()
+ opens0 := driver.openCount
+ closes0 := driver.closeCount
+ driver.mu.Unlock()
+
+ db.SetMaxIdleConns(10)
+ db.SetMaxOpenConns(10)
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ offset = time.Second
+ tx2, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tx.Commit()
+ tx2.Commit()
+
+ driver.mu.Lock()
+ opens := driver.openCount - opens0
+ closes := driver.closeCount - closes0
+ driver.mu.Unlock()
+
+ if opens != 2 {
+ t.Errorf("opens = %d; want 2", opens)
+ }
+ if closes != 0 {
+ t.Errorf("closes = %d; want 0", closes)
+ }
+ if g, w := db.numFreeConns(), 2; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ // Expire first conn
+ offset = 11 * time.Second
+ db.SetConnMaxLifetime(10 * time.Second)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tx, err = db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ tx2, err = db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ tx.Commit()
+ tx2.Commit()
+
+ // Give connectionCleaner chance to run.
+ waitCondition(t, func() bool {
+ driver.mu.Lock()
+ opens = driver.openCount - opens0
+ closes = driver.closeCount - closes0
+ driver.mu.Unlock()
+
+ return closes == 1
+ })
+
+ if opens != 3 {
+ t.Errorf("opens = %d; want 3", opens)
+ }
+ if closes != 1 {
+ t.Errorf("closes = %d; want 1", closes)
+ }
+
+ if s := db.Stats(); s.MaxLifetimeClosed != 1 {
+ t.Errorf("MaxLifetimeClosed = %d; want 1 %#v", s.MaxLifetimeClosed, s)
+ }
+}
+
+// golang.org/issue/5323
+func TestStmtCloseDeps(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ defer setHookpostCloseConn(nil)
+ setHookpostCloseConn(func(_ *fakeConn, err error) {
+ if err != nil {
+ t.Errorf("Error closing fakeConn: %v", err)
+ }
+ })
+
+ db := newTestDB(t, "magicquery")
+ defer closeDB(t, db)
+
+ driver := db.Driver().(*fakeDriver)
+
+ driver.mu.Lock()
+ opens0 := driver.openCount
+ closes0 := driver.closeCount
+ driver.mu.Unlock()
+ openDelta0 := opens0 - closes0
+
+ stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Start 50 parallel slow queries.
+ const (
+ nquery = 50
+ sleepMillis = 25
+ nbatch = 2
+ )
+ var wg sync.WaitGroup
+ for batch := 0; batch < nbatch; batch++ {
+ for i := 0; i < nquery; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ var op string
+ if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows {
+ t.Error(err)
+ }
+ }()
+ }
+ // Wait for the batch of queries above to finish before starting the next round.
+ wg.Wait()
+ }
+
+ if g, w := db.numFreeConns(), 2; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPoll(t, 4); n > 4 {
+ t.Errorf("number of dependencies = %d; expected <= 4", n)
+ db.dumpDeps(t)
+ }
+
+ driver.mu.Lock()
+ opens := driver.openCount - opens0
+ closes := driver.closeCount - closes0
+ openDelta := (driver.openCount - driver.closeCount) - openDelta0
+ driver.mu.Unlock()
+
+ if openDelta > 2 {
+ t.Logf("open calls = %d", opens)
+ t.Logf("close calls = %d", closes)
+ t.Logf("open delta = %d", openDelta)
+ t.Errorf("db connections opened = %d; want <= 2", openDelta)
+ db.dumpDeps(t)
+ }
+
+ if !waitCondition(t, func() bool {
+ return len(stmt.css) <= nquery
+ }) {
+ t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery)
+ }
+
+ if err := stmt.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ if g, w := db.numFreeConns(), 2; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPoll(t, 2); n > 2 {
+ t.Errorf("number of dependencies = %d; expected <= 2", n)
+ db.dumpDeps(t)
+ }
+
+ db.clearAllConns(t)
+}
+
+// golang.org/issue/5046
+func TestCloseConnBeforeStmts(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ defer setHookpostCloseConn(nil)
+ setHookpostCloseConn(func(_ *fakeConn, err error) {
+ if err != nil {
+ t.Errorf("Error closing fakeConn: %v; from %s", err, stack())
+ db.dumpDeps(t)
+ t.Errorf("DB = %#v", db)
+ }
+ })
+
+ stmt, err := db.Prepare("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if len(db.freeConn) != 1 {
+ t.Fatalf("expected 1 freeConn; got %d", len(db.freeConn))
+ }
+ dc := db.freeConn[0]
+ if dc.closed {
+ t.Errorf("conn shouldn't be closed")
+ }
+
+ if n := len(dc.openStmt); n != 1 {
+ t.Errorf("driverConn num openStmt = %d; want 1", n)
+ }
+ err = db.Close()
+ if err != nil {
+ t.Errorf("db Close = %v", err)
+ }
+ if !dc.closed {
+ t.Errorf("after db.Close, driverConn should be closed")
+ }
+ if n := len(dc.openStmt); n != 0 {
+ t.Errorf("driverConn num openStmt = %d; want 0", n)
+ }
+
+ err = stmt.Close()
+ if err != nil {
+ t.Errorf("Stmt close = %v", err)
+ }
+
+ if !dc.closed {
+ t.Errorf("conn should be closed")
+ }
+ if dc.ci != nil {
+ t.Errorf("after Stmt Close, driverConn's Conn interface should be nil")
+ }
+}
+
+// golang.org/issue/5283: don't release the Rows' connection in Close
+// before calling Stmt.Close.
+func TestRowsCloseOrder(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ db.SetMaxIdleConns(0)
+ setStrictFakeConnClose(t)
+ defer setStrictFakeConnClose(nil)
+
+ rows, err := db.Query("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ err = rows.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestRowsImplicitClose(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ rows, err := db.Query("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ want, fail := 2, errors.New("fail")
+ r := rows.rowsi.(*rowsCursor)
+ r.errPos, r.err = want, fail
+
+ got := 0
+ for rows.Next() {
+ got++
+ }
+ if got != want {
+ t.Errorf("got %d rows, want %d", got, want)
+ }
+ if err := rows.Err(); err != fail {
+ t.Errorf("got error %v, want %v", err, fail)
+ }
+ if !r.closed {
+ t.Errorf("r.closed is false, want true")
+ }
+}
+
+func TestRowsCloseError(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer db.Close()
+ rows, err := db.Query("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ type row struct {
+ age int
+ name string
+ }
+ got := []row{}
+
+ rc, ok := rows.rowsi.(*rowsCursor)
+ if !ok {
+ t.Fatal("not using *rowsCursor")
+ }
+ rc.closeErr = errors.New("rowsCursor: failed to close")
+
+ for rows.Next() {
+ var r row
+ err = rows.Scan(&r.age, &r.name)
+ if err != nil {
+ t.Fatalf("Scan: %v", err)
+ }
+ got = append(got, r)
+ }
+ err = rows.Err()
+ if err != rc.closeErr {
+ t.Fatalf("unexpected err: got %v, want %v", err, rc.closeErr)
+ }
+}
+
+func TestStmtCloseOrder(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ db.SetMaxIdleConns(0)
+ setStrictFakeConnClose(t)
+ defer setStrictFakeConnClose(nil)
+
+ _, err := db.Query("SELECT|non_existent|name|")
+ if err == nil {
+ t.Fatal("Querying non-existent table should fail")
+ }
+}
+
+// Test cases where there's more than maxBadConnRetries bad connections in the
+// pool (issue 8834)
+func TestManyErrBadConn(t *testing.T) {
+ manyErrBadConnSetup := func(first ...func(db *DB)) *DB {
+ db := newTestDB(t, "people")
+
+ for _, f := range first {
+ f(db)
+ }
+
+ nconn := maxBadConnRetries + 1
+ db.SetMaxIdleConns(nconn)
+ db.SetMaxOpenConns(nconn)
+ // open enough connections
+ func() {
+ for i := 0; i < nconn; i++ {
+ rows, err := db.Query("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer rows.Close()
+ }
+ }()
+
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ if db.numOpen != nconn {
+ t.Fatalf("unexpected numOpen %d (was expecting %d)", db.numOpen, nconn)
+ } else if len(db.freeConn) != nconn {
+ t.Fatalf("unexpected len(db.freeConn) %d (was expecting %d)", len(db.freeConn), nconn)
+ }
+ for _, conn := range db.freeConn {
+ conn.Lock()
+ conn.ci.(*fakeConn).stickyBad = true
+ conn.Unlock()
+ }
+ return db
+ }
+
+ // Query
+ db := manyErrBadConnSetup()
+ defer closeDB(t, db)
+ rows, err := db.Query("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err = rows.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ // Exec
+ db = manyErrBadConnSetup()
+ defer closeDB(t, db)
+ _, err = db.Exec("INSERT|people|name=Julia,age=19")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Begin
+ db = manyErrBadConnSetup()
+ defer closeDB(t, db)
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err = tx.Rollback(); err != nil {
+ t.Fatal(err)
+ }
+
+ // Prepare
+ db = manyErrBadConnSetup()
+ defer closeDB(t, db)
+ stmt, err := db.Prepare("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err = stmt.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ // Stmt.Exec
+ db = manyErrBadConnSetup(func(db *DB) {
+ stmt, err = db.Prepare("INSERT|people|name=Julia,age=19")
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ defer closeDB(t, db)
+ _, err = stmt.Exec()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err = stmt.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ // Stmt.Query
+ db = manyErrBadConnSetup(func(db *DB) {
+ stmt, err = db.Prepare("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ })
+ defer closeDB(t, db)
+ rows, err = stmt.Query()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err = rows.Close(); err != nil {
+ t.Fatal(err)
+ }
+ if err = stmt.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ // Conn
+ db = manyErrBadConnSetup()
+ defer closeDB(t, db)
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ conn, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conn.dc.ci.(*fakeConn).skipDirtySession = true
+ err = conn.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Ping
+ db = manyErrBadConnSetup()
+ defer closeDB(t, db)
+ err = db.PingContext(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+// Issue 34775: Ensure that a Tx cannot commit after a rollback.
+func TestTxCannotCommitAfterRollback(t *testing.T) {
+ db := newTestDB(t, "tx_status")
+ defer closeDB(t, db)
+
+ // First check query reporting is correct.
+ var txStatus string
+ err := db.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, w := txStatus, "autocommit"; g != w {
+ t.Fatalf("tx_status=%q, wanted %q", g, w)
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ tx, err := db.BeginTx(ctx, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Ignore dirty session for this test.
+ // A failing test should trigger the dirty session flag as well,
+ // but that isn't exactly what this should test for.
+ tx.txi.(*fakeTx).c.skipDirtySession = true
+
+ defer tx.Rollback()
+
+ err = tx.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if g, w := txStatus, "transaction"; g != w {
+ t.Fatalf("tx_status=%q, wanted %q", g, w)
+ }
+
+ // 1. Begin a transaction.
+ // 2. (A) Start a query, (B) begin Tx rollback through a ctx cancel.
+ // 3. Check if 2.A has committed in Tx (pass) or outside of Tx (fail).
+ sendQuery := make(chan struct{})
+ // The Tx status is returned through the row results, ensure
+ // that the rows results are not canceled.
+ bypassRowsAwaitDone = true
+ hookTxGrabConn = func() {
+ cancel()
+ <-sendQuery
+ }
+ rollbackHook = func() {
+ close(sendQuery)
+ }
+ defer func() {
+ hookTxGrabConn = nil
+ rollbackHook = nil
+ bypassRowsAwaitDone = false
+ }()
+
+ err = tx.QueryRow("SELECT|tx_status|tx_status|").Scan(&txStatus)
+ if err != nil {
+ // A failure here would be expected if skipDirtySession was not set to true above.
+ t.Fatal(err)
+ }
+ if g, w := txStatus, "transaction"; g != w {
+ t.Fatalf("tx_status=%q, wanted %q", g, w)
+ }
+}
+
+// Issue 40985 transaction statement deadlock while context cancel.
+func TestTxStmtDeadlock(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ tx, err := db.BeginTx(ctx, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ stmt, err := tx.Prepare("SELECT|people|name,age|age=?")
+ if err != nil {
+ t.Fatal(err)
+ }
+ cancel()
+ // Run number of stmt queries to reproduce deadlock from context cancel
+ for i := 0; i < 1e3; i++ {
+ // Encounter any close related errors (e.g. ErrTxDone, stmt is closed)
+ // is expected due to context cancel.
+ _, err = stmt.Query(1)
+ if err != nil {
+ break
+ }
+ }
+ _ = tx.Rollback()
+}
+
+// Issue32530 encounters an issue where a connection may
+// expire right after it comes out of a used connection pool
+// even when a new connection is requested.
+func TestConnExpiresFreshOutOfPool(t *testing.T) {
+ execCases := []struct {
+ expired bool
+ badReset bool
+ }{
+ {false, false},
+ {true, false},
+ {false, true},
+ }
+
+ t0 := time.Unix(1000000, 0)
+ offset := time.Duration(0)
+ offsetMu := sync.RWMutex{}
+
+ nowFunc = func() time.Time {
+ offsetMu.RLock()
+ defer offsetMu.RUnlock()
+ return t0.Add(offset)
+ }
+ defer func() { nowFunc = time.Now }()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ db := newTestDB(t, "magicquery")
+ defer closeDB(t, db)
+
+ db.SetMaxOpenConns(1)
+
+ for _, ec := range execCases {
+ ec := ec
+ name := fmt.Sprintf("expired=%t,badReset=%t", ec.expired, ec.badReset)
+ t.Run(name, func(t *testing.T) {
+ db.clearAllConns(t)
+
+ db.SetMaxIdleConns(1)
+ db.SetConnMaxLifetime(10 * time.Second)
+
+ conn, err := db.conn(ctx, alwaysNewConn)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ afterPutConn := make(chan struct{})
+ waitingForConn := make(chan struct{})
+
+ go func() {
+ defer close(afterPutConn)
+
+ conn, err := db.conn(ctx, alwaysNewConn)
+ if err == nil {
+ db.putConn(conn, err, false)
+ } else {
+ t.Errorf("db.conn: %v", err)
+ }
+ }()
+ go func() {
+ defer close(waitingForConn)
+
+ for {
+ if t.Failed() {
+ return
+ }
+ db.mu.Lock()
+ ct := len(db.connRequests)
+ db.mu.Unlock()
+ if ct > 0 {
+ return
+ }
+ time.Sleep(pollDuration)
+ }
+ }()
+
+ <-waitingForConn
+
+ if t.Failed() {
+ return
+ }
+
+ offsetMu.Lock()
+ if ec.expired {
+ offset = 11 * time.Second
+ } else {
+ offset = time.Duration(0)
+ }
+ offsetMu.Unlock()
+
+ conn.ci.(*fakeConn).stickyBad = ec.badReset
+
+ db.putConn(conn, err, true)
+
+ <-afterPutConn
+ })
+ }
+}
+
+// TestIssue20575 ensures the Rows from query does not block
+// closing a transaction. Ensure Rows is closed while closing a transaction.
+func TestIssue20575(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ _, err = tx.QueryContext(ctx, "SELECT|people|age,name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ // Do not close Rows from QueryContext.
+ err = tx.Rollback()
+ if err != nil {
+ t.Fatal(err)
+ }
+ select {
+ default:
+ case <-ctx.Done():
+ t.Fatal("timeout: failed to rollback query without closing rows:", ctx.Err())
+ }
+}
+
+// TestIssue20622 tests closing the transaction before rows is closed, requires
+// the race detector to fail.
+func TestIssue20622(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ tx, err := db.BeginTx(ctx, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rows, err := tx.Query("SELECT|people|age,name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ count := 0
+ for rows.Next() {
+ count++
+ var age int
+ var name string
+ if err := rows.Scan(&age, &name); err != nil {
+ t.Fatal("scan failed", err)
+ }
+
+ if count == 1 {
+ cancel()
+ }
+ time.Sleep(100 * time.Millisecond)
+ }
+ rows.Close()
+ tx.Commit()
+}
+
+// golang.org/issue/5718
+func TestErrBadConnReconnect(t *testing.T) {
+ db := newTestDB(t, "foo")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
+
+ simulateBadConn := func(name string, hook *func() bool, op func() error) {
+ broken, retried := false, false
+ numOpen := db.numOpen
+
+ // simulate a broken connection on the first try
+ *hook = func() bool {
+ if !broken {
+ broken = true
+ return true
+ }
+ retried = true
+ return false
+ }
+
+ if err := op(); err != nil {
+ t.Errorf(name+": %v", err)
+ return
+ }
+
+ if !broken || !retried {
+ t.Error(name + ": Failed to simulate broken connection")
+ }
+ *hook = nil
+
+ if numOpen != db.numOpen {
+ t.Errorf(name+": leaked %d connection(s)!", db.numOpen-numOpen)
+ numOpen = db.numOpen
+ }
+ }
+
+ // db.Exec
+ dbExec := func() error {
+ _, err := db.Exec("INSERT|t1|name=?,age=?,dead=?", "Gordon", 3, true)
+ return err
+ }
+ simulateBadConn("db.Exec prepare", &hookPrepareBadConn, dbExec)
+ simulateBadConn("db.Exec exec", &hookExecBadConn, dbExec)
+
+ // db.Query
+ dbQuery := func() error {
+ rows, err := db.Query("SELECT|t1|age,name|")
+ if err == nil {
+ err = rows.Close()
+ }
+ return err
+ }
+ simulateBadConn("db.Query prepare", &hookPrepareBadConn, dbQuery)
+ simulateBadConn("db.Query query", &hookQueryBadConn, dbQuery)
+
+ // db.Prepare
+ simulateBadConn("db.Prepare", &hookPrepareBadConn, func() error {
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?")
+ if err != nil {
+ return err
+ }
+ stmt.Close()
+ return nil
+ })
+
+ // Provide a way to force a re-prepare of a statement on next execution
+ forcePrepare := func(stmt *Stmt) {
+ stmt.css = nil
+ }
+
+ // stmt.Exec
+ stmt1, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?")
+ if err != nil {
+ t.Fatalf("prepare: %v", err)
+ }
+ defer stmt1.Close()
+ // make sure we must prepare the stmt first
+ forcePrepare(stmt1)
+
+ stmtExec := func() error {
+ _, err := stmt1.Exec("Gopher", 3, false)
+ return err
+ }
+ simulateBadConn("stmt.Exec prepare", &hookPrepareBadConn, stmtExec)
+ simulateBadConn("stmt.Exec exec", &hookExecBadConn, stmtExec)
+
+ // stmt.Query
+ stmt2, err := db.Prepare("SELECT|t1|age,name|")
+ if err != nil {
+ t.Fatalf("prepare: %v", err)
+ }
+ defer stmt2.Close()
+ // make sure we must prepare the stmt first
+ forcePrepare(stmt2)
+
+ stmtQuery := func() error {
+ rows, err := stmt2.Query()
+ if err == nil {
+ err = rows.Close()
+ }
+ return err
+ }
+ simulateBadConn("stmt.Query prepare", &hookPrepareBadConn, stmtQuery)
+ simulateBadConn("stmt.Query exec", &hookQueryBadConn, stmtQuery)
+}
+
+// golang.org/issue/11264
+func TestTxEndBadConn(t *testing.T) {
+ db := newTestDB(t, "foo")
+ defer closeDB(t, db)
+ db.SetMaxIdleConns(0)
+ exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
+ db.SetMaxIdleConns(1)
+
+ simulateBadConn := func(name string, hook *func() bool, op func() error) {
+ broken := false
+ numOpen := db.numOpen
+
+ *hook = func() bool {
+ if !broken {
+ broken = true
+ }
+ return broken
+ }
+
+ if err := op(); !errors.Is(err, driver.ErrBadConn) {
+ t.Errorf(name+": %v", err)
+ return
+ }
+
+ if !broken {
+ t.Error(name + ": Failed to simulate broken connection")
+ }
+ *hook = nil
+
+ if numOpen != db.numOpen {
+ t.Errorf(name+": leaked %d connection(s)!", db.numOpen-numOpen)
+ }
+ }
+
+ // db.Exec
+ dbExec := func(endTx func(tx *Tx) error) func() error {
+ return func() error {
+ tx, err := db.Begin()
+ if err != nil {
+ return err
+ }
+ _, err = tx.Exec("INSERT|t1|name=?,age=?,dead=?", "Gordon", 3, true)
+ if err != nil {
+ return err
+ }
+ return endTx(tx)
+ }
+ }
+ simulateBadConn("db.Tx.Exec commit", &hookCommitBadConn, dbExec((*Tx).Commit))
+ simulateBadConn("db.Tx.Exec rollback", &hookRollbackBadConn, dbExec((*Tx).Rollback))
+
+ // db.Query
+ dbQuery := func(endTx func(tx *Tx) error) func() error {
+ return func() error {
+ tx, err := db.Begin()
+ if err != nil {
+ return err
+ }
+ rows, err := tx.Query("SELECT|t1|age,name|")
+ if err == nil {
+ err = rows.Close()
+ } else {
+ return err
+ }
+ return endTx(tx)
+ }
+ }
+ simulateBadConn("db.Tx.Query commit", &hookCommitBadConn, dbQuery((*Tx).Commit))
+ simulateBadConn("db.Tx.Query rollback", &hookRollbackBadConn, dbQuery((*Tx).Rollback))
+}
+
+type concurrentTest interface {
+ init(t testing.TB, db *DB)
+ finish(t testing.TB)
+ test(t testing.TB) error
+}
+
+type concurrentDBQueryTest struct {
+ db *DB
+}
+
+func (c *concurrentDBQueryTest) init(t testing.TB, db *DB) {
+ c.db = db
+}
+
+func (c *concurrentDBQueryTest) finish(t testing.TB) {
+ c.db = nil
+}
+
+func (c *concurrentDBQueryTest) test(t testing.TB) error {
+ rows, err := c.db.Query("SELECT|people|name|")
+ if err != nil {
+ t.Error(err)
+ return err
+ }
+ var name string
+ for rows.Next() {
+ rows.Scan(&name)
+ }
+ rows.Close()
+ return nil
+}
+
+type concurrentDBExecTest struct {
+ db *DB
+}
+
+func (c *concurrentDBExecTest) init(t testing.TB, db *DB) {
+ c.db = db
+}
+
+func (c *concurrentDBExecTest) finish(t testing.TB) {
+ c.db = nil
+}
+
+func (c *concurrentDBExecTest) test(t testing.TB) error {
+ _, err := c.db.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
+ if err != nil {
+ t.Error(err)
+ return err
+ }
+ return nil
+}
+
+type concurrentStmtQueryTest struct {
+ db *DB
+ stmt *Stmt
+}
+
+func (c *concurrentStmtQueryTest) init(t testing.TB, db *DB) {
+ c.db = db
+ var err error
+ c.stmt, err = db.Prepare("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func (c *concurrentStmtQueryTest) finish(t testing.TB) {
+ if c.stmt != nil {
+ c.stmt.Close()
+ c.stmt = nil
+ }
+ c.db = nil
+}
+
+func (c *concurrentStmtQueryTest) test(t testing.TB) error {
+ rows, err := c.stmt.Query()
+ if err != nil {
+ t.Errorf("error on query: %v", err)
+ return err
+ }
+
+ var name string
+ for rows.Next() {
+ rows.Scan(&name)
+ }
+ rows.Close()
+ return nil
+}
+
+type concurrentStmtExecTest struct {
+ db *DB
+ stmt *Stmt
+}
+
+func (c *concurrentStmtExecTest) init(t testing.TB, db *DB) {
+ c.db = db
+ var err error
+ c.stmt, err = db.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?")
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func (c *concurrentStmtExecTest) finish(t testing.TB) {
+ if c.stmt != nil {
+ c.stmt.Close()
+ c.stmt = nil
+ }
+ c.db = nil
+}
+
+func (c *concurrentStmtExecTest) test(t testing.TB) error {
+ _, err := c.stmt.Exec(3, chrisBirthday)
+ if err != nil {
+ t.Errorf("error on exec: %v", err)
+ return err
+ }
+ return nil
+}
+
+type concurrentTxQueryTest struct {
+ db *DB
+ tx *Tx
+}
+
+func (c *concurrentTxQueryTest) init(t testing.TB, db *DB) {
+ c.db = db
+ var err error
+ c.tx, err = c.db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func (c *concurrentTxQueryTest) finish(t testing.TB) {
+ if c.tx != nil {
+ c.tx.Rollback()
+ c.tx = nil
+ }
+ c.db = nil
+}
+
+func (c *concurrentTxQueryTest) test(t testing.TB) error {
+ rows, err := c.db.Query("SELECT|people|name|")
+ if err != nil {
+ t.Error(err)
+ return err
+ }
+ var name string
+ for rows.Next() {
+ rows.Scan(&name)
+ }
+ rows.Close()
+ return nil
+}
+
+type concurrentTxExecTest struct {
+ db *DB
+ tx *Tx
+}
+
+func (c *concurrentTxExecTest) init(t testing.TB, db *DB) {
+ c.db = db
+ var err error
+ c.tx, err = c.db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func (c *concurrentTxExecTest) finish(t testing.TB) {
+ if c.tx != nil {
+ c.tx.Rollback()
+ c.tx = nil
+ }
+ c.db = nil
+}
+
+func (c *concurrentTxExecTest) test(t testing.TB) error {
+ _, err := c.tx.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
+ if err != nil {
+ t.Error(err)
+ return err
+ }
+ return nil
+}
+
+type concurrentTxStmtQueryTest struct {
+ db *DB
+ tx *Tx
+ stmt *Stmt
+}
+
+func (c *concurrentTxStmtQueryTest) init(t testing.TB, db *DB) {
+ c.db = db
+ var err error
+ c.tx, err = c.db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.stmt, err = c.tx.Prepare("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func (c *concurrentTxStmtQueryTest) finish(t testing.TB) {
+ if c.stmt != nil {
+ c.stmt.Close()
+ c.stmt = nil
+ }
+ if c.tx != nil {
+ c.tx.Rollback()
+ c.tx = nil
+ }
+ c.db = nil
+}
+
+func (c *concurrentTxStmtQueryTest) test(t testing.TB) error {
+ rows, err := c.stmt.Query()
+ if err != nil {
+ t.Errorf("error on query: %v", err)
+ return err
+ }
+
+ var name string
+ for rows.Next() {
+ rows.Scan(&name)
+ }
+ rows.Close()
+ return nil
+}
+
+type concurrentTxStmtExecTest struct {
+ db *DB
+ tx *Tx
+ stmt *Stmt
+}
+
+func (c *concurrentTxStmtExecTest) init(t testing.TB, db *DB) {
+ c.db = db
+ var err error
+ c.tx, err = c.db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.stmt, err = c.tx.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?")
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func (c *concurrentTxStmtExecTest) finish(t testing.TB) {
+ if c.stmt != nil {
+ c.stmt.Close()
+ c.stmt = nil
+ }
+ if c.tx != nil {
+ c.tx.Rollback()
+ c.tx = nil
+ }
+ c.db = nil
+}
+
+func (c *concurrentTxStmtExecTest) test(t testing.TB) error {
+ _, err := c.stmt.Exec(3, chrisBirthday)
+ if err != nil {
+ t.Errorf("error on exec: %v", err)
+ return err
+ }
+ return nil
+}
+
+type concurrentRandomTest struct {
+ tests []concurrentTest
+}
+
+func (c *concurrentRandomTest) init(t testing.TB, db *DB) {
+ c.tests = []concurrentTest{
+ new(concurrentDBQueryTest),
+ new(concurrentDBExecTest),
+ new(concurrentStmtQueryTest),
+ new(concurrentStmtExecTest),
+ new(concurrentTxQueryTest),
+ new(concurrentTxExecTest),
+ new(concurrentTxStmtQueryTest),
+ new(concurrentTxStmtExecTest),
+ }
+ for _, ct := range c.tests {
+ ct.init(t, db)
+ }
+}
+
+func (c *concurrentRandomTest) finish(t testing.TB) {
+ for _, ct := range c.tests {
+ ct.finish(t)
+ }
+}
+
+func (c *concurrentRandomTest) test(t testing.TB) error {
+ ct := c.tests[rand.Intn(len(c.tests))]
+ return ct.test(t)
+}
+
+func doConcurrentTest(t testing.TB, ct concurrentTest) {
+ maxProcs, numReqs := 1, 500
+ if testing.Short() {
+ maxProcs, numReqs = 4, 50
+ }
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
+
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ct.init(t, db)
+ defer ct.finish(t)
+
+ var wg sync.WaitGroup
+ wg.Add(numReqs)
+
+ reqs := make(chan bool)
+ defer close(reqs)
+
+ for i := 0; i < maxProcs*2; i++ {
+ go func() {
+ for range reqs {
+ err := ct.test(t)
+ if err != nil {
+ wg.Done()
+ continue
+ }
+ wg.Done()
+ }
+ }()
+ }
+
+ for i := 0; i < numReqs; i++ {
+ reqs <- true
+ }
+
+ wg.Wait()
+}
+
+func TestIssue6081(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ drv := db.Driver().(*fakeDriver)
+ drv.mu.Lock()
+ opens0 := drv.openCount
+ closes0 := drv.closeCount
+ drv.mu.Unlock()
+
+ stmt, err := db.Prepare("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ setRowsCloseHook(func(rows *Rows, err *error) {
+ *err = driver.ErrBadConn
+ })
+ defer setRowsCloseHook(nil)
+ for i := 0; i < 10; i++ {
+ rows, err := stmt.Query()
+ if err != nil {
+ t.Fatal(err)
+ }
+ rows.Close()
+ }
+ if n := len(stmt.css); n > 1 {
+ t.Errorf("len(css slice) = %d; want <= 1", n)
+ }
+ stmt.Close()
+ if n := len(stmt.css); n != 0 {
+ t.Errorf("len(css slice) after Close = %d; want 0", n)
+ }
+
+ drv.mu.Lock()
+ opens := drv.openCount - opens0
+ closes := drv.closeCount - closes0
+ drv.mu.Unlock()
+ if opens < 9 {
+ t.Errorf("opens = %d; want >= 9", opens)
+ }
+ if closes < 9 {
+ t.Errorf("closes = %d; want >= 9", closes)
+ }
+}
+
+// TestIssue18429 attempts to stress rolling back the transaction from a
+// context cancel while simultaneously calling Tx.Rollback. Rolling back from a
+// context happens concurrently so tx.rollback and tx.Commit must guard against
+// double entry.
+//
+// In the test, a context is canceled while the query is in process so
+// the internal rollback will run concurrently with the explicitly called
+// Tx.Rollback.
+//
+// The addition of calling rows.Next also tests
+// Issue 21117.
+func TestIssue18429(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx := context.Background()
+ sem := make(chan bool, 20)
+ var wg sync.WaitGroup
+
+ const milliWait = 30
+
+ for i := 0; i < 100; i++ {
+ sem <- true
+ wg.Add(1)
+ go func() {
+ defer func() {
+ <-sem
+ wg.Done()
+ }()
+ qwait := (time.Duration(rand.Intn(milliWait)) * time.Millisecond).String()
+
+ ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond)
+ defer cancel()
+
+ tx, err := db.BeginTx(ctx, nil)
+ if err != nil {
+ return
+ }
+ // This is expected to give a cancel error most, but not all the time.
+ // Test failure will happen with a panic or other race condition being
+ // reported.
+ rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
+ if rows != nil {
+ var name string
+ // Call Next to test Issue 21117 and check for races.
+ for rows.Next() {
+ // Scan the buffer so it is read and checked for races.
+ rows.Scan(&name)
+ }
+ rows.Close()
+ }
+ // This call will race with the context cancel rollback to complete
+ // if the rollback itself isn't guarded.
+ tx.Rollback()
+ }()
+ }
+ wg.Wait()
+}
+
+// TestIssue20160 attempts to test a short context life on a stmt Query.
+func TestIssue20160(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx := context.Background()
+ sem := make(chan bool, 20)
+ var wg sync.WaitGroup
+
+ const milliWait = 30
+
+ stmt, err := db.PrepareContext(ctx, "SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stmt.Close()
+
+ for i := 0; i < 100; i++ {
+ sem <- true
+ wg.Add(1)
+ go func() {
+ defer func() {
+ <-sem
+ wg.Done()
+ }()
+ ctx, cancel := context.WithTimeout(ctx, time.Duration(rand.Intn(milliWait))*time.Millisecond)
+ defer cancel()
+
+ // This is expected to give a cancel error most, but not all the time.
+ // Test failure will happen with a panic or other race condition being
+ // reported.
+ rows, _ := stmt.QueryContext(ctx)
+ if rows != nil {
+ rows.Close()
+ }
+ }()
+ }
+ wg.Wait()
+}
+
+// TestIssue18719 closes the context right before use. The sql.driverConn
+// will nil out the ci on close in a lock, but if another process uses it right after
+// it will panic with on the nil ref.
+//
+// See https://golang.org/cl/35550 .
+func TestIssue18719(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ tx, err := db.BeginTx(ctx, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ hookTxGrabConn = func() {
+ cancel()
+
+ // Wait for the context to cancel and tx to rollback.
+ for tx.isDone() == false {
+ time.Sleep(pollDuration)
+ }
+ }
+ defer func() { hookTxGrabConn = nil }()
+
+ // This call will grab the connection and cancel the context
+ // after it has done so. Code after must deal with the canceled state.
+ _, err = tx.QueryContext(ctx, "SELECT|people|name|")
+ if err != nil {
+ t.Fatalf("expected error %v but got %v", nil, err)
+ }
+
+ // Rows may be ignored because it will be closed when the context is canceled.
+
+ // Do not explicitly rollback. The rollback will happen from the
+ // canceled context.
+
+ cancel()
+}
+
+func TestIssue20647(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ conn, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ conn.dc.ci.(*fakeConn).skipDirtySession = true
+ defer conn.Close()
+
+ stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer stmt.Close()
+
+ rows1, err := stmt.QueryContext(ctx)
+ if err != nil {
+ t.Fatal("rows1", err)
+ }
+ defer rows1.Close()
+
+ rows2, err := stmt.QueryContext(ctx)
+ if err != nil {
+ t.Fatal("rows2", err)
+ }
+ defer rows2.Close()
+
+ if rows1.dc != rows2.dc {
+ t.Fatal("stmt prepared on Conn does not use same connection")
+ }
+}
+
+func TestConcurrency(t *testing.T) {
+ list := []struct {
+ name string
+ ct concurrentTest
+ }{
+ {"Query", new(concurrentDBQueryTest)},
+ {"Exec", new(concurrentDBExecTest)},
+ {"StmtQuery", new(concurrentStmtQueryTest)},
+ {"StmtExec", new(concurrentStmtExecTest)},
+ {"TxQuery", new(concurrentTxQueryTest)},
+ {"TxExec", new(concurrentTxExecTest)},
+ {"TxStmtQuery", new(concurrentTxStmtQueryTest)},
+ {"TxStmtExec", new(concurrentTxStmtExecTest)},
+ {"Random", new(concurrentRandomTest)},
+ }
+ for _, item := range list {
+ t.Run(item.name, func(t *testing.T) {
+ doConcurrentTest(t, item.ct)
+ })
+ }
+}
+
+func TestConnectionLeak(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ // Start by opening defaultMaxIdleConns
+ rows := make([]*Rows, defaultMaxIdleConns)
+ // We need to SetMaxOpenConns > MaxIdleConns, so the DB can open
+ // a new connection and we can fill the idle queue with the released
+ // connections.
+ db.SetMaxOpenConns(len(rows) + 1)
+ for ii := range rows {
+ r, err := db.Query("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ r.Next()
+ if err := r.Err(); err != nil {
+ t.Fatal(err)
+ }
+ rows[ii] = r
+ }
+ // Now we have defaultMaxIdleConns busy connections. Open
+ // a new one, but wait until the busy connections are released
+ // before returning control to DB.
+ drv := db.Driver().(*fakeDriver)
+ drv.waitCh = make(chan struct{}, 1)
+ drv.waitingCh = make(chan struct{}, 1)
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ r, err := db.Query("SELECT|people|name|")
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ r.Close()
+ wg.Done()
+ }()
+ // Wait until the goroutine we've just created has started waiting.
+ <-drv.waitingCh
+ // Now close the busy connections. This provides a connection for
+ // the blocked goroutine and then fills up the idle queue.
+ for _, v := range rows {
+ v.Close()
+ }
+ // At this point we give the new connection to DB. This connection is
+ // now useless, since the idle queue is full and there are no pending
+ // requests. DB should deal with this situation without leaking the
+ // connection.
+ drv.waitCh <- struct{}{}
+ wg.Wait()
+}
+
+func TestStatsMaxIdleClosedZero(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ db.SetMaxOpenConns(1)
+ db.SetMaxIdleConns(1)
+ db.SetConnMaxLifetime(0)
+
+ preMaxIdleClosed := db.Stats().MaxIdleClosed
+
+ for i := 0; i < 10; i++ {
+ rows, err := db.Query("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ rows.Close()
+ }
+
+ st := db.Stats()
+ maxIdleClosed := st.MaxIdleClosed - preMaxIdleClosed
+ t.Logf("MaxIdleClosed: %d", maxIdleClosed)
+ if maxIdleClosed != 0 {
+ t.Fatal("expected 0 max idle closed conns, got: ", maxIdleClosed)
+ }
+}
+
+func TestStatsMaxIdleClosedTen(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ db.SetMaxOpenConns(1)
+ db.SetMaxIdleConns(0)
+ db.SetConnMaxLifetime(0)
+
+ preMaxIdleClosed := db.Stats().MaxIdleClosed
+
+ for i := 0; i < 10; i++ {
+ rows, err := db.Query("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ rows.Close()
+ }
+
+ st := db.Stats()
+ maxIdleClosed := st.MaxIdleClosed - preMaxIdleClosed
+ t.Logf("MaxIdleClosed: %d", maxIdleClosed)
+ if maxIdleClosed != 10 {
+ t.Fatal("expected 0 max idle closed conns, got: ", maxIdleClosed)
+ }
+}
+
+// testUseConns uses count concurrent connections with 1 nanosecond apart.
+// Returns the returnedAt time of the final connection.
+func testUseConns(t *testing.T, count int, tm time.Time, db *DB) time.Time {
+ conns := make([]*Conn, count)
+ ctx := context.Background()
+ for i := range conns {
+ tm = tm.Add(time.Nanosecond)
+ nowFunc = func() time.Time {
+ return tm
+ }
+ c, err := db.Conn(ctx)
+ if err != nil {
+ t.Error(err)
+ }
+ conns[i] = c
+ }
+
+ for i := len(conns) - 1; i >= 0; i-- {
+ tm = tm.Add(time.Nanosecond)
+ nowFunc = func() time.Time {
+ return tm
+ }
+ if err := conns[i].Close(); err != nil {
+ t.Error(err)
+ }
+ }
+
+ return tm
+}
+
+func TestMaxIdleTime(t *testing.T) {
+ usedConns := 5
+ reusedConns := 2
+ list := []struct {
+ wantMaxIdleTime time.Duration
+ wantMaxLifetime time.Duration
+ wantNextCheck time.Duration
+ wantIdleClosed int64
+ wantMaxIdleClosed int64
+ timeOffset time.Duration
+ secondTimeOffset time.Duration
+ }{
+ {
+ time.Millisecond,
+ 0,
+ time.Millisecond - time.Nanosecond,
+ int64(usedConns - reusedConns),
+ int64(usedConns - reusedConns),
+ 10 * time.Millisecond,
+ 0,
+ },
+ {
+ // Want to close some connections via max idle time and one by max lifetime.
+ time.Millisecond,
+ // nowFunc() - MaxLifetime should be 1 * time.Nanosecond in connectionCleanerRunLocked.
+ // This guarantees that first opened connection is to be closed.
+ // Thus it is timeOffset + secondTimeOffset + 3 (+2 for Close while reusing conns and +1 for Conn).
+ 10*time.Millisecond + 100*time.Nanosecond + 3*time.Nanosecond,
+ time.Nanosecond,
+ // Closed all not reused connections and extra one by max lifetime.
+ int64(usedConns - reusedConns + 1),
+ int64(usedConns - reusedConns),
+ 10 * time.Millisecond,
+ // Add second offset because otherwise connections are expired via max lifetime in Close.
+ 100 * time.Nanosecond,
+ },
+ {
+ time.Hour,
+ 0,
+ time.Second,
+ 0,
+ 0,
+ 10 * time.Millisecond,
+ 0},
+ }
+ baseTime := time.Unix(0, 0)
+ defer func() {
+ nowFunc = time.Now
+ }()
+ for _, item := range list {
+ nowFunc = func() time.Time {
+ return baseTime
+ }
+ t.Run(fmt.Sprintf("%v", item.wantMaxIdleTime), func(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ db.SetMaxOpenConns(usedConns)
+ db.SetMaxIdleConns(usedConns)
+ db.SetConnMaxIdleTime(item.wantMaxIdleTime)
+ db.SetConnMaxLifetime(item.wantMaxLifetime)
+
+ preMaxIdleClosed := db.Stats().MaxIdleTimeClosed
+
+ // Busy usedConns.
+ testUseConns(t, usedConns, baseTime, db)
+
+ tm := baseTime.Add(item.timeOffset)
+
+ // Reuse connections which should never be considered idle
+ // and exercises the sorting for issue 39471.
+ tm = testUseConns(t, reusedConns, tm, db)
+
+ tm = tm.Add(item.secondTimeOffset)
+ nowFunc = func() time.Time {
+ return tm
+ }
+
+ db.mu.Lock()
+ nc, closing := db.connectionCleanerRunLocked(time.Second)
+ if nc != item.wantNextCheck {
+ t.Errorf("got %v; want %v next check duration", nc, item.wantNextCheck)
+ }
+
+ // Validate freeConn order.
+ var last time.Time
+ for _, c := range db.freeConn {
+ if last.After(c.returnedAt) {
+ t.Error("freeConn is not ordered by returnedAt")
+ break
+ }
+ last = c.returnedAt
+ }
+
+ db.mu.Unlock()
+ for _, c := range closing {
+ c.Close()
+ }
+ if g, w := int64(len(closing)), item.wantIdleClosed; g != w {
+ t.Errorf("got: %d; want %d closed conns", g, w)
+ }
+
+ st := db.Stats()
+ maxIdleClosed := st.MaxIdleTimeClosed - preMaxIdleClosed
+ if g, w := maxIdleClosed, item.wantMaxIdleClosed; g != w {
+ t.Errorf("got: %d; want %d max idle closed conns", g, w)
+ }
+ })
+ }
+}
+
+type nvcDriver struct {
+ fakeDriver
+ skipNamedValueCheck bool
+}
+
+func (d *nvcDriver) Open(dsn string) (driver.Conn, error) {
+ c, err := d.fakeDriver.Open(dsn)
+ fc := c.(*fakeConn)
+ fc.db.allowAny = true
+ return &nvcConn{fc, d.skipNamedValueCheck}, err
+}
+
+type nvcConn struct {
+ *fakeConn
+ skipNamedValueCheck bool
+}
+
+type decimalInt struct {
+ value int
+}
+
+type doNotInclude struct{}
+
+var _ driver.NamedValueChecker = &nvcConn{}
+
+func (c *nvcConn) CheckNamedValue(nv *driver.NamedValue) error {
+ if c.skipNamedValueCheck {
+ return driver.ErrSkip
+ }
+ switch v := nv.Value.(type) {
+ default:
+ return driver.ErrSkip
+ case Out:
+ switch ov := v.Dest.(type) {
+ default:
+ return errors.New("unknown NameValueCheck OUTPUT type")
+ case *string:
+ *ov = "from-server"
+ nv.Value = "OUT:*string"
+ }
+ return nil
+ case decimalInt, []int64:
+ return nil
+ case doNotInclude:
+ return driver.ErrRemoveArgument
+ }
+}
+
+func TestNamedValueChecker(t *testing.T) {
+ Register("NamedValueCheck", &nvcDriver{})
+ db, err := Open("NamedValueCheck", "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ _, err = db.ExecContext(ctx, "WIPE")
+ if err != nil {
+ t.Fatal("exec wipe", err)
+ }
+
+ _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any,str1=string,out1=string,array1=any")
+ if err != nil {
+ t.Fatal("exec create", err)
+ }
+
+ o1 := ""
+ _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A,str1=?,out1=?O1,array1=?", Named("A", decimalInt{123}), "hello", Named("O1", Out{Dest: &o1}), []int64{42, 128, 707}, doNotInclude{})
+ if err != nil {
+ t.Fatal("exec insert", err)
+ }
+ var (
+ str1 string
+ dec1 decimalInt
+ arr1 []int64
+ )
+ err = db.QueryRowContext(ctx, "SELECT|keys|dec1,str1,array1|").Scan(&dec1, &str1, &arr1)
+ if err != nil {
+ t.Fatal("select", err)
+ }
+
+ list := []struct{ got, want any }{
+ {o1, "from-server"},
+ {dec1, decimalInt{123}},
+ {str1, "hello"},
+ {arr1, []int64{42, 128, 707}},
+ }
+
+ for index, item := range list {
+ if !reflect.DeepEqual(item.got, item.want) {
+ t.Errorf("got %#v wanted %#v for index %d", item.got, item.want, index)
+ }
+ }
+}
+
+func TestNamedValueCheckerSkip(t *testing.T) {
+ Register("NamedValueCheckSkip", &nvcDriver{skipNamedValueCheck: true})
+ db, err := Open("NamedValueCheckSkip", "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ _, err = db.ExecContext(ctx, "WIPE")
+ if err != nil {
+ t.Fatal("exec wipe", err)
+ }
+
+ _, err = db.ExecContext(ctx, "CREATE|keys|dec1=any")
+ if err != nil {
+ t.Fatal("exec create", err)
+ }
+
+ _, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A", Named("A", decimalInt{123}))
+ if err == nil {
+ t.Fatalf("expected error with bad argument, got %v", err)
+ }
+}
+
+func TestOpenConnector(t *testing.T) {
+ Register("testctx", &fakeDriverCtx{})
+ db, err := Open("testctx", "people")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+
+ c, ok := db.connector.(*fakeConnector)
+ if !ok {
+ t.Fatal("not using *fakeConnector")
+ }
+
+ if err := db.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ if !c.closed {
+ t.Fatal("connector is not closed")
+ }
+}
+
+type ctxOnlyDriver struct {
+ fakeDriver
+}
+
+func (d *ctxOnlyDriver) Open(dsn string) (driver.Conn, error) {
+ conn, err := d.fakeDriver.Open(dsn)
+ if err != nil {
+ return nil, err
+ }
+ return &ctxOnlyConn{fc: conn.(*fakeConn)}, nil
+}
+
+var (
+ _ driver.Conn = &ctxOnlyConn{}
+ _ driver.QueryerContext = &ctxOnlyConn{}
+ _ driver.ExecerContext = &ctxOnlyConn{}
+)
+
+type ctxOnlyConn struct {
+ fc *fakeConn
+
+ queryCtxCalled bool
+ execCtxCalled bool
+}
+
+func (c *ctxOnlyConn) Begin() (driver.Tx, error) {
+ return c.fc.Begin()
+}
+
+func (c *ctxOnlyConn) Close() error {
+ return c.fc.Close()
+}
+
+// Prepare is still part of the Conn interface, so while it isn't used
+// must be defined for compatibility.
+func (c *ctxOnlyConn) Prepare(q string) (driver.Stmt, error) {
+ panic("not used")
+}
+
+func (c *ctxOnlyConn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) {
+ return c.fc.PrepareContext(ctx, q)
+}
+
+func (c *ctxOnlyConn) QueryContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Rows, error) {
+ c.queryCtxCalled = true
+ return c.fc.QueryContext(ctx, q, args)
+}
+
+func (c *ctxOnlyConn) ExecContext(ctx context.Context, q string, args []driver.NamedValue) (driver.Result, error) {
+ c.execCtxCalled = true
+ return c.fc.ExecContext(ctx, q, args)
+}
+
+// TestQueryExecContextOnly ensures drivers only need to implement QueryContext
+// and ExecContext methods.
+func TestQueryExecContextOnly(t *testing.T) {
+ // Ensure connection does not implement non-context interfaces.
+ var connType driver.Conn = &ctxOnlyConn{}
+ if _, ok := connType.(driver.Execer); ok {
+ t.Fatalf("%T must not implement driver.Execer", connType)
+ }
+ if _, ok := connType.(driver.Queryer); ok {
+ t.Fatalf("%T must not implement driver.Queryer", connType)
+ }
+
+ Register("ContextOnly", &ctxOnlyDriver{})
+ db, err := Open("ContextOnly", "")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer db.Close()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ conn, err := db.Conn(ctx)
+ if err != nil {
+ t.Fatal("db.Conn", err)
+ }
+ defer conn.Close()
+ coc := conn.dc.ci.(*ctxOnlyConn)
+ coc.fc.skipDirtySession = true
+
+ _, err = conn.ExecContext(ctx, "WIPE")
+ if err != nil {
+ t.Fatal("exec wipe", err)
+ }
+
+ _, err = conn.ExecContext(ctx, "CREATE|keys|v1=string")
+ if err != nil {
+ t.Fatal("exec create", err)
+ }
+ expectedValue := "value1"
+ _, err = conn.ExecContext(ctx, "INSERT|keys|v1=?", expectedValue)
+ if err != nil {
+ t.Fatal("exec insert", err)
+ }
+ rows, err := conn.QueryContext(ctx, "SELECT|keys|v1|")
+ if err != nil {
+ t.Fatal("query select", err)
+ }
+ v1 := ""
+ for rows.Next() {
+ err = rows.Scan(&v1)
+ if err != nil {
+ t.Fatal("rows scan", err)
+ }
+ }
+ rows.Close()
+
+ if v1 != expectedValue {
+ t.Fatalf("expected %q, got %q", expectedValue, v1)
+ }
+
+ if !coc.execCtxCalled {
+ t.Error("ExecContext not called")
+ }
+ if !coc.queryCtxCalled {
+ t.Error("QueryContext not called")
+ }
+}
+
+type alwaysErrScanner struct{}
+
+var errTestScanWrap = errors.New("errTestScanWrap")
+
+func (alwaysErrScanner) Scan(any) error {
+ return errTestScanWrap
+}
+
+// Issue 38099: Ensure that Rows.Scan properly wraps underlying errors.
+func TestRowsScanProperlyWrapsErrors(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ rows, err := db.Query("SELECT|people|age|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+
+ var res alwaysErrScanner
+
+ for rows.Next() {
+ err = rows.Scan(&res)
+ if err == nil {
+ t.Fatal("expecting back an error")
+ }
+ if !errors.Is(err, errTestScanWrap) {
+ t.Fatalf("errors.Is mismatch\n%v\nWant: %v", err, errTestScanWrap)
+ }
+ // Ensure that error substring matching still correctly works.
+ if !strings.Contains(err.Error(), errTestScanWrap.Error()) {
+ t.Fatalf("Error %v does not contain %v", err, errTestScanWrap)
+ }
+ }
+}
+
+func TestContextCancelDuringRawBytesScan(t *testing.T) {
+ for _, mode := range []string{"nocancel", "top", "bottom", "go"} {
+ t.Run(mode, func(t *testing.T) {
+ testContextCancelDuringRawBytesScan(t, mode)
+ })
+ }
+}
+
+// From go.dev/issue/60304
+func testContextCancelDuringRawBytesScan(t *testing.T, mode string) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ if _, err := db.Exec("USE_RAWBYTES"); err != nil {
+ t.Fatal(err)
+ }
+
+ // cancel used to call close asynchronously.
+ // This test checks that it waits so as not to interfere with RawBytes.
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ r, err := db.QueryContext(ctx, "SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ numRows := 0
+ var sink byte
+ for r.Next() {
+ if mode == "top" && numRows == 2 {
+ // cancel between Next and Scan is observed by Scan as err = context.Canceled.
+ // The sleep here is only to make it more likely that the cancel will be observed.
+ // If not, the test should still pass, like in "go" mode.
+ cancel()
+ time.Sleep(100 * time.Millisecond)
+ }
+ numRows++
+ var s RawBytes
+ err = r.Scan(&s)
+ if numRows == 3 && err == context.Canceled {
+ if r.closemuScanHold {
+ t.Errorf("expected closemu NOT to be held")
+ }
+ break
+ }
+ if !r.closemuScanHold {
+ t.Errorf("expected closemu to be held")
+ }
+ if err != nil {
+ t.Fatal(err)
+ }
+ t.Logf("read %q", s)
+ if mode == "bottom" && numRows == 2 {
+ // cancel before Next should be observed by Next, exiting the loop.
+ // The sleep here is only to make it more likely that the cancel will be observed.
+ // If not, the test should still pass, like in "go" mode.
+ cancel()
+ time.Sleep(100 * time.Millisecond)
+ }
+ if mode == "go" && numRows == 2 {
+ // cancel at any future time, to catch other cases
+ go cancel()
+ }
+ for _, b := range s { // some operation reading from the raw memory
+ sink += b
+ }
+ }
+ if r.closemuScanHold {
+ t.Errorf("closemu held; should not be")
+ }
+
+ // There are 3 rows. We canceled after reading 2 so we expect either
+ // 2 or 3 depending on how the awaitDone goroutine schedules.
+ switch numRows {
+ case 0, 1:
+ t.Errorf("got %d rows; want 2+", numRows)
+ case 2:
+ if err := r.Err(); err != context.Canceled {
+ t.Errorf("unexpected error: %v (%T)", err, err)
+ }
+ default:
+ // Made it to the end. This is rare, but fine. Permit it.
+ }
+
+ if err := r.Close(); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestContextCancelBetweenNextAndErr(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ r, err := db.QueryContext(ctx, "SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+ for r.Next() {
+ }
+ cancel() // wake up the awaitDone goroutine
+ time.Sleep(10 * time.Millisecond) // increase odds of seeing failure
+ if err := r.Err(); err != nil {
+ t.Fatal(err)
+ }
+}
+
+func TestNilErrorAfterClose(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ // This WithCancel is important; Rows contains an optimization to avoid
+ // spawning a goroutine when the query/transaction context cannot be
+ // canceled, but this test tests a bug which is caused by said goroutine.
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ r, err := db.QueryContext(ctx, "SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if err := r.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ time.Sleep(10 * time.Millisecond) // increase odds of seeing failure
+ if err := r.Err(); err != nil {
+ t.Fatal(err)
+ }
+}
+
+// badConn implements a bad driver.Conn, for TestBadDriver.
+// The Exec method panics.
+type badConn struct{}
+
+func (bc badConn) Prepare(query string) (driver.Stmt, error) {
+ return nil, errors.New("badConn Prepare")
+}
+
+func (bc badConn) Close() error {
+ return nil
+}
+
+func (bc badConn) Begin() (driver.Tx, error) {
+ return nil, errors.New("badConn Begin")
+}
+
+func (bc badConn) Exec(query string, args []driver.Value) (driver.Result, error) {
+ panic("badConn.Exec")
+}
+
+// badDriver is a driver.Driver that uses badConn.
+type badDriver struct{}
+
+func (bd badDriver) Open(name string) (driver.Conn, error) {
+ return badConn{}, nil
+}
+
+// Issue 15901.
+func TestBadDriver(t *testing.T) {
+ Register("bad", badDriver{})
+ db, err := Open("bad", "ignored")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ if r := recover(); r == nil {
+ t.Error("expected panic")
+ } else {
+ if want := "badConn.Exec"; r.(string) != want {
+ t.Errorf("panic was %v, expected %v", r, want)
+ }
+ }
+ }()
+ defer db.Close()
+ db.Exec("ignored")
+}
+
+type pingDriver struct {
+ fails bool
+}
+
+type pingConn struct {
+ badConn
+ driver *pingDriver
+}
+
+var pingError = errors.New("Ping failed")
+
+func (pc pingConn) Ping(ctx context.Context) error {
+ if pc.driver.fails {
+ return pingError
+ }
+ return nil
+}
+
+var _ driver.Pinger = pingConn{}
+
+func (pd *pingDriver) Open(name string) (driver.Conn, error) {
+ return pingConn{driver: pd}, nil
+}
+
+func TestPing(t *testing.T) {
+ driver := &pingDriver{}
+ Register("ping", driver)
+
+ db, err := Open("ping", "ignored")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if err := db.Ping(); err != nil {
+ t.Errorf("err was %#v, expected nil", err)
+ return
+ }
+
+ driver.fails = true
+ if err := db.Ping(); err != pingError {
+ t.Errorf("err was %#v, expected pingError", err)
+ }
+}
+
+// Issue 18101.
+func TestTypedString(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ type Str string
+ var scanned Str
+
+ err := db.QueryRow("SELECT|people|name|name=?", "Alice").Scan(&scanned)
+ if err != nil {
+ t.Fatal(err)
+ }
+ expected := Str("Alice")
+ if scanned != expected {
+ t.Errorf("expected %+v, got %+v", expected, scanned)
+ }
+}
+
+func BenchmarkConcurrentDBExec(b *testing.B) {
+ b.ReportAllocs()
+ ct := new(concurrentDBExecTest)
+ for i := 0; i < b.N; i++ {
+ doConcurrentTest(b, ct)
+ }
+}
+
+func BenchmarkConcurrentStmtQuery(b *testing.B) {
+ b.ReportAllocs()
+ ct := new(concurrentStmtQueryTest)
+ for i := 0; i < b.N; i++ {
+ doConcurrentTest(b, ct)
+ }
+}
+
+func BenchmarkConcurrentStmtExec(b *testing.B) {
+ b.ReportAllocs()
+ ct := new(concurrentStmtExecTest)
+ for i := 0; i < b.N; i++ {
+ doConcurrentTest(b, ct)
+ }
+}
+
+func BenchmarkConcurrentTxQuery(b *testing.B) {
+ b.ReportAllocs()
+ ct := new(concurrentTxQueryTest)
+ for i := 0; i < b.N; i++ {
+ doConcurrentTest(b, ct)
+ }
+}
+
+func BenchmarkConcurrentTxExec(b *testing.B) {
+ b.ReportAllocs()
+ ct := new(concurrentTxExecTest)
+ for i := 0; i < b.N; i++ {
+ doConcurrentTest(b, ct)
+ }
+}
+
+func BenchmarkConcurrentTxStmtQuery(b *testing.B) {
+ b.ReportAllocs()
+ ct := new(concurrentTxStmtQueryTest)
+ for i := 0; i < b.N; i++ {
+ doConcurrentTest(b, ct)
+ }
+}
+
+func BenchmarkConcurrentTxStmtExec(b *testing.B) {
+ b.ReportAllocs()
+ ct := new(concurrentTxStmtExecTest)
+ for i := 0; i < b.N; i++ {
+ doConcurrentTest(b, ct)
+ }
+}
+
+func BenchmarkConcurrentRandom(b *testing.B) {
+ b.ReportAllocs()
+ ct := new(concurrentRandomTest)
+ for i := 0; i < b.N; i++ {
+ doConcurrentTest(b, ct)
+ }
+}
+
+func BenchmarkManyConcurrentQueries(b *testing.B) {
+ b.ReportAllocs()
+ // To see lock contention in Go 1.4, 16~ cores and 128~ goroutines are required.
+ const parallelism = 16
+
+ db := newTestDB(b, "magicquery")
+ defer closeDB(b, db)
+ db.SetMaxIdleConns(runtime.GOMAXPROCS(0) * parallelism)
+
+ stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?")
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer stmt.Close()
+
+ b.SetParallelism(parallelism)
+ b.RunParallel(func(pb *testing.PB) {
+ for pb.Next() {
+ rows, err := stmt.Query("sleep", 1)
+ if err != nil {
+ b.Error(err)
+ return
+ }
+ rows.Close()
+ }
+ })
+}
+
+func TestGrabConnAllocs(t *testing.T) {
+ testenv.SkipIfOptimizationOff(t)
+ if race.Enabled {
+ t.Skip("skipping allocation test when using race detector")
+ }
+ c := new(Conn)
+ ctx := context.Background()
+ n := int(testing.AllocsPerRun(1000, func() {
+ _, release, err := c.grabConn(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+ release(nil)
+ }))
+ if n > 0 {
+ t.Fatalf("Conn.grabConn allocated %v objects; want 0", n)
+ }
+}
+
+func BenchmarkGrabConn(b *testing.B) {
+ b.ReportAllocs()
+ c := new(Conn)
+ ctx := context.Background()
+ for i := 0; i < b.N; i++ {
+ _, release, err := c.grabConn(ctx)
+ if err != nil {
+ b.Fatal(err)
+ }
+ release(nil)
+ }
+}