diff options
Diffstat (limited to '')
-rw-r--r-- | src/lib/pgsql/pgsql_connection.cc | 560 |
1 files changed, 560 insertions, 0 deletions
diff --git a/src/lib/pgsql/pgsql_connection.cc b/src/lib/pgsql/pgsql_connection.cc new file mode 100644 index 0000000..6ec896d --- /dev/null +++ b/src/lib/pgsql/pgsql_connection.cc @@ -0,0 +1,560 @@ +// Copyright (C) 2016-2023 Internet Systems Consortium, Inc. ("ISC") +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include <config.h> + +#include <database/db_exceptions.h> +#include <database/db_log.h> +#include <pgsql/pgsql_connection.h> + +// PostgreSQL errors should be tested based on the SQL state code. Each state +// code is 5 decimal, ASCII, digits, the first two define the category of +// error, the last three are the specific error. PostgreSQL makes the state +// code as a char[5]. Macros for each code are defined in PostgreSQL's +// server/utils/errcodes.h, although they require a second macro, +// MAKE_SQLSTATE for completion. For example, duplicate key error as: +// +// #define ERRCODE_UNIQUE_VIOLATION MAKE_SQLSTATE('2','3','5','0','5') +// +// PostgreSQL deliberately omits the MAKE_SQLSTATE macro so callers can/must +// supply their own. We'll define it as an initialization list: +#define MAKE_SQLSTATE(ch1,ch2,ch3,ch4,ch5) {ch1,ch2,ch3,ch4,ch5} +// So we can use it like this: const char some_error[] = ERRCODE_xxxx; +#define PGSQL_STATECODE_LEN 5 +#include <utils/errcodes.h> + +#include <sstream> + +using namespace std; + +namespace isc { +namespace db { + +// Default connection timeout + +/// @todo: migrate this default timeout to src/bin/dhcpX/simple_parserX.cc +const int PGSQL_DEFAULT_CONNECTION_TIMEOUT = 5; // seconds + +const char PgSqlConnection::DUPLICATE_KEY[] = ERRCODE_UNIQUE_VIOLATION; +const char PgSqlConnection::NULL_KEY[] = ERRCODE_NOT_NULL_VIOLATION; + +bool PgSqlConnection::warned_about_tls = false; + +PgSqlResult::PgSqlResult(PGresult *result) + : result_(result), rows_(0), cols_(0) { + if (!result) { + // Certain failures, like a loss of connectivity, can return a + // null PGresult and we still need to be able to create a PgSqlResult. + // We'll set row and col counts to -1 to prevent anyone going off the + // rails. + rows_ = -1; + cols_ = -1; + } else { + rows_ = PQntuples(result); + cols_ = PQnfields(result); + } +} + +void +PgSqlResult::rowCheck(int row) const { + if (row < 0 || row >= rows_) { + isc_throw (db::DbOperationError, "row: " << row + << ", out of range: 0.." << rows_); + } +} + +PgSqlResult::~PgSqlResult() { + if (result_) { + PQclear(result_); + } +} + +void +PgSqlResult::colCheck(int col) const { + if (col < 0 || col >= cols_) { + isc_throw (DbOperationError, "col: " << col + << ", out of range: 0.." << cols_); + } +} + +void +PgSqlResult::rowColCheck(int row, int col) const { + rowCheck(row); + colCheck(col); +} + +std::string +PgSqlResult::getColumnLabel(const int col) const { + const char* label = NULL; + try { + colCheck(col); + label = PQfname(result_, col); + } catch (...) { + std::ostringstream os; + os << "Unknown column:" << col; + return (os.str()); + } + + return (label); +} + +PgSqlTransaction::PgSqlTransaction(PgSqlConnection& conn) + : conn_(conn), committed_(false) { + conn_.startTransaction(); +} + +PgSqlTransaction::~PgSqlTransaction() { + // If commit() wasn't explicitly called, rollback. + if (!committed_) { + conn_.rollback(); + } +} + +void +PgSqlTransaction::commit() { + conn_.commit(); + committed_ = true; +} + +PgSqlConnection::~PgSqlConnection() { + if (conn_) { + // Deallocate the prepared queries. + if (PQstatus(conn_) == CONNECTION_OK) { + PgSqlResult r(PQexec(conn_, "DEALLOCATE all")); + if (PQresultStatus(r) != PGRES_COMMAND_OK) { + // Highly unlikely but we'll log it and go on. + DB_LOG_ERROR(PGSQL_DEALLOC_ERROR) + .arg(PQerrorMessage(conn_)); + } + } + } +} + +std::pair<uint32_t, uint32_t> +PgSqlConnection::getVersion(const ParameterMap& parameters) { + // Get a connection. + PgSqlConnection conn(parameters); + + // Open the database. + conn.openDatabaseInternal(false); + + const char* version_sql = "SELECT version, minor FROM schema_version;"; + PgSqlResult r(PQexec(conn.conn_, version_sql)); + if (PQresultStatus(r) != PGRES_TUPLES_OK) { + isc_throw(DbOperationError, "unable to execute PostgreSQL statement <" + << version_sql << ", reason: " << PQerrorMessage(conn.conn_)); + } + + uint32_t version; + PgSqlExchange::getColumnValue(r, 0, 0, version); + + uint32_t minor; + PgSqlExchange::getColumnValue(r, 0, 1, minor); + + return (make_pair(version, minor)); +} + +void +PgSqlConnection::prepareStatement(const PgSqlTaggedStatement& statement) { + // Prepare all statements queries with all known fields datatype + PgSqlResult r(PQprepare(conn_, statement.name, statement.text, + statement.nbparams, statement.types)); + if (PQresultStatus(r) != PGRES_COMMAND_OK) { + isc_throw(DbOperationError, "unable to prepare PostgreSQL statement: " + << " name: " << statement.name + << ", reason: " << PQerrorMessage(conn_) + << ", text: " << statement.text); + } +} + +void +PgSqlConnection::prepareStatements(const PgSqlTaggedStatement* start_statement, + const PgSqlTaggedStatement* end_statement) { + // Created the PostgreSQL prepared statements. + for (const PgSqlTaggedStatement* tagged_statement = start_statement; + tagged_statement != end_statement; ++tagged_statement) { + prepareStatement(*tagged_statement); + } +} + +std::string +PgSqlConnection::getConnParameters() { + return (getConnParametersInternal(false)); +} + +std::string +PgSqlConnection::getConnParametersInternal(bool logging) { + string dbconnparameters; + string shost = "localhost"; + try { + shost = getParameter("host"); + } catch(...) { + // No host. Fine, we'll use "localhost" + } + + dbconnparameters += "host = '" + shost + "'" ; + + unsigned int port = 0; + try { + setIntParameterValue("port", 0, numeric_limits<uint16_t>::max(), port); + + } catch (const std::exception& ex) { + isc_throw(DbInvalidPort, ex.what()); + } + + // Add port to connection parameters when not default. + if (port > 0) { + std::ostringstream oss; + oss << port; + dbconnparameters += " port = " + oss.str(); + } + + string suser; + try { + suser = getParameter("user"); + dbconnparameters += " user = '" + suser + "'"; + } catch(...) { + // No user. Fine, we'll use NULL + } + + string spassword; + try { + spassword = getParameter("password"); + dbconnparameters += " password = '" + spassword + "'"; + } catch(...) { + // No password. Fine, we'll use NULL + } + + string sname; + try { + sname = getParameter("name"); + dbconnparameters += " dbname = '" + sname + "'"; + } catch(...) { + // No database name. Throw a "NoDatabaseName" exception + isc_throw(NoDatabaseName, "must specify a name for the database"); + } + + unsigned int connect_timeout = PGSQL_DEFAULT_CONNECTION_TIMEOUT; + unsigned int tcp_user_timeout = 0; + try { + // The timeout is only valid if greater than zero, as depending on the + // database, a zero timeout might signify something like "wait + // indefinitely". + setIntParameterValue("connect-timeout", 1, numeric_limits<int>::max(), connect_timeout); + // This timeout value can be 0, meaning that the database client will + // follow a default behavior. Earlier Postgres versions didn't have + // this parameter, so we allow 0 to skip setting them for these + // earlier versions. + setIntParameterValue("tcp-user-timeout", 0, numeric_limits<int>::max(), tcp_user_timeout); + + } catch (const std::exception& ex) { + isc_throw(DbInvalidTimeout, ex.what()); + } + + // Append connection timeout. + std::ostringstream oss; + oss << " connect_timeout = " << connect_timeout; + + if (tcp_user_timeout > 0) { +// tcp_user_timeout parameter is a PostgreSQL 12+ feature. +#ifdef HAVE_PGSQL_TCP_USER_TIMEOUT + oss << " tcp_user_timeout = " << tcp_user_timeout * 1000; + static_cast<void>(logging); +#else + if (logging) { + DB_LOG_WARN(PGSQL_TCP_USER_TIMEOUT_UNSUPPORTED).arg(); + } +#endif + } + dbconnparameters += oss.str(); + + return (dbconnparameters); +} + +void +PgSqlConnection::openDatabase() { + openDatabaseInternal(true); +} + +void +PgSqlConnection::openDatabaseInternal(bool logging) { + std::string dbconnparameters = getConnParametersInternal(logging); + // Connect to Postgres, saving the low level connection pointer + // in the holder object + PGconn* new_conn = PQconnectdb(dbconnparameters.c_str()); + if (!new_conn) { + isc_throw(DbOpenError, "could not allocate connection object"); + } + + if (PQstatus(new_conn) != CONNECTION_OK) { + // If we have a connection object, we have to call finish + // to release it, but grab the error message first. + std::string error_message = PQerrorMessage(new_conn); + PQfinish(new_conn); + isc_throw(DbOpenError, error_message); + } + + // We have a valid connection, so let's save it to our holder + conn_.setConnection(new_conn); +} + +bool +PgSqlConnection::compareError(const PgSqlResult& r, const char* error_state) { + const char* sqlstate = PQresultErrorField(r, PG_DIAG_SQLSTATE); + // PostgreSQL guarantees it will always be 5 characters long + return ((sqlstate != NULL) && + (memcmp(sqlstate, error_state, PGSQL_STATECODE_LEN) == 0)); +} + +void +PgSqlConnection::checkStatementError(const PgSqlResult& r, + PgSqlTaggedStatement& statement) { + int s = PQresultStatus(r); + if (s != PGRES_COMMAND_OK && s != PGRES_TUPLES_OK) { + // We're testing the first two chars of SQLSTATE, as this is the + // error class. Note, there is a severity field, but it can be + // misleadingly returned as fatal. However, a loss of connectivity + // can lead to a NULL sqlstate with a status of PGRES_FATAL_ERROR. + const char* sqlstate = PQresultErrorField(r, PG_DIAG_SQLSTATE); + if ((sqlstate == NULL) || + ((memcmp(sqlstate, "08", 2) == 0) || // Connection Exception + (memcmp(sqlstate, "53", 2) == 0) || // Insufficient resources + (memcmp(sqlstate, "54", 2) == 0) || // Program Limit exceeded + (memcmp(sqlstate, "57", 2) == 0) || // Operator intervention + (memcmp(sqlstate, "58", 2) == 0))) { // System error + DB_LOG_ERROR(PGSQL_FATAL_ERROR) + .arg(statement.name) + .arg(PQerrorMessage(conn_)) + .arg(sqlstate ? sqlstate : "<sqlstate null>"); + + // Mark this connection as no longer usable. + markUnusable(); + + // Start the connection recovery. + startRecoverDbConnection(); + + // We still need to throw so caller can error out of the current + // processing. + isc_throw(DbConnectionUnusable, + "fatal database error or connectivity lost"); + } + + // Failure: check for the special case of duplicate entry. + if (compareError(r, PgSqlConnection::DUPLICATE_KEY)) { + isc_throw(DuplicateEntry, "statement: " << statement.name + << ", reason: " << PQerrorMessage(conn_)); + } + + // Failure: check for the special case of null key violation. + if (compareError(r, PgSqlConnection::NULL_KEY)) { + isc_throw(NullKeyError, "statement: " << statement.name + << ", reason: " << PQerrorMessage(conn_)); + } + + // Apparently it wasn't fatal, so we throw with a helpful message. + const char* error_message = PQerrorMessage(conn_); + isc_throw(DbOperationError, "Statement exec failed for: " + << statement.name << ", status: " << s + << "sqlstate:[ " << (sqlstate ? sqlstate : "<null>") + << " ], reason: " << error_message); + } +} + +void +PgSqlConnection::startTransaction() { + // If it is nested transaction, do nothing. + if (++transaction_ref_count_ > 1) { + return; + } + + DB_LOG_DEBUG(DB_DBG_TRACE_DETAIL, PGSQL_START_TRANSACTION); + checkUnusable(); + PgSqlResult r(PQexec(conn_, "START TRANSACTION")); + if (PQresultStatus(r) != PGRES_COMMAND_OK) { + const char* error_message = PQerrorMessage(conn_); + isc_throw(DbOperationError, "unable to start transaction" + << error_message); + } +} + +bool +PgSqlConnection::isTransactionStarted() const { + return (transaction_ref_count_ > 0); +} + +void +PgSqlConnection::commit() { + if (transaction_ref_count_ <= 0) { + isc_throw(Unexpected, "commit called for not started transaction - coding error"); + } + + // When committing nested transaction, do nothing. + if (--transaction_ref_count_ > 0) { + return; + } + + DB_LOG_DEBUG(DB_DBG_TRACE_DETAIL, PGSQL_COMMIT); + checkUnusable(); + PgSqlResult r(PQexec(conn_, "COMMIT")); + if (PQresultStatus(r) != PGRES_COMMAND_OK) { + const char* error_message = PQerrorMessage(conn_); + isc_throw(DbOperationError, "commit failed: " << error_message); + } +} + +void +PgSqlConnection::rollback() { + if (transaction_ref_count_ <= 0) { + isc_throw(Unexpected, "rollback called for not started transaction - coding error"); + } + + // When rolling back nested transaction, do nothing. + if (--transaction_ref_count_ > 0) { + return; + } + + DB_LOG_DEBUG(DB_DBG_TRACE_DETAIL, PGSQL_ROLLBACK); + checkUnusable(); + PgSqlResult r(PQexec(conn_, "ROLLBACK")); + if (PQresultStatus(r) != PGRES_COMMAND_OK) { + const char* error_message = PQerrorMessage(conn_); + isc_throw(DbOperationError, "rollback failed: " << error_message); + } +} + +void +PgSqlConnection::createSavepoint(const std::string& name) { + if (transaction_ref_count_ <= 0) { + isc_throw(InvalidOperation, "no transaction, cannot create savepoint: " << name); + } + + DB_LOG_DEBUG(DB_DBG_TRACE_DETAIL, PGSQL_CREATE_SAVEPOINT).arg(name); + std::string sql("SAVEPOINT " + name); + executeSQL(sql); +} + +void +PgSqlConnection::rollbackToSavepoint(const std::string& name) { + if (transaction_ref_count_ <= 0) { + isc_throw(InvalidOperation, "no transaction, cannot rollback to savepoint: " << name); + } + + std::string sql("ROLLBACK TO SAVEPOINT " + name); + executeSQL(sql); +} + +void +PgSqlConnection::executeSQL(const std::string& sql) { + // Use a TaggedStatement so we can call checkStatementError and ensure + // we detect connectivity issues properly. + PgSqlTaggedStatement statement({0, {OID_NONE}, "run-statement", sql.c_str()}); + checkUnusable(); + PgSqlResult r(PQexec(conn_, statement.text)); + checkStatementError(r, statement); +} + +PgSqlResultPtr +PgSqlConnection::executePreparedStatement(PgSqlTaggedStatement& statement, + const PsqlBindArray& in_bindings) { + checkUnusable(); + + if (statement.nbparams != in_bindings.size()) { + isc_throw (InvalidOperation, "executePreparedStatement:" + << " expected: " << statement.nbparams + << " parameters, given: " << in_bindings.size() + << ", statement: " << statement.name + << ", SQL: " << statement.text); + } + + const char* const* values = 0; + const int* lengths = 0; + const int* formats = 0; + if (statement.nbparams > 0) { + values = static_cast<const char* const*>(&in_bindings.values_[0]); + lengths = static_cast<const int *>(&in_bindings.lengths_[0]); + formats = static_cast<const int *>(&in_bindings.formats_[0]); + } + + PgSqlResultPtr result_set; + result_set.reset(new PgSqlResult(PQexecPrepared(conn_, statement.name, statement.nbparams, + values, lengths, formats, 0))); + + checkStatementError(*result_set, statement); + return (result_set); +} + +void +PgSqlConnection::selectQuery(PgSqlTaggedStatement& statement, + const PsqlBindArray& in_bindings, + ConsumeResultRowFun process_result_row) { + // Execute the prepared statement. + PgSqlResultPtr result_set = executePreparedStatement(statement, in_bindings); + + // Iterate over the returned rows and invoke the row consumption + // function on each one. + int rows = result_set->getRows(); + for (int row = 0; row < rows; ++row) { + try { + process_result_row(*result_set, row); + } catch (const std::exception& ex) { + // Rethrow the exception with a bit more data. + isc_throw(BadValue, ex.what() << ". Statement is <" << + statement.text << ">"); + } + } +} + +void +PgSqlConnection::insertQuery(PgSqlTaggedStatement& statement, + const PsqlBindArray& in_bindings) { + // Execute the prepared statement. + PgSqlResultPtr result_set = executePreparedStatement(statement, in_bindings); +} + +uint64_t +PgSqlConnection::updateDeleteQuery(PgSqlTaggedStatement& statement, + const PsqlBindArray& in_bindings) { + // Execute the prepared statement. + PgSqlResultPtr result_set = executePreparedStatement(statement, in_bindings); + + return (boost::lexical_cast<int>(PQcmdTuples(*result_set))); +} + +template<typename T> +void +PgSqlConnection::setIntParameterValue(const std::string& name, int64_t min, int64_t max, T& value) { + string svalue; + try { + svalue = getParameter(name); + } catch (...) { + // Do nothing if the parameter is not present. + } + if (svalue.empty()) { + return; + } + try { + // Try to convert the value. + auto parsed_value = boost::lexical_cast<T>(svalue); + // Check if the value is within the specified range. + if ((parsed_value < min) || (parsed_value > max)) { + isc_throw(BadValue, "bad " << svalue << " value"); + } + // Everything is fine. Return the parsed value. + value = parsed_value; + + } catch (...) { + // We may end up here when lexical_cast fails or when the + // parsed value is not within the desired range. In both + // cases let's throw the same general error. + isc_throw(BadValue, name << " parameter (" << + svalue << ") must be an integer between " + << min << " and " << max); + } +} + + +} // end of isc::db namespace +} // end of isc namespace |