diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-16 19:23:18 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-16 19:23:18 +0000 |
commit | 43a123c1ae6613b3efeed291fa552ecd909d3acf (patch) | |
tree | fd92518b7024bc74031f78a1cf9e454b65e73665 /src/database/sql/sql.go | |
parent | Initial commit. (diff) | |
download | golang-1.20-43a123c1ae6613b3efeed291fa552ecd909d3acf.tar.xz golang-1.20-43a123c1ae6613b3efeed291fa552ecd909d3acf.zip |
Adding upstream version 1.20.14.upstream/1.20.14upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/database/sql/sql.go')
-rw-r--r-- | src/database/sql/sql.go | 3406 |
1 files changed, 3406 insertions, 0 deletions
diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go new file mode 100644 index 0000000..ad17eb3 --- /dev/null +++ b/src/database/sql/sql.go @@ -0,0 +1,3406 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package sql provides a generic interface around SQL (or SQL-like) +// databases. +// +// The sql package must be used in conjunction with a database driver. +// See https://golang.org/s/sqldrivers for a list of drivers. +// +// Drivers that do not support context cancellation will not return until +// after the query is completed. +// +// For usage examples, see the wiki page at +// https://golang.org/s/sqlwiki. +package sql + +import ( + "context" + "database/sql/driver" + "errors" + "fmt" + "io" + "reflect" + "runtime" + "sort" + "strconv" + "sync" + "sync/atomic" + "time" +) + +var ( + driversMu sync.RWMutex + drivers = make(map[string]driver.Driver) +) + +// nowFunc returns the current time; it's overridden in tests. +var nowFunc = time.Now + +// Register makes a database driver available by the provided name. +// If Register is called twice with the same name or if driver is nil, +// it panics. +func Register(name string, driver driver.Driver) { + driversMu.Lock() + defer driversMu.Unlock() + if driver == nil { + panic("sql: Register driver is nil") + } + if _, dup := drivers[name]; dup { + panic("sql: Register called twice for driver " + name) + } + drivers[name] = driver +} + +func unregisterAllDrivers() { + driversMu.Lock() + defer driversMu.Unlock() + // For tests. + drivers = make(map[string]driver.Driver) +} + +// Drivers returns a sorted list of the names of the registered drivers. +func Drivers() []string { + driversMu.RLock() + defer driversMu.RUnlock() + list := make([]string, 0, len(drivers)) + for name := range drivers { + list = append(list, name) + } + sort.Strings(list) + return list +} + +// A NamedArg is a named argument. NamedArg values may be used as +// arguments to Query or Exec and bind to the corresponding named +// parameter in the SQL statement. +// +// For a more concise way to create NamedArg values, see +// the Named function. +type NamedArg struct { + _NamedFieldsRequired struct{} + + // Name is the name of the parameter placeholder. + // + // If empty, the ordinal position in the argument list will be + // used. + // + // Name must omit any symbol prefix. + Name string + + // Value is the value of the parameter. + // It may be assigned the same value types as the query + // arguments. + Value any +} + +// Named provides a more concise way to create NamedArg values. +// +// Example usage: +// +// db.ExecContext(ctx, ` +// delete from Invoice +// where +// TimeCreated < @end +// and TimeCreated >= @start;`, +// sql.Named("start", startTime), +// sql.Named("end", endTime), +// ) +func Named(name string, value any) NamedArg { + // This method exists because the go1compat promise + // doesn't guarantee that structs don't grow more fields, + // so unkeyed struct literals are a vet error. Thus, we don't + // want to allow sql.NamedArg{name, value}. + return NamedArg{Name: name, Value: value} +} + +// IsolationLevel is the transaction isolation level used in TxOptions. +type IsolationLevel int + +// Various isolation levels that drivers may support in BeginTx. +// If a driver does not support a given isolation level an error may be returned. +// +// See https://en.wikipedia.org/wiki/Isolation_(database_systems)#Isolation_levels. +const ( + LevelDefault IsolationLevel = iota + LevelReadUncommitted + LevelReadCommitted + LevelWriteCommitted + LevelRepeatableRead + LevelSnapshot + LevelSerializable + LevelLinearizable +) + +// String returns the name of the transaction isolation level. +func (i IsolationLevel) String() string { + switch i { + case LevelDefault: + return "Default" + case LevelReadUncommitted: + return "Read Uncommitted" + case LevelReadCommitted: + return "Read Committed" + case LevelWriteCommitted: + return "Write Committed" + case LevelRepeatableRead: + return "Repeatable Read" + case LevelSnapshot: + return "Snapshot" + case LevelSerializable: + return "Serializable" + case LevelLinearizable: + return "Linearizable" + default: + return "IsolationLevel(" + strconv.Itoa(int(i)) + ")" + } +} + +var _ fmt.Stringer = LevelDefault + +// TxOptions holds the transaction options to be used in DB.BeginTx. +type TxOptions struct { + // Isolation is the transaction isolation level. + // If zero, the driver or database's default level is used. + Isolation IsolationLevel + ReadOnly bool +} + +// RawBytes is a byte slice that holds a reference to memory owned by +// the database itself. After a Scan into a RawBytes, the slice is only +// valid until the next call to Next, Scan, or Close. +type RawBytes []byte + +// NullString represents a string that may be null. +// NullString implements the Scanner interface so +// it can be used as a scan destination: +// +// var s NullString +// err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&s) +// ... +// if s.Valid { +// // use s.String +// } else { +// // NULL value +// } +type NullString struct { + String string + Valid bool // Valid is true if String is not NULL +} + +// Scan implements the Scanner interface. +func (ns *NullString) Scan(value any) error { + if value == nil { + ns.String, ns.Valid = "", false + return nil + } + ns.Valid = true + return convertAssign(&ns.String, value) +} + +// Value implements the driver Valuer interface. +func (ns NullString) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return ns.String, nil +} + +// NullInt64 represents an int64 that may be null. +// NullInt64 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullInt64 struct { + Int64 int64 + Valid bool // Valid is true if Int64 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullInt64) Scan(value any) error { + if value == nil { + n.Int64, n.Valid = 0, false + return nil + } + n.Valid = true + return convertAssign(&n.Int64, value) +} + +// Value implements the driver Valuer interface. +func (n NullInt64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Int64, nil +} + +// NullInt32 represents an int32 that may be null. +// NullInt32 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullInt32 struct { + Int32 int32 + Valid bool // Valid is true if Int32 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullInt32) Scan(value any) error { + if value == nil { + n.Int32, n.Valid = 0, false + return nil + } + n.Valid = true + return convertAssign(&n.Int32, value) +} + +// Value implements the driver Valuer interface. +func (n NullInt32) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return int64(n.Int32), nil +} + +// NullInt16 represents an int16 that may be null. +// NullInt16 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullInt16 struct { + Int16 int16 + Valid bool // Valid is true if Int16 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullInt16) Scan(value any) error { + if value == nil { + n.Int16, n.Valid = 0, false + return nil + } + err := convertAssign(&n.Int16, value) + n.Valid = err == nil + return err +} + +// Value implements the driver Valuer interface. +func (n NullInt16) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return int64(n.Int16), nil +} + +// NullByte represents a byte that may be null. +// NullByte implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullByte struct { + Byte byte + Valid bool // Valid is true if Byte is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullByte) Scan(value any) error { + if value == nil { + n.Byte, n.Valid = 0, false + return nil + } + err := convertAssign(&n.Byte, value) + n.Valid = err == nil + return err +} + +// Value implements the driver Valuer interface. +func (n NullByte) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return int64(n.Byte), nil +} + +// NullFloat64 represents a float64 that may be null. +// NullFloat64 implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullFloat64 struct { + Float64 float64 + Valid bool // Valid is true if Float64 is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullFloat64) Scan(value any) error { + if value == nil { + n.Float64, n.Valid = 0, false + return nil + } + n.Valid = true + return convertAssign(&n.Float64, value) +} + +// Value implements the driver Valuer interface. +func (n NullFloat64) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Float64, nil +} + +// NullBool represents a bool that may be null. +// NullBool implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullBool struct { + Bool bool + Valid bool // Valid is true if Bool is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullBool) Scan(value any) error { + if value == nil { + n.Bool, n.Valid = false, false + return nil + } + n.Valid = true + return convertAssign(&n.Bool, value) +} + +// Value implements the driver Valuer interface. +func (n NullBool) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Bool, nil +} + +// NullTime represents a time.Time that may be null. +// NullTime implements the Scanner interface so +// it can be used as a scan destination, similar to NullString. +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullTime) Scan(value any) error { + if value == nil { + n.Time, n.Valid = time.Time{}, false + return nil + } + n.Valid = true + return convertAssign(&n.Time, value) +} + +// Value implements the driver Valuer interface. +func (n NullTime) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Time, nil +} + +// Scanner is an interface used by Scan. +type Scanner interface { + // Scan assigns a value from a database driver. + // + // The src value will be of one of the following types: + // + // int64 + // float64 + // bool + // []byte + // string + // time.Time + // nil - for NULL values + // + // An error should be returned if the value cannot be stored + // without loss of information. + // + // Reference types such as []byte are only valid until the next call to Scan + // and should not be retained. Their underlying memory is owned by the driver. + // If retention is necessary, copy their values before the next call to Scan. + Scan(src any) error +} + +// Out may be used to retrieve OUTPUT value parameters from stored procedures. +// +// Not all drivers and databases support OUTPUT value parameters. +// +// Example usage: +// +// var outArg string +// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", sql.Out{Dest: &outArg})) +type Out struct { + _NamedFieldsRequired struct{} + + // Dest is a pointer to the value that will be set to the result of the + // stored procedure's OUTPUT parameter. + Dest any + + // In is whether the parameter is an INOUT parameter. If so, the input value to the stored + // procedure is the dereferenced value of Dest's pointer, which is then replaced with + // the output value. + In bool +} + +// ErrNoRows is returned by Scan when QueryRow doesn't return a +// row. In such a case, QueryRow returns a placeholder *Row value that +// defers this error until a Scan. +var ErrNoRows = errors.New("sql: no rows in result set") + +// DB is a database handle representing a pool of zero or more +// underlying connections. It's safe for concurrent use by multiple +// goroutines. +// +// The sql package creates and frees connections automatically; it +// also maintains a free pool of idle connections. If the database has +// a concept of per-connection state, such state can be reliably observed +// within a transaction (Tx) or connection (Conn). Once DB.Begin is called, the +// returned Tx is bound to a single connection. Once Commit or +// Rollback is called on the transaction, that transaction's +// connection is returned to DB's idle connection pool. The pool size +// can be controlled with SetMaxIdleConns. +type DB struct { + // Total time waited for new connections. + waitDuration atomic.Int64 + + connector driver.Connector + // numClosed is an atomic counter which represents a total number of + // closed connections. Stmt.openStmt checks it before cleaning closed + // connections in Stmt.css. + numClosed atomic.Uint64 + + mu sync.Mutex // protects following fields + freeConn []*driverConn // free connections ordered by returnedAt oldest to newest + connRequests map[uint64]chan connRequest + nextRequest uint64 // Next key to use in connRequests. + numOpen int // number of opened and pending open connections + // Used to signal the need for new connections + // a goroutine running connectionOpener() reads on this chan and + // maybeOpenNewConnections sends on the chan (one send per needed connection) + // It is closed during db.Close(). The close tells the connectionOpener + // goroutine to exit. + openerCh chan struct{} + closed bool + dep map[finalCloser]depSet + lastPut map[*driverConn]string // stacktrace of last conn's put; debug only + maxIdleCount int // zero means defaultMaxIdleConns; negative means 0 + maxOpen int // <= 0 means unlimited + maxLifetime time.Duration // maximum amount of time a connection may be reused + maxIdleTime time.Duration // maximum amount of time a connection may be idle before being closed + cleanerCh chan struct{} + waitCount int64 // Total number of connections waited for. + maxIdleClosed int64 // Total number of connections closed due to idle count. + maxIdleTimeClosed int64 // Total number of connections closed due to idle time. + maxLifetimeClosed int64 // Total number of connections closed due to max connection lifetime limit. + + stop func() // stop cancels the connection opener. +} + +// connReuseStrategy determines how (*DB).conn returns database connections. +type connReuseStrategy uint8 + +const ( + // alwaysNewConn forces a new connection to the database. + alwaysNewConn connReuseStrategy = iota + // cachedOrNewConn returns a cached connection, if available, else waits + // for one to become available (if MaxOpenConns has been reached) or + // creates a new database connection. + cachedOrNewConn +) + +// driverConn wraps a driver.Conn with a mutex, to +// be held during all calls into the Conn. (including any calls onto +// interfaces returned via that Conn, such as calls on Tx, Stmt, +// Result, Rows) +type driverConn struct { + db *DB + createdAt time.Time + + sync.Mutex // guards following + ci driver.Conn + needReset bool // The connection session should be reset before use if true. + closed bool + finalClosed bool // ci.Close has been called + openStmt map[*driverStmt]bool + + // guarded by db.mu + inUse bool + returnedAt time.Time // Time the connection was created or returned. + onPut []func() // code (with db.mu held) run when conn is next returned + dbmuClosed bool // same as closed, but guarded by db.mu, for removeClosedStmtLocked +} + +func (dc *driverConn) releaseConn(err error) { + dc.db.putConn(dc, err, true) +} + +func (dc *driverConn) removeOpenStmt(ds *driverStmt) { + dc.Lock() + defer dc.Unlock() + delete(dc.openStmt, ds) +} + +func (dc *driverConn) expired(timeout time.Duration) bool { + if timeout <= 0 { + return false + } + return dc.createdAt.Add(timeout).Before(nowFunc()) +} + +// resetSession checks if the driver connection needs the +// session to be reset and if required, resets it. +func (dc *driverConn) resetSession(ctx context.Context) error { + dc.Lock() + defer dc.Unlock() + + if !dc.needReset { + return nil + } + if cr, ok := dc.ci.(driver.SessionResetter); ok { + return cr.ResetSession(ctx) + } + return nil +} + +// validateConnection checks if the connection is valid and can +// still be used. It also marks the session for reset if required. +func (dc *driverConn) validateConnection(needsReset bool) bool { + dc.Lock() + defer dc.Unlock() + + if needsReset { + dc.needReset = true + } + if cv, ok := dc.ci.(driver.Validator); ok { + return cv.IsValid() + } + return true +} + +// prepareLocked prepares the query on dc. When cg == nil the dc must keep track of +// the prepared statements in a pool. +func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) { + si, err := ctxDriverPrepare(ctx, dc.ci, query) + if err != nil { + return nil, err + } + ds := &driverStmt{Locker: dc, si: si} + + // No need to manage open statements if there is a single connection grabber. + if cg != nil { + return ds, nil + } + + // Track each driverConn's open statements, so we can close them + // before closing the conn. + // + // Wrap all driver.Stmt is *driverStmt to ensure they are only closed once. + if dc.openStmt == nil { + dc.openStmt = make(map[*driverStmt]bool) + } + dc.openStmt[ds] = true + return ds, nil +} + +// the dc.db's Mutex is held. +func (dc *driverConn) closeDBLocked() func() error { + dc.Lock() + defer dc.Unlock() + if dc.closed { + return func() error { return errors.New("sql: duplicate driverConn close") } + } + dc.closed = true + return dc.db.removeDepLocked(dc, dc) +} + +func (dc *driverConn) Close() error { + dc.Lock() + if dc.closed { + dc.Unlock() + return errors.New("sql: duplicate driverConn close") + } + dc.closed = true + dc.Unlock() // not defer; removeDep finalClose calls may need to lock + + // And now updates that require holding dc.mu.Lock. + dc.db.mu.Lock() + dc.dbmuClosed = true + fn := dc.db.removeDepLocked(dc, dc) + dc.db.mu.Unlock() + return fn() +} + +func (dc *driverConn) finalClose() error { + var err error + + // Each *driverStmt has a lock to the dc. Copy the list out of the dc + // before calling close on each stmt. + var openStmt []*driverStmt + withLock(dc, func() { + openStmt = make([]*driverStmt, 0, len(dc.openStmt)) + for ds := range dc.openStmt { + openStmt = append(openStmt, ds) + } + dc.openStmt = nil + }) + for _, ds := range openStmt { + ds.Close() + } + withLock(dc, func() { + dc.finalClosed = true + err = dc.ci.Close() + dc.ci = nil + }) + + dc.db.mu.Lock() + dc.db.numOpen-- + dc.db.maybeOpenNewConnections() + dc.db.mu.Unlock() + + dc.db.numClosed.Add(1) + return err +} + +// driverStmt associates a driver.Stmt with the +// *driverConn from which it came, so the driverConn's lock can be +// held during calls. +type driverStmt struct { + sync.Locker // the *driverConn + si driver.Stmt + closed bool + closeErr error // return value of previous Close call +} + +// Close ensures driver.Stmt is only closed once and always returns the same +// result. +func (ds *driverStmt) Close() error { + ds.Lock() + defer ds.Unlock() + if ds.closed { + return ds.closeErr + } + ds.closed = true + ds.closeErr = ds.si.Close() + return ds.closeErr +} + +// depSet is a finalCloser's outstanding dependencies +type depSet map[any]bool // set of true bools + +// The finalCloser interface is used by (*DB).addDep and related +// dependency reference counting. +type finalCloser interface { + // finalClose is called when the reference count of an object + // goes to zero. (*DB).mu is not held while calling it. + finalClose() error +} + +// addDep notes that x now depends on dep, and x's finalClose won't be +// called until all of x's dependencies are removed with removeDep. +func (db *DB) addDep(x finalCloser, dep any) { + db.mu.Lock() + defer db.mu.Unlock() + db.addDepLocked(x, dep) +} + +func (db *DB) addDepLocked(x finalCloser, dep any) { + if db.dep == nil { + db.dep = make(map[finalCloser]depSet) + } + xdep := db.dep[x] + if xdep == nil { + xdep = make(depSet) + db.dep[x] = xdep + } + xdep[dep] = true +} + +// removeDep notes that x no longer depends on dep. +// If x still has dependencies, nil is returned. +// If x no longer has any dependencies, its finalClose method will be +// called and its error value will be returned. +func (db *DB) removeDep(x finalCloser, dep any) error { + db.mu.Lock() + fn := db.removeDepLocked(x, dep) + db.mu.Unlock() + return fn() +} + +func (db *DB) removeDepLocked(x finalCloser, dep any) func() error { + xdep, ok := db.dep[x] + if !ok { + panic(fmt.Sprintf("unpaired removeDep: no deps for %T", x)) + } + + l0 := len(xdep) + delete(xdep, dep) + + switch len(xdep) { + case l0: + // Nothing removed. Shouldn't happen. + panic(fmt.Sprintf("unpaired removeDep: no %T dep on %T", dep, x)) + case 0: + // No more dependencies. + delete(db.dep, x) + return x.finalClose + default: + // Dependencies remain. + return func() error { return nil } + } +} + +// This is the size of the connectionOpener request chan (DB.openerCh). +// This value should be larger than the maximum typical value +// used for db.maxOpen. If maxOpen is significantly larger than +// connectionRequestQueueSize then it is possible for ALL calls into the *DB +// to block until the connectionOpener can satisfy the backlog of requests. +var connectionRequestQueueSize = 1000000 + +type dsnConnector struct { + dsn string + driver driver.Driver +} + +func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) { + return t.driver.Open(t.dsn) +} + +func (t dsnConnector) Driver() driver.Driver { + return t.driver +} + +// OpenDB opens a database using a Connector, allowing drivers to +// bypass a string based data source name. +// +// Most users will open a database via a driver-specific connection +// helper function that returns a *DB. No database drivers are included +// in the Go standard library. See https://golang.org/s/sqldrivers for +// a list of third-party drivers. +// +// OpenDB may just validate its arguments without creating a connection +// to the database. To verify that the data source name is valid, call +// Ping. +// +// The returned DB is safe for concurrent use by multiple goroutines +// and maintains its own pool of idle connections. Thus, the OpenDB +// function should be called just once. It is rarely necessary to +// close a DB. +func OpenDB(c driver.Connector) *DB { + ctx, cancel := context.WithCancel(context.Background()) + db := &DB{ + connector: c, + openerCh: make(chan struct{}, connectionRequestQueueSize), + lastPut: make(map[*driverConn]string), + connRequests: make(map[uint64]chan connRequest), + stop: cancel, + } + + go db.connectionOpener(ctx) + + return db +} + +// Open opens a database specified by its database driver name and a +// driver-specific data source name, usually consisting of at least a +// database name and connection information. +// +// Most users will open a database via a driver-specific connection +// helper function that returns a *DB. No database drivers are included +// in the Go standard library. See https://golang.org/s/sqldrivers for +// a list of third-party drivers. +// +// Open may just validate its arguments without creating a connection +// to the database. To verify that the data source name is valid, call +// Ping. +// +// The returned DB is safe for concurrent use by multiple goroutines +// and maintains its own pool of idle connections. Thus, the Open +// function should be called just once. It is rarely necessary to +// close a DB. +func Open(driverName, dataSourceName string) (*DB, error) { + driversMu.RLock() + driveri, ok := drivers[driverName] + driversMu.RUnlock() + if !ok { + return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName) + } + + if driverCtx, ok := driveri.(driver.DriverContext); ok { + connector, err := driverCtx.OpenConnector(dataSourceName) + if err != nil { + return nil, err + } + return OpenDB(connector), nil + } + + return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil +} + +func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error { + var err error + if pinger, ok := dc.ci.(driver.Pinger); ok { + withLock(dc, func() { + err = pinger.Ping(ctx) + }) + } + release(err) + return err +} + +// PingContext verifies a connection to the database is still alive, +// establishing a connection if necessary. +func (db *DB) PingContext(ctx context.Context) error { + var dc *driverConn + var err error + + err = db.retry(func(strategy connReuseStrategy) error { + dc, err = db.conn(ctx, strategy) + return err + }) + + if err != nil { + return err + } + + return db.pingDC(ctx, dc, dc.releaseConn) +} + +// Ping verifies a connection to the database is still alive, +// establishing a connection if necessary. +// +// Ping uses context.Background internally; to specify the context, use +// PingContext. +func (db *DB) Ping() error { + return db.PingContext(context.Background()) +} + +// Close closes the database and prevents new queries from starting. +// Close then waits for all queries that have started processing on the server +// to finish. +// +// It is rare to Close a DB, as the DB handle is meant to be +// long-lived and shared between many goroutines. +func (db *DB) Close() error { + db.mu.Lock() + if db.closed { // Make DB.Close idempotent + db.mu.Unlock() + return nil + } + if db.cleanerCh != nil { + close(db.cleanerCh) + } + var err error + fns := make([]func() error, 0, len(db.freeConn)) + for _, dc := range db.freeConn { + fns = append(fns, dc.closeDBLocked()) + } + db.freeConn = nil + db.closed = true + for _, req := range db.connRequests { + close(req) + } + db.mu.Unlock() + for _, fn := range fns { + err1 := fn() + if err1 != nil { + err = err1 + } + } + db.stop() + if c, ok := db.connector.(io.Closer); ok { + err1 := c.Close() + if err1 != nil { + err = err1 + } + } + return err +} + +const defaultMaxIdleConns = 2 + +func (db *DB) maxIdleConnsLocked() int { + n := db.maxIdleCount + switch { + case n == 0: + // TODO(bradfitz): ask driver, if supported, for its default preference + return defaultMaxIdleConns + case n < 0: + return 0 + default: + return n + } +} + +func (db *DB) shortestIdleTimeLocked() time.Duration { + if db.maxIdleTime <= 0 { + return db.maxLifetime + } + if db.maxLifetime <= 0 { + return db.maxIdleTime + } + + min := db.maxIdleTime + if min > db.maxLifetime { + min = db.maxLifetime + } + return min +} + +// SetMaxIdleConns sets the maximum number of connections in the idle +// connection pool. +// +// If MaxOpenConns is greater than 0 but less than the new MaxIdleConns, +// then the new MaxIdleConns will be reduced to match the MaxOpenConns limit. +// +// If n <= 0, no idle connections are retained. +// +// The default max idle connections is currently 2. This may change in +// a future release. +func (db *DB) SetMaxIdleConns(n int) { + db.mu.Lock() + if n > 0 { + db.maxIdleCount = n + } else { + // No idle connections. + db.maxIdleCount = -1 + } + // Make sure maxIdle doesn't exceed maxOpen + if db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen { + db.maxIdleCount = db.maxOpen + } + var closing []*driverConn + idleCount := len(db.freeConn) + maxIdle := db.maxIdleConnsLocked() + if idleCount > maxIdle { + closing = db.freeConn[maxIdle:] + db.freeConn = db.freeConn[:maxIdle] + } + db.maxIdleClosed += int64(len(closing)) + db.mu.Unlock() + for _, c := range closing { + c.Close() + } +} + +// SetMaxOpenConns sets the maximum number of open connections to the database. +// +// If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than +// MaxIdleConns, then MaxIdleConns will be reduced to match the new +// MaxOpenConns limit. +// +// If n <= 0, then there is no limit on the number of open connections. +// The default is 0 (unlimited). +func (db *DB) SetMaxOpenConns(n int) { + db.mu.Lock() + db.maxOpen = n + if n < 0 { + db.maxOpen = 0 + } + syncMaxIdle := db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen + db.mu.Unlock() + if syncMaxIdle { + db.SetMaxIdleConns(n) + } +} + +// SetConnMaxLifetime sets the maximum amount of time a connection may be reused. +// +// Expired connections may be closed lazily before reuse. +// +// If d <= 0, connections are not closed due to a connection's age. +func (db *DB) SetConnMaxLifetime(d time.Duration) { + if d < 0 { + d = 0 + } + db.mu.Lock() + // Wake cleaner up when lifetime is shortened. + if d > 0 && d < db.maxLifetime && db.cleanerCh != nil { + select { + case db.cleanerCh <- struct{}{}: + default: + } + } + db.maxLifetime = d + db.startCleanerLocked() + db.mu.Unlock() +} + +// SetConnMaxIdleTime sets the maximum amount of time a connection may be idle. +// +// Expired connections may be closed lazily before reuse. +// +// If d <= 0, connections are not closed due to a connection's idle time. +func (db *DB) SetConnMaxIdleTime(d time.Duration) { + if d < 0 { + d = 0 + } + db.mu.Lock() + defer db.mu.Unlock() + + // Wake cleaner up when idle time is shortened. + if d > 0 && d < db.maxIdleTime && db.cleanerCh != nil { + select { + case db.cleanerCh <- struct{}{}: + default: + } + } + db.maxIdleTime = d + db.startCleanerLocked() +} + +// startCleanerLocked starts connectionCleaner if needed. +func (db *DB) startCleanerLocked() { + if (db.maxLifetime > 0 || db.maxIdleTime > 0) && db.numOpen > 0 && db.cleanerCh == nil { + db.cleanerCh = make(chan struct{}, 1) + go db.connectionCleaner(db.shortestIdleTimeLocked()) + } +} + +func (db *DB) connectionCleaner(d time.Duration) { + const minInterval = time.Second + + if d < minInterval { + d = minInterval + } + t := time.NewTimer(d) + + for { + select { + case <-t.C: + case <-db.cleanerCh: // maxLifetime was changed or db was closed. + } + + db.mu.Lock() + + d = db.shortestIdleTimeLocked() + if db.closed || db.numOpen == 0 || d <= 0 { + db.cleanerCh = nil + db.mu.Unlock() + return + } + + d, closing := db.connectionCleanerRunLocked(d) + db.mu.Unlock() + for _, c := range closing { + c.Close() + } + + if d < minInterval { + d = minInterval + } + + if !t.Stop() { + select { + case <-t.C: + default: + } + } + t.Reset(d) + } +} + +// connectionCleanerRunLocked removes connections that should be closed from +// freeConn and returns them along side an updated duration to the next check +// if a quicker check is required to ensure connections are checked appropriately. +func (db *DB) connectionCleanerRunLocked(d time.Duration) (time.Duration, []*driverConn) { + var idleClosing int64 + var closing []*driverConn + if db.maxIdleTime > 0 { + // As freeConn is ordered by returnedAt process + // in reverse order to minimise the work needed. + idleSince := nowFunc().Add(-db.maxIdleTime) + last := len(db.freeConn) - 1 + for i := last; i >= 0; i-- { + c := db.freeConn[i] + if c.returnedAt.Before(idleSince) { + i++ + closing = db.freeConn[:i:i] + db.freeConn = db.freeConn[i:] + idleClosing = int64(len(closing)) + db.maxIdleTimeClosed += idleClosing + break + } + } + + if len(db.freeConn) > 0 { + c := db.freeConn[0] + if d2 := c.returnedAt.Sub(idleSince); d2 < d { + // Ensure idle connections are cleaned up as soon as + // possible. + d = d2 + } + } + } + + if db.maxLifetime > 0 { + expiredSince := nowFunc().Add(-db.maxLifetime) + for i := 0; i < len(db.freeConn); i++ { + c := db.freeConn[i] + if c.createdAt.Before(expiredSince) { + closing = append(closing, c) + + last := len(db.freeConn) - 1 + // Use slow delete as order is required to ensure + // connections are reused least idle time first. + copy(db.freeConn[i:], db.freeConn[i+1:]) + db.freeConn[last] = nil + db.freeConn = db.freeConn[:last] + i-- + } else if d2 := c.createdAt.Sub(expiredSince); d2 < d { + // Prevent connections sitting the freeConn when they + // have expired by updating our next deadline d. + d = d2 + } + } + db.maxLifetimeClosed += int64(len(closing)) - idleClosing + } + + return d, closing +} + +// DBStats contains database statistics. +type DBStats struct { + MaxOpenConnections int // Maximum number of open connections to the database. + + // Pool Status + OpenConnections int // The number of established connections both in use and idle. + InUse int // The number of connections currently in use. + Idle int // The number of idle connections. + + // Counters + WaitCount int64 // The total number of connections waited for. + WaitDuration time.Duration // The total time blocked waiting for a new connection. + MaxIdleClosed int64 // The total number of connections closed due to SetMaxIdleConns. + MaxIdleTimeClosed int64 // The total number of connections closed due to SetConnMaxIdleTime. + MaxLifetimeClosed int64 // The total number of connections closed due to SetConnMaxLifetime. +} + +// Stats returns database statistics. +func (db *DB) Stats() DBStats { + wait := db.waitDuration.Load() + + db.mu.Lock() + defer db.mu.Unlock() + + stats := DBStats{ + MaxOpenConnections: db.maxOpen, + + Idle: len(db.freeConn), + OpenConnections: db.numOpen, + InUse: db.numOpen - len(db.freeConn), + + WaitCount: db.waitCount, + WaitDuration: time.Duration(wait), + MaxIdleClosed: db.maxIdleClosed, + MaxIdleTimeClosed: db.maxIdleTimeClosed, + MaxLifetimeClosed: db.maxLifetimeClosed, + } + return stats +} + +// Assumes db.mu is locked. +// If there are connRequests and the connection limit hasn't been reached, +// then tell the connectionOpener to open new connections. +func (db *DB) maybeOpenNewConnections() { + numRequests := len(db.connRequests) + if db.maxOpen > 0 { + numCanOpen := db.maxOpen - db.numOpen + if numRequests > numCanOpen { + numRequests = numCanOpen + } + } + for numRequests > 0 { + db.numOpen++ // optimistically + numRequests-- + if db.closed { + return + } + db.openerCh <- struct{}{} + } +} + +// Runs in a separate goroutine, opens new connections when requested. +func (db *DB) connectionOpener(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-db.openerCh: + db.openNewConnection(ctx) + } + } +} + +// Open one new connection +func (db *DB) openNewConnection(ctx context.Context) { + // maybeOpenNewConnections has already executed db.numOpen++ before it sent + // on db.openerCh. This function must execute db.numOpen-- if the + // connection fails or is closed before returning. + ci, err := db.connector.Connect(ctx) + db.mu.Lock() + defer db.mu.Unlock() + if db.closed { + if err == nil { + ci.Close() + } + db.numOpen-- + return + } + if err != nil { + db.numOpen-- + db.putConnDBLocked(nil, err) + db.maybeOpenNewConnections() + return + } + dc := &driverConn{ + db: db, + createdAt: nowFunc(), + returnedAt: nowFunc(), + ci: ci, + } + if db.putConnDBLocked(dc, err) { + db.addDepLocked(dc, dc) + } else { + db.numOpen-- + ci.Close() + } +} + +// connRequest represents one request for a new connection +// When there are no idle connections available, DB.conn will create +// a new connRequest and put it on the db.connRequests list. +type connRequest struct { + conn *driverConn + err error +} + +var errDBClosed = errors.New("sql: database is closed") + +// nextRequestKeyLocked returns the next connection request key. +// It is assumed that nextRequest will not overflow. +func (db *DB) nextRequestKeyLocked() uint64 { + next := db.nextRequest + db.nextRequest++ + return next +} + +// conn returns a newly-opened or cached *driverConn. +func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) { + db.mu.Lock() + if db.closed { + db.mu.Unlock() + return nil, errDBClosed + } + // Check if the context is expired. + select { + default: + case <-ctx.Done(): + db.mu.Unlock() + return nil, ctx.Err() + } + lifetime := db.maxLifetime + + // Prefer a free connection, if possible. + last := len(db.freeConn) - 1 + if strategy == cachedOrNewConn && last >= 0 { + // Reuse the lowest idle time connection so we can close + // connections which remain idle as soon as possible. + conn := db.freeConn[last] + db.freeConn = db.freeConn[:last] + conn.inUse = true + if conn.expired(lifetime) { + db.maxLifetimeClosed++ + db.mu.Unlock() + conn.Close() + return nil, driver.ErrBadConn + } + db.mu.Unlock() + + // Reset the session if required. + if err := conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) { + conn.Close() + return nil, err + } + + return conn, nil + } + + // Out of free connections or we were asked not to use one. If we're not + // allowed to open any more connections, make a request and wait. + if db.maxOpen > 0 && db.numOpen >= db.maxOpen { + // Make the connRequest channel. It's buffered so that the + // connectionOpener doesn't block while waiting for the req to be read. + req := make(chan connRequest, 1) + reqKey := db.nextRequestKeyLocked() + db.connRequests[reqKey] = req + db.waitCount++ + db.mu.Unlock() + + waitStart := nowFunc() + + // Timeout the connection request with the context. + select { + case <-ctx.Done(): + // Remove the connection request and ensure no value has been sent + // on it after removing. + db.mu.Lock() + delete(db.connRequests, reqKey) + db.mu.Unlock() + + db.waitDuration.Add(int64(time.Since(waitStart))) + + select { + default: + case ret, ok := <-req: + if ok && ret.conn != nil { + db.putConn(ret.conn, ret.err, false) + } + } + return nil, ctx.Err() + case ret, ok := <-req: + db.waitDuration.Add(int64(time.Since(waitStart))) + + if !ok { + return nil, errDBClosed + } + // Only check if the connection is expired if the strategy is cachedOrNewConns. + // If we require a new connection, just re-use the connection without looking + // at the expiry time. If it is expired, it will be checked when it is placed + // back into the connection pool. + // This prioritizes giving a valid connection to a client over the exact connection + // lifetime, which could expire exactly after this point anyway. + if strategy == cachedOrNewConn && ret.err == nil && ret.conn.expired(lifetime) { + db.mu.Lock() + db.maxLifetimeClosed++ + db.mu.Unlock() + ret.conn.Close() + return nil, driver.ErrBadConn + } + if ret.conn == nil { + return nil, ret.err + } + + // Reset the session if required. + if err := ret.conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) { + ret.conn.Close() + return nil, err + } + return ret.conn, ret.err + } + } + + db.numOpen++ // optimistically + db.mu.Unlock() + ci, err := db.connector.Connect(ctx) + if err != nil { + db.mu.Lock() + db.numOpen-- // correct for earlier optimism + db.maybeOpenNewConnections() + db.mu.Unlock() + return nil, err + } + db.mu.Lock() + dc := &driverConn{ + db: db, + createdAt: nowFunc(), + returnedAt: nowFunc(), + ci: ci, + inUse: true, + } + db.addDepLocked(dc, dc) + db.mu.Unlock() + return dc, nil +} + +// putConnHook is a hook for testing. +var putConnHook func(*DB, *driverConn) + +// noteUnusedDriverStatement notes that ds is no longer used and should +// be closed whenever possible (when c is next not in use), unless c is +// already closed. +func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) { + db.mu.Lock() + defer db.mu.Unlock() + if c.inUse { + c.onPut = append(c.onPut, func() { + ds.Close() + }) + } else { + c.Lock() + fc := c.finalClosed + c.Unlock() + if !fc { + ds.Close() + } + } +} + +// debugGetPut determines whether getConn & putConn calls' stack traces +// are returned for more verbose crashes. +const debugGetPut = false + +// putConn adds a connection to the db's free pool. +// err is optionally the last error that occurred on this connection. +func (db *DB) putConn(dc *driverConn, err error, resetSession bool) { + if !errors.Is(err, driver.ErrBadConn) { + if !dc.validateConnection(resetSession) { + err = driver.ErrBadConn + } + } + db.mu.Lock() + if !dc.inUse { + db.mu.Unlock() + if debugGetPut { + fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc]) + } + panic("sql: connection returned that was never out") + } + + if !errors.Is(err, driver.ErrBadConn) && dc.expired(db.maxLifetime) { + db.maxLifetimeClosed++ + err = driver.ErrBadConn + } + if debugGetPut { + db.lastPut[dc] = stack() + } + dc.inUse = false + dc.returnedAt = nowFunc() + + for _, fn := range dc.onPut { + fn() + } + dc.onPut = nil + + if errors.Is(err, driver.ErrBadConn) { + // Don't reuse bad connections. + // Since the conn is considered bad and is being discarded, treat it + // as closed. Don't decrement the open count here, finalClose will + // take care of that. + db.maybeOpenNewConnections() + db.mu.Unlock() + dc.Close() + return + } + if putConnHook != nil { + putConnHook(db, dc) + } + added := db.putConnDBLocked(dc, nil) + db.mu.Unlock() + + if !added { + dc.Close() + return + } +} + +// Satisfy a connRequest or put the driverConn in the idle pool and return true +// or return false. +// putConnDBLocked will satisfy a connRequest if there is one, or it will +// return the *driverConn to the freeConn list if err == nil and the idle +// connection limit will not be exceeded. +// If err != nil, the value of dc is ignored. +// If err == nil, then dc must not equal nil. +// If a connRequest was fulfilled or the *driverConn was placed in the +// freeConn list, then true is returned, otherwise false is returned. +func (db *DB) putConnDBLocked(dc *driverConn, err error) bool { + if db.closed { + return false + } + if db.maxOpen > 0 && db.numOpen > db.maxOpen { + return false + } + if c := len(db.connRequests); c > 0 { + var req chan connRequest + var reqKey uint64 + for reqKey, req = range db.connRequests { + break + } + delete(db.connRequests, reqKey) // Remove from pending requests. + if err == nil { + dc.inUse = true + } + req <- connRequest{ + conn: dc, + err: err, + } + return true + } else if err == nil && !db.closed { + if db.maxIdleConnsLocked() > len(db.freeConn) { + db.freeConn = append(db.freeConn, dc) + db.startCleanerLocked() + return true + } + db.maxIdleClosed++ + } + return false +} + +// maxBadConnRetries is the number of maximum retries if the driver returns +// driver.ErrBadConn to signal a broken connection before forcing a new +// connection to be opened. +const maxBadConnRetries = 2 + +func (db *DB) retry(fn func(strategy connReuseStrategy) error) error { + for i := int64(0); i < maxBadConnRetries; i++ { + err := fn(cachedOrNewConn) + // retry if err is driver.ErrBadConn + if err == nil || !errors.Is(err, driver.ErrBadConn) { + return err + } + } + + return fn(alwaysNewConn) +} + +// PrepareContext creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. +// The caller must call the statement's Close method +// when the statement is no longer needed. +// +// The provided context is used for the preparation of the statement, not for the +// execution of the statement. +func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + var stmt *Stmt + var err error + + err = db.retry(func(strategy connReuseStrategy) error { + stmt, err = db.prepare(ctx, query, strategy) + return err + }) + + return stmt, err +} + +// Prepare creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. +// The caller must call the statement's Close method +// when the statement is no longer needed. +// +// Prepare uses context.Background internally; to specify the context, use +// PrepareContext. +func (db *DB) Prepare(query string) (*Stmt, error) { + return db.PrepareContext(context.Background(), query) +} + +func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) { + // TODO: check if db.driver supports an optional + // driver.Preparer interface and call that instead, if so, + // otherwise we make a prepared statement that's bound + // to a connection, and to execute this prepared statement + // we either need to use this connection (if it's free), else + // get a new connection + re-prepare + execute on that one. + dc, err := db.conn(ctx, strategy) + if err != nil { + return nil, err + } + return db.prepareDC(ctx, dc, dc.releaseConn, nil, query) +} + +// prepareDC prepares a query on the driverConn and calls release before +// returning. When cg == nil it implies that a connection pool is used, and +// when cg != nil only a single driver connection is used. +func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) { + var ds *driverStmt + var err error + defer func() { + release(err) + }() + withLock(dc, func() { + ds, err = dc.prepareLocked(ctx, cg, query) + }) + if err != nil { + return nil, err + } + stmt := &Stmt{ + db: db, + query: query, + cg: cg, + cgds: ds, + } + + // When cg == nil this statement will need to keep track of various + // connections they are prepared on and record the stmt dependency on + // the DB. + if cg == nil { + stmt.css = []connStmt{{dc, ds}} + stmt.lastNumClosed = db.numClosed.Load() + db.addDep(stmt, stmt) + } + return stmt, nil +} + +// ExecContext executes a query without returning any rows. +// The args are for any placeholder parameters in the query. +func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error) { + var res Result + var err error + + err = db.retry(func(strategy connReuseStrategy) error { + res, err = db.exec(ctx, query, args, strategy) + return err + }) + + return res, err +} + +// Exec executes a query without returning any rows. +// The args are for any placeholder parameters in the query. +// +// Exec uses context.Background internally; to specify the context, use +// ExecContext. +func (db *DB) Exec(query string, args ...any) (Result, error) { + return db.ExecContext(context.Background(), query, args...) +} + +func (db *DB) exec(ctx context.Context, query string, args []any, strategy connReuseStrategy) (Result, error) { + dc, err := db.conn(ctx, strategy) + if err != nil { + return nil, err + } + return db.execDC(ctx, dc, dc.releaseConn, query, args) +} + +func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), query string, args []any) (res Result, err error) { + defer func() { + release(err) + }() + execerCtx, ok := dc.ci.(driver.ExecerContext) + var execer driver.Execer + if !ok { + execer, ok = dc.ci.(driver.Execer) + } + if ok { + var nvdargs []driver.NamedValue + var resi driver.Result + withLock(dc, func() { + nvdargs, err = driverArgsConnLocked(dc.ci, nil, args) + if err != nil { + return + } + resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs) + }) + if err != driver.ErrSkip { + if err != nil { + return nil, err + } + return driverResult{dc, resi}, nil + } + } + + var si driver.Stmt + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, query) + }) + if err != nil { + return nil, err + } + ds := &driverStmt{Locker: dc, si: si} + defer ds.Close() + return resultFromStatement(ctx, dc.ci, ds, args...) +} + +// QueryContext executes a query that returns rows, typically a SELECT. +// The args are for any placeholder parameters in the query. +func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) { + var rows *Rows + var err error + + err = db.retry(func(strategy connReuseStrategy) error { + rows, err = db.query(ctx, query, args, strategy) + return err + }) + + return rows, err +} + +// Query executes a query that returns rows, typically a SELECT. +// The args are for any placeholder parameters in the query. +// +// Query uses context.Background internally; to specify the context, use +// QueryContext. +func (db *DB) Query(query string, args ...any) (*Rows, error) { + return db.QueryContext(context.Background(), query, args...) +} + +func (db *DB) query(ctx context.Context, query string, args []any, strategy connReuseStrategy) (*Rows, error) { + dc, err := db.conn(ctx, strategy) + if err != nil { + return nil, err + } + + return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args) +} + +// queryDC executes a query on the given connection. +// The connection gets released by the releaseConn function. +// The ctx context is from a query method and the txctx context is from an +// optional transaction context. +func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []any) (*Rows, error) { + queryerCtx, ok := dc.ci.(driver.QueryerContext) + var queryer driver.Queryer + if !ok { + queryer, ok = dc.ci.(driver.Queryer) + } + if ok { + var nvdargs []driver.NamedValue + var rowsi driver.Rows + var err error + withLock(dc, func() { + nvdargs, err = driverArgsConnLocked(dc.ci, nil, args) + if err != nil { + return + } + rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs) + }) + if err != driver.ErrSkip { + if err != nil { + releaseConn(err) + return nil, err + } + // Note: ownership of dc passes to the *Rows, to be freed + // with releaseConn. + rows := &Rows{ + dc: dc, + releaseConn: releaseConn, + rowsi: rowsi, + } + rows.initContextClose(ctx, txctx) + return rows, nil + } + } + + var si driver.Stmt + var err error + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, query) + }) + if err != nil { + releaseConn(err) + return nil, err + } + + ds := &driverStmt{Locker: dc, si: si} + rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...) + if err != nil { + ds.Close() + releaseConn(err) + return nil, err + } + + // Note: ownership of ci passes to the *Rows, to be freed + // with releaseConn. + rows := &Rows{ + dc: dc, + releaseConn: releaseConn, + rowsi: rowsi, + closeStmt: ds, + } + rows.initContextClose(ctx, txctx) + return rows, nil +} + +// QueryRowContext executes a query that is expected to return at most one row. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *Row { + rows, err := db.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err} +} + +// QueryRow executes a query that is expected to return at most one row. +// QueryRow always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +// +// QueryRow uses context.Background internally; to specify the context, use +// QueryRowContext. +func (db *DB) QueryRow(query string, args ...any) *Row { + return db.QueryRowContext(context.Background(), query, args...) +} + +// BeginTx starts a transaction. +// +// The provided context is used until the transaction is committed or rolled back. +// If the context is canceled, the sql package will roll back +// the transaction. Tx.Commit will return an error if the context provided to +// BeginTx is canceled. +// +// The provided TxOptions is optional and may be nil if defaults should be used. +// If a non-default isolation level is used that the driver doesn't support, +// an error will be returned. +func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) { + var tx *Tx + var err error + + err = db.retry(func(strategy connReuseStrategy) error { + tx, err = db.begin(ctx, opts, strategy) + return err + }) + + return tx, err +} + +// Begin starts a transaction. The default isolation level is dependent on +// the driver. +// +// Begin uses context.Background internally; to specify the context, use +// BeginTx. +func (db *DB) Begin() (*Tx, error) { + return db.BeginTx(context.Background(), nil) +} + +func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) { + dc, err := db.conn(ctx, strategy) + if err != nil { + return nil, err + } + return db.beginDC(ctx, dc, dc.releaseConn, opts) +} + +// beginDC starts a transaction. The provided dc must be valid and ready to use. +func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) { + var txi driver.Tx + keepConnOnRollback := false + withLock(dc, func() { + _, hasSessionResetter := dc.ci.(driver.SessionResetter) + _, hasConnectionValidator := dc.ci.(driver.Validator) + keepConnOnRollback = hasSessionResetter && hasConnectionValidator + txi, err = ctxDriverBegin(ctx, opts, dc.ci) + }) + if err != nil { + release(err) + return nil, err + } + + // Schedule the transaction to rollback when the context is canceled. + // The cancel function in Tx will be called after done is set to true. + ctx, cancel := context.WithCancel(ctx) + tx = &Tx{ + db: db, + dc: dc, + releaseConn: release, + txi: txi, + cancel: cancel, + keepConnOnRollback: keepConnOnRollback, + ctx: ctx, + } + go tx.awaitDone() + return tx, nil +} + +// Driver returns the database's underlying driver. +func (db *DB) Driver() driver.Driver { + return db.connector.Driver() +} + +// ErrConnDone is returned by any operation that is performed on a connection +// that has already been returned to the connection pool. +var ErrConnDone = errors.New("sql: connection is already closed") + +// Conn returns a single connection by either opening a new connection +// or returning an existing connection from the connection pool. Conn will +// block until either a connection is returned or ctx is canceled. +// Queries run on the same Conn will be run in the same database session. +// +// Every Conn must be returned to the database pool after use by +// calling Conn.Close. +func (db *DB) Conn(ctx context.Context) (*Conn, error) { + var dc *driverConn + var err error + + err = db.retry(func(strategy connReuseStrategy) error { + dc, err = db.conn(ctx, strategy) + return err + }) + + if err != nil { + return nil, err + } + + conn := &Conn{ + db: db, + dc: dc, + } + return conn, nil +} + +type releaseConn func(error) + +// Conn represents a single database connection rather than a pool of database +// connections. Prefer running queries from DB unless there is a specific +// need for a continuous single database connection. +// +// A Conn must call Close to return the connection to the database pool +// and may do so concurrently with a running query. +// +// After a call to Close, all operations on the +// connection fail with ErrConnDone. +type Conn struct { + db *DB + + // closemu prevents the connection from closing while there + // is an active query. It is held for read during queries + // and exclusively during close. + closemu sync.RWMutex + + // dc is owned until close, at which point + // it's returned to the connection pool. + dc *driverConn + + // done transitions from 0 to 1 exactly once, on close. + // Once done, all operations fail with ErrConnDone. + // Use atomic operations on value when checking value. + done int32 +} + +// grabConn takes a context to implement stmtConnGrabber +// but the context is not used. +func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) { + if atomic.LoadInt32(&c.done) != 0 { + return nil, nil, ErrConnDone + } + c.closemu.RLock() + return c.dc, c.closemuRUnlockCondReleaseConn, nil +} + +// PingContext verifies the connection to the database is still alive. +func (c *Conn) PingContext(ctx context.Context) error { + dc, release, err := c.grabConn(ctx) + if err != nil { + return err + } + return c.db.pingDC(ctx, dc, release) +} + +// ExecContext executes a query without returning any rows. +// The args are for any placeholder parameters in the query. +func (c *Conn) ExecContext(ctx context.Context, query string, args ...any) (Result, error) { + dc, release, err := c.grabConn(ctx) + if err != nil { + return nil, err + } + return c.db.execDC(ctx, dc, release, query, args) +} + +// QueryContext executes a query that returns rows, typically a SELECT. +// The args are for any placeholder parameters in the query. +func (c *Conn) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) { + dc, release, err := c.grabConn(ctx) + if err != nil { + return nil, err + } + return c.db.queryDC(ctx, nil, dc, release, query, args) +} + +// QueryRowContext executes a query that is expected to return at most one row. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...any) *Row { + rows, err := c.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err} +} + +// PrepareContext creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. +// The caller must call the statement's Close method +// when the statement is no longer needed. +// +// The provided context is used for the preparation of the statement, not for the +// execution of the statement. +func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + dc, release, err := c.grabConn(ctx) + if err != nil { + return nil, err + } + return c.db.prepareDC(ctx, dc, release, c, query) +} + +// Raw executes f exposing the underlying driver connection for the +// duration of f. The driverConn must not be used outside of f. +// +// Once f returns and err is not driver.ErrBadConn, the Conn will continue to be usable +// until Conn.Close is called. +func (c *Conn) Raw(f func(driverConn any) error) (err error) { + var dc *driverConn + var release releaseConn + + // grabConn takes a context to implement stmtConnGrabber, but the context is not used. + dc, release, err = c.grabConn(nil) + if err != nil { + return + } + fPanic := true + dc.Mutex.Lock() + defer func() { + dc.Mutex.Unlock() + + // If f panics fPanic will remain true. + // Ensure an error is passed to release so the connection + // may be discarded. + if fPanic { + err = driver.ErrBadConn + } + release(err) + }() + err = f(dc.ci) + fPanic = false + + return +} + +// BeginTx starts a transaction. +// +// The provided context is used until the transaction is committed or rolled back. +// If the context is canceled, the sql package will roll back +// the transaction. Tx.Commit will return an error if the context provided to +// BeginTx is canceled. +// +// The provided TxOptions is optional and may be nil if defaults should be used. +// If a non-default isolation level is used that the driver doesn't support, +// an error will be returned. +func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) { + dc, release, err := c.grabConn(ctx) + if err != nil { + return nil, err + } + return c.db.beginDC(ctx, dc, release, opts) +} + +// closemuRUnlockCondReleaseConn read unlocks closemu +// as the sql operation is done with the dc. +func (c *Conn) closemuRUnlockCondReleaseConn(err error) { + c.closemu.RUnlock() + if errors.Is(err, driver.ErrBadConn) { + c.close(err) + } +} + +func (c *Conn) txCtx() context.Context { + return nil +} + +func (c *Conn) close(err error) error { + if !atomic.CompareAndSwapInt32(&c.done, 0, 1) { + return ErrConnDone + } + + // Lock around releasing the driver connection + // to ensure all queries have been stopped before doing so. + c.closemu.Lock() + defer c.closemu.Unlock() + + c.dc.releaseConn(err) + c.dc = nil + c.db = nil + return err +} + +// Close returns the connection to the connection pool. +// All operations after a Close will return with ErrConnDone. +// Close is safe to call concurrently with other operations and will +// block until all other operations finish. It may be useful to first +// cancel any used context and then call close directly after. +func (c *Conn) Close() error { + return c.close(nil) +} + +// Tx is an in-progress database transaction. +// +// A transaction must end with a call to Commit or Rollback. +// +// After a call to Commit or Rollback, all operations on the +// transaction fail with ErrTxDone. +// +// The statements prepared for a transaction by calling +// the transaction's Prepare or Stmt methods are closed +// by the call to Commit or Rollback. +type Tx struct { + db *DB + + // closemu prevents the transaction from closing while there + // is an active query. It is held for read during queries + // and exclusively during close. + closemu sync.RWMutex + + // dc is owned exclusively until Commit or Rollback, at which point + // it's returned with putConn. + dc *driverConn + txi driver.Tx + + // releaseConn is called once the Tx is closed to release + // any held driverConn back to the pool. + releaseConn func(error) + + // done transitions from false to true exactly once, on Commit + // or Rollback. once done, all operations fail with + // ErrTxDone. + done atomic.Bool + + // keepConnOnRollback is true if the driver knows + // how to reset the connection's session and if need be discard + // the connection. + keepConnOnRollback bool + + // All Stmts prepared for this transaction. These will be closed after the + // transaction has been committed or rolled back. + stmts struct { + sync.Mutex + v []*Stmt + } + + // cancel is called after done transitions from 0 to 1. + cancel func() + + // ctx lives for the life of the transaction. + ctx context.Context +} + +// awaitDone blocks until the context in Tx is canceled and rolls back +// the transaction if it's not already done. +func (tx *Tx) awaitDone() { + // Wait for either the transaction to be committed or rolled + // back, or for the associated context to be closed. + <-tx.ctx.Done() + + // Discard and close the connection used to ensure the + // transaction is closed and the resources are released. This + // rollback does nothing if the transaction has already been + // committed or rolled back. + // Do not discard the connection if the connection knows + // how to reset the session. + discardConnection := !tx.keepConnOnRollback + tx.rollback(discardConnection) +} + +func (tx *Tx) isDone() bool { + return tx.done.Load() +} + +// ErrTxDone is returned by any operation that is performed on a transaction +// that has already been committed or rolled back. +var ErrTxDone = errors.New("sql: transaction has already been committed or rolled back") + +// close returns the connection to the pool and +// must only be called by Tx.rollback or Tx.Commit while +// tx is already canceled and won't be executed concurrently. +func (tx *Tx) close(err error) { + tx.releaseConn(err) + tx.dc = nil + tx.txi = nil +} + +// hookTxGrabConn specifies an optional hook to be called on +// a successful call to (*Tx).grabConn. For tests. +var hookTxGrabConn func() + +func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) { + select { + default: + case <-ctx.Done(): + return nil, nil, ctx.Err() + } + + // closemu.RLock must come before the check for isDone to prevent the Tx from + // closing while a query is executing. + tx.closemu.RLock() + if tx.isDone() { + tx.closemu.RUnlock() + return nil, nil, ErrTxDone + } + if hookTxGrabConn != nil { // test hook + hookTxGrabConn() + } + return tx.dc, tx.closemuRUnlockRelease, nil +} + +func (tx *Tx) txCtx() context.Context { + return tx.ctx +} + +// closemuRUnlockRelease is used as a func(error) method value in +// ExecContext and QueryContext. Unlocking in the releaseConn keeps +// the driver conn from being returned to the connection pool until +// the Rows has been closed. +func (tx *Tx) closemuRUnlockRelease(error) { + tx.closemu.RUnlock() +} + +// Closes all Stmts prepared for this transaction. +func (tx *Tx) closePrepared() { + tx.stmts.Lock() + defer tx.stmts.Unlock() + for _, stmt := range tx.stmts.v { + stmt.Close() + } +} + +// Commit commits the transaction. +func (tx *Tx) Commit() error { + // Check context first to avoid transaction leak. + // If put it behind tx.done CompareAndSwap statement, we can't ensure + // the consistency between tx.done and the real COMMIT operation. + select { + default: + case <-tx.ctx.Done(): + if tx.done.Load() { + return ErrTxDone + } + return tx.ctx.Err() + } + if !tx.done.CompareAndSwap(false, true) { + return ErrTxDone + } + + // Cancel the Tx to release any active R-closemu locks. + // This is safe to do because tx.done has already transitioned + // from 0 to 1. Hold the W-closemu lock prior to rollback + // to ensure no other connection has an active query. + tx.cancel() + tx.closemu.Lock() + tx.closemu.Unlock() + + var err error + withLock(tx.dc, func() { + err = tx.txi.Commit() + }) + if !errors.Is(err, driver.ErrBadConn) { + tx.closePrepared() + } + tx.close(err) + return err +} + +var rollbackHook func() + +// rollback aborts the transaction and optionally forces the pool to discard +// the connection. +func (tx *Tx) rollback(discardConn bool) error { + if !tx.done.CompareAndSwap(false, true) { + return ErrTxDone + } + + if rollbackHook != nil { + rollbackHook() + } + + // Cancel the Tx to release any active R-closemu locks. + // This is safe to do because tx.done has already transitioned + // from 0 to 1. Hold the W-closemu lock prior to rollback + // to ensure no other connection has an active query. + tx.cancel() + tx.closemu.Lock() + tx.closemu.Unlock() + + var err error + withLock(tx.dc, func() { + err = tx.txi.Rollback() + }) + if !errors.Is(err, driver.ErrBadConn) { + tx.closePrepared() + } + if discardConn { + err = driver.ErrBadConn + } + tx.close(err) + return err +} + +// Rollback aborts the transaction. +func (tx *Tx) Rollback() error { + return tx.rollback(false) +} + +// PrepareContext creates a prepared statement for use within a transaction. +// +// The returned statement operates within the transaction and will be closed +// when the transaction has been committed or rolled back. +// +// To use an existing prepared statement on this transaction, see Tx.Stmt. +// +// The provided context will be used for the preparation of the context, not +// for the execution of the returned statement. The returned statement +// will run in the transaction context. +func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { + dc, release, err := tx.grabConn(ctx) + if err != nil { + return nil, err + } + + stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query) + if err != nil { + return nil, err + } + tx.stmts.Lock() + tx.stmts.v = append(tx.stmts.v, stmt) + tx.stmts.Unlock() + return stmt, nil +} + +// Prepare creates a prepared statement for use within a transaction. +// +// The returned statement operates within the transaction and will be closed +// when the transaction has been committed or rolled back. +// +// To use an existing prepared statement on this transaction, see Tx.Stmt. +// +// Prepare uses context.Background internally; to specify the context, use +// PrepareContext. +func (tx *Tx) Prepare(query string) (*Stmt, error) { + return tx.PrepareContext(context.Background(), query) +} + +// StmtContext returns a transaction-specific prepared statement from +// an existing statement. +// +// Example: +// +// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?") +// ... +// tx, err := db.Begin() +// ... +// res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203) +// +// The provided context is used for the preparation of the statement, not for the +// execution of the statement. +// +// The returned statement operates within the transaction and will be closed +// when the transaction has been committed or rolled back. +func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { + dc, release, err := tx.grabConn(ctx) + if err != nil { + return &Stmt{stickyErr: err} + } + defer release(nil) + + if tx.db != stmt.db { + return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} + } + var si driver.Stmt + var parentStmt *Stmt + stmt.mu.Lock() + if stmt.closed || stmt.cg != nil { + // If the statement has been closed or already belongs to a + // transaction, we can't reuse it in this connection. + // Since tx.StmtContext should never need to be called with a + // Stmt already belonging to tx, we ignore this edge case and + // re-prepare the statement in this case. No need to add + // code-complexity for this. + stmt.mu.Unlock() + withLock(dc, func() { + si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query) + }) + if err != nil { + return &Stmt{stickyErr: err} + } + } else { + stmt.removeClosedStmtLocked() + // See if the statement has already been prepared on this connection, + // and reuse it if possible. + for _, v := range stmt.css { + if v.dc == dc { + si = v.ds.si + break + } + } + + stmt.mu.Unlock() + + if si == nil { + var ds *driverStmt + withLock(dc, func() { + ds, err = stmt.prepareOnConnLocked(ctx, dc) + }) + if err != nil { + return &Stmt{stickyErr: err} + } + si = ds.si + } + parentStmt = stmt + } + + txs := &Stmt{ + db: tx.db, + cg: tx, + cgds: &driverStmt{ + Locker: dc, + si: si, + }, + parentStmt: parentStmt, + query: stmt.query, + } + if parentStmt != nil { + tx.db.addDep(parentStmt, txs) + } + tx.stmts.Lock() + tx.stmts.v = append(tx.stmts.v, txs) + tx.stmts.Unlock() + return txs +} + +// Stmt returns a transaction-specific prepared statement from +// an existing statement. +// +// Example: +// +// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?") +// ... +// tx, err := db.Begin() +// ... +// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203) +// +// The returned statement operates within the transaction and will be closed +// when the transaction has been committed or rolled back. +// +// Stmt uses context.Background internally; to specify the context, use +// StmtContext. +func (tx *Tx) Stmt(stmt *Stmt) *Stmt { + return tx.StmtContext(context.Background(), stmt) +} + +// ExecContext executes a query that doesn't return rows. +// For example: an INSERT and UPDATE. +func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (Result, error) { + dc, release, err := tx.grabConn(ctx) + if err != nil { + return nil, err + } + return tx.db.execDC(ctx, dc, release, query, args) +} + +// Exec executes a query that doesn't return rows. +// For example: an INSERT and UPDATE. +// +// Exec uses context.Background internally; to specify the context, use +// ExecContext. +func (tx *Tx) Exec(query string, args ...any) (Result, error) { + return tx.ExecContext(context.Background(), query, args...) +} + +// QueryContext executes a query that returns rows, typically a SELECT. +func (tx *Tx) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) { + dc, release, err := tx.grabConn(ctx) + if err != nil { + return nil, err + } + + return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args) +} + +// Query executes a query that returns rows, typically a SELECT. +// +// Query uses context.Background internally; to specify the context, use +// QueryContext. +func (tx *Tx) Query(query string, args ...any) (*Rows, error) { + return tx.QueryContext(context.Background(), query, args...) +} + +// QueryRowContext executes a query that is expected to return at most one row. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *Row { + rows, err := tx.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err} +} + +// QueryRow executes a query that is expected to return at most one row. +// QueryRow always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +// +// QueryRow uses context.Background internally; to specify the context, use +// QueryRowContext. +func (tx *Tx) QueryRow(query string, args ...any) *Row { + return tx.QueryRowContext(context.Background(), query, args...) +} + +// connStmt is a prepared statement on a particular connection. +type connStmt struct { + dc *driverConn + ds *driverStmt +} + +// stmtConnGrabber represents a Tx or Conn that will return the underlying +// driverConn and release function. +type stmtConnGrabber interface { + // grabConn returns the driverConn and the associated release function + // that must be called when the operation completes. + grabConn(context.Context) (*driverConn, releaseConn, error) + + // txCtx returns the transaction context if available. + // The returned context should be selected on along with + // any query context when awaiting a cancel. + txCtx() context.Context +} + +var ( + _ stmtConnGrabber = &Tx{} + _ stmtConnGrabber = &Conn{} +) + +// Stmt is a prepared statement. +// A Stmt is safe for concurrent use by multiple goroutines. +// +// If a Stmt is prepared on a Tx or Conn, it will be bound to a single +// underlying connection forever. If the Tx or Conn closes, the Stmt will +// become unusable and all operations will return an error. +// If a Stmt is prepared on a DB, it will remain usable for the lifetime of the +// DB. When the Stmt needs to execute on a new underlying connection, it will +// prepare itself on the new connection automatically. +type Stmt struct { + // Immutable: + db *DB // where we came from + query string // that created the Stmt + stickyErr error // if non-nil, this error is returned for all operations + + closemu sync.RWMutex // held exclusively during close, for read otherwise. + + // If Stmt is prepared on a Tx or Conn then cg is present and will + // only ever grab a connection from cg. + // If cg is nil then the Stmt must grab an arbitrary connection + // from db and determine if it must prepare the stmt again by + // inspecting css. + cg stmtConnGrabber + cgds *driverStmt + + // parentStmt is set when a transaction-specific statement + // is requested from an identical statement prepared on the same + // conn. parentStmt is used to track the dependency of this statement + // on its originating ("parent") statement so that parentStmt may + // be closed by the user without them having to know whether or not + // any transactions are still using it. + parentStmt *Stmt + + mu sync.Mutex // protects the rest of the fields + closed bool + + // css is a list of underlying driver statement interfaces + // that are valid on particular connections. This is only + // used if cg == nil and one is found that has idle + // connections. If cg != nil, cgds is always used. + css []connStmt + + // lastNumClosed is copied from db.numClosed when Stmt is created + // without tx and closed connections in css are removed. + lastNumClosed uint64 +} + +// ExecContext executes a prepared statement with the given arguments and +// returns a Result summarizing the effect of the statement. +func (s *Stmt) ExecContext(ctx context.Context, args ...any) (Result, error) { + s.closemu.RLock() + defer s.closemu.RUnlock() + + var res Result + err := s.db.retry(func(strategy connReuseStrategy) error { + dc, releaseConn, ds, err := s.connStmt(ctx, strategy) + if err != nil { + return err + } + + res, err = resultFromStatement(ctx, dc.ci, ds, args...) + releaseConn(err) + return err + }) + + return res, err +} + +// Exec executes a prepared statement with the given arguments and +// returns a Result summarizing the effect of the statement. +// +// Exec uses context.Background internally; to specify the context, use +// ExecContext. +func (s *Stmt) Exec(args ...any) (Result, error) { + return s.ExecContext(context.Background(), args...) +} + +func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (Result, error) { + ds.Lock() + defer ds.Unlock() + + dargs, err := driverArgsConnLocked(ci, ds, args) + if err != nil { + return nil, err + } + + resi, err := ctxDriverStmtExec(ctx, ds.si, dargs) + if err != nil { + return nil, err + } + return driverResult{ds.Locker, resi}, nil +} + +// removeClosedStmtLocked removes closed conns in s.css. +// +// To avoid lock contention on DB.mu, we do it only when +// s.db.numClosed - s.lastNum is large enough. +func (s *Stmt) removeClosedStmtLocked() { + t := len(s.css)/2 + 1 + if t > 10 { + t = 10 + } + dbClosed := s.db.numClosed.Load() + if dbClosed-s.lastNumClosed < uint64(t) { + return + } + + s.db.mu.Lock() + for i := 0; i < len(s.css); i++ { + if s.css[i].dc.dbmuClosed { + s.css[i] = s.css[len(s.css)-1] + s.css = s.css[:len(s.css)-1] + i-- + } + } + s.db.mu.Unlock() + s.lastNumClosed = dbClosed +} + +// connStmt returns a free driver connection on which to execute the +// statement, a function to call to release the connection, and a +// statement bound to that connection. +func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) { + if err = s.stickyErr; err != nil { + return + } + s.mu.Lock() + if s.closed { + s.mu.Unlock() + err = errors.New("sql: statement is closed") + return + } + + // In a transaction or connection, we always use the connection that the + // stmt was created on. + if s.cg != nil { + s.mu.Unlock() + dc, releaseConn, err = s.cg.grabConn(ctx) // blocks, waiting for the connection. + if err != nil { + return + } + return dc, releaseConn, s.cgds, nil + } + + s.removeClosedStmtLocked() + s.mu.Unlock() + + dc, err = s.db.conn(ctx, strategy) + if err != nil { + return nil, nil, nil, err + } + + s.mu.Lock() + for _, v := range s.css { + if v.dc == dc { + s.mu.Unlock() + return dc, dc.releaseConn, v.ds, nil + } + } + s.mu.Unlock() + + // No luck; we need to prepare the statement on this connection + withLock(dc, func() { + ds, err = s.prepareOnConnLocked(ctx, dc) + }) + if err != nil { + dc.releaseConn(err) + return nil, nil, nil, err + } + + return dc, dc.releaseConn, ds, nil +} + +// prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of +// open connStmt on the statement. It assumes the caller is holding the lock on dc. +func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) { + si, err := dc.prepareLocked(ctx, s.cg, s.query) + if err != nil { + return nil, err + } + cs := connStmt{dc, si} + s.mu.Lock() + s.css = append(s.css, cs) + s.mu.Unlock() + return cs.ds, nil +} + +// QueryContext executes a prepared query statement with the given arguments +// and returns the query results as a *Rows. +func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*Rows, error) { + s.closemu.RLock() + defer s.closemu.RUnlock() + + var rowsi driver.Rows + var rows *Rows + + err := s.db.retry(func(strategy connReuseStrategy) error { + dc, releaseConn, ds, err := s.connStmt(ctx, strategy) + if err != nil { + return err + } + + rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...) + if err == nil { + // Note: ownership of ci passes to the *Rows, to be freed + // with releaseConn. + rows = &Rows{ + dc: dc, + rowsi: rowsi, + // releaseConn set below + } + // addDep must be added before initContextClose or it could attempt + // to removeDep before it has been added. + s.db.addDep(s, rows) + + // releaseConn must be set before initContextClose or it could + // release the connection before it is set. + rows.releaseConn = func(err error) { + releaseConn(err) + s.db.removeDep(s, rows) + } + var txctx context.Context + if s.cg != nil { + txctx = s.cg.txCtx() + } + rows.initContextClose(ctx, txctx) + return nil + } + + releaseConn(err) + return err + }) + + return rows, err +} + +// Query executes a prepared query statement with the given arguments +// and returns the query results as a *Rows. +// +// Query uses context.Background internally; to specify the context, use +// QueryContext. +func (s *Stmt) Query(args ...any) (*Rows, error) { + return s.QueryContext(context.Background(), args...) +} + +func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...any) (driver.Rows, error) { + ds.Lock() + defer ds.Unlock() + dargs, err := driverArgsConnLocked(ci, ds, args) + if err != nil { + return nil, err + } + return ctxDriverStmtQuery(ctx, ds.si, dargs) +} + +// QueryRowContext executes a prepared query statement with the given arguments. +// If an error occurs during the execution of the statement, that error will +// be returned by a call to Scan on the returned *Row, which is always non-nil. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +func (s *Stmt) QueryRowContext(ctx context.Context, args ...any) *Row { + rows, err := s.QueryContext(ctx, args...) + if err != nil { + return &Row{err: err} + } + return &Row{rows: rows} +} + +// QueryRow executes a prepared query statement with the given arguments. +// If an error occurs during the execution of the statement, that error will +// be returned by a call to Scan on the returned *Row, which is always non-nil. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +// +// Example usage: +// +// var name string +// err := nameByUseridStmt.QueryRow(id).Scan(&name) +// +// QueryRow uses context.Background internally; to specify the context, use +// QueryRowContext. +func (s *Stmt) QueryRow(args ...any) *Row { + return s.QueryRowContext(context.Background(), args...) +} + +// Close closes the statement. +func (s *Stmt) Close() error { + s.closemu.Lock() + defer s.closemu.Unlock() + + if s.stickyErr != nil { + return s.stickyErr + } + s.mu.Lock() + if s.closed { + s.mu.Unlock() + return nil + } + s.closed = true + txds := s.cgds + s.cgds = nil + + s.mu.Unlock() + + if s.cg == nil { + return s.db.removeDep(s, s) + } + + if s.parentStmt != nil { + // If parentStmt is set, we must not close s.txds since it's stored + // in the css array of the parentStmt. + return s.db.removeDep(s.parentStmt, s) + } + return txds.Close() +} + +func (s *Stmt) finalClose() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.css != nil { + for _, v := range s.css { + s.db.noteUnusedDriverStatement(v.dc, v.ds) + v.dc.removeOpenStmt(v.ds) + } + s.css = nil + } + return nil +} + +// Rows is the result of a query. Its cursor starts before the first row +// of the result set. Use Next to advance from row to row. +type Rows struct { + dc *driverConn // owned; must call releaseConn when closed to release + releaseConn func(error) + rowsi driver.Rows + cancel func() // called when Rows is closed, may be nil. + closeStmt *driverStmt // if non-nil, statement to Close on close + + // closemu prevents Rows from closing while there + // is an active streaming result. It is held for read during non-close operations + // and exclusively during close. + // + // closemu guards lasterr and closed. + closemu sync.RWMutex + closed bool + lasterr error // non-nil only if closed is true + + // lastcols is only used in Scan, Next, and NextResultSet which are expected + // not to be called concurrently. + lastcols []driver.Value +} + +// lasterrOrErrLocked returns either lasterr or the provided err. +// rs.closemu must be read-locked. +func (rs *Rows) lasterrOrErrLocked(err error) error { + if rs.lasterr != nil && rs.lasterr != io.EOF { + return rs.lasterr + } + return err +} + +// bypassRowsAwaitDone is only used for testing. +// If true, it will not close the Rows automatically from the context. +var bypassRowsAwaitDone = false + +func (rs *Rows) initContextClose(ctx, txctx context.Context) { + if ctx.Done() == nil && (txctx == nil || txctx.Done() == nil) { + return + } + if bypassRowsAwaitDone { + return + } + ctx, rs.cancel = context.WithCancel(ctx) + go rs.awaitDone(ctx, txctx) +} + +// awaitDone blocks until either ctx or txctx is canceled. The ctx is provided +// from the query context and is canceled when the query Rows is closed. +// If the query was issued in a transaction, the transaction's context +// is also provided in txctx to ensure Rows is closed if the Tx is closed. +func (rs *Rows) awaitDone(ctx, txctx context.Context) { + var txctxDone <-chan struct{} + if txctx != nil { + txctxDone = txctx.Done() + } + select { + case <-ctx.Done(): + case <-txctxDone: + } + rs.close(ctx.Err()) +} + +// Next prepares the next result row for reading with the Scan method. It +// returns true on success, or false if there is no next result row or an error +// happened while preparing it. Err should be consulted to distinguish between +// the two cases. +// +// Every call to Scan, even the first one, must be preceded by a call to Next. +func (rs *Rows) Next() bool { + var doClose, ok bool + withLock(rs.closemu.RLocker(), func() { + doClose, ok = rs.nextLocked() + }) + if doClose { + rs.Close() + } + return ok +} + +func (rs *Rows) nextLocked() (doClose, ok bool) { + if rs.closed { + return false, false + } + + // Lock the driver connection before calling the driver interface + // rowsi to prevent a Tx from rolling back the connection at the same time. + rs.dc.Lock() + defer rs.dc.Unlock() + + if rs.lastcols == nil { + rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) + } + + rs.lasterr = rs.rowsi.Next(rs.lastcols) + if rs.lasterr != nil { + // Close the connection if there is a driver error. + if rs.lasterr != io.EOF { + return true, false + } + nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) + if !ok { + return true, false + } + // The driver is at the end of the current result set. + // Test to see if there is another result set after the current one. + // Only close Rows if there is no further result sets to read. + if !nextResultSet.HasNextResultSet() { + doClose = true + } + return doClose, false + } + return false, true +} + +// NextResultSet prepares the next result set for reading. It reports whether +// there is further result sets, or false if there is no further result set +// or if there is an error advancing to it. The Err method should be consulted +// to distinguish between the two cases. +// +// After calling NextResultSet, the Next method should always be called before +// scanning. If there are further result sets they may not have rows in the result +// set. +func (rs *Rows) NextResultSet() bool { + var doClose bool + defer func() { + if doClose { + rs.Close() + } + }() + rs.closemu.RLock() + defer rs.closemu.RUnlock() + + if rs.closed { + return false + } + + rs.lastcols = nil + nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) + if !ok { + doClose = true + return false + } + + // Lock the driver connection before calling the driver interface + // rowsi to prevent a Tx from rolling back the connection at the same time. + rs.dc.Lock() + defer rs.dc.Unlock() + + rs.lasterr = nextResultSet.NextResultSet() + if rs.lasterr != nil { + doClose = true + return false + } + return true +} + +// Err returns the error, if any, that was encountered during iteration. +// Err may be called after an explicit or implicit Close. +func (rs *Rows) Err() error { + rs.closemu.RLock() + defer rs.closemu.RUnlock() + return rs.lasterrOrErrLocked(nil) +} + +var errRowsClosed = errors.New("sql: Rows are closed") +var errNoRows = errors.New("sql: no Rows available") + +// Columns returns the column names. +// Columns returns an error if the rows are closed. +func (rs *Rows) Columns() ([]string, error) { + rs.closemu.RLock() + defer rs.closemu.RUnlock() + if rs.closed { + return nil, rs.lasterrOrErrLocked(errRowsClosed) + } + if rs.rowsi == nil { + return nil, rs.lasterrOrErrLocked(errNoRows) + } + rs.dc.Lock() + defer rs.dc.Unlock() + + return rs.rowsi.Columns(), nil +} + +// ColumnTypes returns column information such as column type, length, +// and nullable. Some information may not be available from some drivers. +func (rs *Rows) ColumnTypes() ([]*ColumnType, error) { + rs.closemu.RLock() + defer rs.closemu.RUnlock() + if rs.closed { + return nil, rs.lasterrOrErrLocked(errRowsClosed) + } + if rs.rowsi == nil { + return nil, rs.lasterrOrErrLocked(errNoRows) + } + rs.dc.Lock() + defer rs.dc.Unlock() + + return rowsColumnInfoSetupConnLocked(rs.rowsi), nil +} + +// ColumnType contains the name and type of a column. +type ColumnType struct { + name string + + hasNullable bool + hasLength bool + hasPrecisionScale bool + + nullable bool + length int64 + databaseType string + precision int64 + scale int64 + scanType reflect.Type +} + +// Name returns the name or alias of the column. +func (ci *ColumnType) Name() string { + return ci.name +} + +// Length returns the column type length for variable length column types such +// as text and binary field types. If the type length is unbounded the value will +// be math.MaxInt64 (any database limits will still apply). +// If the column type is not variable length, such as an int, or if not supported +// by the driver ok is false. +func (ci *ColumnType) Length() (length int64, ok bool) { + return ci.length, ci.hasLength +} + +// DecimalSize returns the scale and precision of a decimal type. +// If not applicable or if not supported ok is false. +func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) { + return ci.precision, ci.scale, ci.hasPrecisionScale +} + +// ScanType returns a Go type suitable for scanning into using Rows.Scan. +// If a driver does not support this property ScanType will return +// the type of an empty interface. +func (ci *ColumnType) ScanType() reflect.Type { + return ci.scanType +} + +// Nullable reports whether the column may be null. +// If a driver does not support this property ok will be false. +func (ci *ColumnType) Nullable() (nullable, ok bool) { + return ci.nullable, ci.hasNullable +} + +// DatabaseTypeName returns the database system name of the column type. If an empty +// string is returned, then the driver type name is not supported. +// Consult your driver documentation for a list of driver data types. Length specifiers +// are not included. +// Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", +// "INT", and "BIGINT". +func (ci *ColumnType) DatabaseTypeName() string { + return ci.databaseType +} + +func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType { + names := rowsi.Columns() + + list := make([]*ColumnType, len(names)) + for i := range list { + ci := &ColumnType{ + name: names[i], + } + list[i] = ci + + if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok { + ci.scanType = prop.ColumnTypeScanType(i) + } else { + ci.scanType = reflect.TypeOf(new(any)).Elem() + } + if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok { + ci.databaseType = prop.ColumnTypeDatabaseTypeName(i) + } + if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok { + ci.length, ci.hasLength = prop.ColumnTypeLength(i) + } + if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok { + ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i) + } + if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok { + ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i) + } + } + return list +} + +// Scan copies the columns in the current row into the values pointed +// at by dest. The number of values in dest must be the same as the +// number of columns in Rows. +// +// Scan converts columns read from the database into the following +// common Go types and special types provided by the sql package: +// +// *string +// *[]byte +// *int, *int8, *int16, *int32, *int64 +// *uint, *uint8, *uint16, *uint32, *uint64 +// *bool +// *float32, *float64 +// *interface{} +// *RawBytes +// *Rows (cursor value) +// any type implementing Scanner (see Scanner docs) +// +// In the most simple case, if the type of the value from the source +// column is an integer, bool or string type T and dest is of type *T, +// Scan simply assigns the value through the pointer. +// +// Scan also converts between string and numeric types, as long as no +// information would be lost. While Scan stringifies all numbers +// scanned from numeric database columns into *string, scans into +// numeric types are checked for overflow. For example, a float64 with +// value 300 or a string with value "300" can scan into a uint16, but +// not into a uint8, though float64(255) or "255" can scan into a +// uint8. One exception is that scans of some float64 numbers to +// strings may lose information when stringifying. In general, scan +// floating point columns into *float64. +// +// If a dest argument has type *[]byte, Scan saves in that argument a +// copy of the corresponding data. The copy is owned by the caller and +// can be modified and held indefinitely. The copy can be avoided by +// using an argument of type *RawBytes instead; see the documentation +// for RawBytes for restrictions on its use. +// +// If an argument has type *interface{}, Scan copies the value +// provided by the underlying driver without conversion. When scanning +// from a source value of type []byte to *interface{}, a copy of the +// slice is made and the caller owns the result. +// +// Source values of type time.Time may be scanned into values of type +// *time.Time, *interface{}, *string, or *[]byte. When converting to +// the latter two, time.RFC3339Nano is used. +// +// Source values of type bool may be scanned into types *bool, +// *interface{}, *string, *[]byte, or *RawBytes. +// +// For scanning into *bool, the source may be true, false, 1, 0, or +// string inputs parseable by strconv.ParseBool. +// +// Scan can also convert a cursor returned from a query, such as +// "select cursor(select * from my_table) from dual", into a +// *Rows value that can itself be scanned from. The parent +// select query will close any cursor *Rows if the parent *Rows is closed. +// +// If any of the first arguments implementing Scanner returns an error, +// that error will be wrapped in the returned error. +func (rs *Rows) Scan(dest ...any) error { + rs.closemu.RLock() + + if rs.lasterr != nil && rs.lasterr != io.EOF { + rs.closemu.RUnlock() + return rs.lasterr + } + if rs.closed { + err := rs.lasterrOrErrLocked(errRowsClosed) + rs.closemu.RUnlock() + return err + } + rs.closemu.RUnlock() + + if rs.lastcols == nil { + return errors.New("sql: Scan called without calling Next") + } + if len(dest) != len(rs.lastcols) { + return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest)) + } + for i, sv := range rs.lastcols { + err := convertAssignRows(dest[i], sv, rs) + if err != nil { + return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err) + } + } + return nil +} + +// rowsCloseHook returns a function so tests may install the +// hook through a test only mutex. +var rowsCloseHook = func() func(*Rows, *error) { return nil } + +// Close closes the Rows, preventing further enumeration. If Next is called +// and returns false and there are no further result sets, +// the Rows are closed automatically and it will suffice to check the +// result of Err. Close is idempotent and does not affect the result of Err. +func (rs *Rows) Close() error { + return rs.close(nil) +} + +func (rs *Rows) close(err error) error { + rs.closemu.Lock() + defer rs.closemu.Unlock() + + if rs.closed { + return nil + } + rs.closed = true + + if rs.lasterr == nil { + rs.lasterr = err + } + + withLock(rs.dc, func() { + err = rs.rowsi.Close() + }) + if fn := rowsCloseHook(); fn != nil { + fn(rs, &err) + } + if rs.cancel != nil { + rs.cancel() + } + + if rs.closeStmt != nil { + rs.closeStmt.Close() + } + rs.releaseConn(err) + + rs.lasterr = rs.lasterrOrErrLocked(err) + return err +} + +// Row is the result of calling QueryRow to select a single row. +type Row struct { + // One of these two will be non-nil: + err error // deferred error for easy chaining + rows *Rows +} + +// Scan copies the columns from the matched row into the values +// pointed at by dest. See the documentation on Rows.Scan for details. +// If more than one row matches the query, +// Scan uses the first row and discards the rest. If no row matches +// the query, Scan returns ErrNoRows. +func (r *Row) Scan(dest ...any) error { + if r.err != nil { + return r.err + } + + // TODO(bradfitz): for now we need to defensively clone all + // []byte that the driver returned (not permitting + // *RawBytes in Rows.Scan), since we're about to close + // the Rows in our defer, when we return from this function. + // the contract with the driver.Next(...) interface is that it + // can return slices into read-only temporary memory that's + // only valid until the next Scan/Close. But the TODO is that + // for a lot of drivers, this copy will be unnecessary. We + // should provide an optional interface for drivers to + // implement to say, "don't worry, the []bytes that I return + // from Next will not be modified again." (for instance, if + // they were obtained from the network anyway) But for now we + // don't care. + defer r.rows.Close() + for _, dp := range dest { + if _, ok := dp.(*RawBytes); ok { + return errors.New("sql: RawBytes isn't allowed on Row.Scan") + } + } + + if !r.rows.Next() { + if err := r.rows.Err(); err != nil { + return err + } + return ErrNoRows + } + err := r.rows.Scan(dest...) + if err != nil { + return err + } + // Make sure the query can be processed to completion with no errors. + return r.rows.Close() +} + +// Err provides a way for wrapping packages to check for +// query errors without calling Scan. +// Err returns the error, if any, that was encountered while running the query. +// If this error is not nil, this error will also be returned from Scan. +func (r *Row) Err() error { + return r.err +} + +// A Result summarizes an executed SQL command. +type Result interface { + // LastInsertId returns the integer generated by the database + // in response to a command. Typically this will be from an + // "auto increment" column when inserting a new row. Not all + // databases support this feature, and the syntax of such + // statements varies. + LastInsertId() (int64, error) + + // RowsAffected returns the number of rows affected by an + // update, insert, or delete. Not every database or database + // driver may support this. + RowsAffected() (int64, error) +} + +type driverResult struct { + sync.Locker // the *driverConn + resi driver.Result +} + +func (dr driverResult) LastInsertId() (int64, error) { + dr.Lock() + defer dr.Unlock() + return dr.resi.LastInsertId() +} + +func (dr driverResult) RowsAffected() (int64, error) { + dr.Lock() + defer dr.Unlock() + return dr.resi.RowsAffected() +} + +func stack() string { + var buf [2 << 10]byte + return string(buf[:runtime.Stack(buf[:], false)]) +} + +// withLock runs while holding lk. +func withLock(lk sync.Locker, fn func()) { + lk.Lock() + defer lk.Unlock() // in case fn panics + fn() +} |