diff options
Diffstat (limited to '')
-rw-r--r-- | src/auth/userdb-sql.c | 319 |
1 files changed, 319 insertions, 0 deletions
diff --git a/src/auth/userdb-sql.c b/src/auth/userdb-sql.c new file mode 100644 index 0000000..3a87640 --- /dev/null +++ b/src/auth/userdb-sql.c @@ -0,0 +1,319 @@ +/* Copyright (c) 2004-2018 Dovecot authors, see the included COPYING file */ + +#include "auth-common.h" +#include "userdb.h" + +#ifdef USERDB_SQL + +#include "auth-cache.h" +#include "db-sql.h" + +#include <string.h> + +struct sql_userdb_module { + struct userdb_module module; + + struct db_sql_connection *conn; +}; + +struct userdb_sql_request { + struct auth_request *auth_request; + userdb_callback_t *callback; +}; + +struct sql_userdb_iterate_context { + struct userdb_iterate_context ctx; + struct sql_result *result; + bool freed:1; + bool call_iter:1; +}; + +static void userdb_sql_iterate_next(struct userdb_iterate_context *_ctx); +static int userdb_sql_iterate_deinit(struct userdb_iterate_context *_ctx); + +static void +sql_query_get_result(struct sql_result *result, + struct auth_request *auth_request) +{ + const char *name, *value; + unsigned int i, fields_count; + + fields_count = sql_result_get_fields_count(result); + for (i = 0; i < fields_count; i++) { + name = sql_result_get_field_name(result, i); + value = sql_result_get_field_value(result, i); + + if (*name != '\0' && value != NULL) { + auth_request_set_userdb_field(auth_request, + name, value); + } + } +} + +static void sql_query_callback(struct sql_result *sql_result, + struct userdb_sql_request *sql_request) +{ + struct auth_request *auth_request = sql_request->auth_request; + struct userdb_module *_module = auth_request->userdb->userdb; + struct sql_userdb_module *module = + (struct sql_userdb_module *)_module; + enum userdb_result result = USERDB_RESULT_INTERNAL_FAILURE; + int ret; + + ret = sql_result_next_row(sql_result); + if (ret >= 0) + db_sql_success(module->conn); + if (ret < 0) { + if (!module->conn->default_user_query) { + e_error(authdb_event(auth_request), + "User query failed: %s", + sql_result_get_error(sql_result)); + } else { + e_error(authdb_event(auth_request), + "User query failed: %s " + "(using built-in default user_query: %s)", + sql_result_get_error(sql_result), + module->conn->set.user_query); + } + } else if (ret == 0) { + result = USERDB_RESULT_USER_UNKNOWN; + auth_request_log_unknown_user(auth_request, AUTH_SUBSYS_DB); + } else { + sql_query_get_result(sql_result, auth_request); + result = USERDB_RESULT_OK; + } + + sql_request->callback(result, auth_request); + auth_request_unref(&auth_request); + i_free(sql_request); +} + +static const char * +userdb_sql_escape(const char *str, const struct auth_request *auth_request) +{ + struct userdb_module *_module = auth_request->userdb->userdb; + struct sql_userdb_module *module = + (struct sql_userdb_module *)_module; + + return sql_escape_string(module->conn->db, str); +} + +static void userdb_sql_lookup(struct auth_request *auth_request, + userdb_callback_t *callback) +{ + struct userdb_module *_module = auth_request->userdb->userdb; + struct sql_userdb_module *module = + (struct sql_userdb_module *)_module; + struct userdb_sql_request *sql_request; + const char *query, *error; + + if (t_auth_request_var_expand(module->conn->set.user_query, + auth_request, userdb_sql_escape, + &query, &error) <= 0) { + e_error(authdb_event(auth_request), + "Failed to expand user_query=%s: %s", + module->conn->set.user_query, error); + callback(USERDB_RESULT_INTERNAL_FAILURE, auth_request); + return; + } + + auth_request_ref(auth_request); + sql_request = i_new(struct userdb_sql_request, 1); + sql_request->callback = callback; + sql_request->auth_request = auth_request; + + e_debug(authdb_event(auth_request), "%s", query); + + sql_query(module->conn->db, query, + sql_query_callback, sql_request); +} + +static void sql_iter_query_callback(struct sql_result *sql_result, + struct sql_userdb_iterate_context *ctx) +{ + ctx->result = sql_result; + sql_result_ref(sql_result); + + if (ctx->freed) + (void)userdb_sql_iterate_deinit(&ctx->ctx); + else if (ctx->call_iter) + userdb_sql_iterate_next(&ctx->ctx); +} + +static struct userdb_iterate_context * +userdb_sql_iterate_init(struct auth_request *auth_request, + userdb_iter_callback_t *callback, void *context) +{ + struct userdb_module *_module = auth_request->userdb->userdb; + struct sql_userdb_module *module = + (struct sql_userdb_module *)_module; + struct sql_userdb_iterate_context *ctx; + const char *query, *error; + + if (t_auth_request_var_expand(module->conn->set.iterate_query, + auth_request, userdb_sql_escape, + &query, &error) <= 0) { + e_error(authdb_event(auth_request), + "Failed to expand iterate_query=%s: %s", + module->conn->set.iterate_query, error); + } + + ctx = i_new(struct sql_userdb_iterate_context, 1); + ctx->ctx.auth_request = auth_request; + ctx->ctx.callback = callback; + ctx->ctx.context = context; + auth_request_ref(auth_request); + + sql_query(module->conn->db, query, + sql_iter_query_callback, ctx); + e_debug(authdb_event(auth_request), "%s", query); + return &ctx->ctx; +} + +static int userdb_sql_iterate_get_user(struct sql_userdb_iterate_context *ctx, + const char **user_r) +{ + const char *domain; + int idx; + + /* try user first */ + idx = sql_result_find_field(ctx->result, "user"); + if (idx == 0) { + *user_r = sql_result_get_field_value(ctx->result, idx); + return 0; + } + + /* username [+ domain]? */ + idx = sql_result_find_field(ctx->result, "username"); + if (idx < 0) { + /* no user or username, fail */ + return -1; + } + + *user_r = sql_result_get_field_value(ctx->result, idx); + if (*user_r == NULL) + return 0; + + domain = sql_result_find_field_value(ctx->result, "domain"); + if (domain != NULL) + *user_r = t_strconcat(*user_r, "@", domain, NULL); + return 0; +} + +static void userdb_sql_iterate_next(struct userdb_iterate_context *_ctx) +{ + struct sql_userdb_iterate_context *ctx = + (struct sql_userdb_iterate_context *)_ctx; + struct userdb_module *_module = _ctx->auth_request->userdb->userdb; + struct sql_userdb_module *module = (struct sql_userdb_module *)_module; + const char *user; + int ret; + + if (ctx->result == NULL) { + /* query not finished yet */ + ctx->call_iter = TRUE; + return; + } + + ret = sql_result_next_row(ctx->result); + if (ret >= 0) + db_sql_success(module->conn); + if (ret > 0) { + if (userdb_sql_iterate_get_user(ctx, &user) < 0) + e_error(authdb_event(_ctx->auth_request), + "sql: Iterate query didn't return 'user' field"); + else if (user == NULL) + e_error(authdb_event(_ctx->auth_request), + "sql: Iterate query returned NULL user"); + else { + _ctx->callback(user, _ctx->context); + return; + } + _ctx->failed = TRUE; + } else if (ret < 0) { + if (!module->conn->default_iterate_query) { + e_error(authdb_event(_ctx->auth_request), + "sql: Iterate query failed: %s", + sql_result_get_error(ctx->result)); + } else { + e_error(authdb_event(_ctx->auth_request), + "sql: Iterate query failed: %s " + "(using built-in default iterate_query: %s)", + sql_result_get_error(ctx->result), + module->conn->set.iterate_query); + } + _ctx->failed = TRUE; + } + _ctx->callback(NULL, _ctx->context); +} + +static int userdb_sql_iterate_deinit(struct userdb_iterate_context *_ctx) +{ + struct sql_userdb_iterate_context *ctx = + (struct sql_userdb_iterate_context *)_ctx; + int ret = _ctx->failed ? -1 : 0; + + auth_request_unref(&_ctx->auth_request); + if (ctx->result == NULL) { + /* sql query hasn't finished yet */ + ctx->freed = TRUE; + } else { + if (ctx->result != NULL) + sql_result_unref(ctx->result); + i_free(ctx); + } + return ret; +} + +static struct userdb_module * +userdb_sql_preinit(pool_t pool, const char *args) +{ + struct sql_userdb_module *module; + + module = p_new(pool, struct sql_userdb_module, 1); + module->conn = db_sql_init(args, TRUE); + + module->module.default_cache_key = + auth_cache_parse_key(pool, module->conn->set.user_query); + return &module->module; +} + +static void userdb_sql_init(struct userdb_module *_module) +{ + struct sql_userdb_module *module = + (struct sql_userdb_module *)_module; + enum sql_db_flags flags; + + flags = sql_get_flags(module->conn->db); + _module->blocking = (flags & SQL_DB_FLAG_BLOCKING) != 0; + + if (!_module->blocking || worker) + db_sql_connect(module->conn); +} + +static void userdb_sql_deinit(struct userdb_module *_module) +{ + struct sql_userdb_module *module = + (struct sql_userdb_module *)_module; + + db_sql_unref(&module->conn); +} + +struct userdb_module_interface userdb_sql = { + "sql", + + userdb_sql_preinit, + userdb_sql_init, + userdb_sql_deinit, + + userdb_sql_lookup, + + userdb_sql_iterate_init, + userdb_sql_iterate_next, + userdb_sql_iterate_deinit +}; +#else +struct userdb_module_interface userdb_sql = { + .name = "sql" +}; +#endif |