diff options
Diffstat (limited to 'pkg/icingadb/db.go')
-rw-r--r-- | pkg/icingadb/db.go | 730 |
1 files changed, 730 insertions, 0 deletions
diff --git a/pkg/icingadb/db.go b/pkg/icingadb/db.go new file mode 100644 index 0000000..a9eed7f --- /dev/null +++ b/pkg/icingadb/db.go @@ -0,0 +1,730 @@ +package icingadb + +import ( + "context" + sqlDriver "database/sql/driver" + "fmt" + "github.com/go-sql-driver/mysql" + "github.com/icinga/icingadb/internal" + "github.com/icinga/icingadb/pkg/backoff" + "github.com/icinga/icingadb/pkg/com" + "github.com/icinga/icingadb/pkg/contracts" + "github.com/icinga/icingadb/pkg/driver" + "github.com/icinga/icingadb/pkg/logging" + "github.com/icinga/icingadb/pkg/periodic" + "github.com/icinga/icingadb/pkg/retry" + "github.com/icinga/icingadb/pkg/utils" + "github.com/jmoiron/sqlx" + "github.com/lib/pq" + "github.com/pkg/errors" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" + "reflect" + "strings" + "sync" + "time" +) + +// DB is a wrapper around sqlx.DB with bulk execution, +// statement building, streaming and logging capabilities. +type DB struct { + *sqlx.DB + + Options *Options + + logger *logging.Logger + tableSemaphores map[string]*semaphore.Weighted + tableSemaphoresMu sync.Mutex +} + +// Options define user configurable database options. +type Options struct { + // Maximum number of open connections to the database. + MaxConnections int `yaml:"max_connections" default:"16"` + + // Maximum number of connections per table, + // regardless of what the connection is actually doing, + // e.g. INSERT, UPDATE, DELETE. + MaxConnectionsPerTable int `yaml:"max_connections_per_table" default:"8"` + + // MaxPlaceholdersPerStatement defines the maximum number of placeholders in an + // INSERT, UPDATE or DELETE statement. Theoretically, MySQL can handle up to 2^16-1 placeholders, + // but this increases the execution time of queries and thus reduces the number of queries + // that can be executed in parallel in a given time. + // The default is 2^13, which in our tests showed the best performance in terms of execution time and parallelism. + MaxPlaceholdersPerStatement int `yaml:"max_placeholders_per_statement" default:"8192"` + + // MaxRowsPerTransaction defines the maximum number of rows per transaction. + // The default is 2^13, which in our tests showed the best performance in terms of execution time and parallelism. + MaxRowsPerTransaction int `yaml:"max_rows_per_transaction" default:"8192"` +} + +// Validate checks constraints in the supplied database options and returns an error if they are violated. +func (o *Options) Validate() error { + if o.MaxConnections == 0 { + return errors.New("max_connections cannot be 0. Configure a value greater than zero, or use -1 for no connection limit") + } + if o.MaxConnectionsPerTable < 1 { + return errors.New("max_connections_per_table must be at least 1") + } + if o.MaxPlaceholdersPerStatement < 1 { + return errors.New("max_placeholders_per_statement must be at least 1") + } + if o.MaxRowsPerTransaction < 1 { + return errors.New("max_rows_per_transaction must be at least 1") + } + + return nil +} + +// NewDb returns a new icingadb.DB wrapper for a pre-existing *sqlx.DB. +func NewDb(db *sqlx.DB, logger *logging.Logger, options *Options) *DB { + return &DB{ + DB: db, + logger: logger, + Options: options, + tableSemaphores: make(map[string]*semaphore.Weighted), + } +} + +const ( + expectedMysqlSchemaVersion = 3 + expectedPostgresSchemaVersion = 1 +) + +// CheckSchema asserts the database schema of the expected version being present. +func (db *DB) CheckSchema(ctx context.Context) error { + var expectedDbSchemaVersion uint16 + switch db.DriverName() { + case driver.MySQL: + expectedDbSchemaVersion = expectedMysqlSchemaVersion + case driver.PostgreSQL: + expectedDbSchemaVersion = expectedPostgresSchemaVersion + } + + var version uint16 + + err := db.QueryRowxContext(ctx, "SELECT version FROM icingadb_schema ORDER BY id DESC LIMIT 1").Scan(&version) + if err != nil { + return errors.Wrap(err, "can't check database schema version") + } + + if version != expectedDbSchemaVersion { + // Since these error messages are trivial and mostly caused by users, we don't need + // to print a stack trace here. However, since errors.Errorf() does this automatically, + // we need to use fmt instead. + return fmt.Errorf( + "unexpected database schema version: v%d (expected v%d), please make sure you have applied all database"+ + " migrations after upgrading Icinga DB", version, expectedDbSchemaVersion, + ) + } + + return nil +} + +// BuildColumns returns all columns of the given struct. +func (db *DB) BuildColumns(subject interface{}) []string { + fields := db.Mapper.TypeMap(reflect.TypeOf(subject)).Names + columns := make([]string, 0, len(fields)) + for _, f := range fields { + if f.Field.Tag == "" { + continue + } + columns = append(columns, f.Name) + } + + return columns +} + +// BuildDeleteStmt returns a DELETE statement for the given struct. +func (db *DB) BuildDeleteStmt(from interface{}) string { + return fmt.Sprintf( + `DELETE FROM "%s" WHERE id IN (?)`, + utils.TableName(from), + ) +} + +// BuildInsertStmt returns an INSERT INTO statement for the given struct. +func (db *DB) BuildInsertStmt(into interface{}) (string, int) { + columns := db.BuildColumns(into) + + return fmt.Sprintf( + `INSERT INTO "%s" ("%s") VALUES (%s)`, + utils.TableName(into), + strings.Join(columns, `", "`), + fmt.Sprintf(":%s", strings.Join(columns, ", :")), + ), len(columns) +} + +// BuildInsertIgnoreStmt returns an INSERT statement for the specified struct for +// which the database ignores rows that have already been inserted. +func (db *DB) BuildInsertIgnoreStmt(into interface{}) (string, int) { + table := utils.TableName(into) + columns := db.BuildColumns(into) + var clause string + + switch db.DriverName() { + case driver.MySQL: + // MySQL treats UPDATE id = id as a no-op. + clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%s" = "%s"`, columns[0], columns[0]) + case driver.PostgreSQL: + clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT pk_%s DO NOTHING", table) + } + + return fmt.Sprintf( + `INSERT INTO "%s" ("%s") VALUES (%s) %s`, + table, + strings.Join(columns, `", "`), + fmt.Sprintf(":%s", strings.Join(columns, ", :")), + clause, + ), len(columns) +} + +// BuildSelectStmt returns a SELECT query that creates the FROM part from the given table struct +// and the column list from the specified columns struct. +func (db *DB) BuildSelectStmt(table interface{}, columns interface{}) string { + q := fmt.Sprintf( + `SELECT "%s" FROM "%s"`, + strings.Join(db.BuildColumns(columns), `", "`), + utils.TableName(table), + ) + + if scoper, ok := table.(contracts.Scoper); ok { + where, _ := db.BuildWhere(scoper.Scope()) + q += ` WHERE ` + where + } + + return q +} + +// BuildUpdateStmt returns an UPDATE statement for the given struct. +func (db *DB) BuildUpdateStmt(update interface{}) (string, int) { + columns := db.BuildColumns(update) + set := make([]string, 0, len(columns)) + + for _, col := range columns { + set = append(set, fmt.Sprintf(`"%s" = :%s`, col, col)) + } + + return fmt.Sprintf( + `UPDATE "%s" SET %s WHERE id = :id`, + utils.TableName(update), + strings.Join(set, ", "), + ), len(columns) + 1 // +1 because of WHERE id = :id +} + +// BuildUpsertStmt returns an upsert statement for the given struct. +func (db *DB) BuildUpsertStmt(subject interface{}) (stmt string, placeholders int) { + insertColumns := db.BuildColumns(subject) + table := utils.TableName(subject) + var updateColumns []string + + if upserter, ok := subject.(contracts.Upserter); ok { + updateColumns = db.BuildColumns(upserter.Upsert()) + } else { + updateColumns = insertColumns + } + + var clause, setFormat string + switch db.DriverName() { + case driver.MySQL: + clause = "ON DUPLICATE KEY UPDATE" + setFormat = `"%[1]s" = VALUES("%[1]s")` + case driver.PostgreSQL: + clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT pk_%s DO UPDATE SET", table) + setFormat = `"%[1]s" = EXCLUDED."%[1]s"` + } + + set := make([]string, 0, len(updateColumns)) + + for _, col := range updateColumns { + set = append(set, fmt.Sprintf(setFormat, col)) + } + + return fmt.Sprintf( + `INSERT INTO "%s" ("%s") VALUES (%s) %s %s`, + table, + strings.Join(insertColumns, `", "`), + fmt.Sprintf(":%s", strings.Join(insertColumns, ",:")), + clause, + strings.Join(set, ","), + ), len(insertColumns) +} + +// BuildWhere returns a WHERE clause with named placeholder conditions built from the specified struct +// combined with the AND operator. +func (db *DB) BuildWhere(subject interface{}) (string, int) { + columns := db.BuildColumns(subject) + where := make([]string, 0, len(columns)) + for _, col := range columns { + where = append(where, fmt.Sprintf(`"%s" = :%s`, col, col)) + } + + return strings.Join(where, ` AND `), len(columns) +} + +// OnSuccess is a callback for successful (bulk) DML operations. +type OnSuccess[T any] func(ctx context.Context, affectedRows []T) (err error) + +func OnSuccessIncrement[T any](counter *com.Counter) OnSuccess[T] { + return func(_ context.Context, rows []T) error { + counter.Add(uint64(len(rows))) + return nil + } +} + +func OnSuccessSendTo[T any](ch chan<- T) OnSuccess[T] { + return func(ctx context.Context, rows []T) error { + for _, row := range rows { + select { + case ch <- row: + case <-ctx.Done(): + return ctx.Err() + } + } + + return nil + } +} + +// BulkExec bulk executes queries with a single slice placeholder in the form of `IN (?)`. +// Takes in up to the number of arguments specified in count from the arg stream, +// derives and expands a query and executes it with this set of arguments until the arg stream has been processed. +// The derived queries are executed in a separate goroutine with a weighting of 1 +// and can be executed concurrently to the extent allowed by the semaphore passed in sem. +// Arguments for which the query ran successfully will be passed to onSuccess. +func (db *DB) BulkExec( + ctx context.Context, query string, count int, sem *semaphore.Weighted, arg <-chan any, onSuccess ...OnSuccess[any], +) error { + var counter com.Counter + defer db.log(ctx, query, &counter).Stop() + + g, ctx := errgroup.WithContext(ctx) + // Use context from group. + bulk := com.Bulk(ctx, arg, count, com.NeverSplit[any]) + + g.Go(func() error { + g, ctx := errgroup.WithContext(ctx) + + for b := range bulk { + if err := sem.Acquire(ctx, 1); err != nil { + return errors.Wrap(err, "can't acquire semaphore") + } + + g.Go(func(b []interface{}) func() error { + return func() error { + defer sem.Release(1) + + return retry.WithBackoff( + ctx, + func(context.Context) error { + stmt, args, err := sqlx.In(query, b) + if err != nil { + return errors.Wrapf(err, "can't build placeholders for %q", query) + } + + stmt = db.Rebind(stmt) + _, err = db.ExecContext(ctx, stmt, args...) + if err != nil { + return internal.CantPerformQuery(err, query) + } + + counter.Add(uint64(len(b))) + + for _, onSuccess := range onSuccess { + if err := onSuccess(ctx, b); err != nil { + return err + } + } + + return nil + }, + IsRetryable, + backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second), + retry.Settings{}, + ) + } + }(b)) + } + + return g.Wait() + }) + + return g.Wait() +} + +// NamedBulkExec bulk executes queries with named placeholders in a VALUES clause most likely +// in the format INSERT ... VALUES. Takes in up to the number of entities specified in count +// from the arg stream, derives and executes a new query with the VALUES clause expanded to +// this set of arguments, until the arg stream has been processed. +// The queries are executed in a separate goroutine with a weighting of 1 +// and can be executed concurrently to the extent allowed by the semaphore passed in sem. +// Entities for which the query ran successfully will be passed to onSuccess. +func (db *DB) NamedBulkExec( + ctx context.Context, query string, count int, sem *semaphore.Weighted, arg <-chan contracts.Entity, + splitPolicyFactory com.BulkChunkSplitPolicyFactory[contracts.Entity], onSuccess ...OnSuccess[contracts.Entity], +) error { + var counter com.Counter + defer db.log(ctx, query, &counter).Stop() + + g, ctx := errgroup.WithContext(ctx) + bulk := com.Bulk(ctx, arg, count, splitPolicyFactory) + + g.Go(func() error { + for { + select { + case b, ok := <-bulk: + if !ok { + return nil + } + + if err := sem.Acquire(ctx, 1); err != nil { + return errors.Wrap(err, "can't acquire semaphore") + } + + g.Go(func(b []contracts.Entity) func() error { + return func() error { + defer sem.Release(1) + + return retry.WithBackoff( + ctx, + func(ctx context.Context) error { + _, err := db.NamedExecContext(ctx, query, b) + if err != nil { + return internal.CantPerformQuery(err, query) + } + + counter.Add(uint64(len(b))) + + for _, onSuccess := range onSuccess { + if err := onSuccess(ctx, b); err != nil { + return err + } + } + + return nil + }, + IsRetryable, + backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second), + retry.Settings{}, + ) + } + }(b)) + case <-ctx.Done(): + return ctx.Err() + } + } + }) + + return g.Wait() +} + +// NamedBulkExecTx bulk executes queries with named placeholders in separate transactions. +// Takes in up to the number of entities specified in count from the arg stream and +// executes a new transaction that runs a new query for each entity in this set of arguments, +// until the arg stream has been processed. +// The transactions are executed in a separate goroutine with a weighting of 1 +// and can be executed concurrently to the extent allowed by the semaphore passed in sem. +func (db *DB) NamedBulkExecTx( + ctx context.Context, query string, count int, sem *semaphore.Weighted, arg <-chan contracts.Entity, +) error { + var counter com.Counter + defer db.log(ctx, query, &counter).Stop() + + g, ctx := errgroup.WithContext(ctx) + bulk := com.Bulk(ctx, arg, count, com.NeverSplit[contracts.Entity]) + + g.Go(func() error { + for { + select { + case b, ok := <-bulk: + if !ok { + return nil + } + + if err := sem.Acquire(ctx, 1); err != nil { + return errors.Wrap(err, "can't acquire semaphore") + } + + g.Go(func(b []contracts.Entity) func() error { + return func() error { + defer sem.Release(1) + + return retry.WithBackoff( + ctx, + func(ctx context.Context) error { + tx, err := db.BeginTxx(ctx, nil) + if err != nil { + return errors.Wrap(err, "can't start transaction") + } + + stmt, err := tx.PrepareNamedContext(ctx, query) + if err != nil { + return errors.Wrap(err, "can't prepare named statement with context in transaction") + } + + for _, arg := range b { + if _, err := stmt.ExecContext(ctx, arg); err != nil { + return errors.Wrap(err, "can't execute statement in transaction") + } + } + + if err := tx.Commit(); err != nil { + return errors.Wrap(err, "can't commit transaction") + } + + counter.Add(uint64(len(b))) + + return nil + }, + IsRetryable, + backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second), + retry.Settings{}, + ) + } + }(b)) + case <-ctx.Done(): + return ctx.Err() + } + } + }) + + return g.Wait() +} + +// BatchSizeByPlaceholders returns how often the specified number of placeholders fits +// into Options.MaxPlaceholdersPerStatement, but at least 1. +func (db *DB) BatchSizeByPlaceholders(n int) int { + s := db.Options.MaxPlaceholdersPerStatement / n + if s > 0 { + return s + } + + return 1 +} + +// YieldAll executes the query with the supplied scope, +// scans each resulting row into an entity returned by the factory function, +// and streams them into a returned channel. +func (db *DB) YieldAll(ctx context.Context, factoryFunc contracts.EntityFactoryFunc, query string, scope interface{}) (<-chan contracts.Entity, <-chan error) { + entities := make(chan contracts.Entity, 1) + g, ctx := errgroup.WithContext(ctx) + + g.Go(func() error { + var counter com.Counter + defer db.log(ctx, query, &counter).Stop() + defer close(entities) + + rows, err := db.NamedQueryContext(ctx, query, scope) + if err != nil { + return internal.CantPerformQuery(err, query) + } + defer rows.Close() + + for rows.Next() { + e := factoryFunc() + + if err := rows.StructScan(e); err != nil { + return errors.Wrapf(err, "can't store query result into a %T: %s", e, query) + } + + select { + case entities <- e: + counter.Inc() + case <-ctx.Done(): + return ctx.Err() + } + } + + return nil + }) + + return entities, com.WaitAsync(g) +} + +// CreateStreamed bulk creates the specified entities via NamedBulkExec. +// The insert statement is created using BuildInsertStmt with the first entity from the entities stream. +// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and +// concurrency is controlled via Options.MaxConnectionsPerTable. +// Entities for which the query ran successfully will be passed to onSuccess. +func (db *DB) CreateStreamed( + ctx context.Context, entities <-chan contracts.Entity, onSuccess ...OnSuccess[contracts.Entity], +) error { + first, forward, err := com.CopyFirst(ctx, entities) + if first == nil { + return errors.Wrap(err, "can't copy first entity") + } + + sem := db.GetSemaphoreForTable(utils.TableName(first)) + stmt, placeholders := db.BuildInsertStmt(first) + + return db.NamedBulkExec( + ctx, stmt, db.BatchSizeByPlaceholders(placeholders), sem, + forward, com.NeverSplit[contracts.Entity], onSuccess..., + ) +} + +// CreateIgnoreStreamed bulk creates the specified entities via NamedBulkExec. +// The insert statement is created using BuildInsertIgnoreStmt with the first entity from the entities stream. +// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and +// concurrency is controlled via Options.MaxConnectionsPerTable. +// Entities for which the query ran successfully will be passed to onSuccess. +func (db *DB) CreateIgnoreStreamed( + ctx context.Context, entities <-chan contracts.Entity, onSuccess ...OnSuccess[contracts.Entity], +) error { + first, forward, err := com.CopyFirst(ctx, entities) + if first == nil { + return errors.Wrap(err, "can't copy first entity") + } + + sem := db.GetSemaphoreForTable(utils.TableName(first)) + stmt, placeholders := db.BuildInsertIgnoreStmt(first) + + return db.NamedBulkExec( + ctx, stmt, db.BatchSizeByPlaceholders(placeholders), sem, + forward, com.SplitOnDupId[contracts.Entity], onSuccess..., + ) +} + +// UpsertStreamed bulk upserts the specified entities via NamedBulkExec. +// The upsert statement is created using BuildUpsertStmt with the first entity from the entities stream. +// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and +// concurrency is controlled via Options.MaxConnectionsPerTable. +// Entities for which the query ran successfully will be passed to onSuccess. +func (db *DB) UpsertStreamed( + ctx context.Context, entities <-chan contracts.Entity, onSuccess ...OnSuccess[contracts.Entity], +) error { + first, forward, err := com.CopyFirst(ctx, entities) + if first == nil { + return errors.Wrap(err, "can't copy first entity") + } + + sem := db.GetSemaphoreForTable(utils.TableName(first)) + stmt, placeholders := db.BuildUpsertStmt(first) + + return db.NamedBulkExec( + ctx, stmt, db.BatchSizeByPlaceholders(placeholders), sem, + forward, com.SplitOnDupId[contracts.Entity], onSuccess..., + ) +} + +// UpdateStreamed bulk updates the specified entities via NamedBulkExecTx. +// The update statement is created using BuildUpdateStmt with the first entity from the entities stream. +// Bulk size is controlled via Options.MaxRowsPerTransaction and +// concurrency is controlled via Options.MaxConnectionsPerTable. +func (db *DB) UpdateStreamed(ctx context.Context, entities <-chan contracts.Entity) error { + first, forward, err := com.CopyFirst(ctx, entities) + if first == nil { + return errors.Wrap(err, "can't copy first entity") + } + sem := db.GetSemaphoreForTable(utils.TableName(first)) + stmt, _ := db.BuildUpdateStmt(first) + + return db.NamedBulkExecTx(ctx, stmt, db.Options.MaxRowsPerTransaction, sem, forward) +} + +// DeleteStreamed bulk deletes the specified ids via BulkExec. +// The delete statement is created using BuildDeleteStmt with the passed entityType. +// Bulk size is controlled via Options.MaxPlaceholdersPerStatement and +// concurrency is controlled via Options.MaxConnectionsPerTable. +// IDs for which the query ran successfully will be passed to onSuccess. +func (db *DB) DeleteStreamed( + ctx context.Context, entityType contracts.Entity, ids <-chan interface{}, onSuccess ...OnSuccess[any], +) error { + sem := db.GetSemaphoreForTable(utils.TableName(entityType)) + return db.BulkExec( + ctx, db.BuildDeleteStmt(entityType), db.Options.MaxPlaceholdersPerStatement, sem, ids, onSuccess..., + ) +} + +// Delete creates a channel from the specified ids and +// bulk deletes them by passing the channel along with the entityType to DeleteStreamed. +// IDs for which the query ran successfully will be passed to onSuccess. +func (db *DB) Delete( + ctx context.Context, entityType contracts.Entity, ids []interface{}, onSuccess ...OnSuccess[any], +) error { + idsCh := make(chan interface{}, len(ids)) + for _, id := range ids { + idsCh <- id + } + close(idsCh) + + return db.DeleteStreamed(ctx, entityType, idsCh, onSuccess...) +} + +func (db *DB) GetSemaphoreForTable(table string) *semaphore.Weighted { + db.tableSemaphoresMu.Lock() + defer db.tableSemaphoresMu.Unlock() + + if sem, ok := db.tableSemaphores[table]; ok { + return sem + } else { + sem = semaphore.NewWeighted(int64(db.Options.MaxConnectionsPerTable)) + db.tableSemaphores[table] = sem + return sem + } +} + +func (db *DB) log(ctx context.Context, query string, counter *com.Counter) periodic.Stopper { + return periodic.Start(ctx, db.logger.Interval(), func(tick periodic.Tick) { + if count := counter.Reset(); count > 0 { + db.logger.Debugf("Executed %q with %d rows", query, count) + } + }, periodic.OnStop(func(tick periodic.Tick) { + db.logger.Debugf("Finished executing %q with %d rows in %s", query, counter.Total(), tick.Elapsed) + })) +} + +// IsRetryable checks whether the given error is retryable. +func IsRetryable(err error) bool { + if errors.Is(err, sqlDriver.ErrBadConn) { + return true + } + + if errors.Is(err, mysql.ErrInvalidConn) { + return true + } + + var e *mysql.MySQLError + if errors.As(err, &e) { + switch e.Number { + case 1053, 1205, 1213, 2006: + // 1053: Server shutdown in progress + // 1205: Lock wait timeout + // 1213: Deadlock found when trying to get lock + // 2006: MySQL server has gone away + return true + default: + return false + } + } + + var pe *pq.Error + if errors.As(err, &pe) { + switch pe.Code { + case "08000", // connection_exception + "08006", // connection_failure + "08001", // sqlclient_unable_to_establish_sqlconnection + "08004", // sqlserver_rejected_establishment_of_sqlconnection + "40001", // serialization_failure + "40P01", // deadlock_detected + "54000", // program_limit_exceeded + "55006", // object_in_use + "55P03", // lock_not_available + "57P01", // admin_shutdown + "57P02", // crash_shutdown + "57P03", // cannot_connect_now + "58000", // system_error + "58030", // io_error + "XX000": // internal_error + return true + default: + if strings.HasPrefix(string(pe.Code), "53") { + // Class 53 - Insufficient Resources + return true + } + } + } + + return false +} |