summaryrefslogtreecommitdiffstats
path: root/src/lib-sql/driver-pgsql.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib-sql/driver-pgsql.c')
-rw-r--r--src/lib-sql/driver-pgsql.c1344
1 files changed, 1344 insertions, 0 deletions
diff --git a/src/lib-sql/driver-pgsql.c b/src/lib-sql/driver-pgsql.c
new file mode 100644
index 0000000..63188c0
--- /dev/null
+++ b/src/lib-sql/driver-pgsql.c
@@ -0,0 +1,1344 @@
+/* Copyright (c) 2004-2018 Dovecot authors, see the included COPYING file */
+
+#include "lib.h"
+#include "array.h"
+#include "ioloop.h"
+#include "hex-binary.h"
+#include "str.h"
+#include "time-util.h"
+#include "sql-api-private.h"
+#include "llist.h"
+
+#ifdef BUILD_PGSQL
+#include <libpq-fe.h>
+
+#define PGSQL_DNS_WARN_MSECS 500
+
+struct pgsql_db {
+ struct sql_db api;
+
+ pool_t pool;
+ char *connect_string;
+ char *host;
+ PGconn *pg;
+
+ struct io *io;
+ struct timeout *to_connect;
+ enum io_condition io_dir;
+
+ struct pgsql_result *pending_results;
+ struct pgsql_result *cur_result;
+ struct ioloop *ioloop, *orig_ioloop;
+ struct sql_result *sync_result;
+
+ bool (*next_callback)(void *);
+ void *next_context;
+
+ char *error;
+ const char *connect_state;
+
+ bool fatal_error:1;
+};
+
+struct pgsql_binary_value {
+ unsigned char *value;
+ size_t size;
+};
+
+struct pgsql_result {
+ struct sql_result api;
+
+ struct pgsql_result *prev, *next;
+
+ PGresult *pgres;
+ struct timeout *to;
+
+ unsigned int rownum, rows;
+ unsigned int fields_count;
+ const char **fields;
+ const char **values;
+ char *query;
+
+ ARRAY(struct pgsql_binary_value) binary_values;
+
+ sql_query_callback_t *callback;
+ void *context;
+
+ bool timeout:1;
+};
+
+struct pgsql_transaction_context {
+ struct sql_transaction_context ctx;
+ int refcount;
+
+ sql_commit_callback_t *callback;
+ void *context;
+
+ pool_t query_pool;
+ const char *error;
+
+ bool failed:1;
+};
+
+extern const struct sql_db driver_pgsql_db;
+extern const struct sql_result driver_pgsql_result;
+
+static void result_finish(struct pgsql_result *result);
+static void
+transaction_update_callback(struct sql_result *result,
+ struct sql_transaction_query *query);
+
+static struct event_category event_category_pgsql = {
+ .parent = &event_category_sql,
+ .name = "pgsql"
+};
+
+static void driver_pgsql_set_state(struct pgsql_db *db, enum sql_db_state state)
+{
+ i_assert(state == SQL_DB_STATE_BUSY || db->cur_result == NULL);
+
+ /* switch back to original ioloop in case the caller wants to
+ add/remove timeouts */
+ if (db->ioloop != NULL)
+ io_loop_set_current(db->orig_ioloop);
+ sql_db_set_state(&db->api, state);
+ if (db->ioloop != NULL)
+ io_loop_set_current(db->ioloop);
+}
+
+static bool driver_pgsql_next_callback(struct pgsql_db *db)
+{
+ bool (*next_callback)(void *) = db->next_callback;
+ void *next_context = db->next_context;
+
+ if (next_callback == NULL)
+ return FALSE;
+
+ db->next_callback = NULL;
+ db->next_context = NULL;
+ return next_callback(next_context);
+}
+
+static void driver_pgsql_stop_io(struct pgsql_db *db)
+{
+ if (db->io != NULL) {
+ io_remove(&db->io);
+ db->io_dir = 0;
+ }
+}
+
+static void driver_pgsql_close(struct pgsql_db *db)
+{
+ db->io_dir = 0;
+ db->fatal_error = FALSE;
+
+ driver_pgsql_stop_io(db);
+
+ PQfinish(db->pg);
+ db->pg = NULL;
+
+ timeout_remove(&db->to_connect);
+
+ driver_pgsql_set_state(db, SQL_DB_STATE_DISCONNECTED);
+
+ if (db->ioloop != NULL) {
+ /* running a sync query, stop it */
+ io_loop_stop(db->ioloop);
+ }
+ driver_pgsql_next_callback(db);
+}
+
+static const char *last_error(struct pgsql_db *db)
+{
+ const char *msg;
+ size_t len;
+
+ msg = PQerrorMessage(db->pg);
+ if (msg == NULL)
+ return "(no error set)";
+
+ /* Error message should contain trailing \n, we don't want it */
+ len = strlen(msg);
+ return len == 0 || msg[len-1] != '\n' ? msg :
+ t_strndup(msg, len-1);
+}
+
+static void connect_callback(struct pgsql_db *db)
+{
+ enum io_condition io_dir = 0;
+ int ret;
+
+ driver_pgsql_stop_io(db);
+
+ while ((ret = PQconnectPoll(db->pg)) == PGRES_POLLING_ACTIVE)
+ ;
+
+ switch (ret) {
+ case PGRES_POLLING_READING:
+ db->connect_state = "wait for input";
+ io_dir = IO_READ;
+ break;
+ case PGRES_POLLING_WRITING:
+ db->connect_state = "wait for output";
+ io_dir = IO_WRITE;
+ break;
+ case PGRES_POLLING_OK:
+ break;
+ case PGRES_POLLING_FAILED:
+ e_error(db->api.event, "Connect failed to database %s: %s (state: %s)",
+ PQdb(db->pg), last_error(db), db->connect_state);
+ driver_pgsql_close(db);
+ return;
+ }
+
+ if (io_dir != 0) {
+ db->io = io_add(PQsocket(db->pg), io_dir, connect_callback, db);
+ db->io_dir = io_dir;
+ }
+
+ if (io_dir == 0) {
+ db->connect_state = "connected";
+ timeout_remove(&db->to_connect);
+ if (PQserverVersion(db->pg) >= 90500) {
+ /* v9.5+ */
+ db->api.flags |= SQL_DB_FLAG_ON_CONFLICT_DO;
+ }
+ driver_pgsql_set_state(db, SQL_DB_STATE_IDLE);
+ if (db->ioloop != NULL) {
+ /* driver_pgsql_sync_init() waiting for connection to
+ finish */
+ io_loop_stop(db->ioloop);
+ }
+ }
+}
+
+static void driver_pgsql_connect_timeout(struct pgsql_db *db)
+{
+ unsigned int secs = ioloop_time - db->api.last_connect_try;
+
+ e_error(db->api.event, "Connect failed: Timeout after %u seconds (state: %s)",
+ secs, db->connect_state);
+ driver_pgsql_close(db);
+}
+
+static int driver_pgsql_connect(struct sql_db *_db)
+{
+ struct pgsql_db *db = (struct pgsql_db *)_db;
+ struct timeval tv_start;
+ int msecs;
+
+ i_assert(db->api.state == SQL_DB_STATE_DISCONNECTED);
+
+ io_loop_time_refresh();
+ tv_start = ioloop_timeval;
+
+ db->pg = PQconnectStart(db->connect_string);
+ if (db->pg == NULL) {
+ i_fatal("pgsql: PQconnectStart() failed (out of memory)");
+ }
+
+ if (PQstatus(db->pg) == CONNECTION_BAD) {
+ e_error(_db->event, "Connect failed to database %s: %s",
+ PQdb(db->pg), last_error(db));
+ driver_pgsql_close(db);
+ return -1;
+ }
+ /* PQconnectStart() blocks on host name resolving. Log a warning if
+ it takes too long. Also don't include time spent on that in the
+ connect timeout (by refreshing ioloop time). */
+ io_loop_time_refresh();
+ msecs = timeval_diff_msecs(&ioloop_timeval, &tv_start);
+ if (msecs > PGSQL_DNS_WARN_MSECS) {
+ e_warning(_db->event, "DNS lookup took %d.%03d s",
+ msecs/1000, msecs % 1000);
+ }
+
+ /* nonblocking connecting begins. */
+ if (PQsetnonblocking(db->pg, 1) < 0)
+ e_error(_db->event, "PQsetnonblocking() failed");
+ i_assert(db->to_connect == NULL);
+ db->to_connect = timeout_add(SQL_CONNECT_TIMEOUT_SECS * 1000,
+ driver_pgsql_connect_timeout, db);
+ db->connect_state = "connecting";
+ db->io = io_add(PQsocket(db->pg), IO_WRITE, connect_callback, db);
+ db->io_dir = IO_WRITE;
+ driver_pgsql_set_state(db, SQL_DB_STATE_CONNECTING);
+ return 0;
+}
+
+static void driver_pgsql_disconnect(struct sql_db *_db)
+{
+ struct pgsql_db *db = (struct pgsql_db *)_db;
+
+ if (db->cur_result != NULL && db->cur_result->to != NULL) {
+ driver_pgsql_stop_io(db);
+ result_finish(db->cur_result);
+ }
+
+ _db->no_reconnect = TRUE;
+ driver_pgsql_close(db);
+ _db->no_reconnect = FALSE;
+}
+
+static void driver_pgsql_free(struct pgsql_db **_db)
+{
+ struct pgsql_db *db = *_db;
+ *_db = NULL;
+
+ event_unref(&db->api.event);
+ i_free(db->connect_string);
+ i_free(db->host);
+ i_free(db->error);
+ array_free(&db->api.module_contexts);
+ i_free(db);
+}
+
+static enum sql_db_flags driver_pgsql_get_flags(struct sql_db *db)
+{
+ switch (db->state) {
+ case SQL_DB_STATE_DISCONNECTED:
+ if (sql_connect(db) < 0)
+ break;
+ /* fall through */
+ case SQL_DB_STATE_CONNECTING:
+ /* Wait for connection to finish, so we can get the flags
+ reliably. */
+ sql_wait(db);
+ break;
+ case SQL_DB_STATE_IDLE:
+ case SQL_DB_STATE_BUSY:
+ break;
+ }
+ return db->flags;
+}
+
+static int driver_pgsql_init_full_v(const struct sql_settings *set,
+ struct sql_db **db_r, const char **error_r ATTR_UNUSED)
+{
+ struct pgsql_db *db;
+
+ db = i_new(struct pgsql_db, 1);
+ db->connect_string = i_strdup(set->connect_string);
+ db->api = driver_pgsql_db;
+ db->api.event = event_create(set->event_parent);
+ event_add_category(db->api.event, &event_category_pgsql);
+
+ /* NOTE: Connection string will be parsed by pgsql itself
+ We only pick the host part here */
+ T_BEGIN {
+ const char *const *arg = t_strsplit(db->connect_string, " ");
+
+ for (; *arg != NULL; arg++) {
+ if (str_begins(*arg, "host="))
+ db->host = i_strdup(*arg + 5);
+
+ }
+ } T_END;
+
+ event_set_append_log_prefix(db->api.event, t_strdup_printf("pgsql(%s): ", db->host));
+
+ *db_r = &db->api;
+ return 0;
+}
+
+static void driver_pgsql_deinit_v(struct sql_db *_db)
+{
+ struct pgsql_db *db = (struct pgsql_db *)_db;
+
+ driver_pgsql_disconnect(_db);
+ driver_pgsql_free(&db);
+}
+
+static void driver_pgsql_set_idle(struct pgsql_db *db)
+{
+ i_assert(db->api.state == SQL_DB_STATE_BUSY);
+
+ if (db->fatal_error)
+ driver_pgsql_close(db);
+ else if (!driver_pgsql_next_callback(db))
+ driver_pgsql_set_state(db, SQL_DB_STATE_IDLE);
+}
+
+static void consume_results(struct pgsql_db *db)
+{
+ PGresult *pgres;
+
+ driver_pgsql_stop_io(db);
+
+ while (PQconsumeInput(db->pg) != 0) {
+ if (PQisBusy(db->pg) != 0) {
+ db->io = io_add(PQsocket(db->pg), IO_READ,
+ consume_results, db);
+ db->io_dir = IO_READ;
+ return;
+ }
+
+ pgres = PQgetResult(db->pg);
+ if (pgres == NULL)
+ break;
+ PQclear(pgres);
+ }
+
+ if (PQstatus(db->pg) == CONNECTION_BAD)
+ driver_pgsql_close(db);
+ else
+ driver_pgsql_set_idle(db);
+}
+
+static void driver_pgsql_result_free(struct sql_result *_result)
+{
+ struct pgsql_db *db = (struct pgsql_db *)_result->db;
+ struct pgsql_result *result = (struct pgsql_result *)_result;
+ bool success;
+
+ i_assert(!result->api.callback);
+ i_assert(db->cur_result == result);
+ i_assert(result->callback == NULL);
+
+ if (_result == db->sync_result)
+ db->sync_result = NULL;
+ db->cur_result = NULL;
+
+ success = result->pgres != NULL && !db->fatal_error;
+ if (result->pgres != NULL) {
+ PQclear(result->pgres);
+ result->pgres = NULL;
+ }
+
+ if (success) {
+ /* we'll have to read the rest of the results as well */
+ i_assert(db->io == NULL);
+ consume_results(db);
+ } else {
+ driver_pgsql_set_idle(db);
+ }
+
+ if (array_is_created(&result->binary_values)) {
+ struct pgsql_binary_value *value;
+
+ array_foreach_modifiable(&result->binary_values, value)
+ PQfreemem(value->value);
+ array_free(&result->binary_values);
+ }
+
+ event_unref(&result->api.event);
+ i_free(result->query);
+ i_free(result->fields);
+ i_free(result->values);
+ i_free(result);
+}
+
+static void result_finish(struct pgsql_result *result)
+{
+ struct pgsql_db *db = (struct pgsql_db *)result->api.db;
+ bool free_result = TRUE;
+ int duration;
+
+ i_assert(db->io == NULL);
+ timeout_remove(&result->to);
+ DLLIST_REMOVE(&db->pending_results, result);
+
+ /* if connection to server was lost, we don't yet see that the
+ connection is bad. we only see the fatal error, so assume it also
+ means disconnection. */
+ if (PQstatus(db->pg) == CONNECTION_BAD || result->pgres == NULL ||
+ PQresultStatus(result->pgres) == PGRES_FATAL_ERROR)
+ db->fatal_error = TRUE;
+
+ if (db->fatal_error) {
+ result->api.failed = TRUE;
+ result->api.failed_try_retry = TRUE;
+ }
+
+ /* emit event */
+ if (result->api.failed) {
+ const char *error = result->timeout ? "Timed out" : last_error(db);
+ struct event_passthrough *e =
+ sql_query_finished_event(&db->api, result->api.event,
+ result->query, TRUE, &duration);
+ e->add_str("error", error);
+ e_debug(e->event(), SQL_QUERY_FINISHED_FMT": %s", result->query,
+ duration, error);
+ } else {
+ e_debug(sql_query_finished_event(&db->api, result->api.event,
+ result->query, FALSE, &duration)->
+ event(),
+ SQL_QUERY_FINISHED_FMT, result->query, duration);
+ }
+ result->api.callback = TRUE;
+ T_BEGIN {
+ if (result->callback != NULL)
+ result->callback(&result->api, result->context);
+ } T_END;
+ result->api.callback = FALSE;
+
+ free_result = db->sync_result != &result->api;
+ if (db->ioloop != NULL)
+ io_loop_stop(db->ioloop);
+
+ i_assert(!free_result || result->api.refcount > 0);
+ result->callback = NULL;
+ if (free_result)
+ sql_result_unref(&result->api);
+}
+
+static void get_result(struct pgsql_result *result)
+{
+ struct pgsql_db *db = (struct pgsql_db *)result->api.db;
+
+ driver_pgsql_stop_io(db);
+
+ if (PQconsumeInput(db->pg) == 0) {
+ result_finish(result);
+ return;
+ }
+
+ if (PQisBusy(db->pg) != 0) {
+ db->io = io_add(PQsocket(db->pg), IO_READ,
+ get_result, result);
+ db->io_dir = IO_READ;
+ return;
+ }
+
+ result->pgres = PQgetResult(db->pg);
+ result_finish(result);
+}
+
+static void flush_callback(struct pgsql_result *result)
+{
+ struct pgsql_db *db = (struct pgsql_db *)result->api.db;
+ int ret;
+
+ driver_pgsql_stop_io(db);
+
+ ret = PQflush(db->pg);
+ if (ret > 0) {
+ db->io = io_add(PQsocket(db->pg), IO_WRITE,
+ flush_callback, result);
+ db->io_dir = IO_WRITE;
+ return;
+ }
+
+ if (ret < 0) {
+ result_finish(result);
+ } else {
+ /* all flushed */
+ get_result(result);
+ }
+}
+
+static void query_timeout(struct pgsql_result *result)
+{
+ struct pgsql_db *db = (struct pgsql_db *)result->api.db;
+
+ driver_pgsql_stop_io(db);
+
+ result->timeout = TRUE;
+ result_finish(result);
+}
+
+static void do_query(struct pgsql_result *result, const char *query)
+{
+ struct pgsql_db *db = (struct pgsql_db *)result->api.db;
+ int ret;
+
+ i_assert(SQL_DB_IS_READY(&db->api));
+ i_assert(db->cur_result == NULL);
+ i_assert(db->io == NULL);
+
+ driver_pgsql_set_state(db, SQL_DB_STATE_BUSY);
+ db->cur_result = result;
+ DLLIST_PREPEND(&db->pending_results, result);
+ result->to = timeout_add(SQL_QUERY_TIMEOUT_SECS * 1000,
+ query_timeout, result);
+ result->query = i_strdup(query);
+
+ if (PQsendQuery(db->pg, query) == 0 ||
+ (ret = PQflush(db->pg)) < 0) {
+ /* failed to send query */
+ result_finish(result);
+ return;
+ }
+
+ if (ret > 0) {
+ /* write blocks */
+ db->io = io_add(PQsocket(db->pg), IO_WRITE,
+ flush_callback, result);
+ db->io_dir = IO_WRITE;
+ } else {
+ get_result(result);
+ }
+}
+
+static const char *
+driver_pgsql_escape_string(struct sql_db *_db, const char *string)
+{
+ struct pgsql_db *db = (struct pgsql_db *)_db;
+ size_t len = strlen(string);
+ char *to;
+
+#ifdef HAVE_PQESCAPE_STRING_CONN
+ if (db->api.state == SQL_DB_STATE_DISCONNECTED) {
+ /* try connecting again */
+ (void)sql_connect(&db->api);
+ }
+ if (db->api.state != SQL_DB_STATE_DISCONNECTED) {
+ int error;
+
+ to = t_buffer_get(len * 2 + 1);
+ len = PQescapeStringConn(db->pg, to, string, len, &error);
+ } else
+#endif
+ {
+ to = t_buffer_get(len * 2 + 1);
+ len = PQescapeString(to, string, len);
+ }
+ t_buffer_alloc(len + 1);
+ return to;
+}
+
+static void exec_callback(struct sql_result *_result,
+ void *context ATTR_UNUSED)
+{
+ struct pgsql_result *result = (struct pgsql_result*)_result;
+ result_finish(result);
+}
+
+static void driver_pgsql_exec(struct sql_db *db, const char *query)
+{
+ struct pgsql_result *result;
+
+ result = i_new(struct pgsql_result, 1);
+ result->api = driver_pgsql_result;
+ result->api.db = db;
+ result->api.refcount = 1;
+ result->api.event = event_create(db->event);
+ result->callback = exec_callback;
+ do_query(result, query);
+}
+
+static void driver_pgsql_query(struct sql_db *db, const char *query,
+ sql_query_callback_t *callback, void *context)
+{
+ struct pgsql_result *result;
+
+ result = i_new(struct pgsql_result, 1);
+ result->api = driver_pgsql_result;
+ result->api.db = db;
+ result->api.refcount = 1;
+ result->api.event = event_create(db->event);
+ result->callback = callback;
+ result->context = context;
+ do_query(result, query);
+}
+
+static void pgsql_query_s_callback(struct sql_result *result, void *context)
+{
+ struct pgsql_db *db = context;
+
+ db->sync_result = result;
+}
+
+static void driver_pgsql_sync_init(struct pgsql_db *db)
+{
+ bool add_to_connect;
+
+ db->orig_ioloop = current_ioloop;
+ if (db->io == NULL) {
+ db->ioloop = io_loop_create();
+ return;
+ }
+
+ i_assert(db->api.state == SQL_DB_STATE_CONNECTING);
+
+ /* have to move our existing I/O and timeout handlers to new I/O loop */
+ io_remove(&db->io);
+
+ add_to_connect = (db->to_connect != NULL);
+ timeout_remove(&db->to_connect);
+
+ db->ioloop = io_loop_create();
+ if (add_to_connect) {
+ db->to_connect = timeout_add(SQL_CONNECT_TIMEOUT_SECS * 1000,
+ driver_pgsql_connect_timeout, db);
+ }
+ db->io = io_add(PQsocket(db->pg), db->io_dir, connect_callback, db);
+ /* wait for connecting to finish */
+ io_loop_run(db->ioloop);
+}
+
+static void driver_pgsql_sync_deinit(struct pgsql_db *db)
+{
+ io_loop_destroy(&db->ioloop);
+}
+
+static struct sql_result *
+driver_pgsql_sync_query(struct pgsql_db *db, const char *query)
+{
+ struct sql_result *result;
+
+ i_assert(db->sync_result == NULL);
+
+ switch (db->api.state) {
+ case SQL_DB_STATE_CONNECTING:
+ case SQL_DB_STATE_BUSY:
+ i_unreached();
+ case SQL_DB_STATE_DISCONNECTED:
+ sql_not_connected_result.refcount++;
+ return &sql_not_connected_result;
+ case SQL_DB_STATE_IDLE:
+ break;
+ }
+
+ driver_pgsql_query(&db->api, query, pgsql_query_s_callback, db);
+ if (db->sync_result == NULL)
+ io_loop_run(db->ioloop);
+
+ i_assert(db->io == NULL);
+
+ result = db->sync_result;
+ if (result == &sql_not_connected_result) {
+ /* we don't end up in pgsql's free function, so sync_result
+ won't be set to NULL if we don't do it here. */
+ db->sync_result = NULL;
+ } else if (result == NULL) {
+ result = &sql_not_connected_result;
+ result->refcount++;
+ }
+
+ i_assert(db->io == NULL);
+ return result;
+}
+
+static struct sql_result *
+driver_pgsql_query_s(struct sql_db *_db, const char *query)
+{
+ struct pgsql_db *db = (struct pgsql_db *)_db;
+ struct sql_result *result;
+
+ driver_pgsql_sync_init(db);
+ result = driver_pgsql_sync_query(db, query);
+ driver_pgsql_sync_deinit(db);
+ return result;
+}
+
+static int driver_pgsql_result_next_row(struct sql_result *_result)
+{
+ struct pgsql_result *result = (struct pgsql_result *)_result;
+ struct pgsql_db *db = (struct pgsql_db *)_result->db;
+
+ if (result->rows != 0) {
+ /* second time we're here */
+ if (++result->rownum < result->rows)
+ return 1;
+
+ /* end of this packet. see if there's more. FIXME: this may
+ block, but the current API doesn't provide a non-blocking
+ way to do this.. */
+ PQclear(result->pgres);
+ result->pgres = PQgetResult(db->pg);
+ if (result->pgres == NULL)
+ return 0;
+ }
+
+ if (result->pgres == NULL) {
+ _result->failed = TRUE;
+ return -1;
+ }
+
+ switch (PQresultStatus(result->pgres)) {
+ case PGRES_COMMAND_OK:
+ /* no rows returned */
+ return 0;
+ case PGRES_TUPLES_OK:
+ result->rows = PQntuples(result->pgres);
+ return result->rows > 0 ? 1 : 0;
+ case PGRES_EMPTY_QUERY:
+ case PGRES_NONFATAL_ERROR:
+ /* nonfatal error */
+ _result->failed = TRUE;
+ return -1;
+ default:
+ /* treat as fatal error */
+ _result->failed = TRUE;
+ db->fatal_error = TRUE;
+ return -1;
+ }
+}
+
+static void driver_pgsql_result_fetch_fields(struct pgsql_result *result)
+{
+ unsigned int i;
+
+ if (result->fields != NULL)
+ return;
+
+ /* @UNSAFE */
+ result->fields_count = PQnfields(result->pgres);
+ result->fields = i_new(const char *, result->fields_count);
+ for (i = 0; i < result->fields_count; i++)
+ result->fields[i] = PQfname(result->pgres, i);
+}
+
+static unsigned int
+driver_pgsql_result_get_fields_count(struct sql_result *_result)
+{
+ struct pgsql_result *result = (struct pgsql_result *)_result;
+
+ driver_pgsql_result_fetch_fields(result);
+ return result->fields_count;
+}
+
+static const char *
+driver_pgsql_result_get_field_name(struct sql_result *_result, unsigned int idx)
+{
+ struct pgsql_result *result = (struct pgsql_result *)_result;
+
+ driver_pgsql_result_fetch_fields(result);
+ i_assert(idx < result->fields_count);
+ return result->fields[idx];
+}
+
+static int driver_pgsql_result_find_field(struct sql_result *_result,
+ const char *field_name)
+{
+ struct pgsql_result *result = (struct pgsql_result *)_result;
+ unsigned int i;
+
+ driver_pgsql_result_fetch_fields(result);
+ for (i = 0; i < result->fields_count; i++) {
+ if (strcmp(result->fields[i], field_name) == 0)
+ return i;
+ }
+ return -1;
+}
+
+static const char *
+driver_pgsql_result_get_field_value(struct sql_result *_result,
+ unsigned int idx)
+{
+ struct pgsql_result *result = (struct pgsql_result *)_result;
+
+ if (PQgetisnull(result->pgres, result->rownum, idx) != 0)
+ return NULL;
+
+ return PQgetvalue(result->pgres, result->rownum, idx);
+}
+
+static const unsigned char *
+driver_pgsql_result_get_field_value_binary(struct sql_result *_result,
+ unsigned int idx, size_t *size_r)
+{
+ struct pgsql_result *result = (struct pgsql_result *)_result;
+ const char *value;
+ struct pgsql_binary_value *binary_value;
+
+ if (PQgetisnull(result->pgres, result->rownum, idx) != 0) {
+ *size_r = 0;
+ return NULL;
+ }
+
+ value = PQgetvalue(result->pgres, result->rownum, idx);
+
+ if (!array_is_created(&result->binary_values))
+ i_array_init(&result->binary_values, idx + 1);
+
+ binary_value = array_idx_get_space(&result->binary_values, idx);
+ if (binary_value->value == NULL) {
+ binary_value->value =
+ PQunescapeBytea((const unsigned char *)value,
+ &binary_value->size);
+ }
+
+ *size_r = binary_value->size;
+ return binary_value->value;
+}
+
+static const char *
+driver_pgsql_result_find_field_value(struct sql_result *result,
+ const char *field_name)
+{
+ int idx;
+
+ idx = driver_pgsql_result_find_field(result, field_name);
+ if (idx < 0)
+ return NULL;
+ return driver_pgsql_result_get_field_value(result, idx);
+}
+
+static const char *const *
+driver_pgsql_result_get_values(struct sql_result *_result)
+{
+ struct pgsql_result *result = (struct pgsql_result *)_result;
+ unsigned int i;
+
+ if (result->values == NULL) {
+ driver_pgsql_result_fetch_fields(result);
+ result->values = i_new(const char *, result->fields_count);
+ }
+
+ /* @UNSAFE */
+ for (i = 0; i < result->fields_count; i++) {
+ result->values[i] =
+ driver_pgsql_result_get_field_value(_result, i);
+ }
+
+ return result->values;
+}
+
+static const char *driver_pgsql_result_get_error(struct sql_result *_result)
+{
+ struct pgsql_result *result = (struct pgsql_result *)_result;
+ struct pgsql_db *db = (struct pgsql_db *)_result->db;
+ const char *msg;
+ size_t len;
+
+ i_free_and_null(db->error);
+
+ if (result->timeout) {
+ db->error = i_strdup("Query timed out");
+ } else if (result->pgres == NULL) {
+ /* connection error */
+ db->error = i_strdup(last_error(db));
+ } else {
+ msg = PQresultErrorMessage(result->pgres);
+ if (msg == NULL)
+ return "(no error set)";
+
+ /* Error message should contain trailing \n, we don't want it */
+ len = strlen(msg);
+ db->error = len == 0 || msg[len-1] != '\n' ?
+ i_strdup(msg) : i_strndup(msg, len-1);
+ }
+ return db->error;
+}
+
+static struct sql_transaction_context *
+driver_pgsql_transaction_begin(struct sql_db *db)
+{
+ struct pgsql_transaction_context *ctx;
+
+ ctx = i_new(struct pgsql_transaction_context, 1);
+ ctx->ctx.db = db;
+ ctx->ctx.event = event_create(db->event);
+ /* we need to be able to handle multiple open transactions, so at least
+ for now just keep them in memory until commit time. */
+ ctx->query_pool = pool_alloconly_create("pgsql transaction", 1024);
+ return &ctx->ctx;
+}
+
+static void
+driver_pgsql_transaction_free(struct pgsql_transaction_context *ctx)
+{
+ pool_unref(&ctx->query_pool);
+ event_unref(&ctx->ctx.event);
+ i_free(ctx);
+}
+
+static void
+transaction_commit_callback(struct sql_result *result,
+ struct pgsql_transaction_context *ctx)
+{
+ struct sql_commit_result commit_result;
+
+ i_zero(&commit_result);
+ if (sql_result_next_row(result) < 0) {
+ commit_result.error = sql_result_get_error(result);
+ commit_result.error_type = sql_result_get_error_type(result);
+ }
+ ctx->callback(&commit_result, ctx->context);
+ driver_pgsql_transaction_free(ctx);
+}
+
+static bool transaction_send_next(void *context)
+{
+ struct pgsql_transaction_context *ctx = context;
+
+ i_assert(!ctx->failed);
+
+ if (ctx->ctx.db->state == SQL_DB_STATE_BUSY) {
+ /* kludgy.. */
+ ctx->ctx.db->state = SQL_DB_STATE_IDLE;
+ } else if (!SQL_DB_IS_READY(ctx->ctx.db)) {
+ struct sql_commit_result commit_result = {
+ .error = "Not connected"
+ };
+ ctx->callback(&commit_result, ctx->context);
+ return FALSE;
+ }
+
+ if (ctx->ctx.head != NULL) {
+ struct sql_transaction_query *query = ctx->ctx.head;
+
+ ctx->ctx.head = ctx->ctx.head->next;
+ sql_query(ctx->ctx.db, query->query,
+ transaction_update_callback, query);
+ } else {
+ sql_query(ctx->ctx.db, "COMMIT",
+ transaction_commit_callback, ctx);
+ }
+ return TRUE;
+}
+
+static void
+transaction_commit_error_callback(struct pgsql_transaction_context *ctx,
+ struct sql_result *result)
+{
+ struct sql_commit_result commit_result;
+
+ i_zero(&commit_result);
+ commit_result.error = sql_result_get_error(result);
+ commit_result.error_type = sql_result_get_error_type(result);
+ e_debug(sql_transaction_finished_event(&ctx->ctx)->
+ add_str("error", commit_result.error)->event(),
+ "Transaction failed: %s", commit_result.error);
+ ctx->callback(&commit_result, ctx->context);
+}
+
+static void
+transaction_begin_callback(struct sql_result *result,
+ struct pgsql_transaction_context *ctx)
+{
+ struct pgsql_db *db = (struct pgsql_db *)result->db;
+
+ i_assert(result->db == ctx->ctx.db);
+
+ if (sql_result_next_row(result) < 0) {
+ transaction_commit_error_callback(ctx, result);
+ driver_pgsql_transaction_free(ctx);
+ return;
+ }
+ i_assert(db->next_callback == NULL);
+ db->next_callback = transaction_send_next;
+ db->next_context = ctx;
+}
+
+static void
+transaction_update_callback(struct sql_result *result,
+ struct sql_transaction_query *query)
+{
+ struct pgsql_transaction_context *ctx =
+ (struct pgsql_transaction_context *)query->trans;
+ struct pgsql_db *db = (struct pgsql_db *)result->db;
+
+ if (sql_result_next_row(result) < 0) {
+ transaction_commit_error_callback(ctx, result);
+ driver_pgsql_transaction_free(ctx);
+ return;
+ }
+
+ if (query->affected_rows != NULL) {
+ struct pgsql_result *pg_result = (struct pgsql_result *)result;
+
+ if (str_to_uint(PQcmdTuples(pg_result->pgres),
+ query->affected_rows) < 0)
+ i_unreached();
+ }
+ i_assert(db->next_callback == NULL);
+ db->next_callback = transaction_send_next;
+ db->next_context = ctx;
+}
+
+static void
+transaction_trans_query_callback(struct sql_result *result,
+ struct sql_transaction_query *query)
+{
+ struct pgsql_transaction_context *ctx =
+ (struct pgsql_transaction_context *)query->trans;
+ struct sql_commit_result commit_result;
+
+ if (sql_result_next_row(result) < 0) {
+ transaction_commit_error_callback(ctx, result);
+ driver_pgsql_transaction_free(ctx);
+ return;
+ }
+
+ if (query->affected_rows != NULL) {
+ struct pgsql_result *pg_result = (struct pgsql_result *)result;
+
+ if (str_to_uint(PQcmdTuples(pg_result->pgres),
+ query->affected_rows) < 0)
+ i_unreached();
+ }
+ e_debug(sql_transaction_finished_event(&ctx->ctx)->event(),
+ "Transaction committed");
+ i_zero(&commit_result);
+ ctx->callback(&commit_result, ctx->context);
+ driver_pgsql_transaction_free(ctx);
+}
+
+static void
+driver_pgsql_transaction_commit(struct sql_transaction_context *_ctx,
+ sql_commit_callback_t *callback, void *context)
+{
+ struct pgsql_transaction_context *ctx =
+ (struct pgsql_transaction_context *)_ctx;
+ struct sql_commit_result result;
+
+ i_zero(&result);
+ ctx->callback = callback;
+ ctx->context = context;
+
+ if (ctx->failed || _ctx->head == NULL) {
+ if (ctx->failed) {
+ result.error = ctx->error;
+ e_debug(sql_transaction_finished_event(_ctx)->
+ add_str("error", ctx->error)->event(),
+ "Transaction failed: %s", ctx->error);
+ } else {
+ e_debug(sql_transaction_finished_event(_ctx)->event(),
+ "Transaction committed");
+ }
+ callback(&result, context);
+ driver_pgsql_transaction_free(ctx);
+ } else if (_ctx->head->next == NULL) {
+ /* just a single query, send it */
+ sql_query(_ctx->db, _ctx->head->query,
+ transaction_trans_query_callback, _ctx->head);
+ } else {
+ /* multiple queries, use a transaction */
+ i_assert(_ctx->db->v.query == driver_pgsql_query);
+ sql_query(_ctx->db, "BEGIN", transaction_begin_callback, ctx);
+ }
+}
+
+static void
+commit_multi_fail(struct pgsql_transaction_context *ctx,
+ struct sql_result *result, const char *query)
+{
+ ctx->failed = TRUE;
+ ctx->error = t_strdup_printf("%s (query: %s)",
+ sql_result_get_error(result), query);
+ sql_result_unref(result);
+}
+
+static struct sql_result *
+driver_pgsql_transaction_commit_multi(struct pgsql_transaction_context *ctx)
+{
+ struct pgsql_db *db = (struct pgsql_db *)ctx->ctx.db;
+ struct sql_result *result;
+ struct sql_transaction_query *query;
+
+ result = driver_pgsql_sync_query(db, "BEGIN");
+ if (sql_result_next_row(result) < 0) {
+ commit_multi_fail(ctx, result, "BEGIN");
+ return NULL;
+ }
+ sql_result_unref(result);
+
+ /* send queries */
+ for (query = ctx->ctx.head; query != NULL; query = query->next) {
+ result = driver_pgsql_sync_query(db, query->query);
+ if (sql_result_next_row(result) < 0) {
+ commit_multi_fail(ctx, result, query->query);
+ break;
+ }
+ if (query->affected_rows != NULL) {
+ struct pgsql_result *pg_result =
+ (struct pgsql_result *)result;
+
+ if (str_to_uint(PQcmdTuples(pg_result->pgres),
+ query->affected_rows) < 0)
+ i_unreached();
+ }
+ sql_result_unref(result);
+ }
+
+ return driver_pgsql_sync_query(db, ctx->failed ?
+ "ROLLBACK" : "COMMIT");
+}
+
+static void
+driver_pgsql_try_commit_s(struct pgsql_transaction_context *ctx,
+ const char **error_r)
+{
+ struct sql_transaction_context *_ctx = &ctx->ctx;
+ struct pgsql_db *db = (struct pgsql_db *)_ctx->db;
+ struct sql_transaction_query *single_query = NULL;
+ struct sql_result *result;
+
+ if (_ctx->head->next == NULL) {
+ /* just a single query, send it */
+ single_query = _ctx->head;
+ result = sql_query_s(_ctx->db, single_query->query);
+ } else {
+ /* multiple queries, use a transaction */
+ driver_pgsql_sync_init(db);
+ result = driver_pgsql_transaction_commit_multi(ctx);
+ driver_pgsql_sync_deinit(db);
+ }
+
+ if (ctx->failed) {
+ i_assert(ctx->error != NULL);
+ e_debug(sql_transaction_finished_event(_ctx)->
+ add_str("error", ctx->error)->event(),
+ "Transaction failed: %s", ctx->error);
+ *error_r = ctx->error;
+ } else if (result != NULL) {
+ if (sql_result_next_row(result) < 0)
+ *error_r = sql_result_get_error(result);
+ else if (single_query != NULL &&
+ single_query->affected_rows != NULL) {
+ struct pgsql_result *pg_result =
+ (struct pgsql_result *)result;
+
+ if (str_to_uint(PQcmdTuples(pg_result->pgres),
+ single_query->affected_rows) < 0)
+ i_unreached();
+ }
+ }
+
+ if (!ctx->failed) {
+ e_debug(sql_transaction_finished_event(_ctx)->event(),
+ "Transaction committed");
+ }
+
+ if (result != NULL)
+ sql_result_unref(result);
+}
+
+static int
+driver_pgsql_transaction_commit_s(struct sql_transaction_context *_ctx,
+ const char **error_r)
+{
+ struct pgsql_transaction_context *ctx =
+ (struct pgsql_transaction_context *)_ctx;
+ struct pgsql_db *db = (struct pgsql_db *)_ctx->db;
+
+ *error_r = NULL;
+
+ if (_ctx->head != NULL) {
+ driver_pgsql_try_commit_s(ctx, error_r);
+ if (_ctx->db->state == SQL_DB_STATE_DISCONNECTED) {
+ *error_r = t_strdup(*error_r);
+ e_info(db->api.event, "Disconnected from database, "
+ "retrying commit");
+ if (sql_connect(_ctx->db) >= 0) {
+ ctx->failed = FALSE;
+ *error_r = NULL;
+ driver_pgsql_try_commit_s(ctx, error_r);
+ }
+ }
+ }
+
+ driver_pgsql_transaction_free(ctx);
+ return *error_r == NULL ? 0 : -1;
+}
+
+static void
+driver_pgsql_transaction_rollback(struct sql_transaction_context *_ctx)
+{
+ struct pgsql_transaction_context *ctx =
+ (struct pgsql_transaction_context *)_ctx;
+ e_debug(sql_transaction_finished_event(_ctx)->
+ add_str("error", "Rolled back")->event(),
+ "Transaction rolled back");
+
+ driver_pgsql_transaction_free(ctx);
+}
+
+static void
+driver_pgsql_update(struct sql_transaction_context *_ctx, const char *query,
+ unsigned int *affected_rows)
+{
+ struct pgsql_transaction_context *ctx =
+ (struct pgsql_transaction_context *)_ctx;
+
+ sql_transaction_add_query(_ctx, ctx->query_pool, query, affected_rows);
+}
+
+static const char *
+driver_pgsql_escape_blob(struct sql_db *_db ATTR_UNUSED,
+ const unsigned char *data, size_t size)
+{
+ string_t *str = t_str_new(128);
+
+ str_append(str, "E'\\\\x");
+ binary_to_hex_append(str, data, size);
+ str_append_c(str, '\'');
+ return str_c(str);
+}
+
+static bool driver_pgsql_have_work(struct pgsql_db *db)
+{
+ return db->next_callback != NULL || db->pending_results != NULL ||
+ db->api.state == SQL_DB_STATE_CONNECTING;
+}
+
+static void driver_pgsql_wait(struct sql_db *_db)
+{
+ struct pgsql_db *db = (struct pgsql_db *)_db;
+
+ if (!driver_pgsql_have_work(db))
+ return;
+
+ db->orig_ioloop = current_ioloop;
+ db->ioloop = io_loop_create();
+ db->io = io_loop_move_io(&db->io);
+ while (driver_pgsql_have_work(db))
+ io_loop_run(db->ioloop);
+
+ io_loop_set_current(db->orig_ioloop);
+ db->io = io_loop_move_io(&db->io);
+ io_loop_set_current(db->ioloop);
+ io_loop_destroy(&db->ioloop);
+}
+
+const struct sql_db driver_pgsql_db = {
+ .name = "pgsql",
+ .flags = SQL_DB_FLAG_POOLED,
+
+ .v = {
+ .get_flags = driver_pgsql_get_flags,
+ .init_full = driver_pgsql_init_full_v,
+ .deinit = driver_pgsql_deinit_v,
+ .connect = driver_pgsql_connect,
+ .disconnect = driver_pgsql_disconnect,
+ .escape_string = driver_pgsql_escape_string,
+ .exec = driver_pgsql_exec,
+ .query = driver_pgsql_query,
+ .query_s = driver_pgsql_query_s,
+ .wait = driver_pgsql_wait,
+
+ .transaction_begin = driver_pgsql_transaction_begin,
+ .transaction_commit = driver_pgsql_transaction_commit,
+ .transaction_commit_s = driver_pgsql_transaction_commit_s,
+ .transaction_rollback = driver_pgsql_transaction_rollback,
+
+ .update = driver_pgsql_update,
+
+ .escape_blob = driver_pgsql_escape_blob,
+ }
+};
+
+const struct sql_result driver_pgsql_result = {
+ .v = {
+ .free = driver_pgsql_result_free,
+ .next_row = driver_pgsql_result_next_row,
+ .get_fields_count = driver_pgsql_result_get_fields_count,
+ .get_field_name = driver_pgsql_result_get_field_name,
+ .find_field = driver_pgsql_result_find_field,
+ .get_field_value = driver_pgsql_result_get_field_value,
+ .get_field_value_binary = driver_pgsql_result_get_field_value_binary,
+ .find_field_value = driver_pgsql_result_find_field_value,
+ .get_values = driver_pgsql_result_get_values,
+ .get_error = driver_pgsql_result_get_error,
+ }
+};
+
+const char *driver_pgsql_version = DOVECOT_ABI_VERSION;
+
+void driver_pgsql_init(void);
+void driver_pgsql_deinit(void);
+
+void driver_pgsql_init(void)
+{
+ sql_driver_register(&driver_pgsql_db);
+}
+
+void driver_pgsql_deinit(void)
+{
+ sql_driver_unregister(&driver_pgsql_db);
+}
+
+#endif