/* Copyright (c) 2005-2018 Dovecot authors, see the included COPYING file */ #include "lib.h" #include "array.h" #include "istream.h" #include "hex-binary.h" #include "hash.h" #include "str.h" #include "sql-api-private.h" #include "sql-db-cache.h" #include "dict-private.h" #include "dict-sql-settings.h" #include "dict-sql.h" #include "dict-sql-private.h" #include #include #define DICT_SQL_MAX_UNUSED_CONNECTIONS 10 enum sql_recurse_type { SQL_DICT_RECURSE_NONE, SQL_DICT_RECURSE_ONE, SQL_DICT_RECURSE_FULL }; struct sql_dict_param { enum dict_sql_type value_type; const char *value_str; int64_t value_int64; const void *value_binary; size_t value_binary_size; }; ARRAY_DEFINE_TYPE(sql_dict_param, struct sql_dict_param); struct sql_dict_iterate_context { struct dict_iterate_context ctx; pool_t pool; enum dict_iterate_flags flags; const char *path; struct sql_result *result; string_t *key; const struct dict_sql_map *map; size_t key_prefix_len, pattern_prefix_len; unsigned int sql_fields_start_idx, next_map_idx; bool destroyed; bool synchronous_result; bool iter_query_sent; bool allow_null_map; /* allow next map to be NULL */ const char *error; }; struct sql_dict_inc_row { struct sql_dict_inc_row *prev; unsigned int rows; }; struct sql_dict_prev { const struct dict_sql_map *map; char *key; union { char *str; long long diff; } value; }; struct sql_dict_transaction_context { struct dict_transaction_context ctx; struct sql_transaction_context *sql_ctx; pool_t inc_row_pool; struct sql_dict_inc_row *inc_row; ARRAY(struct sql_dict_prev) prev_inc; ARRAY(struct sql_dict_prev) prev_set; dict_transaction_commit_callback_t *async_callback; void *async_context; char *error; }; static struct sql_db_cache *dict_sql_db_cache; static void sql_dict_prev_inc_flush(struct sql_dict_transaction_context *ctx); static void sql_dict_prev_set_flush(struct sql_dict_transaction_context *ctx); static void sql_dict_prev_inc_free(struct sql_dict_transaction_context *ctx); static void sql_dict_prev_set_free(struct sql_dict_transaction_context *ctx); static int sql_dict_init(struct dict *driver, const char *uri, const struct dict_settings *set, struct dict **dict_r, const char **error_r) { struct sql_settings sql_set; struct sql_dict *dict; pool_t pool; pool = pool_alloconly_create("sql dict", 2048); dict = p_new(pool, struct sql_dict, 1); dict->pool = pool; dict->dict = *driver; dict->set = dict_sql_settings_read(uri, error_r); if (dict->set == NULL) { pool_unref(&pool); return -1; } i_zero(&sql_set); sql_set.driver = driver->name; sql_set.connect_string = dict->set->connect; sql_set.event_parent = set->event_parent; if (sql_db_cache_new(dict_sql_db_cache, &sql_set, &dict->db, error_r) < 0) { pool_unref(&pool); return -1; } *dict_r = &dict->dict; return 0; } static void sql_dict_deinit(struct dict *_dict) { struct sql_dict *dict = (struct sql_dict *)_dict; sql_unref(&dict->db); pool_unref(&dict->pool); } static void sql_dict_wait(struct dict *_dict) { struct sql_dict *dict = (struct sql_dict *)_dict; sql_wait(dict->db); } /* Try to match path to map->pattern. For example pattern="shared/x/$/$/y" and path="shared/x/1/2/y", this is match and pattern_values=[1, 2]. */ static bool dict_sql_map_match(const struct dict_sql_map *map, const char *path, ARRAY_TYPE(const_string) *pattern_values, size_t *pat_len_r, size_t *path_len_r, bool partial_ok, bool recurse) { const char *path_start = path; const char *pat, *field, *p; size_t len; array_clear(pattern_values); pat = map->pattern; while (*pat != '\0' && *path != '\0') { if (*pat == '$') { /* variable */ pat++; if (*pat == '\0') { /* pattern ended with this variable, it'll match the rest of the path */ len = strlen(path); if (partial_ok) { /* iterating - the last field never matches fully. if there's a trailing '/', drop it. */ pat--; if (path[len-1] == '/') { field = t_strndup(path, len-1); array_push_back(pattern_values, &field); } else { array_push_back(pattern_values, &path); } } else { array_push_back(pattern_values, &path); path += len; } *path_len_r = path - path_start; *pat_len_r = pat - map->pattern; return TRUE; } /* pattern matches until the next '/' in path */ p = strchr(path, '/'); if (p != NULL) { field = t_strdup_until(path, p); array_push_back(pattern_values, &field); path = p; } else { /* no '/' anymore, but it'll still match a partial */ array_push_back(pattern_values, &path); path += strlen(path); pat++; } } else if (*pat == *path) { pat++; path++; } else { return FALSE; } } *path_len_r = path - path_start; *pat_len_r = pat - map->pattern; if (*pat == '\0') return *path == '\0'; else if (!partial_ok) return FALSE; else { /* partial matches must end with '/'. */ if (pat != map->pattern && pat[-1] != '/') return FALSE; /* if we're not recursing, there should be only one $variable left. */ if (recurse) return TRUE; return pat[0] == '$' && strchr(pat, '/') == NULL; } } static const struct dict_sql_map * sql_dict_find_map(struct sql_dict *dict, const char *path, ARRAY_TYPE(const_string) *pattern_values) { const struct dict_sql_map *maps; unsigned int i, count; size_t len; t_array_init(pattern_values, dict->set->max_pattern_fields_count); maps = array_get(&dict->set->maps, &count); for (i = 0; i < count; i++) { if (dict_sql_map_match(&maps[i], path, pattern_values, &len, &len, FALSE, FALSE)) return &maps[i]; } return NULL; } static void sql_dict_statement_bind(struct sql_statement *stmt, unsigned int column_idx, const struct sql_dict_param *param) { switch (param->value_type) { case DICT_SQL_TYPE_STRING: sql_statement_bind_str(stmt, column_idx, param->value_str); break; case DICT_SQL_TYPE_INT: case DICT_SQL_TYPE_UINT: sql_statement_bind_int64(stmt, column_idx, param->value_int64); break; case DICT_SQL_TYPE_HEXBLOB: sql_statement_bind_binary(stmt, column_idx, param->value_binary, param->value_binary_size); break; } } static struct sql_statement * sql_dict_statement_init(struct sql_dict *dict, const char *query, const ARRAY_TYPE(sql_dict_param) *params) { struct sql_statement *stmt; struct sql_prepared_statement *prep_stmt; const struct sql_dict_param *param; if ((sql_get_flags(dict->db) & SQL_DB_FLAG_PREP_STATEMENTS) != 0) { prep_stmt = sql_prepared_statement_init(dict->db, query); stmt = sql_statement_init_prepared(prep_stmt); sql_prepared_statement_unref(&prep_stmt); } else { /* Prepared statements not supported by the backend. Just use regular statements to avoid wasting memory. */ stmt = sql_statement_init(dict->db, query); } array_foreach(params, param) { sql_dict_statement_bind(stmt, array_foreach_idx(params, param), param); } return stmt; } static int sql_dict_value_get(const struct dict_sql_map *map, enum dict_sql_type value_type, const char *field_name, const char *value, const char *value_suffix, ARRAY_TYPE(sql_dict_param) *params, const char **error_r) { struct sql_dict_param *param; buffer_t *buf; param = array_append_space(params); param->value_type = value_type; switch (value_type) { case DICT_SQL_TYPE_STRING: if (value_suffix[0] != '\0') value = t_strconcat(value, value_suffix, NULL); param->value_str = value; return 0; case DICT_SQL_TYPE_INT: if (value_suffix[0] != '\0' || str_to_int64(value, ¶m->value_int64) < 0) { *error_r = t_strdup_printf( "%s field's value isn't 64bit signed integer: %s%s (in pattern: %s)", field_name, value, value_suffix, map->pattern); return -1; } return 0; case DICT_SQL_TYPE_UINT: if (value_suffix[0] != '\0' || value[0] == '-' || str_to_int64(value, ¶m->value_int64) < 0) { *error_r = t_strdup_printf( "%s field's value isn't 64bit unsigned integer: %s%s (in pattern: %s)", field_name, value, value_suffix, map->pattern); return -1; } return 0; case DICT_SQL_TYPE_HEXBLOB: break; } buf = t_buffer_create(strlen(value)/2); if (hex_to_binary(value, buf) < 0) { /* we shouldn't get untrusted input here. it's also a bit annoying to handle this error. */ *error_r = t_strdup_printf("%s field's value isn't hexblob: %s (in pattern: %s)", field_name, value, map->pattern); return -1; } str_append(buf, value_suffix); param->value_binary = buf->data; param->value_binary_size = buf->used; return 0; } static int sql_dict_field_get_value(const struct dict_sql_map *map, const struct dict_sql_field *field, const char *value, const char *value_suffix, ARRAY_TYPE(sql_dict_param) *params, const char **error_r) { return sql_dict_value_get(map, field->value_type, field->name, value, value_suffix, params, error_r); } static int sql_dict_where_build(const char *username, const struct dict_sql_map *map, const ARRAY_TYPE(const_string) *values_arr, bool add_username, enum sql_recurse_type recurse_type, string_t *query, ARRAY_TYPE(sql_dict_param) *params, const char **error_r) { const struct dict_sql_field *pattern_fields; const char *const *pattern_values; unsigned int i, count, count2, exact_count; pattern_fields = array_get(&map->pattern_fields, &count); pattern_values = array_get(values_arr, &count2); /* if we came here from iteration code there may be fewer pattern_values */ i_assert(count2 <= count); if (count2 == 0 && !add_username) { /* we want everything */ return 0; } str_append(query, " WHERE"); exact_count = count == count2 && recurse_type != SQL_DICT_RECURSE_NONE ? count2-1 : count2; if (exact_count != array_count(values_arr)) { *error_r = t_strdup_printf("Key continues past the matched pattern %s", map->pattern); return -1; } for (i = 0; i < exact_count; i++) { if (i > 0) str_append(query, " AND"); str_printfa(query, " %s = ?", pattern_fields[i].name); if (sql_dict_field_get_value(map, &pattern_fields[i], pattern_values[i], "", params, error_r) < 0) return -1; } switch (recurse_type) { case SQL_DICT_RECURSE_NONE: break; case SQL_DICT_RECURSE_ONE: if (i > 0) str_append(query, " AND"); if (i < count2) { str_printfa(query, " %s LIKE ?", pattern_fields[i].name); if (sql_dict_field_get_value(map, &pattern_fields[i], pattern_values[i], "/%", params, error_r) < 0) return -1; str_printfa(query, " AND %s NOT LIKE ?", pattern_fields[i].name); if (sql_dict_field_get_value(map, &pattern_fields[i], pattern_values[i], "/%/%", params, error_r) < 0) return -1; } else { str_printfa(query, " %s LIKE '%%' AND " "%s NOT LIKE '%%/%%'", pattern_fields[i].name, pattern_fields[i].name); } break; case SQL_DICT_RECURSE_FULL: if (i < count2) { if (i > 0) str_append(query, " AND"); str_printfa(query, " %s LIKE ", pattern_fields[i].name); if (sql_dict_field_get_value(map, &pattern_fields[i], pattern_values[i], "/%", params, error_r) < 0) return -1; } break; } if (add_username) { struct sql_dict_param *param = array_append_space(params); if (count2 > 0) str_append(query, " AND"); str_printfa(query, " %s = ?", map->username_field); param->value_type = DICT_SQL_TYPE_STRING; param->value_str = t_strdup(username); } return 0; } static int sql_lookup_get_query(struct sql_dict *dict, const struct dict_op_settings *set, const char *key, const struct dict_sql_map **map_r, struct sql_statement **stmt_r, const char **error_r) { const struct dict_sql_map *map; ARRAY_TYPE(const_string) pattern_values; const char *error; map = *map_r = sql_dict_find_map(dict, key, &pattern_values); if (map == NULL) { *error_r = t_strdup_printf( "sql dict lookup: Invalid/unmapped key: %s", key); return -1; } string_t *query = t_str_new(256); ARRAY_TYPE(sql_dict_param) params; t_array_init(¶ms, 4); str_printfa(query, "SELECT %s FROM %s", map->value_field, map->table); if (sql_dict_where_build(set->username, map, &pattern_values, key[0] == DICT_PATH_PRIVATE[0], SQL_DICT_RECURSE_NONE, query, ¶ms, &error) < 0) { *error_r = t_strdup_printf( "sql dict lookup: Failed to lookup key %s: %s", key, error); return -1; } *stmt_r = sql_dict_statement_init(dict, str_c(query), ¶ms); return 0; } static const char * sql_dict_result_unescape(enum dict_sql_type type, pool_t pool, struct sql_result *result, unsigned int result_idx) { const unsigned char *data; size_t size; const char *value; string_t *str; switch (type) { case DICT_SQL_TYPE_STRING: case DICT_SQL_TYPE_INT: case DICT_SQL_TYPE_UINT: value = sql_result_get_field_value(result, result_idx); return value == NULL ? "" : p_strdup(pool, value); case DICT_SQL_TYPE_HEXBLOB: break; } data = sql_result_get_field_value_binary(result, result_idx, &size); str = str_new(pool, size*2 + 1); binary_to_hex_append(str, data, size); return str_c(str); } static const char * sql_dict_result_unescape_value(const struct dict_sql_map *map, pool_t pool, struct sql_result *result) { return sql_dict_result_unescape(map->value_types[0], pool, result, 0); } static const char *const * sql_dict_result_unescape_values(const struct dict_sql_map *map, pool_t pool, struct sql_result *result) { const char **values; unsigned int i; values = p_new(pool, const char *, map->values_count + 1); for (i = 0; i < map->values_count; i++) { values[i] = sql_dict_result_unescape(map->value_types[i], pool, result, i); } return values; } static const char * sql_dict_result_unescape_field(const struct dict_sql_map *map, pool_t pool, struct sql_result *result, unsigned int result_idx, unsigned int sql_field_idx) { const struct dict_sql_field *sql_field; sql_field = array_idx(&map->pattern_fields, sql_field_idx); return sql_dict_result_unescape(sql_field->value_type, pool, result, result_idx); } static int sql_dict_lookup(struct dict *_dict, const struct dict_op_settings *set, pool_t pool, const char *key, const char **value_r, const char **error_r) { struct sql_dict *dict = (struct sql_dict *)_dict; const struct dict_sql_map *map; struct sql_statement *stmt; struct sql_result *result = NULL; int ret; *value_r = NULL; if (sql_lookup_get_query(dict, set, key, &map, &stmt, error_r) < 0) return -1; result = sql_statement_query_s(&stmt); ret = sql_result_next_row(result); if (ret < 0) { *error_r = t_strdup_printf("dict sql lookup failed: %s", sql_result_get_error(result)); } else if (ret > 0) { *value_r = sql_dict_result_unescape_value(map, pool, result); } sql_result_unref(result); return ret; } struct sql_dict_lookup_context { const struct dict_sql_map *map; dict_lookup_callback_t *callback; void *context; }; static void sql_dict_lookup_async_callback(struct sql_result *sql_result, struct sql_dict_lookup_context *ctx) { struct dict_lookup_result result; i_zero(&result); result.ret = sql_result_next_row(sql_result); if (result.ret < 0) result.error = sql_result_get_error(sql_result); else if (result.ret > 0) { result.values = sql_dict_result_unescape_values(ctx->map, pool_datastack_create(), sql_result); result.value = result.values[0]; if (result.value == NULL) { /* NULL value returned. we'll treat this as "not found", which is probably what is usually wanted. */ result.ret = 0; } } ctx->callback(&result, ctx->context); i_free(ctx); } static void sql_dict_lookup_async(struct dict *_dict, const struct dict_op_settings *set, const char *key, dict_lookup_callback_t *callback, void *context) { struct sql_dict *dict = (struct sql_dict *)_dict; const struct dict_sql_map *map; struct sql_dict_lookup_context *ctx; struct sql_statement *stmt; const char *error; if (sql_lookup_get_query(dict, set, key, &map, &stmt, &error) < 0) { struct dict_lookup_result result; i_zero(&result); result.ret = -1; result.error = error; callback(&result, context); } else { ctx = i_new(struct sql_dict_lookup_context, 1); ctx->callback = callback; ctx->context = context; ctx->map = map; sql_statement_query(&stmt, sql_dict_lookup_async_callback, ctx); } } static const struct dict_sql_map * sql_dict_iterate_find_next_map(struct sql_dict_iterate_context *ctx, ARRAY_TYPE(const_string) *pattern_values) { struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict; const struct dict_sql_map *maps; unsigned int i, count; size_t pat_len, path_len; bool recurse = (ctx->flags & DICT_ITERATE_FLAG_RECURSE) != 0; t_array_init(pattern_values, dict->set->max_pattern_fields_count); maps = array_get(&dict->set->maps, &count); for (i = ctx->next_map_idx; i < count; i++) { if (dict_sql_map_match(&maps[i], ctx->path, pattern_values, &pat_len, &path_len, TRUE, recurse) && (recurse || array_count(pattern_values)+1 >= array_count(&maps[i].pattern_fields))) { ctx->key_prefix_len = path_len; ctx->pattern_prefix_len = pat_len; ctx->next_map_idx = i + 1; str_truncate(ctx->key, 0); str_append(ctx->key, ctx->path); return &maps[i]; } } return NULL; } static int sql_dict_iterate_build_next_query(struct sql_dict_iterate_context *ctx, struct sql_statement **stmt_r, const char **error_r) { struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict; const struct dict_op_settings_private *set = &ctx->ctx.set; const struct dict_sql_map *map; ARRAY_TYPE(const_string) pattern_values; const struct dict_sql_field *pattern_fields; enum sql_recurse_type recurse_type; unsigned int i, count; map = sql_dict_iterate_find_next_map(ctx, &pattern_values); /* NULL map is allowed if we have already done some lookups */ if (map == NULL) { if (!ctx->allow_null_map) { *error_r = "Invalid/unmapped path"; return -1; } return 0; } if (ctx->result != NULL) { sql_result_unref(ctx->result); ctx->result = NULL; } string_t *query = t_str_new(256); str_append(query, "SELECT "); if ((ctx->flags & DICT_ITERATE_FLAG_NO_VALUE) == 0) str_printfa(query, "%s,", map->value_field); /* get all missing fields */ pattern_fields = array_get(&map->pattern_fields, &count); i = array_count(&pattern_values); if (i == count) { /* we always want to know the last field since we're iterating its children */ i_assert(i > 0); i--; } ctx->sql_fields_start_idx = i; for (; i < count; i++) str_printfa(query, "%s,", pattern_fields[i].name); str_truncate(query, str_len(query)-1); str_printfa(query, " FROM %s", map->table); if ((ctx->flags & DICT_ITERATE_FLAG_RECURSE) != 0) recurse_type = SQL_DICT_RECURSE_FULL; else if ((ctx->flags & DICT_ITERATE_FLAG_EXACT_KEY) != 0) recurse_type = SQL_DICT_RECURSE_NONE; else recurse_type = SQL_DICT_RECURSE_ONE; ARRAY_TYPE(sql_dict_param) params; t_array_init(¶ms, 4); bool add_username = (ctx->path[0] == DICT_PATH_PRIVATE[0]); if (sql_dict_where_build(set->username, map, &pattern_values, add_username, recurse_type, query, ¶ms, error_r) < 0) return -1; if ((ctx->flags & DICT_ITERATE_FLAG_SORT_BY_KEY) != 0) { str_append(query, " ORDER BY "); for (i = 0; i < count; i++) { str_printfa(query, "%s", pattern_fields[i].name); if (i < count-1) str_append_c(query, ','); } } else if ((ctx->flags & DICT_ITERATE_FLAG_SORT_BY_VALUE) != 0) str_printfa(query, " ORDER BY %s", map->value_field); if (ctx->ctx.max_rows > 0) { i_assert(ctx->ctx.row_count < ctx->ctx.max_rows); str_printfa(query, " LIMIT %"PRIu64, ctx->ctx.max_rows - ctx->ctx.row_count); } *stmt_r = sql_dict_statement_init(dict, str_c(query), ¶ms); ctx->map = map; return 1; } static void sql_dict_iterate_callback(struct sql_result *result, struct sql_dict_iterate_context *ctx) { if (!ctx->destroyed) { sql_result_ref(result); ctx->result = result; if (ctx->ctx.async_callback != NULL && !ctx->synchronous_result) ctx->ctx.async_callback(ctx->ctx.async_context); } pool_t pool_copy = ctx->pool; pool_unref(&pool_copy); } static int sql_dict_iterate_next_query(struct sql_dict_iterate_context *ctx) { struct sql_statement *stmt; const char *error; int ret; ret = sql_dict_iterate_build_next_query(ctx, &stmt, &error); if (ret <= 0) { /* this is expected error */ if (ret == 0) return ret; /* failed */ ctx->error = p_strdup_printf(ctx->pool, "sql dict iterate failed for %s: %s", ctx->path, error); return -1; } if ((ctx->flags & DICT_ITERATE_FLAG_ASYNC) == 0) { ctx->result = sql_statement_query_s(&stmt); } else { i_assert(ctx->result == NULL); ctx->synchronous_result = TRUE; pool_ref(ctx->pool); sql_statement_query(&stmt, sql_dict_iterate_callback, ctx); ctx->synchronous_result = FALSE; } return ret; } static struct dict_iterate_context * sql_dict_iterate_init(struct dict *_dict, const struct dict_op_settings *set ATTR_UNUSED, const char *path, enum dict_iterate_flags flags) { struct sql_dict_iterate_context *ctx; pool_t pool; pool = pool_alloconly_create("sql dict iterate", 512); ctx = p_new(pool, struct sql_dict_iterate_context, 1); ctx->ctx.dict = _dict; ctx->pool = pool; ctx->flags = flags; ctx->path = p_strdup(pool, path); ctx->key = str_new(pool, 256); return &ctx->ctx; } static bool sql_dict_iterate(struct dict_iterate_context *_ctx, const char **key_r, const char *const **values_r) { struct sql_dict_iterate_context *ctx = (struct sql_dict_iterate_context *)_ctx; const char *p, *value; unsigned int i, sql_field_i, count; int ret; _ctx->has_more = FALSE; if (ctx->error != NULL) return FALSE; if (!ctx->iter_query_sent) { ctx->iter_query_sent = TRUE; if (sql_dict_iterate_next_query(ctx) <= 0) return FALSE; } if (ctx->result == NULL) { /* wait for async lookup to finish */ i_assert((ctx->flags & DICT_ITERATE_FLAG_ASYNC) != 0); _ctx->has_more = TRUE; return FALSE; } ret = sql_result_next_row(ctx->result); while (ret == SQL_RESULT_NEXT_MORE) { if ((ctx->flags & DICT_ITERATE_FLAG_ASYNC) == 0) sql_result_more_s(&ctx->result); else { /* get more results asynchronously */ ctx->synchronous_result = TRUE; pool_ref(ctx->pool); sql_result_more(&ctx->result, sql_dict_iterate_callback, ctx); ctx->synchronous_result = FALSE; if (ctx->result == NULL) { _ctx->has_more = TRUE; return FALSE; } } ret = sql_result_next_row(ctx->result); } if (ret == 0) { /* see if there are more results in the next map. don't do it if we're looking for an exact match, since we already should have handled it. */ if ((ctx->flags & DICT_ITERATE_FLAG_EXACT_KEY) != 0) return FALSE; ctx->iter_query_sent = FALSE; /* we have gotten *SOME* results, so can allow unmapped next key now. */ ctx->allow_null_map = TRUE; return sql_dict_iterate(_ctx, key_r, values_r); } if (ret < 0) { ctx->error = p_strdup_printf(ctx->pool, "dict sql iterate failed: %s", sql_result_get_error(ctx->result)); return FALSE; } /* convert fetched row to dict key */ str_truncate(ctx->key, ctx->key_prefix_len); if (ctx->key_prefix_len > 0 && str_c(ctx->key)[ctx->key_prefix_len-1] != '/') str_append_c(ctx->key, '/'); count = sql_result_get_fields_count(ctx->result); i = (ctx->flags & DICT_ITERATE_FLAG_NO_VALUE) != 0 ? 0 : ctx->map->values_count; sql_field_i = ctx->sql_fields_start_idx; for (p = ctx->map->pattern + ctx->pattern_prefix_len; *p != '\0'; p++) { if (*p != '$') str_append_c(ctx->key, *p); else { i_assert(i < count); value = sql_dict_result_unescape_field(ctx->map, pool_datastack_create(), ctx->result, i, sql_field_i); if (value != NULL) str_append(ctx->key, value); i++; sql_field_i++; } } *key_r = str_c(ctx->key); if ((ctx->flags & DICT_ITERATE_FLAG_NO_VALUE) == 0) { *values_r = sql_dict_result_unescape_values(ctx->map, pool_datastack_create(), ctx->result); } return TRUE; } static int sql_dict_iterate_deinit(struct dict_iterate_context *_ctx, const char **error_r) { struct sql_dict_iterate_context *ctx = (struct sql_dict_iterate_context *)_ctx; int ret = ctx->error != NULL ? -1 : 0; *error_r = t_strdup(ctx->error); if (ctx->result != NULL) sql_result_unref(ctx->result); ctx->destroyed = TRUE; pool_t pool_copy = ctx->pool; pool_unref(&pool_copy); return ret; } static struct dict_transaction_context * sql_dict_transaction_init(struct dict *_dict) { struct sql_dict *dict = (struct sql_dict *)_dict; struct sql_dict_transaction_context *ctx; ctx = i_new(struct sql_dict_transaction_context, 1); ctx->ctx.dict = _dict; ctx->sql_ctx = sql_transaction_begin(dict->db); return &ctx->ctx; } static void sql_dict_transaction_free(struct sql_dict_transaction_context *ctx) { if (array_is_created(&ctx->prev_inc)) sql_dict_prev_inc_free(ctx); if (array_is_created(&ctx->prev_set)) sql_dict_prev_set_free(ctx); pool_unref(&ctx->inc_row_pool); i_free(ctx->error); i_free(ctx); } static bool sql_dict_transaction_has_nonexistent(struct sql_dict_transaction_context *ctx) { struct sql_dict_inc_row *inc_row; for (inc_row = ctx->inc_row; inc_row != NULL; inc_row = inc_row->prev) { i_assert(inc_row->rows != UINT_MAX); if (inc_row->rows == 0) return TRUE; } return FALSE; } static void sql_dict_transaction_commit_callback(const struct sql_commit_result *sql_result, struct sql_dict_transaction_context *ctx) { struct dict_commit_result result; i_zero(&result); if (sql_result->error == NULL) result.ret = sql_dict_transaction_has_nonexistent(ctx) ? DICT_COMMIT_RET_NOTFOUND : DICT_COMMIT_RET_OK; else { result.error = t_strdup_printf("sql dict: commit failed: %s", sql_result->error); switch (sql_result->error_type) { case SQL_RESULT_ERROR_TYPE_UNKNOWN: default: result.ret = DICT_COMMIT_RET_FAILED; break; case SQL_RESULT_ERROR_TYPE_WRITE_UNCERTAIN: result.ret = DICT_COMMIT_RET_WRITE_UNCERTAIN; break; } } if (ctx->async_callback != NULL) ctx->async_callback(&result, ctx->async_context); else if (result.ret < 0) i_error("%s", result.error); sql_dict_transaction_free(ctx); } static void sql_dict_transaction_commit(struct dict_transaction_context *_ctx, bool async, dict_transaction_commit_callback_t *callback, void *context) { struct sql_dict_transaction_context *ctx = (struct sql_dict_transaction_context *)_ctx; const char *error; struct dict_commit_result result; /* flush any pending set/inc */ if (array_is_created(&ctx->prev_inc)) sql_dict_prev_inc_flush(ctx); if (array_is_created(&ctx->prev_set)) sql_dict_prev_set_flush(ctx); /* note that the above calls might still set ctx->error */ i_zero(&result); result.ret = DICT_COMMIT_RET_FAILED; result.error = t_strdup(ctx->error); if (ctx->error != NULL) { sql_transaction_rollback(&ctx->sql_ctx); } else if (!_ctx->changed) { /* nothing changed, no need to commit */ sql_transaction_rollback(&ctx->sql_ctx); result.ret = DICT_COMMIT_RET_OK; } else if (async) { ctx->async_callback = callback; ctx->async_context = context; sql_transaction_commit(&ctx->sql_ctx, sql_dict_transaction_commit_callback, ctx); return; } else if (sql_transaction_commit_s(&ctx->sql_ctx, &error) < 0) { result.error = t_strdup_printf( "sql dict: commit failed: %s", error); } else { if (sql_dict_transaction_has_nonexistent(ctx)) result.ret = DICT_COMMIT_RET_NOTFOUND; else result.ret = DICT_COMMIT_RET_OK; } sql_dict_transaction_free(ctx); callback(&result, context); } static void sql_dict_transaction_rollback(struct dict_transaction_context *_ctx) { struct sql_dict_transaction_context *ctx = (struct sql_dict_transaction_context *)_ctx; sql_transaction_rollback(&ctx->sql_ctx); sql_dict_transaction_free(ctx); } static struct sql_statement * sql_dict_transaction_stmt_init(struct sql_dict_transaction_context *ctx, const char *query, const ARRAY_TYPE(sql_dict_param) *params) { struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict; struct sql_statement *stmt = sql_dict_statement_init(dict, query, params); if (ctx->ctx.timestamp.tv_sec != 0) sql_statement_set_timestamp(stmt, &ctx->ctx.timestamp); return stmt; } struct dict_sql_build_query_field { const struct dict_sql_map *map; const char *value; }; struct dict_sql_build_query { struct sql_dict *dict; ARRAY(struct dict_sql_build_query_field) fields; const ARRAY_TYPE(const_string) *pattern_values; bool add_username; }; static int sql_dict_set_query(struct sql_dict_transaction_context *ctx, const struct dict_sql_build_query *build, struct sql_statement **stmt_r, const char **error_r) { struct sql_dict *dict = build->dict; const struct dict_sql_build_query_field *fields; const struct dict_sql_field *pattern_fields; ARRAY_TYPE(sql_dict_param) params; const char *const *pattern_values; unsigned int i, field_count, count, count2; string_t *prefix, *suffix; fields = array_get(&build->fields, &field_count); i_assert(field_count > 0); t_array_init(¶ms, 4); prefix = t_str_new(64); suffix = t_str_new(256); /* SQL table is guaranteed to be the same for all fields. Build all the SQL field names into prefix and '?' placeholders for each value into the suffix. The actual field values will be added into params[]. */ str_printfa(prefix, "INSERT INTO %s", fields[0].map->table); str_append(prefix, " ("); str_append(suffix, ") VALUES ("); for (i = 0; i < field_count; i++) { if (i > 0) { str_append_c(prefix, ','); str_append_c(suffix, ','); } str_append(prefix, t_strcut(fields[i].map->value_field, ',')); enum dict_sql_type value_type = fields[i].map->value_types[0]; str_append_c(suffix, '?'); if (sql_dict_value_get(fields[i].map, value_type, "value", fields[i].value, "", ¶ms, error_r) < 0) return -1; } if (build->add_username) { struct sql_dict_param *param = array_append_space(¶ms); str_printfa(prefix, ",%s", fields[0].map->username_field); str_append(suffix, ",?"); param->value_type = DICT_SQL_TYPE_STRING; param->value_str = ctx->ctx.set.username; } /* add the variable fields that were parsed from the path */ pattern_fields = array_get(&fields[0].map->pattern_fields, &count); pattern_values = array_get(build->pattern_values, &count2); i_assert(count == count2); for (i = 0; i < count; i++) { str_printfa(prefix, ",%s", pattern_fields[i].name); str_append(suffix, ",?"); if (sql_dict_field_get_value(fields[0].map, &pattern_fields[i], pattern_values[i], "", ¶ms, error_r) < 0) return -1; } str_append_str(prefix, suffix); str_append_c(prefix, ')'); enum sql_db_flags flags = sql_get_flags(dict->db); if ((flags & SQL_DB_FLAG_ON_DUPLICATE_KEY) != 0) str_append(prefix, " ON DUPLICATE KEY UPDATE "); else if ((flags & SQL_DB_FLAG_ON_CONFLICT_DO) != 0) { str_append(prefix, " ON CONFLICT ("); for (i = 0; i < count; i++) { if (i > 0) str_append_c(prefix, ','); str_append(prefix, pattern_fields[i].name); } if (build->add_username) { if (count > 0) str_append_c(prefix, ','); str_append(prefix, fields[0].map->username_field); } str_append(prefix, ") DO UPDATE SET "); } else { *stmt_r = sql_dict_transaction_stmt_init(ctx, str_c(prefix), ¶ms); return 0; } /* If the row already exists, UPDATE it instead. The pattern_values don't need to be updated here, because they are expected to be part of the row's primary key. */ for (i = 0; i < field_count; i++) { const char *first_value_field = t_strcut(fields[i].map->value_field, ','); if (i > 0) str_append_c(prefix, ','); str_append(prefix, first_value_field); str_append_c(prefix, '='); enum dict_sql_type value_type = fields[i].map->value_types[0]; str_append_c(prefix, '?'); if (sql_dict_value_get(fields[i].map, value_type, "value", fields[i].value, "", ¶ms, error_r) < 0) return -1; } *stmt_r = sql_dict_transaction_stmt_init(ctx, str_c(prefix), ¶ms); return 0; } static int sql_dict_update_query(const struct dict_sql_build_query *build, const struct dict_op_settings_private *set, const char **query_r, ARRAY_TYPE(sql_dict_param) *params, const char **error_r) { const struct dict_sql_build_query_field *fields; unsigned int i, field_count; string_t *query; fields = array_get(&build->fields, &field_count); i_assert(field_count > 0); query = t_str_new(64); str_printfa(query, "UPDATE %s SET ", fields[0].map->table); for (i = 0; i < field_count; i++) { const char *first_value_field = t_strcut(fields[i].map->value_field, ','); if (i > 0) str_append_c(query, ','); str_printfa(query, "%s=%s+?", first_value_field, first_value_field); } if (sql_dict_where_build(set->username, fields[0].map, build->pattern_values, build->add_username, SQL_DICT_RECURSE_NONE, query, params, error_r) < 0) return -1; *query_r = str_c(query); return 0; } static void sql_dict_prev_set_free(struct sql_dict_transaction_context *ctx) { struct sql_dict_prev *prev_set; array_foreach_modifiable(&ctx->prev_set, prev_set) { i_free(prev_set->value.str); i_free(prev_set->key); } array_free(&ctx->prev_set); } static void sql_dict_prev_set_flush(struct sql_dict_transaction_context *ctx) { struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict; const struct sql_dict_prev *prev_sets; unsigned int count; struct sql_statement *stmt; ARRAY_TYPE(const_string) pattern_values; struct dict_sql_build_query build; struct dict_sql_build_query_field *field; const char *error; i_assert(array_is_created(&ctx->prev_set)); if (ctx->error != NULL) { sql_dict_prev_set_free(ctx); return; } prev_sets = array_get(&ctx->prev_set, &count); i_assert(count > 0); /* Get the variable values from the dict path. We already verified that these are all exactly the same for everything in prev_sets. */ if (sql_dict_find_map(dict, prev_sets[0].key, &pattern_values) == NULL) i_unreached(); /* this was already checked */ i_zero(&build); build.dict = dict; build.pattern_values = &pattern_values; build.add_username = (prev_sets[0].key[0] == DICT_PATH_PRIVATE[0]); /* build.fields[] is used to get the map { value_field } for the SQL field names, as well as the values for them. Example: INSERT INTO ... (build.fields[0].map->value_field, ...[1], ...) VALUES (build.fields[0].value, ...[1], ...) */ t_array_init(&build.fields, count); for (unsigned int i = 0; i < count; i++) { i_assert(build.add_username == (prev_sets[i].key[0] == DICT_PATH_PRIVATE[0])); field = array_append_space(&build.fields); field->map = prev_sets[i].map; field->value = prev_sets[i].value.str; } if (sql_dict_set_query(ctx, &build, &stmt, &error) < 0) { ctx->error = i_strdup_printf( "dict-sql: Failed to set %u fields (first %s): %s", count, prev_sets[0].key, error); } else { sql_update_stmt(ctx->sql_ctx, &stmt); } sql_dict_prev_set_free(ctx); } static void sql_dict_unset(struct dict_transaction_context *_ctx, const char *key) { struct sql_dict_transaction_context *ctx = (struct sql_dict_transaction_context *)_ctx; struct sql_dict *dict = (struct sql_dict *)_ctx->dict; const struct dict_op_settings_private *set = &_ctx->set; const struct dict_sql_map *map; ARRAY_TYPE(const_string) pattern_values; string_t *query = t_str_new(256); ARRAY_TYPE(sql_dict_param) params; const char *error; if (ctx->error != NULL) return; /* In theory we could unset one of the previous set/incs in this same transaction, so flush them first. */ if (array_is_created(&ctx->prev_inc)) sql_dict_prev_inc_flush(ctx); if (array_is_created(&ctx->prev_set)) sql_dict_prev_set_flush(ctx); map = sql_dict_find_map(dict, key, &pattern_values); if (map == NULL) { ctx->error = i_strdup_printf("dict-sql: Invalid/unmapped key: %s", key); return; } str_printfa(query, "DELETE FROM %s", map->table); t_array_init(¶ms, 4); if (sql_dict_where_build(set->username, map, &pattern_values, key[0] == DICT_PATH_PRIVATE[0], SQL_DICT_RECURSE_NONE, query, ¶ms, &error) < 0) { ctx->error = i_strdup_printf( "dict-sql: Failed to delete %s: %s", key, error); } else { struct sql_statement *stmt = sql_dict_transaction_stmt_init(ctx, str_c(query), ¶ms); sql_update_stmt(ctx->sql_ctx, &stmt); } } static unsigned int * sql_dict_next_inc_row(struct sql_dict_transaction_context *ctx) { struct sql_dict_inc_row *row; if (ctx->inc_row_pool == NULL) { ctx->inc_row_pool = pool_alloconly_create("sql dict inc rows", 128); } row = p_new(ctx->inc_row_pool, struct sql_dict_inc_row, 1); row->prev = ctx->inc_row; row->rows = UINT_MAX; ctx->inc_row = row; return &row->rows; } static void sql_dict_prev_inc_free(struct sql_dict_transaction_context *ctx) { struct sql_dict_prev *prev_inc; array_foreach_modifiable(&ctx->prev_inc, prev_inc) i_free(prev_inc->key); array_free(&ctx->prev_inc); } static void sql_dict_prev_inc_flush(struct sql_dict_transaction_context *ctx) { struct sql_dict *dict = (struct sql_dict *)ctx->ctx.dict; const struct dict_op_settings_private *set = &ctx->ctx.set; const struct sql_dict_prev *prev_incs; unsigned int count; ARRAY_TYPE(const_string) pattern_values; struct dict_sql_build_query build; struct dict_sql_build_query_field *field; ARRAY_TYPE(sql_dict_param) params; struct sql_dict_param *param; const char *query, *error; i_assert(array_is_created(&ctx->prev_inc)); if (ctx->error != NULL) { sql_dict_prev_inc_free(ctx); return; } prev_incs = array_get(&ctx->prev_inc, &count); i_assert(count > 0); /* Get the variable values from the dict path. We already verified that these are all exactly the same for everything in prev_incs. */ if (sql_dict_find_map(dict, prev_incs[0].key, &pattern_values) == NULL) i_unreached(); /* this was already checked */ i_zero(&build); build.dict = dict; build.pattern_values = &pattern_values; build.add_username = (prev_incs[0].key[0] == DICT_PATH_PRIVATE[0]); /* build.fields[] is an array of maps, which are used to get the map { value_field } for the SQL field names. params[] specifies the list of values to use for each field. Example: UPDATE .. SET build.fields[0].map->value_field = ...->value_field + params[0]->value_int64, ...[1]... */ t_array_init(&build.fields, count); t_array_init(¶ms, count); for (unsigned int i = 0; i < count; i++) { i_assert(build.add_username == (prev_incs[i].key[0] == DICT_PATH_PRIVATE[0])); field = array_append_space(&build.fields); field->map = prev_incs[i].map; field->value = NULL; /* unused */ param = array_append_space(¶ms); param->value_type = DICT_SQL_TYPE_INT; param->value_int64 = prev_incs[i].value.diff; } if (sql_dict_update_query(&build, set, &query, ¶ms, &error) < 0) { ctx->error = i_strdup_printf( "dict-sql: Failed to increase %u fields (first %s): %s", count, prev_incs[0].key, error); } else { struct sql_statement *stmt = sql_dict_transaction_stmt_init(ctx, query, ¶ms); sql_update_stmt_get_rows(ctx->sql_ctx, &stmt, sql_dict_next_inc_row(ctx)); } sql_dict_prev_inc_free(ctx); } static bool sql_dict_maps_are_mergeable(struct sql_dict *dict, const struct sql_dict_prev *prev1, const struct dict_sql_map *map2, const char *map2_key, const ARRAY_TYPE(const_string) *map2_pattern_values) { const struct dict_sql_map *map3; ARRAY_TYPE(const_string) map1_pattern_values; /* sql table names must equal */ if (strcmp(prev1->map->table, map2->table) != 0) return FALSE; /* private vs shared prefix must equal */ if (prev1->key[0] != map2_key[0]) return FALSE; if (prev1->key[0] == DICT_PATH_PRIVATE[0]) { /* for private keys, username must equal */ if (strcmp(prev1->map->username_field, map2->username_field) != 0) return FALSE; } /* variable values in the paths must equal exactly */ map3 = sql_dict_find_map(dict, prev1->key, &map1_pattern_values); i_assert(map3 == prev1->map); return array_equal_fn(&map1_pattern_values, map2_pattern_values, i_strcmp_p); } static void sql_dict_set(struct dict_transaction_context *_ctx, const char *key, const char *value) { struct sql_dict_transaction_context *ctx = (struct sql_dict_transaction_context *)_ctx; struct sql_dict *dict = (struct sql_dict *)_ctx->dict; const struct dict_sql_map *map; ARRAY_TYPE(const_string) pattern_values; if (ctx->error != NULL) return; /* In theory we could set the previous inc in this same transaction, so flush it first. */ if (array_is_created(&ctx->prev_inc)) sql_dict_prev_inc_flush(ctx); map = sql_dict_find_map(dict, key, &pattern_values); if (map == NULL) { ctx->error = i_strdup_printf( "sql dict set: Invalid/unmapped key: %s", key); return; } if (array_is_created(&ctx->prev_set) && !sql_dict_maps_are_mergeable(dict, array_front(&ctx->prev_set), map, key, &pattern_values)) { /* couldn't merge to the previous set - flush it */ sql_dict_prev_set_flush(ctx); } if (!array_is_created(&ctx->prev_set)) i_array_init(&ctx->prev_set, 4); /* Either this is the first set, or this can be merged with the previous set. */ struct sql_dict_prev *prev_set = array_append_space(&ctx->prev_set); prev_set->map = map; prev_set->key = i_strdup(key); prev_set->value.str = i_strdup(value); } static void sql_dict_atomic_inc(struct dict_transaction_context *_ctx, const char *key, long long diff) { struct sql_dict_transaction_context *ctx = (struct sql_dict_transaction_context *)_ctx; struct sql_dict *dict = (struct sql_dict *)_ctx->dict; const struct dict_sql_map *map; ARRAY_TYPE(const_string) pattern_values; if (ctx->error != NULL) return; /* In theory we could inc the previous set in this same transaction, so flush it first. */ if (array_is_created(&ctx->prev_set)) sql_dict_prev_set_flush(ctx); map = sql_dict_find_map(dict, key, &pattern_values); if (map == NULL) { ctx->error = i_strdup_printf( "sql dict atomic inc: Invalid/unmapped key: %s", key); return; } if (array_is_created(&ctx->prev_inc) && !sql_dict_maps_are_mergeable(dict, array_front(&ctx->prev_inc), map, key, &pattern_values)) { /* couldn't merge to the previous inc - flush it */ sql_dict_prev_inc_flush(ctx); } if (!array_is_created(&ctx->prev_inc)) i_array_init(&ctx->prev_inc, 4); /* Either this is the first inc, or this can be merged with the previous inc. */ struct sql_dict_prev *prev_inc = array_append_space(&ctx->prev_inc); prev_inc->map = map; prev_inc->key = i_strdup(key); prev_inc->value.diff = diff; } static struct dict sql_dict = { .name = "sql", { .init = sql_dict_init, .deinit = sql_dict_deinit, .wait = sql_dict_wait, .lookup = sql_dict_lookup, .iterate_init = sql_dict_iterate_init, .iterate = sql_dict_iterate, .iterate_deinit = sql_dict_iterate_deinit, .transaction_init = sql_dict_transaction_init, .transaction_commit = sql_dict_transaction_commit, .transaction_rollback = sql_dict_transaction_rollback, .set = sql_dict_set, .unset = sql_dict_unset, .atomic_inc = sql_dict_atomic_inc, .lookup_async = sql_dict_lookup_async, } }; static struct dict *dict_sql_drivers; void dict_sql_register(void) { const struct sql_db *const *drivers; unsigned int i, count; dict_sql_db_cache = sql_db_cache_init(DICT_SQL_MAX_UNUSED_CONNECTIONS); /* @UNSAFE */ drivers = array_get(&sql_drivers, &count); dict_sql_drivers = i_new(struct dict, count + 1); for (i = 0; i < count; i++) { dict_sql_drivers[i] = sql_dict; dict_sql_drivers[i].name = drivers[i]->name; dict_driver_register(&dict_sql_drivers[i]); } } void dict_sql_unregister(void) { int i; for (i = 0; dict_sql_drivers[i].name != NULL; i++) dict_driver_unregister(&dict_sql_drivers[i]); i_free(dict_sql_drivers); sql_db_cache_deinit(&dict_sql_db_cache); dict_sql_settings_deinit(); }