summaryrefslogtreecommitdiffstats
path: root/src/common/scram-common.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/common/scram-common.c')
-rw-r--r--src/common/scram-common.c330
1 files changed, 330 insertions, 0 deletions
diff --git a/src/common/scram-common.c b/src/common/scram-common.c
new file mode 100644
index 0000000..6448564
--- /dev/null
+++ b/src/common/scram-common.c
@@ -0,0 +1,330 @@
+/*-------------------------------------------------------------------------
+ * scram-common.c
+ * Shared frontend/backend code for SCRAM authentication
+ *
+ * This contains the common low-level functions needed in both frontend and
+ * backend, for implement the Salted Challenge Response Authentication
+ * Mechanism (SCRAM), per IETF's RFC 5802.
+ *
+ * Portions Copyright (c) 2017-2023, PostgreSQL Global Development Group
+ *
+ * IDENTIFICATION
+ * src/common/scram-common.c
+ *
+ *-------------------------------------------------------------------------
+ */
+#ifndef FRONTEND
+#include "postgres.h"
+#else
+#include "postgres_fe.h"
+#endif
+
+#include "common/base64.h"
+#include "common/hmac.h"
+#include "common/scram-common.h"
+#ifndef FRONTEND
+#include "miscadmin.h"
+#endif
+#include "port/pg_bswap.h"
+
+/*
+ * Calculate SaltedPassword.
+ *
+ * The password should already be normalized by SASLprep. Returns 0 on
+ * success, -1 on failure with *errstr pointing to a message about the
+ * error details.
+ */
+int
+scram_SaltedPassword(const char *password,
+ pg_cryptohash_type hash_type, int key_length,
+ const char *salt, int saltlen, int iterations,
+ uint8 *result, const char **errstr)
+{
+ int password_len = strlen(password);
+ uint32 one = pg_hton32(1);
+ int i,
+ j;
+ uint8 Ui[SCRAM_MAX_KEY_LEN];
+ uint8 Ui_prev[SCRAM_MAX_KEY_LEN];
+ pg_hmac_ctx *hmac_ctx = pg_hmac_create(hash_type);
+
+ if (hmac_ctx == NULL)
+ {
+ *errstr = pg_hmac_error(NULL); /* returns OOM */
+ return -1;
+ }
+
+ /*
+ * Iterate hash calculation of HMAC entry using given salt. This is
+ * essentially PBKDF2 (see RFC2898) with HMAC() as the pseudorandom
+ * function.
+ */
+
+ /* First iteration */
+ if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
+ pg_hmac_update(hmac_ctx, (uint8 *) salt, saltlen) < 0 ||
+ pg_hmac_update(hmac_ctx, (uint8 *) &one, sizeof(uint32)) < 0 ||
+ pg_hmac_final(hmac_ctx, Ui_prev, key_length) < 0)
+ {
+ *errstr = pg_hmac_error(hmac_ctx);
+ pg_hmac_free(hmac_ctx);
+ return -1;
+ }
+
+ memcpy(result, Ui_prev, key_length);
+
+ /* Subsequent iterations */
+ for (i = 2; i <= iterations; i++)
+ {
+#ifndef FRONTEND
+ /*
+ * Make sure that this is interruptible as scram_iterations could be
+ * set to a large value.
+ */
+ CHECK_FOR_INTERRUPTS();
+#endif
+
+ if (pg_hmac_init(hmac_ctx, (uint8 *) password, password_len) < 0 ||
+ pg_hmac_update(hmac_ctx, (uint8 *) Ui_prev, key_length) < 0 ||
+ pg_hmac_final(hmac_ctx, Ui, key_length) < 0)
+ {
+ *errstr = pg_hmac_error(hmac_ctx);
+ pg_hmac_free(hmac_ctx);
+ return -1;
+ }
+
+ for (j = 0; j < key_length; j++)
+ result[j] ^= Ui[j];
+ memcpy(Ui_prev, Ui, key_length);
+ }
+
+ pg_hmac_free(hmac_ctx);
+ return 0;
+}
+
+
+/*
+ * Calculate hash for a NULL-terminated string. (The NULL terminator is
+ * not included in the hash). Returns 0 on success, -1 on failure with *errstr
+ * pointing to a message about the error details.
+ */
+int
+scram_H(const uint8 *input, pg_cryptohash_type hash_type, int key_length,
+ uint8 *result, const char **errstr)
+{
+ pg_cryptohash_ctx *ctx;
+
+ ctx = pg_cryptohash_create(hash_type);
+ if (ctx == NULL)
+ {
+ *errstr = pg_cryptohash_error(NULL); /* returns OOM */
+ return -1;
+ }
+
+ if (pg_cryptohash_init(ctx) < 0 ||
+ pg_cryptohash_update(ctx, input, key_length) < 0 ||
+ pg_cryptohash_final(ctx, result, key_length) < 0)
+ {
+ *errstr = pg_cryptohash_error(ctx);
+ pg_cryptohash_free(ctx);
+ return -1;
+ }
+
+ pg_cryptohash_free(ctx);
+ return 0;
+}
+
+/*
+ * Calculate ClientKey. Returns 0 on success, -1 on failure with *errstr
+ * pointing to a message about the error details.
+ */
+int
+scram_ClientKey(const uint8 *salted_password,
+ pg_cryptohash_type hash_type, int key_length,
+ uint8 *result, const char **errstr)
+{
+ pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
+
+ if (ctx == NULL)
+ {
+ *errstr = pg_hmac_error(NULL); /* returns OOM */
+ return -1;
+ }
+
+ if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
+ pg_hmac_update(ctx, (uint8 *) "Client Key", strlen("Client Key")) < 0 ||
+ pg_hmac_final(ctx, result, key_length) < 0)
+ {
+ *errstr = pg_hmac_error(ctx);
+ pg_hmac_free(ctx);
+ return -1;
+ }
+
+ pg_hmac_free(ctx);
+ return 0;
+}
+
+/*
+ * Calculate ServerKey. Returns 0 on success, -1 on failure with *errstr
+ * pointing to a message about the error details.
+ */
+int
+scram_ServerKey(const uint8 *salted_password,
+ pg_cryptohash_type hash_type, int key_length,
+ uint8 *result, const char **errstr)
+{
+ pg_hmac_ctx *ctx = pg_hmac_create(hash_type);
+
+ if (ctx == NULL)
+ {
+ *errstr = pg_hmac_error(NULL); /* returns OOM */
+ return -1;
+ }
+
+ if (pg_hmac_init(ctx, salted_password, key_length) < 0 ||
+ pg_hmac_update(ctx, (uint8 *) "Server Key", strlen("Server Key")) < 0 ||
+ pg_hmac_final(ctx, result, key_length) < 0)
+ {
+ *errstr = pg_hmac_error(ctx);
+ pg_hmac_free(ctx);
+ return -1;
+ }
+
+ pg_hmac_free(ctx);
+ return 0;
+}
+
+
+/*
+ * Construct a SCRAM secret, for storing in pg_authid.rolpassword.
+ *
+ * The password should already have been processed with SASLprep, if necessary!
+ *
+ * If iterations is 0, default number of iterations is used. The result is
+ * palloc'd or malloc'd, so caller is responsible for freeing it.
+ *
+ * On error, returns NULL and sets *errstr to point to a message about the
+ * error details.
+ */
+char *
+scram_build_secret(pg_cryptohash_type hash_type, int key_length,
+ const char *salt, int saltlen, int iterations,
+ const char *password, const char **errstr)
+{
+ uint8 salted_password[SCRAM_MAX_KEY_LEN];
+ uint8 stored_key[SCRAM_MAX_KEY_LEN];
+ uint8 server_key[SCRAM_MAX_KEY_LEN];
+ char *result;
+ char *p;
+ int maxlen;
+ int encoded_salt_len;
+ int encoded_stored_len;
+ int encoded_server_len;
+ int encoded_result;
+
+ /* Only this hash method is supported currently */
+ Assert(hash_type == PG_SHA256);
+
+ Assert(iterations > 0);
+
+ /* Calculate StoredKey and ServerKey */
+ if (scram_SaltedPassword(password, hash_type, key_length,
+ salt, saltlen, iterations,
+ salted_password, errstr) < 0 ||
+ scram_ClientKey(salted_password, hash_type, key_length,
+ stored_key, errstr) < 0 ||
+ scram_H(stored_key, hash_type, key_length,
+ stored_key, errstr) < 0 ||
+ scram_ServerKey(salted_password, hash_type, key_length,
+ server_key, errstr) < 0)
+ {
+ /* errstr is filled already here */
+#ifdef FRONTEND
+ return NULL;
+#else
+ elog(ERROR, "could not calculate stored key and server key: %s",
+ *errstr);
+#endif
+ }
+
+ /*----------
+ * The format is:
+ * SCRAM-SHA-256$<iteration count>:<salt>$<StoredKey>:<ServerKey>
+ *----------
+ */
+ encoded_salt_len = pg_b64_enc_len(saltlen);
+ encoded_stored_len = pg_b64_enc_len(key_length);
+ encoded_server_len = pg_b64_enc_len(key_length);
+
+ maxlen = strlen("SCRAM-SHA-256") + 1
+ + 10 + 1 /* iteration count */
+ + encoded_salt_len + 1 /* Base64-encoded salt */
+ + encoded_stored_len + 1 /* Base64-encoded StoredKey */
+ + encoded_server_len + 1; /* Base64-encoded ServerKey */
+
+#ifdef FRONTEND
+ result = malloc(maxlen);
+ if (!result)
+ {
+ *errstr = _("out of memory");
+ return NULL;
+ }
+#else
+ result = palloc(maxlen);
+#endif
+
+ p = result + sprintf(result, "SCRAM-SHA-256$%d:", iterations);
+
+ /* salt */
+ encoded_result = pg_b64_encode(salt, saltlen, p, encoded_salt_len);
+ if (encoded_result < 0)
+ {
+ *errstr = _("could not encode salt");
+#ifdef FRONTEND
+ free(result);
+ return NULL;
+#else
+ elog(ERROR, "%s", *errstr);
+#endif
+ }
+ p += encoded_result;
+ *(p++) = '$';
+
+ /* stored key */
+ encoded_result = pg_b64_encode((char *) stored_key, key_length, p,
+ encoded_stored_len);
+ if (encoded_result < 0)
+ {
+ *errstr = _("could not encode stored key");
+#ifdef FRONTEND
+ free(result);
+ return NULL;
+#else
+ elog(ERROR, "%s", *errstr);
+#endif
+ }
+
+ p += encoded_result;
+ *(p++) = ':';
+
+ /* server key */
+ encoded_result = pg_b64_encode((char *) server_key, key_length, p,
+ encoded_server_len);
+ if (encoded_result < 0)
+ {
+ *errstr = _("could not encode server key");
+#ifdef FRONTEND
+ free(result);
+ return NULL;
+#else
+ elog(ERROR, "%s", *errstr);
+#endif
+ }
+
+ p += encoded_result;
+ *(p++) = '\0';
+
+ Assert(p - result <= maxlen);
+
+ return result;
+}