summaryrefslogtreecommitdiffstats
path: root/src/backend/libpq
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/backend/libpq/Makefile39
-rw-r--r--src/backend/libpq/README.SSL82
-rw-r--r--src/backend/libpq/auth-scram.c1445
-rw-r--r--src/backend/libpq/auth.c3492
-rw-r--r--src/backend/libpq/be-fsstubs.c860
-rw-r--r--src/backend/libpq/be-gssapi-common.c94
-rw-r--r--src/backend/libpq/be-secure-common.c195
-rw-r--r--src/backend/libpq/be-secure-gssapi.c733
-rw-r--r--src/backend/libpq/be-secure-openssl.c1526
-rw-r--r--src/backend/libpq/be-secure.c345
-rw-r--r--src/backend/libpq/crypt.c290
-rw-r--r--src/backend/libpq/hba.c3166
-rw-r--r--src/backend/libpq/ifaddr.c594
-rw-r--r--src/backend/libpq/pg_hba.conf.sample94
-rw-r--r--src/backend/libpq/pg_ident.conf.sample42
-rw-r--r--src/backend/libpq/pqcomm.c1976
-rw-r--r--src/backend/libpq/pqformat.c643
-rw-r--r--src/backend/libpq/pqmq.c313
-rw-r--r--src/backend/libpq/pqsignal.c148
19 files changed, 16077 insertions, 0 deletions
diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile
new file mode 100644
index 0000000..8d1d16b
--- /dev/null
+++ b/src/backend/libpq/Makefile
@@ -0,0 +1,39 @@
+#-------------------------------------------------------------------------
+#
+# Makefile--
+# Makefile for libpq subsystem (backend half of libpq interface)
+#
+# IDENTIFICATION
+# src/backend/libpq/Makefile
+#
+#-------------------------------------------------------------------------
+
+subdir = src/backend/libpq
+top_builddir = ../../..
+include $(top_builddir)/src/Makefile.global
+
+# be-fsstubs is here for historical reasons, probably belongs elsewhere
+
+OBJS = \
+ auth-scram.o \
+ auth.o \
+ be-fsstubs.o \
+ be-secure-common.o \
+ be-secure.o \
+ crypt.o \
+ hba.o \
+ ifaddr.o \
+ pqcomm.o \
+ pqformat.o \
+ pqmq.o \
+ pqsignal.o
+
+ifeq ($(with_ssl),openssl)
+OBJS += be-secure-openssl.o
+endif
+
+ifeq ($(with_gssapi),yes)
+OBJS += be-gssapi-common.o be-secure-gssapi.o
+endif
+
+include $(top_srcdir)/src/backend/common.mk
diff --git a/src/backend/libpq/README.SSL b/src/backend/libpq/README.SSL
new file mode 100644
index 0000000..d84a434
--- /dev/null
+++ b/src/backend/libpq/README.SSL
@@ -0,0 +1,82 @@
+src/backend/libpq/README.SSL
+
+SSL
+===
+
+>From the servers perspective:
+
+
+ Receives StartupPacket
+ |
+ |
+ (Is SSL_NEGOTIATE_CODE?) ----------- Normal startup
+ | No
+ |
+ | Yes
+ |
+ |
+ (Server compiled with USE_SSL?) ------- Send 'N'
+ | No |
+ | |
+ | Yes Normal startup
+ |
+ |
+ Send 'S'
+ |
+ |
+ Establish SSL
+ |
+ |
+ Normal startup
+
+
+
+
+
+>From the clients perspective (v6.6 client _with_ SSL):
+
+
+ Connect
+ |
+ |
+ Send packet with SSL_NEGOTIATE_CODE
+ |
+ |
+ Receive single char ------- 'S' -------- Establish SSL
+ | |
+ | '<else>' |
+ | Normal startup
+ |
+ |
+ Is it 'E' for error ------------------- Retry connection
+ | Yes without SSL
+ | No
+ |
+ Is it 'N' for normal ------------------- Normal startup
+ | Yes
+ |
+ Fail with unknown
+
+---------------------------------------------------------------------------
+
+Ephemeral DH
+============
+
+Since the server static private key ($DataDir/server.key) will
+normally be stored unencrypted so that the database backend can
+restart automatically, it is important that we select an algorithm
+that continues to provide confidentiality even if the attacker has the
+server's private key. Ephemeral DH (EDH) keys provide this and more
+(Perfect Forward Secrecy aka PFS).
+
+N.B., the static private key should still be protected to the largest
+extent possible, to minimize the risk of impersonations.
+
+Another benefit of EDH is that it allows the backend and clients to
+use DSA keys. DSA keys can only provide digital signatures, not
+encryption, and are often acceptable in jurisdictions where RSA keys
+are unacceptable.
+
+The downside to EDH is that it makes it impossible to use ssldump(1)
+if there's a problem establishing an SSL session. In this case you'll
+need to temporarily disable EDH (see initialize_dh()).
diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c
new file mode 100644
index 0000000..f9e1026
--- /dev/null
+++ b/src/backend/libpq/auth-scram.c
@@ -0,0 +1,1445 @@
+/*-------------------------------------------------------------------------
+ *
+ * auth-scram.c
+ * Server-side implementation of the SASL SCRAM-SHA-256 mechanism.
+ *
+ * See the following RFCs for more details:
+ * - RFC 5802: https://tools.ietf.org/html/rfc5802
+ * - RFC 5803: https://tools.ietf.org/html/rfc5803
+ * - RFC 7677: https://tools.ietf.org/html/rfc7677
+ *
+ * Here are some differences:
+ *
+ * - Username from the authentication exchange is not used. The client
+ * should send an empty string as the username.
+ *
+ * - If the password isn't valid UTF-8, or contains characters prohibited
+ * by the SASLprep profile, we skip the SASLprep pre-processing and use
+ * the raw bytes in calculating the hash.
+ *
+ * - If channel binding is used, the channel binding type is always
+ * "tls-server-end-point". The spec says the default is "tls-unique"
+ * (RFC 5802, section 6.1. Default Channel Binding), but there are some
+ * problems with that. Firstly, not all SSL libraries provide an API to
+ * get the TLS Finished message, required to use "tls-unique". Secondly,
+ * "tls-unique" is not specified for TLS v1.3, and as of this writing,
+ * it's not clear if there will be a replacement. We could support both
+ * "tls-server-end-point" and "tls-unique", but for our use case,
+ * "tls-unique" doesn't really have any advantages. The main advantage
+ * of "tls-unique" would be that it works even if the server doesn't
+ * have a certificate, but PostgreSQL requires a server certificate
+ * whenever SSL is used, anyway.
+ *
+ *
+ * The password stored in pg_authid consists of the iteration count, salt,
+ * StoredKey and ServerKey.
+ *
+ * SASLprep usage
+ * --------------
+ *
+ * One notable difference to the SCRAM specification is that while the
+ * specification dictates that the password is in UTF-8, and prohibits
+ * certain characters, we are more lenient. If the password isn't a valid
+ * UTF-8 string, or contains prohibited characters, the raw bytes are used
+ * to calculate the hash instead, without SASLprep processing. This is
+ * because PostgreSQL supports other encodings too, and the encoding being
+ * used during authentication is undefined (client_encoding isn't set until
+ * after authentication). In effect, we try to interpret the password as
+ * UTF-8 and apply SASLprep processing, but if it looks invalid, we assume
+ * that it's in some other encoding.
+ *
+ * In the worst case, we misinterpret a password that's in a different
+ * encoding as being Unicode, because it happens to consists entirely of
+ * valid UTF-8 bytes, and we apply Unicode normalization to it. As long
+ * as we do that consistently, that will not lead to failed logins.
+ * Fortunately, the UTF-8 byte sequences that are ignored by SASLprep
+ * don't correspond to any commonly used characters in any of the other
+ * supported encodings, so it should not lead to any significant loss in
+ * entropy, even if the normalization is incorrectly applied to a
+ * non-UTF-8 password.
+ *
+ * Error handling
+ * --------------
+ *
+ * Don't reveal user information to an unauthenticated client. We don't
+ * want an attacker to be able to probe whether a particular username is
+ * valid. In SCRAM, the server has to read the salt and iteration count
+ * from the user's stored secret, and send it to the client. To avoid
+ * revealing whether a user exists, when the client tries to authenticate
+ * with a username that doesn't exist, or doesn't have a valid SCRAM
+ * secret in pg_authid, we create a fake salt and iteration count
+ * on-the-fly, and proceed with the authentication with that. In the end,
+ * we'll reject the attempt, as if an incorrect password was given. When
+ * we are performing a "mock" authentication, the 'doomed' flag in
+ * scram_state is set.
+ *
+ * In the error messages, avoid printing strings from the client, unless
+ * you check that they are pure ASCII. We don't want an unauthenticated
+ * attacker to be able to spam the logs with characters that are not valid
+ * to the encoding being used, whatever that is. We cannot avoid that in
+ * general, after logging in, but let's do what we can here.
+ *
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/backend/libpq/auth-scram.c
+ *
+ *-------------------------------------------------------------------------
+ */
+#include "postgres.h"
+
+#include <unistd.h>
+
+#include "access/xlog.h"
+#include "catalog/pg_authid.h"
+#include "catalog/pg_control.h"
+#include "common/base64.h"
+#include "common/hmac.h"
+#include "common/saslprep.h"
+#include "common/scram-common.h"
+#include "common/sha2.h"
+#include "libpq/auth.h"
+#include "libpq/crypt.h"
+#include "libpq/scram.h"
+#include "miscadmin.h"
+#include "utils/builtins.h"
+#include "utils/timestamp.h"
+
+/*
+ * Status data for a SCRAM authentication exchange. This should be kept
+ * internal to this file.
+ */
+typedef enum
+{
+ SCRAM_AUTH_INIT,
+ SCRAM_AUTH_SALT_SENT,
+ SCRAM_AUTH_FINISHED
+} scram_state_enum;
+
+typedef struct
+{
+ scram_state_enum state;
+
+ const char *username; /* username from startup packet */
+
+ Port *port;
+ bool channel_binding_in_use;
+
+ int iterations;
+ char *salt; /* base64-encoded */
+ uint8 StoredKey[SCRAM_KEY_LEN];
+ uint8 ServerKey[SCRAM_KEY_LEN];
+
+ /* Fields of the first message from client */
+ char cbind_flag;
+ char *client_first_message_bare;
+ char *client_username;
+ char *client_nonce;
+
+ /* Fields from the last message from client */
+ char *client_final_message_without_proof;
+ char *client_final_nonce;
+ char ClientProof[SCRAM_KEY_LEN];
+
+ /* Fields generated in the server */
+ char *server_first_message;
+ char *server_nonce;
+
+ /*
+ * If something goes wrong during the authentication, or we are performing
+ * a "mock" authentication (see comments at top of file), the 'doomed'
+ * flag is set. A reason for the failure, for the server log, is put in
+ * 'logdetail'.
+ */
+ bool doomed;
+ char *logdetail;
+} scram_state;
+
+static void read_client_first_message(scram_state *state, const char *input);
+static void read_client_final_message(scram_state *state, const char *input);
+static char *build_server_first_message(scram_state *state);
+static char *build_server_final_message(scram_state *state);
+static bool verify_client_proof(scram_state *state);
+static bool verify_final_nonce(scram_state *state);
+static void mock_scram_secret(const char *username, int *iterations,
+ char **salt, uint8 *stored_key, uint8 *server_key);
+static bool is_scram_printable(char *p);
+static char *sanitize_char(char c);
+static char *sanitize_str(const char *s);
+static char *scram_mock_salt(const char *username);
+
+/*
+ * pg_be_scram_get_mechanisms
+ *
+ * Get a list of SASL mechanisms that this module supports.
+ *
+ * For the convenience of building the FE/BE packet that lists the
+ * mechanisms, the names are appended to the given StringInfo buffer,
+ * separated by '\0' bytes.
+ */
+void
+pg_be_scram_get_mechanisms(Port *port, StringInfo buf)
+{
+ /*
+ * Advertise the mechanisms in decreasing order of importance. So the
+ * channel-binding variants go first, if they are supported. Channel
+ * binding is only supported with SSL, and only if the SSL implementation
+ * has a function to get the certificate's hash.
+ */
+#ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
+ if (port->ssl_in_use)
+ {
+ appendStringInfoString(buf, SCRAM_SHA_256_PLUS_NAME);
+ appendStringInfoChar(buf, '\0');
+ }
+#endif
+ appendStringInfoString(buf, SCRAM_SHA_256_NAME);
+ appendStringInfoChar(buf, '\0');
+}
+
+/*
+ * pg_be_scram_init
+ *
+ * Initialize a new SCRAM authentication exchange status tracker. This
+ * needs to be called before doing any exchange. It will be filled later
+ * after the beginning of the exchange with authentication information.
+ *
+ * 'selected_mech' identifies the SASL mechanism that the client selected.
+ * It should be one of the mechanisms that we support, as returned by
+ * pg_be_scram_get_mechanisms().
+ *
+ * 'shadow_pass' is the role's stored secret, from pg_authid.rolpassword.
+ * The username was provided by the client in the startup message, and is
+ * available in port->user_name. If 'shadow_pass' is NULL, we still perform
+ * an authentication exchange, but it will fail, as if an incorrect password
+ * was given.
+ */
+void *
+pg_be_scram_init(Port *port,
+ const char *selected_mech,
+ const char *shadow_pass)
+{
+ scram_state *state;
+ bool got_secret;
+
+ state = (scram_state *) palloc0(sizeof(scram_state));
+ state->port = port;
+ state->state = SCRAM_AUTH_INIT;
+
+ /*
+ * Parse the selected mechanism.
+ *
+ * Note that if we don't support channel binding, either because the SSL
+ * implementation doesn't support it or we're not using SSL at all, we
+ * would not have advertised the PLUS variant in the first place. If the
+ * client nevertheless tries to select it, it's a protocol violation like
+ * selecting any other SASL mechanism we don't support.
+ */
+#ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
+ if (strcmp(selected_mech, SCRAM_SHA_256_PLUS_NAME) == 0 && port->ssl_in_use)
+ state->channel_binding_in_use = true;
+ else
+#endif
+ if (strcmp(selected_mech, SCRAM_SHA_256_NAME) == 0)
+ state->channel_binding_in_use = false;
+ else
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("client selected an invalid SASL authentication mechanism")));
+
+ /*
+ * Parse the stored secret.
+ */
+ if (shadow_pass)
+ {
+ int password_type = get_password_type(shadow_pass);
+
+ if (password_type == PASSWORD_TYPE_SCRAM_SHA_256)
+ {
+ if (parse_scram_secret(shadow_pass, &state->iterations, &state->salt,
+ state->StoredKey, state->ServerKey))
+ got_secret = true;
+ else
+ {
+ /*
+ * The password looked like a SCRAM secret, but could not be
+ * parsed.
+ */
+ ereport(LOG,
+ (errmsg("invalid SCRAM secret for user \"%s\"",
+ state->port->user_name)));
+ got_secret = false;
+ }
+ }
+ else
+ {
+ /*
+ * The user doesn't have SCRAM secret. (You cannot do SCRAM
+ * authentication with an MD5 hash.)
+ */
+ state->logdetail = psprintf(_("User \"%s\" does not have a valid SCRAM secret."),
+ state->port->user_name);
+ got_secret = false;
+ }
+ }
+ else
+ {
+ /*
+ * The caller requested us to perform a dummy authentication. This is
+ * considered normal, since the caller requested it, so don't set log
+ * detail.
+ */
+ got_secret = false;
+ }
+
+ /*
+ * If the user did not have a valid SCRAM secret, we still go through the
+ * motions with a mock one, and fail as if the client supplied an
+ * incorrect password. This is to avoid revealing information to an
+ * attacker.
+ */
+ if (!got_secret)
+ {
+ mock_scram_secret(state->port->user_name, &state->iterations,
+ &state->salt, state->StoredKey, state->ServerKey);
+ state->doomed = true;
+ }
+
+ return state;
+}
+
+/*
+ * Continue a SCRAM authentication exchange.
+ *
+ * 'input' is the SCRAM payload sent by the client. On the first call,
+ * 'input' contains the "Initial Client Response" that the client sent as
+ * part of the SASLInitialResponse message, or NULL if no Initial Client
+ * Response was given. (The SASL specification distinguishes between an
+ * empty response and non-existing one.) On subsequent calls, 'input'
+ * cannot be NULL. For convenience in this function, the caller must
+ * ensure that there is a null terminator at input[inputlen].
+ *
+ * The next message to send to client is saved in 'output', for a length
+ * of 'outputlen'. In the case of an error, optionally store a palloc'd
+ * string at *logdetail that will be sent to the postmaster log (but not
+ * the client).
+ */
+int
+pg_be_scram_exchange(void *opaq, const char *input, int inputlen,
+ char **output, int *outputlen, char **logdetail)
+{
+ scram_state *state = (scram_state *) opaq;
+ int result;
+
+ *output = NULL;
+
+ /*
+ * If the client didn't include an "Initial Client Response" in the
+ * SASLInitialResponse message, send an empty challenge, to which the
+ * client will respond with the same data that usually comes in the
+ * Initial Client Response.
+ */
+ if (input == NULL)
+ {
+ Assert(state->state == SCRAM_AUTH_INIT);
+
+ *output = pstrdup("");
+ *outputlen = 0;
+ return SASL_EXCHANGE_CONTINUE;
+ }
+
+ /*
+ * Check that the input length agrees with the string length of the input.
+ * We can ignore inputlen after this.
+ */
+ if (inputlen == 0)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("The message is empty.")));
+ if (inputlen != strlen(input))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("Message length does not match input length.")));
+
+ switch (state->state)
+ {
+ case SCRAM_AUTH_INIT:
+
+ /*
+ * Initialization phase. Receive the first message from client
+ * and be sure that it parsed correctly. Then send the challenge
+ * to the client.
+ */
+ read_client_first_message(state, input);
+
+ /* prepare message to send challenge */
+ *output = build_server_first_message(state);
+
+ state->state = SCRAM_AUTH_SALT_SENT;
+ result = SASL_EXCHANGE_CONTINUE;
+ break;
+
+ case SCRAM_AUTH_SALT_SENT:
+
+ /*
+ * Final phase for the server. Receive the response to the
+ * challenge previously sent, verify, and let the client know that
+ * everything went well (or not).
+ */
+ read_client_final_message(state, input);
+
+ if (!verify_final_nonce(state))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("invalid SCRAM response"),
+ errdetail("Nonce does not match.")));
+
+ /*
+ * Now check the final nonce and the client proof.
+ *
+ * If we performed a "mock" authentication that we knew would fail
+ * from the get go, this is where we fail.
+ *
+ * The SCRAM specification includes an error code,
+ * "invalid-proof", for authentication failure, but it also allows
+ * erroring out in an application-specific way. We choose to do
+ * the latter, so that the error message for invalid password is
+ * the same for all authentication methods. The caller will call
+ * ereport(), when we return SASL_EXCHANGE_FAILURE with no output.
+ *
+ * NB: the order of these checks is intentional. We calculate the
+ * client proof even in a mock authentication, even though it's
+ * bound to fail, to thwart timing attacks to determine if a role
+ * with the given name exists or not.
+ */
+ if (!verify_client_proof(state) || state->doomed)
+ {
+ result = SASL_EXCHANGE_FAILURE;
+ break;
+ }
+
+ /* Build final message for client */
+ *output = build_server_final_message(state);
+
+ /* Success! */
+ result = SASL_EXCHANGE_SUCCESS;
+ state->state = SCRAM_AUTH_FINISHED;
+ break;
+
+ default:
+ elog(ERROR, "invalid SCRAM exchange state");
+ result = SASL_EXCHANGE_FAILURE;
+ }
+
+ if (result == SASL_EXCHANGE_FAILURE && state->logdetail && logdetail)
+ *logdetail = state->logdetail;
+
+ if (*output)
+ *outputlen = strlen(*output);
+
+ return result;
+}
+
+/*
+ * Construct a SCRAM secret, for storing in pg_authid.rolpassword.
+ *
+ * The result is palloc'd, so caller is responsible for freeing it.
+ */
+char *
+pg_be_scram_build_secret(const char *password)
+{
+ char *prep_password;
+ pg_saslprep_rc rc;
+ char saltbuf[SCRAM_DEFAULT_SALT_LEN];
+ char *result;
+
+ /*
+ * Normalize the password with SASLprep. If that doesn't work, because
+ * the password isn't valid UTF-8 or contains prohibited characters, just
+ * proceed with the original password. (See comments at top of file.)
+ */
+ rc = pg_saslprep(password, &prep_password);
+ if (rc == SASLPREP_SUCCESS)
+ password = (const char *) prep_password;
+
+ /* Generate random salt */
+ if (!pg_strong_random(saltbuf, SCRAM_DEFAULT_SALT_LEN))
+ ereport(ERROR,
+ (errcode(ERRCODE_INTERNAL_ERROR),
+ errmsg("could not generate random salt")));
+
+ result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN,
+ SCRAM_DEFAULT_ITERATIONS, password);
+
+ if (prep_password)
+ pfree(prep_password);
+
+ return result;
+}
+
+/*
+ * Verify a plaintext password against a SCRAM secret. This is used when
+ * performing plaintext password authentication for a user that has a SCRAM
+ * secret stored in pg_authid.
+ */
+bool
+scram_verify_plain_password(const char *username, const char *password,
+ const char *secret)
+{
+ char *encoded_salt;
+ char *salt;
+ int saltlen;
+ int iterations;
+ uint8 salted_password[SCRAM_KEY_LEN];
+ uint8 stored_key[SCRAM_KEY_LEN];
+ uint8 server_key[SCRAM_KEY_LEN];
+ uint8 computed_key[SCRAM_KEY_LEN];
+ char *prep_password;
+ pg_saslprep_rc rc;
+
+ if (!parse_scram_secret(secret, &iterations, &encoded_salt,
+ stored_key, server_key))
+ {
+ /*
+ * The password looked like a SCRAM secret, but could not be parsed.
+ */
+ ereport(LOG,
+ (errmsg("invalid SCRAM secret for user \"%s\"", username)));
+ return false;
+ }
+
+ saltlen = pg_b64_dec_len(strlen(encoded_salt));
+ salt = palloc(saltlen);
+ saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt,
+ saltlen);
+ if (saltlen < 0)
+ {
+ ereport(LOG,
+ (errmsg("invalid SCRAM secret for user \"%s\"", username)));
+ return false;
+ }
+
+ /* Normalize the password */
+ rc = pg_saslprep(password, &prep_password);
+ if (rc == SASLPREP_SUCCESS)
+ password = prep_password;
+
+ /* Compute Server Key based on the user-supplied plaintext password */
+ if (scram_SaltedPassword(password, salt, saltlen, iterations,
+ salted_password) < 0 ||
+ scram_ServerKey(salted_password, computed_key) < 0)
+ {
+ elog(ERROR, "could not compute server key");
+ }
+
+ if (prep_password)
+ pfree(prep_password);
+
+ /*
+ * Compare the secret's Server Key with the one computed from the
+ * user-supplied password.
+ */
+ return memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0;
+}
+
+
+/*
+ * Parse and validate format of given SCRAM secret.
+ *
+ * On success, the iteration count, salt, stored key, and server key are
+ * extracted from the secret, and returned to the caller. For 'stored_key'
+ * and 'server_key', the caller must pass pre-allocated buffers of size
+ * SCRAM_KEY_LEN. Salt is returned as a base64-encoded, null-terminated
+ * string. The buffer for the salt is palloc'd by this function.
+ *
+ * Returns true if the SCRAM secret has been parsed, and false otherwise.
+ */
+bool
+parse_scram_secret(const char *secret, int *iterations, char **salt,
+ uint8 *stored_key, uint8 *server_key)
+{
+ char *v;
+ char *p;
+ char *scheme_str;
+ char *salt_str;
+ char *iterations_str;
+ char *storedkey_str;
+ char *serverkey_str;
+ int decoded_len;
+ char *decoded_salt_buf;
+ char *decoded_stored_buf;
+ char *decoded_server_buf;
+
+ /*
+ * The secret is of form:
+ *
+ * SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey>
+ */
+ v = pstrdup(secret);
+ if ((scheme_str = strtok(v, "$")) == NULL)
+ goto invalid_secret;
+ if ((iterations_str = strtok(NULL, ":")) == NULL)
+ goto invalid_secret;
+ if ((salt_str = strtok(NULL, "$")) == NULL)
+ goto invalid_secret;
+ if ((storedkey_str = strtok(NULL, ":")) == NULL)
+ goto invalid_secret;
+ if ((serverkey_str = strtok(NULL, "")) == NULL)
+ goto invalid_secret;
+
+ /* Parse the fields */
+ if (strcmp(scheme_str, "SCRAM-SHA-256") != 0)
+ goto invalid_secret;
+
+ errno = 0;
+ *iterations = strtol(iterations_str, &p, 10);
+ if (*p || errno != 0)
+ goto invalid_secret;
+
+ /*
+ * Verify that the salt is in Base64-encoded format, by decoding it,
+ * although we return the encoded version to the caller.
+ */
+ decoded_len = pg_b64_dec_len(strlen(salt_str));
+ decoded_salt_buf = palloc(decoded_len);
+ decoded_len = pg_b64_decode(salt_str, strlen(salt_str),
+ decoded_salt_buf, decoded_len);
+ if (decoded_len < 0)
+ goto invalid_secret;
+ *salt = pstrdup(salt_str);
+
+ /*
+ * Decode StoredKey and ServerKey.
+ */
+ decoded_len = pg_b64_dec_len(strlen(storedkey_str));
+ decoded_stored_buf = palloc(decoded_len);
+ decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str),
+ decoded_stored_buf, decoded_len);
+ if (decoded_len != SCRAM_KEY_LEN)
+ goto invalid_secret;
+ memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN);
+
+ decoded_len = pg_b64_dec_len(strlen(serverkey_str));
+ decoded_server_buf = palloc(decoded_len);
+ decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str),
+ decoded_server_buf, decoded_len);
+ if (decoded_len != SCRAM_KEY_LEN)
+ goto invalid_secret;
+ memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN);
+
+ return true;
+
+invalid_secret:
+ *salt = NULL;
+ return false;
+}
+
+/*
+ * Generate plausible SCRAM secret parameters for mock authentication.
+ *
+ * In a normal authentication, these are extracted from the secret
+ * stored in the server. This function generates values that look
+ * realistic, for when there is no stored secret.
+ *
+ * Like in parse_scram_secret(), for 'stored_key' and 'server_key', the
+ * caller must pass pre-allocated buffers of size SCRAM_KEY_LEN, and
+ * the buffer for the salt is palloc'd by this function.
+ */
+static void
+mock_scram_secret(const char *username, int *iterations, char **salt,
+ uint8 *stored_key, uint8 *server_key)
+{
+ char *raw_salt;
+ char *encoded_salt;
+ int encoded_len;
+
+ /*
+ * Generate deterministic salt.
+ *
+ * Note that we cannot reveal any information to an attacker here so the
+ * error messages need to remain generic. This should never fail anyway
+ * as the salt generated for mock authentication uses the cluster's nonce
+ * value.
+ */
+ raw_salt = scram_mock_salt(username);
+ if (raw_salt == NULL)
+ elog(ERROR, "could not encode salt");
+
+ encoded_len = pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN);
+ /* don't forget the zero-terminator */
+ encoded_salt = (char *) palloc(encoded_len + 1);
+ encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt,
+ encoded_len);
+
+ if (encoded_len < 0)
+ elog(ERROR, "could not encode salt");
+ encoded_salt[encoded_len] = '\0';
+
+ *salt = encoded_salt;
+ *iterations = SCRAM_DEFAULT_ITERATIONS;
+
+ /* StoredKey and ServerKey are not used in a doomed authentication */
+ memset(stored_key, 0, SCRAM_KEY_LEN);
+ memset(server_key, 0, SCRAM_KEY_LEN);
+}
+
+/*
+ * Read the value in a given SCRAM exchange message for given attribute.
+ */
+static char *
+read_attr_value(char **input, char attr)
+{
+ char *begin = *input;
+ char *end;
+
+ if (*begin != attr)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("Expected attribute \"%c\" but found \"%s\".",
+ attr, sanitize_char(*begin))));
+ begin++;
+
+ if (*begin != '=')
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("Expected character \"=\" for attribute \"%c\".", attr)));
+ begin++;
+
+ end = begin;
+ while (*end && *end != ',')
+ end++;
+
+ if (*end)
+ {
+ *end = '\0';
+ *input = end + 1;
+ }
+ else
+ *input = end;
+
+ return begin;
+}
+
+static bool
+is_scram_printable(char *p)
+{
+ /*------
+ * Printable characters, as defined by SCRAM spec: (RFC 5802)
+ *
+ * printable = %x21-2B / %x2D-7E
+ * ;; Printable ASCII except ",".
+ * ;; Note that any "printable" is also
+ * ;; a valid "value".
+ *------
+ */
+ for (; *p; p++)
+ {
+ if (*p < 0x21 || *p > 0x7E || *p == 0x2C /* comma */ )
+ return false;
+ }
+ return true;
+}
+
+/*
+ * Convert an arbitrary byte to printable form. For error messages.
+ *
+ * If it's a printable ASCII character, print it as a single character.
+ * otherwise, print it in hex.
+ *
+ * The returned pointer points to a static buffer.
+ */
+static char *
+sanitize_char(char c)
+{
+ static char buf[5];
+
+ if (c >= 0x21 && c <= 0x7E)
+ snprintf(buf, sizeof(buf), "'%c'", c);
+ else
+ snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c);
+ return buf;
+}
+
+/*
+ * Convert an arbitrary string to printable form, for error messages.
+ *
+ * Anything that's not a printable ASCII character is replaced with
+ * '?', and the string is truncated at 30 characters.
+ *
+ * The returned pointer points to a static buffer.
+ */
+static char *
+sanitize_str(const char *s)
+{
+ static char buf[30 + 1];
+ int i;
+
+ for (i = 0; i < sizeof(buf) - 1; i++)
+ {
+ char c = s[i];
+
+ if (c == '\0')
+ break;
+
+ if (c >= 0x21 && c <= 0x7E)
+ buf[i] = c;
+ else
+ buf[i] = '?';
+ }
+ buf[i] = '\0';
+ return buf;
+}
+
+/*
+ * Read the next attribute and value in a SCRAM exchange message.
+ *
+ * The attribute character is set in *attr_p, the attribute value is the
+ * return value.
+ */
+static char *
+read_any_attr(char **input, char *attr_p)
+{
+ char *begin = *input;
+ char *end;
+ char attr = *begin;
+
+ if (attr == '\0')
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("Attribute expected, but found end of string.")));
+
+ /*------
+ * attr-val = ALPHA "=" value
+ * ;; Generic syntax of any attribute sent
+ * ;; by server or client
+ *------
+ */
+ if (!((attr >= 'A' && attr <= 'Z') ||
+ (attr >= 'a' && attr <= 'z')))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("Attribute expected, but found invalid character \"%s\".",
+ sanitize_char(attr))));
+ if (attr_p)
+ *attr_p = attr;
+ begin++;
+
+ if (*begin != '=')
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("Expected character \"=\" for attribute \"%c\".", attr)));
+ begin++;
+
+ end = begin;
+ while (*end && *end != ',')
+ end++;
+
+ if (*end)
+ {
+ *end = '\0';
+ *input = end + 1;
+ }
+ else
+ *input = end;
+
+ return begin;
+}
+
+/*
+ * Read and parse the first message from client in the context of a SCRAM
+ * authentication exchange message.
+ *
+ * At this stage, any errors will be reported directly with ereport(ERROR).
+ */
+static void
+read_client_first_message(scram_state *state, const char *input)
+{
+ char *p = pstrdup(input);
+ char *channel_binding_type;
+
+
+ /*------
+ * The syntax for the client-first-message is: (RFC 5802)
+ *
+ * saslname = 1*(value-safe-char / "=2C" / "=3D")
+ * ;; Conforms to <value>.
+ *
+ * authzid = "a=" saslname
+ * ;; Protocol specific.
+ *
+ * cb-name = 1*(ALPHA / DIGIT / "." / "-")
+ * ;; See RFC 5056, Section 7.
+ * ;; E.g., "tls-server-end-point" or
+ * ;; "tls-unique".
+ *
+ * gs2-cbind-flag = ("p=" cb-name) / "n" / "y"
+ * ;; "n" -> client doesn't support channel binding.
+ * ;; "y" -> client does support channel binding
+ * ;; but thinks the server does not.
+ * ;; "p" -> client requires channel binding.
+ * ;; The selected channel binding follows "p=".
+ *
+ * gs2-header = gs2-cbind-flag "," [ authzid ] ","
+ * ;; GS2 header for SCRAM
+ * ;; (the actual GS2 header includes an optional
+ * ;; flag to indicate that the GSS mechanism is not
+ * ;; "standard", but since SCRAM is "standard", we
+ * ;; don't include that flag).
+ *
+ * username = "n=" saslname
+ * ;; Usernames are prepared using SASLprep.
+ *
+ * reserved-mext = "m=" 1*(value-char)
+ * ;; Reserved for signaling mandatory extensions.
+ * ;; The exact syntax will be defined in
+ * ;; the future.
+ *
+ * nonce = "r=" c-nonce [s-nonce]
+ * ;; Second part provided by server.
+ *
+ * c-nonce = printable
+ *
+ * client-first-message-bare =
+ * [reserved-mext ","]
+ * username "," nonce ["," extensions]
+ *
+ * client-first-message =
+ * gs2-header client-first-message-bare
+ *
+ * For example:
+ * n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL
+ *
+ * The "n,," in the beginning means that the client doesn't support
+ * channel binding, and no authzid is given. "n=user" is the username.
+ * However, in PostgreSQL the username is sent in the startup packet, and
+ * the username in the SCRAM exchange is ignored. libpq always sends it
+ * as an empty string. The last part, "r=fyko+d2lbbFgONRv9qkxdawL" is
+ * the client nonce.
+ *------
+ */
+
+ /*
+ * Read gs2-cbind-flag. (For details see also RFC 5802 Section 6 "Channel
+ * Binding".)
+ */
+ state->cbind_flag = *p;
+ switch (*p)
+ {
+ case 'n':
+
+ /*
+ * The client does not support channel binding or has simply
+ * decided to not use it. In that case just let it go.
+ */
+ if (state->channel_binding_in_use)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
+
+ p++;
+ if (*p != ',')
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("Comma expected, but found character \"%s\".",
+ sanitize_char(*p))));
+ p++;
+ break;
+ case 'y':
+
+ /*
+ * The client supports channel binding and thinks that the server
+ * does not. In this case, the server must fail authentication if
+ * it supports channel binding.
+ */
+ if (state->channel_binding_in_use)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data.")));
+
+#ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
+ if (state->port->ssl_in_use)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
+ errmsg("SCRAM channel binding negotiation error"),
+ errdetail("The client supports SCRAM channel binding but thinks the server does not. "
+ "However, this server does support channel binding.")));
+#endif
+ p++;
+ if (*p != ',')
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("Comma expected, but found character \"%s\".",
+ sanitize_char(*p))));
+ p++;
+ break;
+ case 'p':
+
+ /*
+ * The client requires channel binding. Channel binding type
+ * follows, e.g., "p=tls-server-end-point".
+ */
+ if (!state->channel_binding_in_use)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("The client selected SCRAM-SHA-256 without channel binding, but the SCRAM message includes channel binding data.")));
+
+ channel_binding_type = read_attr_value(&p, 'p');
+
+ /*
+ * The only channel binding type we support is
+ * tls-server-end-point.
+ */
+ if (strcmp(channel_binding_type, "tls-server-end-point") != 0)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("unsupported SCRAM channel-binding type \"%s\"",
+ sanitize_str(channel_binding_type))));
+ break;
+ default:
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("Unexpected channel-binding flag \"%s\".",
+ sanitize_char(*p))));
+ }
+
+ /*
+ * Forbid optional authzid (authorization identity). We don't support it.
+ */
+ if (*p == 'a')
+ ereport(ERROR,
+ (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+ errmsg("client uses authorization identity, but it is not supported")));
+ if (*p != ',')
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("Unexpected attribute \"%s\" in client-first-message.",
+ sanitize_char(*p))));
+ p++;
+
+ state->client_first_message_bare = pstrdup(p);
+
+ /*
+ * Any mandatory extensions would go here. We don't support any.
+ *
+ * RFC 5802 specifies error code "e=extensions-not-supported" for this,
+ * but it can only be sent in the server-final message. We prefer to fail
+ * immediately (which the RFC also allows).
+ */
+ if (*p == 'm')
+ ereport(ERROR,
+ (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+ errmsg("client requires an unsupported SCRAM extension")));
+
+ /*
+ * Read username. Note: this is ignored. We use the username from the
+ * startup message instead, still it is kept around if provided as it
+ * proves to be useful for debugging purposes.
+ */
+ state->client_username = read_attr_value(&p, 'n');
+
+ /* read nonce and check that it is made of only printable characters */
+ state->client_nonce = read_attr_value(&p, 'r');
+ if (!is_scram_printable(state->client_nonce))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("non-printable characters in SCRAM nonce")));
+
+ /*
+ * There can be any number of optional extensions after this. We don't
+ * support any extensions, so ignore them.
+ */
+ while (*p != '\0')
+ read_any_attr(&p, NULL);
+
+ /* success! */
+}
+
+/*
+ * Verify the final nonce contained in the last message received from
+ * client in an exchange.
+ */
+static bool
+verify_final_nonce(scram_state *state)
+{
+ int client_nonce_len = strlen(state->client_nonce);
+ int server_nonce_len = strlen(state->server_nonce);
+ int final_nonce_len = strlen(state->client_final_nonce);
+
+ if (final_nonce_len != client_nonce_len + server_nonce_len)
+ return false;
+ if (memcmp(state->client_final_nonce, state->client_nonce, client_nonce_len) != 0)
+ return false;
+ if (memcmp(state->client_final_nonce + client_nonce_len, state->server_nonce, server_nonce_len) != 0)
+ return false;
+
+ return true;
+}
+
+/*
+ * Verify the client proof contained in the last message received from
+ * client in an exchange. Returns true if the verification is a success,
+ * or false for a failure.
+ */
+static bool
+verify_client_proof(scram_state *state)
+{
+ uint8 ClientSignature[SCRAM_KEY_LEN];
+ uint8 ClientKey[SCRAM_KEY_LEN];
+ uint8 client_StoredKey[SCRAM_KEY_LEN];
+ pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+ int i;
+
+ /*
+ * Calculate ClientSignature. Note that we don't log directly a failure
+ * here even when processing the calculations as this could involve a mock
+ * authentication.
+ */
+ if (pg_hmac_init(ctx, state->StoredKey, SCRAM_KEY_LEN) < 0 ||
+ pg_hmac_update(ctx,
+ (uint8 *) state->client_first_message_bare,
+ strlen(state->client_first_message_bare)) < 0 ||
+ pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
+ pg_hmac_update(ctx,
+ (uint8 *) state->server_first_message,
+ strlen(state->server_first_message)) < 0 ||
+ pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
+ pg_hmac_update(ctx,
+ (uint8 *) state->client_final_message_without_proof,
+ strlen(state->client_final_message_without_proof)) < 0 ||
+ pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0)
+ {
+ elog(ERROR, "could not calculate client signature");
+ }
+
+ pg_hmac_free(ctx);
+
+ /* Extract the ClientKey that the client calculated from the proof */
+ for (i = 0; i < SCRAM_KEY_LEN; i++)
+ ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i];
+
+ /* Hash it one more time, and compare with StoredKey */
+ if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey) < 0)
+ elog(ERROR, "could not hash stored key");
+
+ if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0)
+ return false;
+
+ return true;
+}
+
+/*
+ * Build the first server-side message sent to the client in a SCRAM
+ * communication exchange.
+ */
+static char *
+build_server_first_message(scram_state *state)
+{
+ /*------
+ * The syntax for the server-first-message is: (RFC 5802)
+ *
+ * server-first-message =
+ * [reserved-mext ","] nonce "," salt ","
+ * iteration-count ["," extensions]
+ *
+ * nonce = "r=" c-nonce [s-nonce]
+ * ;; Second part provided by server.
+ *
+ * c-nonce = printable
+ *
+ * s-nonce = printable
+ *
+ * salt = "s=" base64
+ *
+ * iteration-count = "i=" posit-number
+ * ;; A positive number.
+ *
+ * Example:
+ *
+ * r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096
+ *------
+ */
+
+ /*
+ * Per the spec, the nonce may consist of any printable ASCII characters.
+ * For convenience, however, we don't use the whole range available,
+ * rather, we generate some random bytes, and base64 encode them.
+ */
+ char raw_nonce[SCRAM_RAW_NONCE_LEN];
+ int encoded_len;
+
+ if (!pg_strong_random(raw_nonce, SCRAM_RAW_NONCE_LEN))
+ ereport(ERROR,
+ (errcode(ERRCODE_INTERNAL_ERROR),
+ errmsg("could not generate random nonce")));
+
+ encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN);
+ /* don't forget the zero-terminator */
+ state->server_nonce = palloc(encoded_len + 1);
+ encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN,
+ state->server_nonce, encoded_len);
+ if (encoded_len < 0)
+ ereport(ERROR,
+ (errcode(ERRCODE_INTERNAL_ERROR),
+ errmsg("could not encode random nonce")));
+ state->server_nonce[encoded_len] = '\0';
+
+ state->server_first_message =
+ psprintf("r=%s%s,s=%s,i=%u",
+ state->client_nonce, state->server_nonce,
+ state->salt, state->iterations);
+
+ return pstrdup(state->server_first_message);
+}
+
+
+/*
+ * Read and parse the final message received from client.
+ */
+static void
+read_client_final_message(scram_state *state, const char *input)
+{
+ char attr;
+ char *channel_binding;
+ char *value;
+ char *begin,
+ *proof;
+ char *p;
+ char *client_proof;
+ int client_proof_len;
+
+ begin = p = pstrdup(input);
+
+ /*------
+ * The syntax for the server-first-message is: (RFC 5802)
+ *
+ * gs2-header = gs2-cbind-flag "," [ authzid ] ","
+ * ;; GS2 header for SCRAM
+ * ;; (the actual GS2 header includes an optional
+ * ;; flag to indicate that the GSS mechanism is not
+ * ;; "standard", but since SCRAM is "standard", we
+ * ;; don't include that flag).
+ *
+ * cbind-input = gs2-header [ cbind-data ]
+ * ;; cbind-data MUST be present for
+ * ;; gs2-cbind-flag of "p" and MUST be absent
+ * ;; for "y" or "n".
+ *
+ * channel-binding = "c=" base64
+ * ;; base64 encoding of cbind-input.
+ *
+ * proof = "p=" base64
+ *
+ * client-final-message-without-proof =
+ * channel-binding "," nonce [","
+ * extensions]
+ *
+ * client-final-message =
+ * client-final-message-without-proof "," proof
+ *------
+ */
+
+ /*
+ * Read channel binding. This repeats the channel-binding flags and is
+ * then followed by the actual binding data depending on the type.
+ */
+ channel_binding = read_attr_value(&p, 'c');
+ if (state->channel_binding_in_use)
+ {
+#ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH
+ const char *cbind_data = NULL;
+ size_t cbind_data_len = 0;
+ size_t cbind_header_len;
+ char *cbind_input;
+ size_t cbind_input_len;
+ char *b64_message;
+ int b64_message_len;
+
+ Assert(state->cbind_flag == 'p');
+
+ /* Fetch hash data of server's SSL certificate */
+ cbind_data = be_tls_get_certificate_hash(state->port,
+ &cbind_data_len);
+
+ /* should not happen */
+ if (cbind_data == NULL || cbind_data_len == 0)
+ elog(ERROR, "could not get server certificate hash");
+
+ cbind_header_len = strlen("p=tls-server-end-point,,"); /* p=type,, */
+ cbind_input_len = cbind_header_len + cbind_data_len;
+ cbind_input = palloc(cbind_input_len);
+ snprintf(cbind_input, cbind_input_len, "p=tls-server-end-point,,");
+ memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len);
+
+ b64_message_len = pg_b64_enc_len(cbind_input_len);
+ /* don't forget the zero-terminator */
+ b64_message = palloc(b64_message_len + 1);
+ b64_message_len = pg_b64_encode(cbind_input, cbind_input_len,
+ b64_message, b64_message_len);
+ if (b64_message_len < 0)
+ elog(ERROR, "could not encode channel binding data");
+ b64_message[b64_message_len] = '\0';
+
+ /*
+ * Compare the value sent by the client with the value expected by the
+ * server.
+ */
+ if (strcmp(channel_binding, b64_message) != 0)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
+ errmsg("SCRAM channel binding check failed")));
+#else
+ /* shouldn't happen, because we checked this earlier already */
+ elog(ERROR, "channel binding not supported by this build");
+#endif
+ }
+ else
+ {
+ /*
+ * If we are not using channel binding, the binding data is expected
+ * to always be "biws", which is "n,," base64-encoded, or "eSws",
+ * which is "y,,". We also have to check whether the flag is the same
+ * one that the client originally sent.
+ */
+ if (!(strcmp(channel_binding, "biws") == 0 && state->cbind_flag == 'n') &&
+ !(strcmp(channel_binding, "eSws") == 0 && state->cbind_flag == 'y'))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("unexpected SCRAM channel-binding attribute in client-final-message")));
+ }
+
+ state->client_final_nonce = read_attr_value(&p, 'r');
+
+ /* ignore optional extensions, read until we find "p" attribute */
+ do
+ {
+ proof = p - 1;
+ value = read_any_attr(&p, &attr);
+ } while (attr != 'p');
+
+ client_proof_len = pg_b64_dec_len(strlen(value));
+ client_proof = palloc(client_proof_len);
+ if (pg_b64_decode(value, strlen(value), client_proof,
+ client_proof_len) != SCRAM_KEY_LEN)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("Malformed proof in client-final-message.")));
+ memcpy(state->ClientProof, client_proof, SCRAM_KEY_LEN);
+ pfree(client_proof);
+
+ if (*p != '\0')
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("malformed SCRAM message"),
+ errdetail("Garbage found at the end of client-final-message.")));
+
+ state->client_final_message_without_proof = palloc(proof - begin + 1);
+ memcpy(state->client_final_message_without_proof, input, proof - begin);
+ state->client_final_message_without_proof[proof - begin] = '\0';
+}
+
+/*
+ * Build the final server-side message of an exchange.
+ */
+static char *
+build_server_final_message(scram_state *state)
+{
+ uint8 ServerSignature[SCRAM_KEY_LEN];
+ char *server_signature_base64;
+ int siglen;
+ pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256);
+
+ /* calculate ServerSignature */
+ if (pg_hmac_init(ctx, state->ServerKey, SCRAM_KEY_LEN) < 0 ||
+ pg_hmac_update(ctx,
+ (uint8 *) state->client_first_message_bare,
+ strlen(state->client_first_message_bare)) < 0 ||
+ pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
+ pg_hmac_update(ctx,
+ (uint8 *) state->server_first_message,
+ strlen(state->server_first_message)) < 0 ||
+ pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 ||
+ pg_hmac_update(ctx,
+ (uint8 *) state->client_final_message_without_proof,
+ strlen(state->client_final_message_without_proof)) < 0 ||
+ pg_hmac_final(ctx, ServerSignature, sizeof(ServerSignature)) < 0)
+ {
+ elog(ERROR, "could not calculate server signature");
+ }
+
+ pg_hmac_free(ctx);
+
+ siglen = pg_b64_enc_len(SCRAM_KEY_LEN);
+ /* don't forget the zero-terminator */
+ server_signature_base64 = palloc(siglen + 1);
+ siglen = pg_b64_encode((const char *) ServerSignature,
+ SCRAM_KEY_LEN, server_signature_base64,
+ siglen);
+ if (siglen < 0)
+ elog(ERROR, "could not encode server signature");
+ server_signature_base64[siglen] = '\0';
+
+ /*------
+ * The syntax for the server-final-message is: (RFC 5802)
+ *
+ * verifier = "v=" base64
+ * ;; base-64 encoded ServerSignature.
+ *
+ * server-final-message = (server-error / verifier)
+ * ["," extensions]
+ *
+ *------
+ */
+ return psprintf("v=%s", server_signature_base64);
+}
+
+
+/*
+ * Deterministically generate salt for mock authentication, using a SHA256
+ * hash based on the username and a cluster-level secret key. Returns a
+ * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL.
+ */
+static char *
+scram_mock_salt(const char *username)
+{
+ pg_cryptohash_ctx *ctx;
+ static uint8 sha_digest[PG_SHA256_DIGEST_LENGTH];
+ char *mock_auth_nonce = GetMockAuthenticationNonce();
+
+ /*
+ * Generate salt using a SHA256 hash of the username and the cluster's
+ * mock authentication nonce. (This works as long as the salt length is
+ * not larger than the SHA256 digest length. If the salt is smaller, the
+ * caller will just ignore the extra data.)
+ */
+ StaticAssertStmt(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN,
+ "salt length greater than SHA256 digest length");
+
+ ctx = pg_cryptohash_create(PG_SHA256);
+ if (pg_cryptohash_init(ctx) < 0 ||
+ pg_cryptohash_update(ctx, (uint8 *) username, strlen(username)) < 0 ||
+ pg_cryptohash_update(ctx, (uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN) < 0 ||
+ pg_cryptohash_final(ctx, sha_digest, sizeof(sha_digest)) < 0)
+ {
+ pg_cryptohash_free(ctx);
+ return NULL;
+ }
+ pg_cryptohash_free(ctx);
+
+ return (char *) sha_digest;
+}
diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c
new file mode 100644
index 0000000..be14f2f
--- /dev/null
+++ b/src/backend/libpq/auth.c
@@ -0,0 +1,3492 @@
+/*-------------------------------------------------------------------------
+ *
+ * auth.c
+ * Routines to handle network authentication
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ *
+ * IDENTIFICATION
+ * src/backend/libpq/auth.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+
+#include <sys/param.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <unistd.h>
+#ifdef HAVE_SYS_SELECT_H
+#include <sys/select.h>
+#endif
+
+#include "commands/user.h"
+#include "common/ip.h"
+#include "common/md5.h"
+#include "common/scram-common.h"
+#include "libpq/auth.h"
+#include "libpq/crypt.h"
+#include "libpq/libpq.h"
+#include "libpq/pqformat.h"
+#include "libpq/scram.h"
+#include "miscadmin.h"
+#include "port/pg_bswap.h"
+#include "postmaster/postmaster.h"
+#include "replication/walsender.h"
+#include "storage/ipc.h"
+#include "utils/guc.h"
+#include "utils/memutils.h"
+#include "utils/timestamp.h"
+
+/*----------------------------------------------------------------
+ * Global authentication functions
+ *----------------------------------------------------------------
+ */
+static void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata,
+ int extralen);
+static void auth_failed(Port *port, int status, char *logdetail);
+static char *recv_password_packet(Port *port);
+static void set_authn_id(Port *port, const char *id);
+
+
+/*----------------------------------------------------------------
+ * Password-based authentication methods (password, md5, and scram-sha-256)
+ *----------------------------------------------------------------
+ */
+static int CheckPasswordAuth(Port *port, char **logdetail);
+static int CheckPWChallengeAuth(Port *port, char **logdetail);
+
+static int CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail);
+static int CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail);
+
+
+/*----------------------------------------------------------------
+ * Ident authentication
+ *----------------------------------------------------------------
+ */
+/* Max size of username ident server can return (per RFC 1413) */
+#define IDENT_USERNAME_MAX 512
+
+/* Standard TCP port number for Ident service. Assigned by IANA */
+#define IDENT_PORT 113
+
+static int ident_inet(hbaPort *port);
+
+
+/*----------------------------------------------------------------
+ * Peer authentication
+ *----------------------------------------------------------------
+ */
+static int auth_peer(hbaPort *port);
+
+
+/*----------------------------------------------------------------
+ * PAM authentication
+ *----------------------------------------------------------------
+ */
+#ifdef USE_PAM
+#ifdef HAVE_PAM_PAM_APPL_H
+#include <pam/pam_appl.h>
+#endif
+#ifdef HAVE_SECURITY_PAM_APPL_H
+#include <security/pam_appl.h>
+#endif
+
+#define PGSQL_PAM_SERVICE "postgresql" /* Service name passed to PAM */
+
+static int CheckPAMAuth(Port *port, const char *user, const char *password);
+static int pam_passwd_conv_proc(int num_msg, const struct pam_message **msg,
+ struct pam_response **resp, void *appdata_ptr);
+
+static struct pam_conv pam_passw_conv = {
+ &pam_passwd_conv_proc,
+ NULL
+};
+
+static const char *pam_passwd = NULL; /* Workaround for Solaris 2.6
+ * brokenness */
+static Port *pam_port_cludge; /* Workaround for passing "Port *port" into
+ * pam_passwd_conv_proc */
+static bool pam_no_password; /* For detecting no-password-given */
+#endif /* USE_PAM */
+
+
+/*----------------------------------------------------------------
+ * BSD authentication
+ *----------------------------------------------------------------
+ */
+#ifdef USE_BSD_AUTH
+#include <bsd_auth.h>
+
+static int CheckBSDAuth(Port *port, char *user);
+#endif /* USE_BSD_AUTH */
+
+
+/*----------------------------------------------------------------
+ * LDAP authentication
+ *----------------------------------------------------------------
+ */
+#ifdef USE_LDAP
+#ifndef WIN32
+/* We use a deprecated function to keep the codepath the same as win32. */
+#define LDAP_DEPRECATED 1
+#include <ldap.h>
+#else
+#include <winldap.h>
+
+/* Correct header from the Platform SDK */
+typedef
+ULONG (*__ldap_start_tls_sA) (IN PLDAP ExternalHandle,
+ OUT PULONG ServerReturnValue,
+ OUT LDAPMessage **result,
+ IN PLDAPControlA * ServerControls,
+ IN PLDAPControlA * ClientControls
+);
+#endif
+
+static int CheckLDAPAuth(Port *port);
+
+/* LDAP_OPT_DIAGNOSTIC_MESSAGE is the newer spelling */
+#ifndef LDAP_OPT_DIAGNOSTIC_MESSAGE
+#define LDAP_OPT_DIAGNOSTIC_MESSAGE LDAP_OPT_ERROR_STRING
+#endif
+
+#endif /* USE_LDAP */
+
+/*----------------------------------------------------------------
+ * Cert authentication
+ *----------------------------------------------------------------
+ */
+#ifdef USE_SSL
+static int CheckCertAuth(Port *port);
+#endif
+
+
+/*----------------------------------------------------------------
+ * Kerberos and GSSAPI GUCs
+ *----------------------------------------------------------------
+ */
+char *pg_krb_server_keyfile;
+bool pg_krb_caseins_users;
+
+
+/*----------------------------------------------------------------
+ * GSSAPI Authentication
+ *----------------------------------------------------------------
+ */
+#ifdef ENABLE_GSS
+#include "libpq/be-gssapi-common.h"
+
+static int pg_GSS_checkauth(Port *port);
+static int pg_GSS_recvauth(Port *port);
+#endif /* ENABLE_GSS */
+
+
+/*----------------------------------------------------------------
+ * SSPI Authentication
+ *----------------------------------------------------------------
+ */
+#ifdef ENABLE_SSPI
+typedef SECURITY_STATUS
+ (WINAPI * QUERY_SECURITY_CONTEXT_TOKEN_FN) (PCtxtHandle, void **);
+static int pg_SSPI_recvauth(Port *port);
+static int pg_SSPI_make_upn(char *accountname,
+ size_t accountnamesize,
+ char *domainname,
+ size_t domainnamesize,
+ bool update_accountname);
+#endif
+
+/*----------------------------------------------------------------
+ * RADIUS Authentication
+ *----------------------------------------------------------------
+ */
+static int CheckRADIUSAuth(Port *port);
+static int PerformRadiusTransaction(const char *server, const char *secret, const char *portstr, const char *identifier, const char *user_name, const char *passwd);
+
+
+/*
+ * Maximum accepted size of GSS and SSPI authentication tokens.
+ * We also use this as a limit on ordinary password packet lengths.
+ *
+ * Kerberos tickets are usually quite small, but the TGTs issued by Windows
+ * domain controllers include an authorization field known as the Privilege
+ * Attribute Certificate (PAC), which contains the user's Windows permissions
+ * (group memberships etc.). The PAC is copied into all tickets obtained on
+ * the basis of this TGT (even those issued by Unix realms which the Windows
+ * realm trusts), and can be several kB in size. The maximum token size
+ * accepted by Windows systems is determined by the MaxAuthToken Windows
+ * registry setting. Microsoft recommends that it is not set higher than
+ * 65535 bytes, so that seems like a reasonable limit for us as well.
+ */
+#define PG_MAX_AUTH_TOKEN_LENGTH 65535
+
+/*
+ * Maximum accepted size of SASL messages.
+ *
+ * The messages that the server or libpq generate are much smaller than this,
+ * but have some headroom.
+ */
+#define PG_MAX_SASL_MESSAGE_LENGTH 1024
+
+/*----------------------------------------------------------------
+ * Global authentication functions
+ *----------------------------------------------------------------
+ */
+
+/*
+ * This hook allows plugins to get control following client authentication,
+ * but before the user has been informed about the results. It could be used
+ * to record login events, insert a delay after failed authentication, etc.
+ */
+ClientAuthentication_hook_type ClientAuthentication_hook = NULL;
+
+/*
+ * Tell the user the authentication failed, but not (much about) why.
+ *
+ * There is a tradeoff here between security concerns and making life
+ * unnecessarily difficult for legitimate users. We would not, for example,
+ * want to report the password we were expecting to receive...
+ * But it seems useful to report the username and authorization method
+ * in use, and these are items that must be presumed known to an attacker
+ * anyway.
+ * Note that many sorts of failure report additional information in the
+ * postmaster log, which we hope is only readable by good guys. In
+ * particular, if logdetail isn't NULL, we send that string to the log.
+ */
+static void
+auth_failed(Port *port, int status, char *logdetail)
+{
+ const char *errstr;
+ char *cdetail;
+ int errcode_return = ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION;
+
+ /*
+ * If we failed due to EOF from client, just quit; there's no point in
+ * trying to send a message to the client, and not much point in logging
+ * the failure in the postmaster log. (Logging the failure might be
+ * desirable, were it not for the fact that libpq closes the connection
+ * unceremoniously if challenged for a password when it hasn't got one to
+ * send. We'll get a useless log entry for every psql connection under
+ * password auth, even if it's perfectly successful, if we log STATUS_EOF
+ * events.)
+ */
+ if (status == STATUS_EOF)
+ proc_exit(0);
+
+ switch (port->hba->auth_method)
+ {
+ case uaReject:
+ case uaImplicitReject:
+ errstr = gettext_noop("authentication failed for user \"%s\": host rejected");
+ break;
+ case uaTrust:
+ errstr = gettext_noop("\"trust\" authentication failed for user \"%s\"");
+ break;
+ case uaIdent:
+ errstr = gettext_noop("Ident authentication failed for user \"%s\"");
+ break;
+ case uaPeer:
+ errstr = gettext_noop("Peer authentication failed for user \"%s\"");
+ break;
+ case uaPassword:
+ case uaMD5:
+ case uaSCRAM:
+ errstr = gettext_noop("password authentication failed for user \"%s\"");
+ /* We use it to indicate if a .pgpass password failed. */
+ errcode_return = ERRCODE_INVALID_PASSWORD;
+ break;
+ case uaGSS:
+ errstr = gettext_noop("GSSAPI authentication failed for user \"%s\"");
+ break;
+ case uaSSPI:
+ errstr = gettext_noop("SSPI authentication failed for user \"%s\"");
+ break;
+ case uaPAM:
+ errstr = gettext_noop("PAM authentication failed for user \"%s\"");
+ break;
+ case uaBSD:
+ errstr = gettext_noop("BSD authentication failed for user \"%s\"");
+ break;
+ case uaLDAP:
+ errstr = gettext_noop("LDAP authentication failed for user \"%s\"");
+ break;
+ case uaCert:
+ errstr = gettext_noop("certificate authentication failed for user \"%s\"");
+ break;
+ case uaRADIUS:
+ errstr = gettext_noop("RADIUS authentication failed for user \"%s\"");
+ break;
+ default:
+ errstr = gettext_noop("authentication failed for user \"%s\": invalid authentication method");
+ break;
+ }
+
+ cdetail = psprintf(_("Connection matched pg_hba.conf line %d: \"%s\""),
+ port->hba->linenumber, port->hba->rawline);
+ if (logdetail)
+ logdetail = psprintf("%s\n%s", logdetail, cdetail);
+ else
+ logdetail = cdetail;
+
+ ereport(FATAL,
+ (errcode(errcode_return),
+ errmsg(errstr, port->user_name),
+ logdetail ? errdetail_log("%s", logdetail) : 0));
+
+ /* doesn't return */
+}
+
+
+/*
+ * Sets the authenticated identity for the current user. The provided string
+ * will be copied into the TopMemoryContext. The ID will be logged if
+ * log_connections is enabled.
+ *
+ * Auth methods should call this routine exactly once, as soon as the user is
+ * successfully authenticated, even if they have reasons to know that
+ * authorization will fail later.
+ *
+ * The provided string will be copied into TopMemoryContext, to match the
+ * lifetime of the Port, so it is safe to pass a string that is managed by an
+ * external library.
+ */
+static void
+set_authn_id(Port *port, const char *id)
+{
+ Assert(id);
+
+ if (port->authn_id)
+ {
+ /*
+ * An existing authn_id should never be overwritten; that means two
+ * authentication providers are fighting (or one is fighting itself).
+ * Don't leak any authn details to the client, but don't let the
+ * connection continue, either.
+ */
+ ereport(FATAL,
+ (errmsg("authentication identifier set more than once"),
+ errdetail_log("previous identifier: \"%s\"; new identifier: \"%s\"",
+ port->authn_id, id)));
+ }
+
+ port->authn_id = MemoryContextStrdup(TopMemoryContext, id);
+
+ if (Log_connections)
+ {
+ ereport(LOG,
+ errmsg("connection authenticated: identity=\"%s\" method=%s "
+ "(%s:%d)",
+ port->authn_id, hba_authname(port->hba->auth_method), HbaFileName,
+ port->hba->linenumber));
+ }
+}
+
+
+/*
+ * Client authentication starts here. If there is an error, this
+ * function does not return and the backend process is terminated.
+ */
+void
+ClientAuthentication(Port *port)
+{
+ int status = STATUS_ERROR;
+ char *logdetail = NULL;
+
+ /*
+ * Get the authentication method to use for this frontend/database
+ * combination. Note: we do not parse the file at this point; this has
+ * already been done elsewhere. hba.c dropped an error message into the
+ * server logfile if parsing the hba config file failed.
+ */
+ hba_getauthmethod(port);
+
+ CHECK_FOR_INTERRUPTS();
+
+ /*
+ * This is the first point where we have access to the hba record for the
+ * current connection, so perform any verifications based on the hba
+ * options field that should be done *before* the authentication here.
+ */
+ if (port->hba->clientcert != clientCertOff)
+ {
+ /* If we haven't loaded a root certificate store, fail */
+ if (!secure_loaded_verify_locations())
+ ereport(FATAL,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("client certificates can only be checked if a root certificate store is available")));
+
+ /*
+ * If we loaded a root certificate store, and if a certificate is
+ * present on the client, then it has been verified against our root
+ * certificate store, and the connection would have been aborted
+ * already if it didn't verify ok.
+ */
+ if (!port->peer_cert_valid)
+ ereport(FATAL,
+ (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
+ errmsg("connection requires a valid client certificate")));
+ }
+
+ /*
+ * Now proceed to do the actual authentication check
+ */
+ switch (port->hba->auth_method)
+ {
+ case uaReject:
+
+ /*
+ * An explicit "reject" entry in pg_hba.conf. This report exposes
+ * the fact that there's an explicit reject entry, which is
+ * perhaps not so desirable from a security standpoint; but the
+ * message for an implicit reject could confuse the DBA a lot when
+ * the true situation is a match to an explicit reject. And we
+ * don't want to change the message for an implicit reject. As
+ * noted below, the additional information shown here doesn't
+ * expose anything not known to an attacker.
+ */
+ {
+ char hostinfo[NI_MAXHOST];
+ const char *encryption_state;
+
+ pg_getnameinfo_all(&port->raddr.addr, port->raddr.salen,
+ hostinfo, sizeof(hostinfo),
+ NULL, 0,
+ NI_NUMERICHOST);
+
+ encryption_state =
+#ifdef ENABLE_GSS
+ (port->gss && port->gss->enc) ? _("GSS encryption") :
+#endif
+#ifdef USE_SSL
+ port->ssl_in_use ? _("SSL encryption") :
+#endif
+ _("no encryption");
+
+ if (am_walsender && !am_db_walsender)
+ ereport(FATAL,
+ (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
+ /* translator: last %s describes encryption state */
+ errmsg("pg_hba.conf rejects replication connection for host \"%s\", user \"%s\", %s",
+ hostinfo, port->user_name,
+ encryption_state)));
+ else
+ ereport(FATAL,
+ (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
+ /* translator: last %s describes encryption state */
+ errmsg("pg_hba.conf rejects connection for host \"%s\", user \"%s\", database \"%s\", %s",
+ hostinfo, port->user_name,
+ port->database_name,
+ encryption_state)));
+ break;
+ }
+
+ case uaImplicitReject:
+
+ /*
+ * No matching entry, so tell the user we fell through.
+ *
+ * NOTE: the extra info reported here is not a security breach,
+ * because all that info is known at the frontend and must be
+ * assumed known to bad guys. We're merely helping out the less
+ * clueful good guys.
+ */
+ {
+ char hostinfo[NI_MAXHOST];
+ const char *encryption_state;
+
+ pg_getnameinfo_all(&port->raddr.addr, port->raddr.salen,
+ hostinfo, sizeof(hostinfo),
+ NULL, 0,
+ NI_NUMERICHOST);
+
+ encryption_state =
+#ifdef ENABLE_GSS
+ (port->gss && port->gss->enc) ? _("GSS encryption") :
+#endif
+#ifdef USE_SSL
+ port->ssl_in_use ? _("SSL encryption") :
+#endif
+ _("no encryption");
+
+#define HOSTNAME_LOOKUP_DETAIL(port) \
+ (port->remote_hostname ? \
+ (port->remote_hostname_resolv == +1 ? \
+ errdetail_log("Client IP address resolved to \"%s\", forward lookup matches.", \
+ port->remote_hostname) : \
+ port->remote_hostname_resolv == 0 ? \
+ errdetail_log("Client IP address resolved to \"%s\", forward lookup not checked.", \
+ port->remote_hostname) : \
+ port->remote_hostname_resolv == -1 ? \
+ errdetail_log("Client IP address resolved to \"%s\", forward lookup does not match.", \
+ port->remote_hostname) : \
+ port->remote_hostname_resolv == -2 ? \
+ errdetail_log("Could not translate client host name \"%s\" to IP address: %s.", \
+ port->remote_hostname, \
+ gai_strerror(port->remote_hostname_errcode)) : \
+ 0) \
+ : (port->remote_hostname_resolv == -2 ? \
+ errdetail_log("Could not resolve client IP address to a host name: %s.", \
+ gai_strerror(port->remote_hostname_errcode)) : \
+ 0))
+
+ if (am_walsender && !am_db_walsender)
+ ereport(FATAL,
+ (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
+ /* translator: last %s describes encryption state */
+ errmsg("no pg_hba.conf entry for replication connection from host \"%s\", user \"%s\", %s",
+ hostinfo, port->user_name,
+ encryption_state),
+ HOSTNAME_LOOKUP_DETAIL(port)));
+ else
+ ereport(FATAL,
+ (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
+ /* translator: last %s describes encryption state */
+ errmsg("no pg_hba.conf entry for host \"%s\", user \"%s\", database \"%s\", %s",
+ hostinfo, port->user_name,
+ port->database_name,
+ encryption_state),
+ HOSTNAME_LOOKUP_DETAIL(port)));
+ break;
+ }
+
+ case uaGSS:
+#ifdef ENABLE_GSS
+ /* We might or might not have the gss workspace already */
+ if (port->gss == NULL)
+ port->gss = (pg_gssinfo *)
+ MemoryContextAllocZero(TopMemoryContext,
+ sizeof(pg_gssinfo));
+ port->gss->auth = true;
+
+ /*
+ * If GSS state was set up while enabling encryption, we can just
+ * check the client's principal. Otherwise, ask for it.
+ */
+ if (port->gss->enc)
+ status = pg_GSS_checkauth(port);
+ else
+ {
+ sendAuthRequest(port, AUTH_REQ_GSS, NULL, 0);
+ status = pg_GSS_recvauth(port);
+ }
+#else
+ Assert(false);
+#endif
+ break;
+
+ case uaSSPI:
+#ifdef ENABLE_SSPI
+ if (port->gss == NULL)
+ port->gss = (pg_gssinfo *)
+ MemoryContextAllocZero(TopMemoryContext,
+ sizeof(pg_gssinfo));
+ sendAuthRequest(port, AUTH_REQ_SSPI, NULL, 0);
+ status = pg_SSPI_recvauth(port);
+#else
+ Assert(false);
+#endif
+ break;
+
+ case uaPeer:
+ status = auth_peer(port);
+ break;
+
+ case uaIdent:
+ status = ident_inet(port);
+ break;
+
+ case uaMD5:
+ case uaSCRAM:
+ status = CheckPWChallengeAuth(port, &logdetail);
+ break;
+
+ case uaPassword:
+ status = CheckPasswordAuth(port, &logdetail);
+ break;
+
+ case uaPAM:
+#ifdef USE_PAM
+ status = CheckPAMAuth(port, port->user_name, "");
+#else
+ Assert(false);
+#endif /* USE_PAM */
+ break;
+
+ case uaBSD:
+#ifdef USE_BSD_AUTH
+ status = CheckBSDAuth(port, port->user_name);
+#else
+ Assert(false);
+#endif /* USE_BSD_AUTH */
+ break;
+
+ case uaLDAP:
+#ifdef USE_LDAP
+ status = CheckLDAPAuth(port);
+#else
+ Assert(false);
+#endif
+ break;
+ case uaRADIUS:
+ status = CheckRADIUSAuth(port);
+ break;
+ case uaCert:
+ /* uaCert will be treated as if clientcert=verify-full (uaTrust) */
+ case uaTrust:
+ status = STATUS_OK;
+ break;
+ }
+
+ if ((status == STATUS_OK && port->hba->clientcert == clientCertFull)
+ || port->hba->auth_method == uaCert)
+ {
+ /*
+ * Make sure we only check the certificate if we use the cert method
+ * or verify-full option.
+ */
+#ifdef USE_SSL
+ status = CheckCertAuth(port);
+#else
+ Assert(false);
+#endif
+ }
+
+ if (ClientAuthentication_hook)
+ (*ClientAuthentication_hook) (port, status);
+
+ if (status == STATUS_OK)
+ sendAuthRequest(port, AUTH_REQ_OK, NULL, 0);
+ else
+ auth_failed(port, status, logdetail);
+}
+
+
+/*
+ * Send an authentication request packet to the frontend.
+ */
+static void
+sendAuthRequest(Port *port, AuthRequest areq, const char *extradata, int extralen)
+{
+ StringInfoData buf;
+
+ CHECK_FOR_INTERRUPTS();
+
+ pq_beginmessage(&buf, 'R');
+ pq_sendint32(&buf, (int32) areq);
+ if (extralen > 0)
+ pq_sendbytes(&buf, extradata, extralen);
+
+ pq_endmessage(&buf);
+
+ /*
+ * Flush message so client will see it, except for AUTH_REQ_OK and
+ * AUTH_REQ_SASL_FIN, which need not be sent until we are ready for
+ * queries.
+ */
+ if (areq != AUTH_REQ_OK && areq != AUTH_REQ_SASL_FIN)
+ pq_flush();
+
+ CHECK_FOR_INTERRUPTS();
+}
+
+/*
+ * Collect password response packet from frontend.
+ *
+ * Returns NULL if couldn't get password, else palloc'd string.
+ */
+static char *
+recv_password_packet(Port *port)
+{
+ StringInfoData buf;
+ int mtype;
+
+ pq_startmsgread();
+
+ /* Expect 'p' message type */
+ mtype = pq_getbyte();
+ if (mtype != 'p')
+ {
+ /*
+ * If the client just disconnects without offering a password, don't
+ * make a log entry. This is legal per protocol spec and in fact
+ * commonly done by psql, so complaining just clutters the log.
+ */
+ if (mtype != EOF)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("expected password response, got message type %d",
+ mtype)));
+ return NULL; /* EOF or bad message type */
+ }
+
+ initStringInfo(&buf);
+ if (pq_getmessage(&buf, PG_MAX_AUTH_TOKEN_LENGTH)) /* receive password */
+ {
+ /* EOF - pq_getmessage already logged a suitable message */
+ pfree(buf.data);
+ return NULL;
+ }
+
+ /*
+ * Apply sanity check: password packet length should agree with length of
+ * contained string. Note it is safe to use strlen here because
+ * StringInfo is guaranteed to have an appended '\0'.
+ */
+ if (strlen(buf.data) + 1 != buf.len)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("invalid password packet size")));
+
+ /*
+ * Don't allow an empty password. Libpq treats an empty password the same
+ * as no password at all, and won't even try to authenticate. But other
+ * clients might, so allowing it would be confusing.
+ *
+ * Note that this only catches an empty password sent by the client in
+ * plaintext. There's also a check in CREATE/ALTER USER that prevents an
+ * empty string from being stored as a user's password in the first place.
+ * We rely on that for MD5 and SCRAM authentication, but we still need
+ * this check here, to prevent an empty password from being used with
+ * authentication methods that check the password against an external
+ * system, like PAM, LDAP and RADIUS.
+ */
+ if (buf.len == 1)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_PASSWORD),
+ errmsg("empty password returned by client")));
+
+ /* Do not echo password to logs, for security. */
+ elog(DEBUG5, "received password packet");
+
+ /*
+ * Return the received string. Note we do not attempt to do any
+ * character-set conversion on it; since we don't yet know the client's
+ * encoding, there wouldn't be much point.
+ */
+ return buf.data;
+}
+
+
+/*----------------------------------------------------------------
+ * Password-based authentication mechanisms
+ *----------------------------------------------------------------
+ */
+
+/*
+ * Plaintext password authentication.
+ */
+static int
+CheckPasswordAuth(Port *port, char **logdetail)
+{
+ char *passwd;
+ int result;
+ char *shadow_pass;
+
+ sendAuthRequest(port, AUTH_REQ_PASSWORD, NULL, 0);
+
+ passwd = recv_password_packet(port);
+ if (passwd == NULL)
+ return STATUS_EOF; /* client wouldn't send password */
+
+ shadow_pass = get_role_password(port->user_name, logdetail);
+ if (shadow_pass)
+ {
+ result = plain_crypt_verify(port->user_name, shadow_pass, passwd,
+ logdetail);
+ }
+ else
+ result = STATUS_ERROR;
+
+ if (shadow_pass)
+ pfree(shadow_pass);
+ pfree(passwd);
+
+ if (result == STATUS_OK)
+ set_authn_id(port, port->user_name);
+
+ return result;
+}
+
+/*
+ * MD5 and SCRAM authentication.
+ */
+static int
+CheckPWChallengeAuth(Port *port, char **logdetail)
+{
+ int auth_result;
+ char *shadow_pass;
+ PasswordType pwtype;
+
+ Assert(port->hba->auth_method == uaSCRAM ||
+ port->hba->auth_method == uaMD5);
+
+ /* First look up the user's password. */
+ shadow_pass = get_role_password(port->user_name, logdetail);
+
+ /*
+ * If the user does not exist, or has no password or it's expired, we
+ * still go through the motions of authentication, to avoid revealing to
+ * the client that the user didn't exist. If 'md5' is allowed, we choose
+ * whether to use 'md5' or 'scram-sha-256' authentication based on current
+ * password_encryption setting. The idea is that most genuine users
+ * probably have a password of that type, and if we pretend that this user
+ * had a password of that type, too, it "blends in" best.
+ */
+ if (!shadow_pass)
+ pwtype = Password_encryption;
+ else
+ pwtype = get_password_type(shadow_pass);
+
+ /*
+ * If 'md5' authentication is allowed, decide whether to perform 'md5' or
+ * 'scram-sha-256' authentication based on the type of password the user
+ * has. If it's an MD5 hash, we must do MD5 authentication, and if it's a
+ * SCRAM secret, we must do SCRAM authentication.
+ *
+ * If MD5 authentication is not allowed, always use SCRAM. If the user
+ * had an MD5 password, CheckSCRAMAuth() will fail.
+ */
+ if (port->hba->auth_method == uaMD5 && pwtype == PASSWORD_TYPE_MD5)
+ auth_result = CheckMD5Auth(port, shadow_pass, logdetail);
+ else
+ auth_result = CheckSCRAMAuth(port, shadow_pass, logdetail);
+
+ if (shadow_pass)
+ pfree(shadow_pass);
+
+ /*
+ * If get_role_password() returned error, return error, even if the
+ * authentication succeeded.
+ */
+ if (!shadow_pass)
+ {
+ Assert(auth_result != STATUS_OK);
+ return STATUS_ERROR;
+ }
+
+ if (auth_result == STATUS_OK)
+ set_authn_id(port, port->user_name);
+
+ return auth_result;
+}
+
+static int
+CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail)
+{
+ char md5Salt[4]; /* Password salt */
+ char *passwd;
+ int result;
+
+ if (Db_user_namespace)
+ ereport(FATAL,
+ (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION),
+ errmsg("MD5 authentication is not supported when \"db_user_namespace\" is enabled")));
+
+ /* include the salt to use for computing the response */
+ if (!pg_strong_random(md5Salt, 4))
+ {
+ ereport(LOG,
+ (errmsg("could not generate random MD5 salt")));
+ return STATUS_ERROR;
+ }
+
+ sendAuthRequest(port, AUTH_REQ_MD5, md5Salt, 4);
+
+ passwd = recv_password_packet(port);
+ if (passwd == NULL)
+ return STATUS_EOF; /* client wouldn't send password */
+
+ if (shadow_pass)
+ result = md5_crypt_verify(port->user_name, shadow_pass, passwd,
+ md5Salt, 4, logdetail);
+ else
+ result = STATUS_ERROR;
+
+ pfree(passwd);
+
+ return result;
+}
+
+static int
+CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail)
+{
+ StringInfoData sasl_mechs;
+ int mtype;
+ StringInfoData buf;
+ void *scram_opaq = NULL;
+ char *output = NULL;
+ int outputlen = 0;
+ const char *input;
+ int inputlen;
+ int result;
+ bool initial;
+
+ /*
+ * Send the SASL authentication request to user. It includes the list of
+ * authentication mechanisms that are supported.
+ */
+ initStringInfo(&sasl_mechs);
+
+ pg_be_scram_get_mechanisms(port, &sasl_mechs);
+ /* Put another '\0' to mark that list is finished. */
+ appendStringInfoChar(&sasl_mechs, '\0');
+
+ sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs.data, sasl_mechs.len);
+ pfree(sasl_mechs.data);
+
+ /*
+ * Loop through SASL message exchange. This exchange can consist of
+ * multiple messages sent in both directions. First message is always
+ * from the client. All messages from client to server are password
+ * packets (type 'p').
+ */
+ initial = true;
+ do
+ {
+ pq_startmsgread();
+ mtype = pq_getbyte();
+ if (mtype != 'p')
+ {
+ /* Only log error if client didn't disconnect. */
+ if (mtype != EOF)
+ {
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("expected SASL response, got message type %d",
+ mtype)));
+ }
+ else
+ return STATUS_EOF;
+ }
+
+ /* Get the actual SASL message */
+ initStringInfo(&buf);
+ if (pq_getmessage(&buf, PG_MAX_SASL_MESSAGE_LENGTH))
+ {
+ /* EOF - pq_getmessage already logged error */
+ pfree(buf.data);
+ return STATUS_ERROR;
+ }
+
+ elog(DEBUG4, "processing received SASL response of length %d", buf.len);
+
+ /*
+ * The first SASLInitialResponse message is different from the others.
+ * It indicates which SASL mechanism the client selected, and contains
+ * an optional Initial Client Response payload. The subsequent
+ * SASLResponse messages contain just the SASL payload.
+ */
+ if (initial)
+ {
+ const char *selected_mech;
+
+ selected_mech = pq_getmsgrawstring(&buf);
+
+ /*
+ * Initialize the status tracker for message exchanges.
+ *
+ * If the user doesn't exist, or doesn't have a valid password, or
+ * it's expired, we still go through the motions of SASL
+ * authentication, but tell the authentication method that the
+ * authentication is "doomed". That is, it's going to fail, no
+ * matter what.
+ *
+ * This is because we don't want to reveal to an attacker what
+ * usernames are valid, nor which users have a valid password.
+ */
+ scram_opaq = pg_be_scram_init(port, selected_mech, shadow_pass);
+
+ inputlen = pq_getmsgint(&buf, 4);
+ if (inputlen == -1)
+ input = NULL;
+ else
+ input = pq_getmsgbytes(&buf, inputlen);
+
+ initial = false;
+ }
+ else
+ {
+ inputlen = buf.len;
+ input = pq_getmsgbytes(&buf, buf.len);
+ }
+ pq_getmsgend(&buf);
+
+ /*
+ * The StringInfo guarantees that there's a \0 byte after the
+ * response.
+ */
+ Assert(input == NULL || input[inputlen] == '\0');
+
+ /*
+ * we pass 'logdetail' as NULL when doing a mock authentication,
+ * because we should already have a better error message in that case
+ */
+ result = pg_be_scram_exchange(scram_opaq, input, inputlen,
+ &output, &outputlen,
+ logdetail);
+
+ /* input buffer no longer used */
+ pfree(buf.data);
+
+ if (output)
+ {
+ /*
+ * Negotiation generated data to be sent to the client.
+ */
+ elog(DEBUG4, "sending SASL challenge of length %u", outputlen);
+
+ if (result == SASL_EXCHANGE_SUCCESS)
+ sendAuthRequest(port, AUTH_REQ_SASL_FIN, output, outputlen);
+ else
+ sendAuthRequest(port, AUTH_REQ_SASL_CONT, output, outputlen);
+
+ pfree(output);
+ }
+ } while (result == SASL_EXCHANGE_CONTINUE);
+
+ /* Oops, Something bad happened */
+ if (result != SASL_EXCHANGE_SUCCESS)
+ {
+ return STATUS_ERROR;
+ }
+
+ return STATUS_OK;
+}
+
+
+/*----------------------------------------------------------------
+ * GSSAPI authentication system
+ *----------------------------------------------------------------
+ */
+#ifdef ENABLE_GSS
+static int
+pg_GSS_recvauth(Port *port)
+{
+ OM_uint32 maj_stat,
+ min_stat,
+ lmin_s,
+ gflags;
+ int mtype;
+ StringInfoData buf;
+ gss_buffer_desc gbuf;
+
+ /*
+ * Use the configured keytab, if there is one. Unfortunately, Heimdal
+ * doesn't support the cred store extensions, so use the env var.
+ */
+ if (pg_krb_server_keyfile != NULL && pg_krb_server_keyfile[0] != '\0')
+ {
+ if (setenv("KRB5_KTNAME", pg_krb_server_keyfile, 1) != 0)
+ {
+ /* The only likely failure cause is OOM, so use that errcode */
+ ereport(FATAL,
+ (errcode(ERRCODE_OUT_OF_MEMORY),
+ errmsg("could not set environment: %m")));
+ }
+ }
+
+ /*
+ * We accept any service principal that's present in our keytab. This
+ * increases interoperability between kerberos implementations that see
+ * for example case sensitivity differently, while not really opening up
+ * any vector of attack.
+ */
+ port->gss->cred = GSS_C_NO_CREDENTIAL;
+
+ /*
+ * Initialize sequence with an empty context
+ */
+ port->gss->ctx = GSS_C_NO_CONTEXT;
+
+ /*
+ * Loop through GSSAPI message exchange. This exchange can consist of
+ * multiple messages sent in both directions. First message is always from
+ * the client. All messages from client to server are password packets
+ * (type 'p').
+ */
+ do
+ {
+ pq_startmsgread();
+
+ CHECK_FOR_INTERRUPTS();
+
+ mtype = pq_getbyte();
+ if (mtype != 'p')
+ {
+ /* Only log error if client didn't disconnect. */
+ if (mtype != EOF)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("expected GSS response, got message type %d",
+ mtype)));
+ return STATUS_ERROR;
+ }
+
+ /* Get the actual GSS token */
+ initStringInfo(&buf);
+ if (pq_getmessage(&buf, PG_MAX_AUTH_TOKEN_LENGTH))
+ {
+ /* EOF - pq_getmessage already logged error */
+ pfree(buf.data);
+ return STATUS_ERROR;
+ }
+
+ /* Map to GSSAPI style buffer */
+ gbuf.length = buf.len;
+ gbuf.value = buf.data;
+
+ elog(DEBUG4, "processing received GSS token of length %u",
+ (unsigned int) gbuf.length);
+
+ maj_stat = gss_accept_sec_context(&min_stat,
+ &port->gss->ctx,
+ port->gss->cred,
+ &gbuf,
+ GSS_C_NO_CHANNEL_BINDINGS,
+ &port->gss->name,
+ NULL,
+ &port->gss->outbuf,
+ &gflags,
+ NULL,
+ NULL);
+
+ /* gbuf no longer used */
+ pfree(buf.data);
+
+ elog(DEBUG5, "gss_accept_sec_context major: %d, "
+ "minor: %d, outlen: %u, outflags: %x",
+ maj_stat, min_stat,
+ (unsigned int) port->gss->outbuf.length, gflags);
+
+ CHECK_FOR_INTERRUPTS();
+
+ if (port->gss->outbuf.length != 0)
+ {
+ /*
+ * Negotiation generated data to be sent to the client.
+ */
+ elog(DEBUG4, "sending GSS response token of length %u",
+ (unsigned int) port->gss->outbuf.length);
+
+ sendAuthRequest(port, AUTH_REQ_GSS_CONT,
+ port->gss->outbuf.value, port->gss->outbuf.length);
+
+ gss_release_buffer(&lmin_s, &port->gss->outbuf);
+ }
+
+ if (maj_stat != GSS_S_COMPLETE && maj_stat != GSS_S_CONTINUE_NEEDED)
+ {
+ gss_delete_sec_context(&lmin_s, &port->gss->ctx, GSS_C_NO_BUFFER);
+ pg_GSS_error(_("accepting GSS security context failed"),
+ maj_stat, min_stat);
+ return STATUS_ERROR;
+ }
+
+ if (maj_stat == GSS_S_CONTINUE_NEEDED)
+ elog(DEBUG4, "GSS continue needed");
+
+ } while (maj_stat == GSS_S_CONTINUE_NEEDED);
+
+ if (port->gss->cred != GSS_C_NO_CREDENTIAL)
+ {
+ /*
+ * Release service principal credentials
+ */
+ gss_release_cred(&min_stat, &port->gss->cred);
+ }
+ return pg_GSS_checkauth(port);
+}
+
+/*
+ * Check whether the GSSAPI-authenticated user is allowed to connect as the
+ * claimed username.
+ */
+static int
+pg_GSS_checkauth(Port *port)
+{
+ int ret;
+ OM_uint32 maj_stat,
+ min_stat,
+ lmin_s;
+ gss_buffer_desc gbuf;
+ char *princ;
+
+ /*
+ * Get the name of the user that authenticated, and compare it to the pg
+ * username that was specified for the connection.
+ */
+ maj_stat = gss_display_name(&min_stat, port->gss->name, &gbuf, NULL);
+ if (maj_stat != GSS_S_COMPLETE)
+ {
+ pg_GSS_error(_("retrieving GSS user name failed"),
+ maj_stat, min_stat);
+ return STATUS_ERROR;
+ }
+
+ /*
+ * gbuf.value might not be null-terminated, so turn it into a regular
+ * null-terminated string.
+ */
+ princ = palloc(gbuf.length + 1);
+ memcpy(princ, gbuf.value, gbuf.length);
+ princ[gbuf.length] = '\0';
+ gss_release_buffer(&lmin_s, &gbuf);
+
+ /*
+ * Copy the original name of the authenticated principal into our backend
+ * memory for display later.
+ *
+ * This is also our authenticated identity. Set it now, rather than
+ * waiting for the usermap check below, because authentication has already
+ * succeeded and we want the log file to reflect that.
+ */
+ port->gss->princ = MemoryContextStrdup(TopMemoryContext, princ);
+ set_authn_id(port, princ);
+
+ /*
+ * Split the username at the realm separator
+ */
+ if (strchr(princ, '@'))
+ {
+ char *cp = strchr(princ, '@');
+
+ /*
+ * If we are not going to include the realm in the username that is
+ * passed to the ident map, destructively modify it here to remove the
+ * realm. Then advance past the separator to check the realm.
+ */
+ if (!port->hba->include_realm)
+ *cp = '\0';
+ cp++;
+
+ if (port->hba->krb_realm != NULL && strlen(port->hba->krb_realm))
+ {
+ /*
+ * Match the realm part of the name first
+ */
+ if (pg_krb_caseins_users)
+ ret = pg_strcasecmp(port->hba->krb_realm, cp);
+ else
+ ret = strcmp(port->hba->krb_realm, cp);
+
+ if (ret)
+ {
+ /* GSS realm does not match */
+ elog(DEBUG2,
+ "GSSAPI realm (%s) and configured realm (%s) don't match",
+ cp, port->hba->krb_realm);
+ pfree(princ);
+ return STATUS_ERROR;
+ }
+ }
+ }
+ else if (port->hba->krb_realm && strlen(port->hba->krb_realm))
+ {
+ elog(DEBUG2,
+ "GSSAPI did not return realm but realm matching was requested");
+ pfree(princ);
+ return STATUS_ERROR;
+ }
+
+ ret = check_usermap(port->hba->usermap, port->user_name, princ,
+ pg_krb_caseins_users);
+
+ pfree(princ);
+
+ return ret;
+}
+#endif /* ENABLE_GSS */
+
+
+/*----------------------------------------------------------------
+ * SSPI authentication system
+ *----------------------------------------------------------------
+ */
+#ifdef ENABLE_SSPI
+
+/*
+ * Generate an error for SSPI authentication. The caller should apply
+ * _() to errmsg to make it translatable.
+ */
+static void
+pg_SSPI_error(int severity, const char *errmsg, SECURITY_STATUS r)
+{
+ char sysmsg[256];
+
+ if (FormatMessage(FORMAT_MESSAGE_IGNORE_INSERTS |
+ FORMAT_MESSAGE_FROM_SYSTEM,
+ NULL, r, 0,
+ sysmsg, sizeof(sysmsg), NULL) == 0)
+ ereport(severity,
+ (errmsg_internal("%s", errmsg),
+ errdetail_internal("SSPI error %x", (unsigned int) r)));
+ else
+ ereport(severity,
+ (errmsg_internal("%s", errmsg),
+ errdetail_internal("%s (%x)", sysmsg, (unsigned int) r)));
+}
+
+static int
+pg_SSPI_recvauth(Port *port)
+{
+ int mtype;
+ StringInfoData buf;
+ SECURITY_STATUS r;
+ CredHandle sspicred;
+ CtxtHandle *sspictx = NULL,
+ newctx;
+ TimeStamp expiry;
+ ULONG contextattr;
+ SecBufferDesc inbuf;
+ SecBufferDesc outbuf;
+ SecBuffer OutBuffers[1];
+ SecBuffer InBuffers[1];
+ HANDLE token;
+ TOKEN_USER *tokenuser;
+ DWORD retlen;
+ char accountname[MAXPGPATH];
+ char domainname[MAXPGPATH];
+ DWORD accountnamesize = sizeof(accountname);
+ DWORD domainnamesize = sizeof(domainname);
+ SID_NAME_USE accountnameuse;
+ HMODULE secur32;
+ char *authn_id;
+
+ QUERY_SECURITY_CONTEXT_TOKEN_FN _QuerySecurityContextToken;
+
+ /*
+ * Acquire a handle to the server credentials.
+ */
+ r = AcquireCredentialsHandle(NULL,
+ "negotiate",
+ SECPKG_CRED_INBOUND,
+ NULL,
+ NULL,
+ NULL,
+ NULL,
+ &sspicred,
+ &expiry);
+ if (r != SEC_E_OK)
+ pg_SSPI_error(ERROR, _("could not acquire SSPI credentials"), r);
+
+ /*
+ * Loop through SSPI message exchange. This exchange can consist of
+ * multiple messages sent in both directions. First message is always from
+ * the client. All messages from client to server are password packets
+ * (type 'p').
+ */
+ do
+ {
+ pq_startmsgread();
+ mtype = pq_getbyte();
+ if (mtype != 'p')
+ {
+ if (sspictx != NULL)
+ {
+ DeleteSecurityContext(sspictx);
+ free(sspictx);
+ }
+ FreeCredentialsHandle(&sspicred);
+
+ /* Only log error if client didn't disconnect. */
+ if (mtype != EOF)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("expected SSPI response, got message type %d",
+ mtype)));
+ return STATUS_ERROR;
+ }
+
+ /* Get the actual SSPI token */
+ initStringInfo(&buf);
+ if (pq_getmessage(&buf, PG_MAX_AUTH_TOKEN_LENGTH))
+ {
+ /* EOF - pq_getmessage already logged error */
+ pfree(buf.data);
+ if (sspictx != NULL)
+ {
+ DeleteSecurityContext(sspictx);
+ free(sspictx);
+ }
+ FreeCredentialsHandle(&sspicred);
+ return STATUS_ERROR;
+ }
+
+ /* Map to SSPI style buffer */
+ inbuf.ulVersion = SECBUFFER_VERSION;
+ inbuf.cBuffers = 1;
+ inbuf.pBuffers = InBuffers;
+ InBuffers[0].pvBuffer = buf.data;
+ InBuffers[0].cbBuffer = buf.len;
+ InBuffers[0].BufferType = SECBUFFER_TOKEN;
+
+ /* Prepare output buffer */
+ OutBuffers[0].pvBuffer = NULL;
+ OutBuffers[0].BufferType = SECBUFFER_TOKEN;
+ OutBuffers[0].cbBuffer = 0;
+ outbuf.cBuffers = 1;
+ outbuf.pBuffers = OutBuffers;
+ outbuf.ulVersion = SECBUFFER_VERSION;
+
+ elog(DEBUG4, "processing received SSPI token of length %u",
+ (unsigned int) buf.len);
+
+ r = AcceptSecurityContext(&sspicred,
+ sspictx,
+ &inbuf,
+ ASC_REQ_ALLOCATE_MEMORY,
+ SECURITY_NETWORK_DREP,
+ &newctx,
+ &outbuf,
+ &contextattr,
+ NULL);
+
+ /* input buffer no longer used */
+ pfree(buf.data);
+
+ if (outbuf.cBuffers > 0 && outbuf.pBuffers[0].cbBuffer > 0)
+ {
+ /*
+ * Negotiation generated data to be sent to the client.
+ */
+ elog(DEBUG4, "sending SSPI response token of length %u",
+ (unsigned int) outbuf.pBuffers[0].cbBuffer);
+
+ port->gss->outbuf.length = outbuf.pBuffers[0].cbBuffer;
+ port->gss->outbuf.value = outbuf.pBuffers[0].pvBuffer;
+
+ sendAuthRequest(port, AUTH_REQ_GSS_CONT,
+ port->gss->outbuf.value, port->gss->outbuf.length);
+
+ FreeContextBuffer(outbuf.pBuffers[0].pvBuffer);
+ }
+
+ if (r != SEC_E_OK && r != SEC_I_CONTINUE_NEEDED)
+ {
+ if (sspictx != NULL)
+ {
+ DeleteSecurityContext(sspictx);
+ free(sspictx);
+ }
+ FreeCredentialsHandle(&sspicred);
+ pg_SSPI_error(ERROR,
+ _("could not accept SSPI security context"), r);
+ }
+
+ /*
+ * Overwrite the current context with the one we just received. If
+ * sspictx is NULL it was the first loop and we need to allocate a
+ * buffer for it. On subsequent runs, we can just overwrite the buffer
+ * contents since the size does not change.
+ */
+ if (sspictx == NULL)
+ {
+ sspictx = malloc(sizeof(CtxtHandle));
+ if (sspictx == NULL)
+ ereport(ERROR,
+ (errmsg("out of memory")));
+ }
+
+ memcpy(sspictx, &newctx, sizeof(CtxtHandle));
+
+ if (r == SEC_I_CONTINUE_NEEDED)
+ elog(DEBUG4, "SSPI continue needed");
+
+ } while (r == SEC_I_CONTINUE_NEEDED);
+
+
+ /*
+ * Release service principal credentials
+ */
+ FreeCredentialsHandle(&sspicred);
+
+
+ /*
+ * SEC_E_OK indicates that authentication is now complete.
+ *
+ * Get the name of the user that authenticated, and compare it to the pg
+ * username that was specified for the connection.
+ *
+ * MingW is missing the export for QuerySecurityContextToken in the
+ * secur32 library, so we have to load it dynamically.
+ */
+
+ secur32 = LoadLibrary("SECUR32.DLL");
+ if (secur32 == NULL)
+ ereport(ERROR,
+ (errmsg("could not load library \"%s\": error code %lu",
+ "SECUR32.DLL", GetLastError())));
+
+ _QuerySecurityContextToken = (QUERY_SECURITY_CONTEXT_TOKEN_FN) (pg_funcptr_t)
+ GetProcAddress(secur32, "QuerySecurityContextToken");
+ if (_QuerySecurityContextToken == NULL)
+ {
+ FreeLibrary(secur32);
+ ereport(ERROR,
+ (errmsg_internal("could not locate QuerySecurityContextToken in secur32.dll: error code %lu",
+ GetLastError())));
+ }
+
+ r = (_QuerySecurityContextToken) (sspictx, &token);
+ if (r != SEC_E_OK)
+ {
+ FreeLibrary(secur32);
+ pg_SSPI_error(ERROR,
+ _("could not get token from SSPI security context"), r);
+ }
+
+ FreeLibrary(secur32);
+
+ /*
+ * No longer need the security context, everything from here on uses the
+ * token instead.
+ */
+ DeleteSecurityContext(sspictx);
+ free(sspictx);
+
+ if (!GetTokenInformation(token, TokenUser, NULL, 0, &retlen) && GetLastError() != 122)
+ ereport(ERROR,
+ (errmsg_internal("could not get token information buffer size: error code %lu",
+ GetLastError())));
+
+ tokenuser = malloc(retlen);
+ if (tokenuser == NULL)
+ ereport(ERROR,
+ (errmsg("out of memory")));
+
+ if (!GetTokenInformation(token, TokenUser, tokenuser, retlen, &retlen))
+ ereport(ERROR,
+ (errmsg_internal("could not get token information: error code %lu",
+ GetLastError())));
+
+ CloseHandle(token);
+
+ if (!LookupAccountSid(NULL, tokenuser->User.Sid, accountname, &accountnamesize,
+ domainname, &domainnamesize, &accountnameuse))
+ ereport(ERROR,
+ (errmsg_internal("could not look up account SID: error code %lu",
+ GetLastError())));
+
+ free(tokenuser);
+
+ if (!port->hba->compat_realm)
+ {
+ int status = pg_SSPI_make_upn(accountname, sizeof(accountname),
+ domainname, sizeof(domainname),
+ port->hba->upn_username);
+
+ if (status != STATUS_OK)
+ /* Error already reported from pg_SSPI_make_upn */
+ return status;
+ }
+
+ /*
+ * We have all of the information necessary to construct the authenticated
+ * identity. Set it now, rather than waiting for check_usermap below,
+ * because authentication has already succeeded and we want the log file
+ * to reflect that.
+ */
+ if (port->hba->compat_realm)
+ {
+ /* SAM-compatible format. */
+ authn_id = psprintf("%s\\%s", domainname, accountname);
+ }
+ else
+ {
+ /* Kerberos principal format. */
+ authn_id = psprintf("%s@%s", accountname, domainname);
+ }
+
+ set_authn_id(port, authn_id);
+ pfree(authn_id);
+
+ /*
+ * Compare realm/domain if requested. In SSPI, always compare case
+ * insensitive.
+ */
+ if (port->hba->krb_realm && strlen(port->hba->krb_realm))
+ {
+ if (pg_strcasecmp(port->hba->krb_realm, domainname) != 0)
+ {
+ elog(DEBUG2,
+ "SSPI domain (%s) and configured domain (%s) don't match",
+ domainname, port->hba->krb_realm);
+
+ return STATUS_ERROR;
+ }
+ }
+
+ /*
+ * We have the username (without domain/realm) in accountname, compare to
+ * the supplied value. In SSPI, always compare case insensitive.
+ *
+ * If set to include realm, append it in <username>@<realm> format.
+ */
+ if (port->hba->include_realm)
+ {
+ char *namebuf;
+ int retval;
+
+ namebuf = psprintf("%s@%s", accountname, domainname);
+ retval = check_usermap(port->hba->usermap, port->user_name, namebuf, true);
+ pfree(namebuf);
+ return retval;
+ }
+ else
+ return check_usermap(port->hba->usermap, port->user_name, accountname, true);
+}
+
+/*
+ * Replaces the domainname with the Kerberos realm name,
+ * and optionally the accountname with the Kerberos user name.
+ */
+static int
+pg_SSPI_make_upn(char *accountname,
+ size_t accountnamesize,
+ char *domainname,
+ size_t domainnamesize,
+ bool update_accountname)
+{
+ char *samname;
+ char *upname = NULL;
+ char *p = NULL;
+ ULONG upnamesize = 0;
+ size_t upnamerealmsize;
+ BOOLEAN res;
+
+ /*
+ * Build SAM name (DOMAIN\user), then translate to UPN
+ * (user@kerberos.realm). The realm name is returned in lower case, but
+ * that is fine because in SSPI auth, string comparisons are always
+ * case-insensitive.
+ */
+
+ samname = psprintf("%s\\%s", domainname, accountname);
+ res = TranslateName(samname, NameSamCompatible, NameUserPrincipal,
+ NULL, &upnamesize);
+
+ if ((!res && GetLastError() != ERROR_INSUFFICIENT_BUFFER)
+ || upnamesize == 0)
+ {
+ pfree(samname);
+ ereport(LOG,
+ (errcode(ERRCODE_INVALID_ROLE_SPECIFICATION),
+ errmsg("could not translate name")));
+ return STATUS_ERROR;
+ }
+
+ /* upnamesize includes the terminating NUL. */
+ upname = palloc(upnamesize);
+
+ res = TranslateName(samname, NameSamCompatible, NameUserPrincipal,
+ upname, &upnamesize);
+
+ pfree(samname);
+ if (res)
+ p = strchr(upname, '@');
+
+ if (!res || p == NULL)
+ {
+ pfree(upname);
+ ereport(LOG,
+ (errcode(ERRCODE_INVALID_ROLE_SPECIFICATION),
+ errmsg("could not translate name")));
+ return STATUS_ERROR;
+ }
+
+ /* Length of realm name after the '@', including the NUL. */
+ upnamerealmsize = upnamesize - (p - upname + 1);
+
+ /* Replace domainname with realm name. */
+ if (upnamerealmsize > domainnamesize)
+ {
+ pfree(upname);
+ ereport(LOG,
+ (errcode(ERRCODE_INVALID_ROLE_SPECIFICATION),
+ errmsg("realm name too long")));
+ return STATUS_ERROR;
+ }
+
+ /* Length is now safe. */
+ strcpy(domainname, p + 1);
+
+ /* Replace account name as well (in case UPN != SAM)? */
+ if (update_accountname)
+ {
+ if ((p - upname + 1) > accountnamesize)
+ {
+ pfree(upname);
+ ereport(LOG,
+ (errcode(ERRCODE_INVALID_ROLE_SPECIFICATION),
+ errmsg("translated account name too long")));
+ return STATUS_ERROR;
+ }
+
+ *p = 0;
+ strcpy(accountname, upname);
+ }
+
+ pfree(upname);
+ return STATUS_OK;
+}
+#endif /* ENABLE_SSPI */
+
+
+
+/*----------------------------------------------------------------
+ * Ident authentication system
+ *----------------------------------------------------------------
+ */
+
+/*
+ * Parse the string "*ident_response" as a response from a query to an Ident
+ * server. If it's a normal response indicating a user name, return true
+ * and store the user name at *ident_user. If it's anything else,
+ * return false.
+ */
+static bool
+interpret_ident_response(const char *ident_response,
+ char *ident_user)
+{
+ const char *cursor = ident_response; /* Cursor into *ident_response */
+
+ /*
+ * Ident's response, in the telnet tradition, should end in crlf (\r\n).
+ */
+ if (strlen(ident_response) < 2)
+ return false;
+ else if (ident_response[strlen(ident_response) - 2] != '\r')
+ return false;
+ else
+ {
+ while (*cursor != ':' && *cursor != '\r')
+ cursor++; /* skip port field */
+
+ if (*cursor != ':')
+ return false;
+ else
+ {
+ /* We're positioned to colon before response type field */
+ char response_type[80];
+ int i; /* Index into *response_type */
+
+ cursor++; /* Go over colon */
+ while (pg_isblank(*cursor))
+ cursor++; /* skip blanks */
+ i = 0;
+ while (*cursor != ':' && *cursor != '\r' && !pg_isblank(*cursor) &&
+ i < (int) (sizeof(response_type) - 1))
+ response_type[i++] = *cursor++;
+ response_type[i] = '\0';
+ while (pg_isblank(*cursor))
+ cursor++; /* skip blanks */
+ if (strcmp(response_type, "USERID") != 0)
+ return false;
+ else
+ {
+ /*
+ * It's a USERID response. Good. "cursor" should be pointing
+ * to the colon that precedes the operating system type.
+ */
+ if (*cursor != ':')
+ return false;
+ else
+ {
+ cursor++; /* Go over colon */
+ /* Skip over operating system field. */
+ while (*cursor != ':' && *cursor != '\r')
+ cursor++;
+ if (*cursor != ':')
+ return false;
+ else
+ {
+ int i; /* Index into *ident_user */
+
+ cursor++; /* Go over colon */
+ while (pg_isblank(*cursor))
+ cursor++; /* skip blanks */
+ /* Rest of line is user name. Copy it over. */
+ i = 0;
+ while (*cursor != '\r' && i < IDENT_USERNAME_MAX)
+ ident_user[i++] = *cursor++;
+ ident_user[i] = '\0';
+ return true;
+ }
+ }
+ }
+ }
+ }
+}
+
+
+/*
+ * Talk to the ident server on "remote_addr" and find out who
+ * owns the tcp connection to "local_addr"
+ * If the username is successfully retrieved, check the usermap.
+ *
+ * XXX: Using WaitLatchOrSocket() and doing a CHECK_FOR_INTERRUPTS() if the
+ * latch was set would improve the responsiveness to timeouts/cancellations.
+ */
+static int
+ident_inet(hbaPort *port)
+{
+ const SockAddr remote_addr = port->raddr;
+ const SockAddr local_addr = port->laddr;
+ char ident_user[IDENT_USERNAME_MAX + 1];
+ pgsocket sock_fd = PGINVALID_SOCKET; /* for talking to Ident server */
+ int rc; /* Return code from a locally called function */
+ bool ident_return;
+ char remote_addr_s[NI_MAXHOST];
+ char remote_port[NI_MAXSERV];
+ char local_addr_s[NI_MAXHOST];
+ char local_port[NI_MAXSERV];
+ char ident_port[NI_MAXSERV];
+ char ident_query[80];
+ char ident_response[80 + IDENT_USERNAME_MAX];
+ struct addrinfo *ident_serv = NULL,
+ *la = NULL,
+ hints;
+
+ /*
+ * Might look a little weird to first convert it to text and then back to
+ * sockaddr, but it's protocol independent.
+ */
+ pg_getnameinfo_all(&remote_addr.addr, remote_addr.salen,
+ remote_addr_s, sizeof(remote_addr_s),
+ remote_port, sizeof(remote_port),
+ NI_NUMERICHOST | NI_NUMERICSERV);
+ pg_getnameinfo_all(&local_addr.addr, local_addr.salen,
+ local_addr_s, sizeof(local_addr_s),
+ local_port, sizeof(local_port),
+ NI_NUMERICHOST | NI_NUMERICSERV);
+
+ snprintf(ident_port, sizeof(ident_port), "%d", IDENT_PORT);
+ hints.ai_flags = AI_NUMERICHOST;
+ hints.ai_family = remote_addr.addr.ss_family;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_protocol = 0;
+ hints.ai_addrlen = 0;
+ hints.ai_canonname = NULL;
+ hints.ai_addr = NULL;
+ hints.ai_next = NULL;
+ rc = pg_getaddrinfo_all(remote_addr_s, ident_port, &hints, &ident_serv);
+ if (rc || !ident_serv)
+ {
+ /* we don't expect this to happen */
+ ident_return = false;
+ goto ident_inet_done;
+ }
+
+ hints.ai_flags = AI_NUMERICHOST;
+ hints.ai_family = local_addr.addr.ss_family;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_protocol = 0;
+ hints.ai_addrlen = 0;
+ hints.ai_canonname = NULL;
+ hints.ai_addr = NULL;
+ hints.ai_next = NULL;
+ rc = pg_getaddrinfo_all(local_addr_s, NULL, &hints, &la);
+ if (rc || !la)
+ {
+ /* we don't expect this to happen */
+ ident_return = false;
+ goto ident_inet_done;
+ }
+
+ sock_fd = socket(ident_serv->ai_family, ident_serv->ai_socktype,
+ ident_serv->ai_protocol);
+ if (sock_fd == PGINVALID_SOCKET)
+ {
+ ereport(LOG,
+ (errcode_for_socket_access(),
+ errmsg("could not create socket for Ident connection: %m")));
+ ident_return = false;
+ goto ident_inet_done;
+ }
+
+ /*
+ * Bind to the address which the client originally contacted, otherwise
+ * the ident server won't be able to match up the right connection. This
+ * is necessary if the PostgreSQL server is running on an IP alias.
+ */
+ rc = bind(sock_fd, la->ai_addr, la->ai_addrlen);
+ if (rc != 0)
+ {
+ ereport(LOG,
+ (errcode_for_socket_access(),
+ errmsg("could not bind to local address \"%s\": %m",
+ local_addr_s)));
+ ident_return = false;
+ goto ident_inet_done;
+ }
+
+ rc = connect(sock_fd, ident_serv->ai_addr,
+ ident_serv->ai_addrlen);
+ if (rc != 0)
+ {
+ ereport(LOG,
+ (errcode_for_socket_access(),
+ errmsg("could not connect to Ident server at address \"%s\", port %s: %m",
+ remote_addr_s, ident_port)));
+ ident_return = false;
+ goto ident_inet_done;
+ }
+
+ /* The query we send to the Ident server */
+ snprintf(ident_query, sizeof(ident_query), "%s,%s\r\n",
+ remote_port, local_port);
+
+ /* loop in case send is interrupted */
+ do
+ {
+ CHECK_FOR_INTERRUPTS();
+
+ rc = send(sock_fd, ident_query, strlen(ident_query), 0);
+ } while (rc < 0 && errno == EINTR);
+
+ if (rc < 0)
+ {
+ ereport(LOG,
+ (errcode_for_socket_access(),
+ errmsg("could not send query to Ident server at address \"%s\", port %s: %m",
+ remote_addr_s, ident_port)));
+ ident_return = false;
+ goto ident_inet_done;
+ }
+
+ do
+ {
+ CHECK_FOR_INTERRUPTS();
+
+ rc = recv(sock_fd, ident_response, sizeof(ident_response) - 1, 0);
+ } while (rc < 0 && errno == EINTR);
+
+ if (rc < 0)
+ {
+ ereport(LOG,
+ (errcode_for_socket_access(),
+ errmsg("could not receive response from Ident server at address \"%s\", port %s: %m",
+ remote_addr_s, ident_port)));
+ ident_return = false;
+ goto ident_inet_done;
+ }
+
+ ident_response[rc] = '\0';
+ ident_return = interpret_ident_response(ident_response, ident_user);
+ if (!ident_return)
+ ereport(LOG,
+ (errmsg("invalidly formatted response from Ident server: \"%s\"",
+ ident_response)));
+
+ident_inet_done:
+ if (sock_fd != PGINVALID_SOCKET)
+ closesocket(sock_fd);
+ if (ident_serv)
+ pg_freeaddrinfo_all(remote_addr.addr.ss_family, ident_serv);
+ if (la)
+ pg_freeaddrinfo_all(local_addr.addr.ss_family, la);
+
+ if (ident_return)
+ {
+ /*
+ * Success! Store the identity, then check the usermap. Note that
+ * setting the authenticated identity is done before checking the
+ * usermap, because at this point authentication has succeeded.
+ */
+ set_authn_id(port, ident_user);
+ return check_usermap(port->hba->usermap, port->user_name, ident_user, false);
+ }
+ return STATUS_ERROR;
+}
+
+
+/*----------------------------------------------------------------
+ * Peer authentication system
+ *----------------------------------------------------------------
+ */
+
+/*
+ * Ask kernel about the credentials of the connecting process,
+ * determine the symbolic name of the corresponding user, and check
+ * if valid per the usermap.
+ *
+ * Iff authorized, return STATUS_OK, otherwise return STATUS_ERROR.
+ */
+static int
+auth_peer(hbaPort *port)
+{
+ uid_t uid;
+ gid_t gid;
+#ifndef WIN32
+ struct passwd *pw;
+ int ret;
+#endif
+
+ if (getpeereid(port->sock, &uid, &gid) != 0)
+ {
+ /* Provide special error message if getpeereid is a stub */
+ if (errno == ENOSYS)
+ ereport(LOG,
+ (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+ errmsg("peer authentication is not supported on this platform")));
+ else
+ ereport(LOG,
+ (errcode_for_socket_access(),
+ errmsg("could not get peer credentials: %m")));
+ return STATUS_ERROR;
+ }
+
+#ifndef WIN32
+ errno = 0; /* clear errno before call */
+ pw = getpwuid(uid);
+ if (!pw)
+ {
+ int save_errno = errno;
+
+ ereport(LOG,
+ (errmsg("could not look up local user ID %ld: %s",
+ (long) uid,
+ save_errno ? strerror(save_errno) : _("user does not exist"))));
+ return STATUS_ERROR;
+ }
+
+ /*
+ * Make a copy of static getpw*() result area; this is our authenticated
+ * identity. Set it before calling check_usermap, because authentication
+ * has already succeeded and we want the log file to reflect that.
+ */
+ set_authn_id(port, pw->pw_name);
+
+ ret = check_usermap(port->hba->usermap, port->user_name, port->authn_id, false);
+
+ return ret;
+#else
+ /* should have failed with ENOSYS above */
+ Assert(false);
+ return STATUS_ERROR;
+#endif
+}
+
+
+/*----------------------------------------------------------------
+ * PAM authentication system
+ *----------------------------------------------------------------
+ */
+#ifdef USE_PAM
+
+/*
+ * PAM conversation function
+ */
+
+static int
+pam_passwd_conv_proc(int num_msg, const struct pam_message **msg,
+ struct pam_response **resp, void *appdata_ptr)
+{
+ const char *passwd;
+ struct pam_response *reply;
+ int i;
+
+ if (appdata_ptr)
+ passwd = (char *) appdata_ptr;
+ else
+ {
+ /*
+ * Workaround for Solaris 2.6 where the PAM library is broken and does
+ * not pass appdata_ptr to the conversation routine
+ */
+ passwd = pam_passwd;
+ }
+
+ *resp = NULL; /* in case of error exit */
+
+ if (num_msg <= 0 || num_msg > PAM_MAX_NUM_MSG)
+ return PAM_CONV_ERR;
+
+ /*
+ * Explicitly not using palloc here - PAM will free this memory in
+ * pam_end()
+ */
+ if ((reply = calloc(num_msg, sizeof(struct pam_response))) == NULL)
+ {
+ ereport(LOG,
+ (errcode(ERRCODE_OUT_OF_MEMORY),
+ errmsg("out of memory")));
+ return PAM_CONV_ERR;
+ }
+
+ for (i = 0; i < num_msg; i++)
+ {
+ switch (msg[i]->msg_style)
+ {
+ case PAM_PROMPT_ECHO_OFF:
+ if (strlen(passwd) == 0)
+ {
+ /*
+ * Password wasn't passed to PAM the first time around -
+ * let's go ask the client to send a password, which we
+ * then stuff into PAM.
+ */
+ sendAuthRequest(pam_port_cludge, AUTH_REQ_PASSWORD, NULL, 0);
+ passwd = recv_password_packet(pam_port_cludge);
+ if (passwd == NULL)
+ {
+ /*
+ * Client didn't want to send password. We
+ * intentionally do not log anything about this,
+ * either here or at higher levels.
+ */
+ pam_no_password = true;
+ goto fail;
+ }
+ }
+ if ((reply[i].resp = strdup(passwd)) == NULL)
+ goto fail;
+ reply[i].resp_retcode = PAM_SUCCESS;
+ break;
+ case PAM_ERROR_MSG:
+ ereport(LOG,
+ (errmsg("error from underlying PAM layer: %s",
+ msg[i]->msg)));
+ /* FALL THROUGH */
+ case PAM_TEXT_INFO:
+ /* we don't bother to log TEXT_INFO messages */
+ if ((reply[i].resp = strdup("")) == NULL)
+ goto fail;
+ reply[i].resp_retcode = PAM_SUCCESS;
+ break;
+ default:
+ ereport(LOG,
+ (errmsg("unsupported PAM conversation %d/\"%s\"",
+ msg[i]->msg_style,
+ msg[i]->msg ? msg[i]->msg : "(none)")));
+ goto fail;
+ }
+ }
+
+ *resp = reply;
+ return PAM_SUCCESS;
+
+fail:
+ /* free up whatever we allocated */
+ for (i = 0; i < num_msg; i++)
+ {
+ if (reply[i].resp != NULL)
+ free(reply[i].resp);
+ }
+ free(reply);
+
+ return PAM_CONV_ERR;
+}
+
+
+/*
+ * Check authentication against PAM.
+ */
+static int
+CheckPAMAuth(Port *port, const char *user, const char *password)
+{
+ int retval;
+ pam_handle_t *pamh = NULL;
+
+ /*
+ * We can't entirely rely on PAM to pass through appdata --- it appears
+ * not to work on at least Solaris 2.6. So use these ugly static
+ * variables instead.
+ */
+ pam_passwd = password;
+ pam_port_cludge = port;
+ pam_no_password = false;
+
+ /*
+ * Set the application data portion of the conversation struct. This is
+ * later used inside the PAM conversation to pass the password to the
+ * authentication module.
+ */
+ pam_passw_conv.appdata_ptr = unconstify(char *, password); /* from password above,
+ * not allocated */
+
+ /* Optionally, one can set the service name in pg_hba.conf */
+ if (port->hba->pamservice && port->hba->pamservice[0] != '\0')
+ retval = pam_start(port->hba->pamservice, "pgsql@",
+ &pam_passw_conv, &pamh);
+ else
+ retval = pam_start(PGSQL_PAM_SERVICE, "pgsql@",
+ &pam_passw_conv, &pamh);
+
+ if (retval != PAM_SUCCESS)
+ {
+ ereport(LOG,
+ (errmsg("could not create PAM authenticator: %s",
+ pam_strerror(pamh, retval))));
+ pam_passwd = NULL; /* Unset pam_passwd */
+ return STATUS_ERROR;
+ }
+
+ retval = pam_set_item(pamh, PAM_USER, user);
+
+ if (retval != PAM_SUCCESS)
+ {
+ ereport(LOG,
+ (errmsg("pam_set_item(PAM_USER) failed: %s",
+ pam_strerror(pamh, retval))));
+ pam_passwd = NULL; /* Unset pam_passwd */
+ return STATUS_ERROR;
+ }
+
+ if (port->hba->conntype != ctLocal)
+ {
+ char hostinfo[NI_MAXHOST];
+ int flags;
+
+ if (port->hba->pam_use_hostname)
+ flags = 0;
+ else
+ flags = NI_NUMERICHOST | NI_NUMERICSERV;
+
+ retval = pg_getnameinfo_all(&port->raddr.addr, port->raddr.salen,
+ hostinfo, sizeof(hostinfo), NULL, 0,
+ flags);
+ if (retval != 0)
+ {
+ ereport(WARNING,
+ (errmsg_internal("pg_getnameinfo_all() failed: %s",
+ gai_strerror(retval))));
+ return STATUS_ERROR;
+ }
+
+ retval = pam_set_item(pamh, PAM_RHOST, hostinfo);
+
+ if (retval != PAM_SUCCESS)
+ {
+ ereport(LOG,
+ (errmsg("pam_set_item(PAM_RHOST) failed: %s",
+ pam_strerror(pamh, retval))));
+ pam_passwd = NULL;
+ return STATUS_ERROR;
+ }
+ }
+
+ retval = pam_set_item(pamh, PAM_CONV, &pam_passw_conv);
+
+ if (retval != PAM_SUCCESS)
+ {
+ ereport(LOG,
+ (errmsg("pam_set_item(PAM_CONV) failed: %s",
+ pam_strerror(pamh, retval))));
+ pam_passwd = NULL; /* Unset pam_passwd */
+ return STATUS_ERROR;
+ }
+
+ retval = pam_authenticate(pamh, 0);
+
+ if (retval != PAM_SUCCESS)
+ {
+ /* If pam_passwd_conv_proc saw EOF, don't log anything */
+ if (!pam_no_password)
+ ereport(LOG,
+ (errmsg("pam_authenticate failed: %s",
+ pam_strerror(pamh, retval))));
+ pam_passwd = NULL; /* Unset pam_passwd */
+ return pam_no_password ? STATUS_EOF : STATUS_ERROR;
+ }
+
+ retval = pam_acct_mgmt(pamh, 0);
+
+ if (retval != PAM_SUCCESS)
+ {
+ /* If pam_passwd_conv_proc saw EOF, don't log anything */
+ if (!pam_no_password)
+ ereport(LOG,
+ (errmsg("pam_acct_mgmt failed: %s",
+ pam_strerror(pamh, retval))));
+ pam_passwd = NULL; /* Unset pam_passwd */
+ return pam_no_password ? STATUS_EOF : STATUS_ERROR;
+ }
+
+ retval = pam_end(pamh, retval);
+
+ if (retval != PAM_SUCCESS)
+ {
+ ereport(LOG,
+ (errmsg("could not release PAM authenticator: %s",
+ pam_strerror(pamh, retval))));
+ }
+
+ pam_passwd = NULL; /* Unset pam_passwd */
+
+ if (retval == PAM_SUCCESS)
+ set_authn_id(port, user);
+
+ return (retval == PAM_SUCCESS ? STATUS_OK : STATUS_ERROR);
+}
+#endif /* USE_PAM */
+
+
+/*----------------------------------------------------------------
+ * BSD authentication system
+ *----------------------------------------------------------------
+ */
+#ifdef USE_BSD_AUTH
+static int
+CheckBSDAuth(Port *port, char *user)
+{
+ char *passwd;
+ int retval;
+
+ /* Send regular password request to client, and get the response */
+ sendAuthRequest(port, AUTH_REQ_PASSWORD, NULL, 0);
+
+ passwd = recv_password_packet(port);
+ if (passwd == NULL)
+ return STATUS_EOF;
+
+ /*
+ * Ask the BSD auth system to verify password. Note that auth_userokay
+ * will overwrite the password string with zeroes, but it's just a
+ * temporary string so we don't care.
+ */
+ retval = auth_userokay(user, NULL, "auth-postgresql", passwd);
+
+ pfree(passwd);
+
+ if (!retval)
+ return STATUS_ERROR;
+
+ set_authn_id(port, user);
+ return STATUS_OK;
+}
+#endif /* USE_BSD_AUTH */
+
+
+/*----------------------------------------------------------------
+ * LDAP authentication system
+ *----------------------------------------------------------------
+ */
+#ifdef USE_LDAP
+
+static int errdetail_for_ldap(LDAP *ldap);
+
+/*
+ * Initialize a connection to the LDAP server, including setting up
+ * TLS if requested.
+ */
+static int
+InitializeLDAPConnection(Port *port, LDAP **ldap)
+{
+ const char *scheme;
+ int ldapversion = LDAP_VERSION3;
+ int r;
+
+ scheme = port->hba->ldapscheme;
+ if (scheme == NULL)
+ scheme = "ldap";
+#ifdef WIN32
+ if (strcmp(scheme, "ldaps") == 0)
+ *ldap = ldap_sslinit(port->hba->ldapserver, port->hba->ldapport, 1);
+ else
+ *ldap = ldap_init(port->hba->ldapserver, port->hba->ldapport);
+ if (!*ldap)
+ {
+ ereport(LOG,
+ (errmsg("could not initialize LDAP: error code %d",
+ (int) LdapGetLastError())));
+
+ return STATUS_ERROR;
+ }
+#else
+#ifdef HAVE_LDAP_INITIALIZE
+
+ /*
+ * OpenLDAP provides a non-standard extension ldap_initialize() that takes
+ * a list of URIs, allowing us to request "ldaps" instead of "ldap". It
+ * also provides ldap_domain2hostlist() to find LDAP servers automatically
+ * using DNS SRV. They were introduced in the same version, so for now we
+ * don't have an extra configure check for the latter.
+ */
+ {
+ StringInfoData uris;
+ char *hostlist = NULL;
+ char *p;
+ bool append_port;
+
+ /* We'll build a space-separated scheme://hostname:port list here */
+ initStringInfo(&uris);
+
+ /*
+ * If pg_hba.conf provided no hostnames, we can ask OpenLDAP to try to
+ * find some by extracting a domain name from the base DN and looking
+ * up DSN SRV records for _ldap._tcp.<domain>.
+ */
+ if (!port->hba->ldapserver || port->hba->ldapserver[0] == '\0')
+ {
+ char *domain;
+
+ /* ou=blah,dc=foo,dc=bar -> foo.bar */
+ if (ldap_dn2domain(port->hba->ldapbasedn, &domain))
+ {
+ ereport(LOG,
+ (errmsg("could not extract domain name from ldapbasedn")));
+ return STATUS_ERROR;
+ }
+
+ /* Look up a list of LDAP server hosts and port numbers */
+ if (ldap_domain2hostlist(domain, &hostlist))
+ {
+ ereport(LOG,
+ (errmsg("LDAP authentication could not find DNS SRV records for \"%s\"",
+ domain),
+ (errhint("Set an LDAP server name explicitly."))));
+ ldap_memfree(domain);
+ return STATUS_ERROR;
+ }
+ ldap_memfree(domain);
+
+ /* We have a space-separated list of host:port entries */
+ p = hostlist;
+ append_port = false;
+ }
+ else
+ {
+ /* We have a space-separated list of hosts from pg_hba.conf */
+ p = port->hba->ldapserver;
+ append_port = true;
+ }
+
+ /* Convert the list of host[:port] entries to full URIs */
+ do
+ {
+ size_t size;
+
+ /* Find the span of the next entry */
+ size = strcspn(p, " ");
+
+ /* Append a space separator if this isn't the first URI */
+ if (uris.len > 0)
+ appendStringInfoChar(&uris, ' ');
+
+ /* Append scheme://host:port */
+ appendStringInfoString(&uris, scheme);
+ appendStringInfoString(&uris, "://");
+ appendBinaryStringInfo(&uris, p, size);
+ if (append_port)
+ appendStringInfo(&uris, ":%d", port->hba->ldapport);
+
+ /* Step over this entry and any number of trailing spaces */
+ p += size;
+ while (*p == ' ')
+ ++p;
+ } while (*p);
+
+ /* Free memory from OpenLDAP if we looked up SRV records */
+ if (hostlist)
+ ldap_memfree(hostlist);
+
+ /* Finally, try to connect using the URI list */
+ r = ldap_initialize(ldap, uris.data);
+ pfree(uris.data);
+ if (r != LDAP_SUCCESS)
+ {
+ ereport(LOG,
+ (errmsg("could not initialize LDAP: %s",
+ ldap_err2string(r))));
+
+ return STATUS_ERROR;
+ }
+ }
+#else
+ if (strcmp(scheme, "ldaps") == 0)
+ {
+ ereport(LOG,
+ (errmsg("ldaps not supported with this LDAP library")));
+
+ return STATUS_ERROR;
+ }
+ *ldap = ldap_init(port->hba->ldapserver, port->hba->ldapport);
+ if (!*ldap)
+ {
+ ereport(LOG,
+ (errmsg("could not initialize LDAP: %m")));
+
+ return STATUS_ERROR;
+ }
+#endif
+#endif
+
+ if ((r = ldap_set_option(*ldap, LDAP_OPT_PROTOCOL_VERSION, &ldapversion)) != LDAP_SUCCESS)
+ {
+ ereport(LOG,
+ (errmsg("could not set LDAP protocol version: %s",
+ ldap_err2string(r)),
+ errdetail_for_ldap(*ldap)));
+ ldap_unbind(*ldap);
+ return STATUS_ERROR;
+ }
+
+ if (port->hba->ldaptls)
+ {
+#ifndef WIN32
+ if ((r = ldap_start_tls_s(*ldap, NULL, NULL)) != LDAP_SUCCESS)
+#else
+ static __ldap_start_tls_sA _ldap_start_tls_sA = NULL;
+
+ if (_ldap_start_tls_sA == NULL)
+ {
+ /*
+ * Need to load this function dynamically because it may not exist
+ * on Windows, and causes a load error for the whole exe if
+ * referenced.
+ */
+ HANDLE ldaphandle;
+
+ ldaphandle = LoadLibrary("WLDAP32.DLL");
+ if (ldaphandle == NULL)
+ {
+ /*
+ * should never happen since we import other files from
+ * wldap32, but check anyway
+ */
+ ereport(LOG,
+ (errmsg("could not load library \"%s\": error code %lu",
+ "WLDAP32.DLL", GetLastError())));
+ ldap_unbind(*ldap);
+ return STATUS_ERROR;
+ }
+ _ldap_start_tls_sA = (__ldap_start_tls_sA) (pg_funcptr_t) GetProcAddress(ldaphandle, "ldap_start_tls_sA");
+ if (_ldap_start_tls_sA == NULL)
+ {
+ ereport(LOG,
+ (errmsg("could not load function _ldap_start_tls_sA in wldap32.dll"),
+ errdetail("LDAP over SSL is not supported on this platform.")));
+ ldap_unbind(*ldap);
+ FreeLibrary(ldaphandle);
+ return STATUS_ERROR;
+ }
+
+ /*
+ * Leak LDAP handle on purpose, because we need the library to
+ * stay open. This is ok because it will only ever be leaked once
+ * per process and is automatically cleaned up on process exit.
+ */
+ }
+ if ((r = _ldap_start_tls_sA(*ldap, NULL, NULL, NULL, NULL)) != LDAP_SUCCESS)
+#endif
+ {
+ ereport(LOG,
+ (errmsg("could not start LDAP TLS session: %s",
+ ldap_err2string(r)),
+ errdetail_for_ldap(*ldap)));
+ ldap_unbind(*ldap);
+ return STATUS_ERROR;
+ }
+ }
+
+ return STATUS_OK;
+}
+
+/* Placeholders recognized by FormatSearchFilter. For now just one. */
+#define LPH_USERNAME "$username"
+#define LPH_USERNAME_LEN (sizeof(LPH_USERNAME) - 1)
+
+/* Not all LDAP implementations define this. */
+#ifndef LDAP_NO_ATTRS
+#define LDAP_NO_ATTRS "1.1"
+#endif
+
+/* Not all LDAP implementations define this. */
+#ifndef LDAPS_PORT
+#define LDAPS_PORT 636
+#endif
+
+/*
+ * Return a newly allocated C string copied from "pattern" with all
+ * occurrences of the placeholder "$username" replaced with "user_name".
+ */
+static char *
+FormatSearchFilter(const char *pattern, const char *user_name)
+{
+ StringInfoData output;
+
+ initStringInfo(&output);
+ while (*pattern != '\0')
+ {
+ if (strncmp(pattern, LPH_USERNAME, LPH_USERNAME_LEN) == 0)
+ {
+ appendStringInfoString(&output, user_name);
+ pattern += LPH_USERNAME_LEN;
+ }
+ else
+ appendStringInfoChar(&output, *pattern++);
+ }
+
+ return output.data;
+}
+
+/*
+ * Perform LDAP authentication
+ */
+static int
+CheckLDAPAuth(Port *port)
+{
+ char *passwd;
+ LDAP *ldap;
+ int r;
+ char *fulluser;
+ const char *server_name;
+
+#ifdef HAVE_LDAP_INITIALIZE
+
+ /*
+ * For OpenLDAP, allow empty hostname if we have a basedn. We'll look for
+ * servers with DNS SRV records via OpenLDAP library facilities.
+ */
+ if ((!port->hba->ldapserver || port->hba->ldapserver[0] == '\0') &&
+ (!port->hba->ldapbasedn || port->hba->ldapbasedn[0] == '\0'))
+ {
+ ereport(LOG,
+ (errmsg("LDAP server not specified, and no ldapbasedn")));
+ return STATUS_ERROR;
+ }
+#else
+ if (!port->hba->ldapserver || port->hba->ldapserver[0] == '\0')
+ {
+ ereport(LOG,
+ (errmsg("LDAP server not specified")));
+ return STATUS_ERROR;
+ }
+#endif
+
+ /*
+ * If we're using SRV records, we don't have a server name so we'll just
+ * show an empty string in error messages.
+ */
+ server_name = port->hba->ldapserver ? port->hba->ldapserver : "";
+
+ if (port->hba->ldapport == 0)
+ {
+ if (port->hba->ldapscheme != NULL &&
+ strcmp(port->hba->ldapscheme, "ldaps") == 0)
+ port->hba->ldapport = LDAPS_PORT;
+ else
+ port->hba->ldapport = LDAP_PORT;
+ }
+
+ sendAuthRequest(port, AUTH_REQ_PASSWORD, NULL, 0);
+
+ passwd = recv_password_packet(port);
+ if (passwd == NULL)
+ return STATUS_EOF; /* client wouldn't send password */
+
+ if (InitializeLDAPConnection(port, &ldap) == STATUS_ERROR)
+ {
+ /* Error message already sent */
+ pfree(passwd);
+ return STATUS_ERROR;
+ }
+
+ if (port->hba->ldapbasedn)
+ {
+ /*
+ * First perform an LDAP search to find the DN for the user we are
+ * trying to log in as.
+ */
+ char *filter;
+ LDAPMessage *search_message;
+ LDAPMessage *entry;
+ char *attributes[] = {LDAP_NO_ATTRS, NULL};
+ char *dn;
+ char *c;
+ int count;
+
+ /*
+ * Disallow any characters that we would otherwise need to escape,
+ * since they aren't really reasonable in a username anyway. Allowing
+ * them would make it possible to inject any kind of custom filters in
+ * the LDAP filter.
+ */
+ for (c = port->user_name; *c; c++)
+ {
+ if (*c == '*' ||
+ *c == '(' ||
+ *c == ')' ||
+ *c == '\\' ||
+ *c == '/')
+ {
+ ereport(LOG,
+ (errmsg("invalid character in user name for LDAP authentication")));
+ ldap_unbind(ldap);
+ pfree(passwd);
+ return STATUS_ERROR;
+ }
+ }
+
+ /*
+ * Bind with a pre-defined username/password (if available) for
+ * searching. If none is specified, this turns into an anonymous bind.
+ */
+ r = ldap_simple_bind_s(ldap,
+ port->hba->ldapbinddn ? port->hba->ldapbinddn : "",
+ port->hba->ldapbindpasswd ? port->hba->ldapbindpasswd : "");
+ if (r != LDAP_SUCCESS)
+ {
+ ereport(LOG,
+ (errmsg("could not perform initial LDAP bind for ldapbinddn \"%s\" on server \"%s\": %s",
+ port->hba->ldapbinddn ? port->hba->ldapbinddn : "",
+ server_name,
+ ldap_err2string(r)),
+ errdetail_for_ldap(ldap)));
+ ldap_unbind(ldap);
+ pfree(passwd);
+ return STATUS_ERROR;
+ }
+
+ /* Build a custom filter or a single attribute filter? */
+ if (port->hba->ldapsearchfilter)
+ filter = FormatSearchFilter(port->hba->ldapsearchfilter, port->user_name);
+ else if (port->hba->ldapsearchattribute)
+ filter = psprintf("(%s=%s)", port->hba->ldapsearchattribute, port->user_name);
+ else
+ filter = psprintf("(uid=%s)", port->user_name);
+
+ r = ldap_search_s(ldap,
+ port->hba->ldapbasedn,
+ port->hba->ldapscope,
+ filter,
+ attributes,
+ 0,
+ &search_message);
+
+ if (r != LDAP_SUCCESS)
+ {
+ ereport(LOG,
+ (errmsg("could not search LDAP for filter \"%s\" on server \"%s\": %s",
+ filter, server_name, ldap_err2string(r)),
+ errdetail_for_ldap(ldap)));
+ ldap_unbind(ldap);
+ pfree(passwd);
+ pfree(filter);
+ return STATUS_ERROR;
+ }
+
+ count = ldap_count_entries(ldap, search_message);
+ if (count != 1)
+ {
+ if (count == 0)
+ ereport(LOG,
+ (errmsg("LDAP user \"%s\" does not exist", port->user_name),
+ errdetail("LDAP search for filter \"%s\" on server \"%s\" returned no entries.",
+ filter, server_name)));
+ else
+ ereport(LOG,
+ (errmsg("LDAP user \"%s\" is not unique", port->user_name),
+ errdetail_plural("LDAP search for filter \"%s\" on server \"%s\" returned %d entry.",
+ "LDAP search for filter \"%s\" on server \"%s\" returned %d entries.",
+ count,
+ filter, server_name, count)));
+
+ ldap_unbind(ldap);
+ pfree(passwd);
+ pfree(filter);
+ ldap_msgfree(search_message);
+ return STATUS_ERROR;
+ }
+
+ entry = ldap_first_entry(ldap, search_message);
+ dn = ldap_get_dn(ldap, entry);
+ if (dn == NULL)
+ {
+ int error;
+
+ (void) ldap_get_option(ldap, LDAP_OPT_ERROR_NUMBER, &error);
+ ereport(LOG,
+ (errmsg("could not get dn for the first entry matching \"%s\" on server \"%s\": %s",
+ filter, server_name,
+ ldap_err2string(error)),
+ errdetail_for_ldap(ldap)));
+ ldap_unbind(ldap);
+ pfree(passwd);
+ pfree(filter);
+ ldap_msgfree(search_message);
+ return STATUS_ERROR;
+ }
+ fulluser = pstrdup(dn);
+
+ pfree(filter);
+ ldap_memfree(dn);
+ ldap_msgfree(search_message);
+
+ /* Unbind and disconnect from the LDAP server */
+ r = ldap_unbind_s(ldap);
+ if (r != LDAP_SUCCESS)
+ {
+ ereport(LOG,
+ (errmsg("could not unbind after searching for user \"%s\" on server \"%s\"",
+ fulluser, server_name)));
+ pfree(passwd);
+ pfree(fulluser);
+ return STATUS_ERROR;
+ }
+
+ /*
+ * Need to re-initialize the LDAP connection, so that we can bind to
+ * it with a different username.
+ */
+ if (InitializeLDAPConnection(port, &ldap) == STATUS_ERROR)
+ {
+ pfree(passwd);
+ pfree(fulluser);
+
+ /* Error message already sent */
+ return STATUS_ERROR;
+ }
+ }
+ else
+ fulluser = psprintf("%s%s%s",
+ port->hba->ldapprefix ? port->hba->ldapprefix : "",
+ port->user_name,
+ port->hba->ldapsuffix ? port->hba->ldapsuffix : "");
+
+ r = ldap_simple_bind_s(ldap, fulluser, passwd);
+
+ if (r != LDAP_SUCCESS)
+ {
+ ereport(LOG,
+ (errmsg("LDAP login failed for user \"%s\" on server \"%s\": %s",
+ fulluser, server_name, ldap_err2string(r)),
+ errdetail_for_ldap(ldap)));
+ ldap_unbind(ldap);
+ pfree(passwd);
+ pfree(fulluser);
+ return STATUS_ERROR;
+ }
+
+ /* Save the original bind DN as the authenticated identity. */
+ set_authn_id(port, fulluser);
+
+ ldap_unbind(ldap);
+ pfree(passwd);
+ pfree(fulluser);
+
+ return STATUS_OK;
+}
+
+/*
+ * Add a detail error message text to the current error if one can be
+ * constructed from the LDAP 'diagnostic message'.
+ */
+static int
+errdetail_for_ldap(LDAP *ldap)
+{
+ char *message;
+ int rc;
+
+ rc = ldap_get_option(ldap, LDAP_OPT_DIAGNOSTIC_MESSAGE, &message);
+ if (rc == LDAP_SUCCESS && message != NULL)
+ {
+ errdetail("LDAP diagnostics: %s", message);
+ ldap_memfree(message);
+ }
+
+ return 0;
+}
+
+#endif /* USE_LDAP */
+
+
+/*----------------------------------------------------------------
+ * SSL client certificate authentication
+ *----------------------------------------------------------------
+ */
+#ifdef USE_SSL
+static int
+CheckCertAuth(Port *port)
+{
+ int status_check_usermap = STATUS_ERROR;
+ char *peer_username = NULL;
+
+ Assert(port->ssl);
+
+ /* select the correct field to compare */
+ switch (port->hba->clientcertname)
+ {
+ case clientCertDN:
+ peer_username = port->peer_dn;
+ break;
+ case clientCertCN:
+ peer_username = port->peer_cn;
+ }
+
+ /* Make sure we have received a username in the certificate */
+ if (peer_username == NULL ||
+ strlen(peer_username) <= 0)
+ {
+ ereport(LOG,
+ (errmsg("certificate authentication failed for user \"%s\": client certificate contains no user name",
+ port->user_name)));
+ return STATUS_ERROR;
+ }
+
+ if (port->hba->auth_method == uaCert)
+ {
+ /*
+ * For cert auth, the client's Subject DN is always our authenticated
+ * identity, even if we're only using its CN for authorization. Set
+ * it now, rather than waiting for check_usermap() below, because
+ * authentication has already succeeded and we want the log file to
+ * reflect that.
+ */
+ if (!port->peer_dn)
+ {
+ /*
+ * This should not happen as both peer_dn and peer_cn should be
+ * set in this context.
+ */
+ ereport(LOG,
+ (errmsg("certificate authentication failed for user \"%s\": unable to retrieve subject DN",
+ port->user_name)));
+ return STATUS_ERROR;
+ }
+
+ set_authn_id(port, port->peer_dn);
+ }
+
+ /* Just pass the certificate cn/dn to the usermap check */
+ status_check_usermap = check_usermap(port->hba->usermap, port->user_name, peer_username, false);
+ if (status_check_usermap != STATUS_OK)
+ {
+ /*
+ * If clientcert=verify-full was specified and the authentication
+ * method is other than uaCert, log the reason for rejecting the
+ * authentication.
+ */
+ if (port->hba->clientcert == clientCertFull && port->hba->auth_method != uaCert)
+ {
+ switch (port->hba->clientcertname)
+ {
+ case clientCertDN:
+ ereport(LOG,
+ (errmsg("certificate validation (clientcert=verify-full) failed for user \"%s\": DN mismatch",
+ port->user_name)));
+ break;
+ case clientCertCN:
+ ereport(LOG,
+ (errmsg("certificate validation (clientcert=verify-full) failed for user \"%s\": CN mismatch",
+ port->user_name)));
+ }
+ }
+ }
+ return status_check_usermap;
+}
+#endif
+
+
+/*----------------------------------------------------------------
+ * RADIUS authentication
+ *----------------------------------------------------------------
+ */
+
+/*
+ * RADIUS authentication is described in RFC2865 (and several others).
+ */
+
+#define RADIUS_VECTOR_LENGTH 16
+#define RADIUS_HEADER_LENGTH 20
+#define RADIUS_MAX_PASSWORD_LENGTH 128
+
+/* Maximum size of a RADIUS packet we will create or accept */
+#define RADIUS_BUFFER_SIZE 1024
+
+typedef struct
+{
+ uint8 attribute;
+ uint8 length;
+ uint8 data[FLEXIBLE_ARRAY_MEMBER];
+} radius_attribute;
+
+typedef struct
+{
+ uint8 code;
+ uint8 id;
+ uint16 length;
+ uint8 vector[RADIUS_VECTOR_LENGTH];
+ /* this is a bit longer than strictly necessary: */
+ char pad[RADIUS_BUFFER_SIZE - RADIUS_VECTOR_LENGTH];
+} radius_packet;
+
+/* RADIUS packet types */
+#define RADIUS_ACCESS_REQUEST 1
+#define RADIUS_ACCESS_ACCEPT 2
+#define RADIUS_ACCESS_REJECT 3
+
+/* RADIUS attributes */
+#define RADIUS_USER_NAME 1
+#define RADIUS_PASSWORD 2
+#define RADIUS_SERVICE_TYPE 6
+#define RADIUS_NAS_IDENTIFIER 32
+
+/* RADIUS service types */
+#define RADIUS_AUTHENTICATE_ONLY 8
+
+/* Seconds to wait - XXX: should be in a config variable! */
+#define RADIUS_TIMEOUT 3
+
+static void
+radius_add_attribute(radius_packet *packet, uint8 type, const unsigned char *data, int len)
+{
+ radius_attribute *attr;
+
+ if (packet->length + len > RADIUS_BUFFER_SIZE)
+ {
+ /*
+ * With remotely realistic data, this can never happen. But catch it
+ * just to make sure we don't overrun a buffer. We'll just skip adding
+ * the broken attribute, which will in the end cause authentication to
+ * fail.
+ */
+ elog(WARNING,
+ "adding attribute code %d with length %d to radius packet would create oversize packet, ignoring",
+ type, len);
+ return;
+ }
+
+ attr = (radius_attribute *) ((unsigned char *) packet + packet->length);
+ attr->attribute = type;
+ attr->length = len + 2; /* total size includes type and length */
+ memcpy(attr->data, data, len);
+ packet->length += attr->length;
+}
+
+static int
+CheckRADIUSAuth(Port *port)
+{
+ char *passwd;
+ ListCell *server,
+ *secrets,
+ *radiusports,
+ *identifiers;
+
+ /* Make sure struct alignment is correct */
+ Assert(offsetof(radius_packet, vector) == 4);
+
+ /* Verify parameters */
+ if (list_length(port->hba->radiusservers) < 1)
+ {
+ ereport(LOG,
+ (errmsg("RADIUS server not specified")));
+ return STATUS_ERROR;
+ }
+
+ if (list_length(port->hba->radiussecrets) < 1)
+ {
+ ereport(LOG,
+ (errmsg("RADIUS secret not specified")));
+ return STATUS_ERROR;
+ }
+
+ /* Send regular password request to client, and get the response */
+ sendAuthRequest(port, AUTH_REQ_PASSWORD, NULL, 0);
+
+ passwd = recv_password_packet(port);
+ if (passwd == NULL)
+ return STATUS_EOF; /* client wouldn't send password */
+
+ if (strlen(passwd) > RADIUS_MAX_PASSWORD_LENGTH)
+ {
+ ereport(LOG,
+ (errmsg("RADIUS authentication does not support passwords longer than %d characters", RADIUS_MAX_PASSWORD_LENGTH)));
+ pfree(passwd);
+ return STATUS_ERROR;
+ }
+
+ /*
+ * Loop over and try each server in order.
+ */
+ secrets = list_head(port->hba->radiussecrets);
+ radiusports = list_head(port->hba->radiusports);
+ identifiers = list_head(port->hba->radiusidentifiers);
+ foreach(server, port->hba->radiusservers)
+ {
+ int ret = PerformRadiusTransaction(lfirst(server),
+ lfirst(secrets),
+ radiusports ? lfirst(radiusports) : NULL,
+ identifiers ? lfirst(identifiers) : NULL,
+ port->user_name,
+ passwd);
+
+ /*------
+ * STATUS_OK = Login OK
+ * STATUS_ERROR = Login not OK, but try next server
+ * STATUS_EOF = Login not OK, and don't try next server
+ *------
+ */
+ if (ret == STATUS_OK)
+ {
+ set_authn_id(port, port->user_name);
+
+ pfree(passwd);
+ return STATUS_OK;
+ }
+ else if (ret == STATUS_EOF)
+ {
+ pfree(passwd);
+ return STATUS_ERROR;
+ }
+
+ /*
+ * secret, port and identifiers either have length 0 (use default),
+ * length 1 (use the same everywhere) or the same length as servers.
+ * So if the length is >1, we advance one step. In other cases, we
+ * don't and will then reuse the correct value.
+ */
+ if (list_length(port->hba->radiussecrets) > 1)
+ secrets = lnext(port->hba->radiussecrets, secrets);
+ if (list_length(port->hba->radiusports) > 1)
+ radiusports = lnext(port->hba->radiusports, radiusports);
+ if (list_length(port->hba->radiusidentifiers) > 1)
+ identifiers = lnext(port->hba->radiusidentifiers, identifiers);
+ }
+
+ /* No servers left to try, so give up */
+ pfree(passwd);
+ return STATUS_ERROR;
+}
+
+static int
+PerformRadiusTransaction(const char *server, const char *secret, const char *portstr, const char *identifier, const char *user_name, const char *passwd)
+{
+ radius_packet radius_send_pack;
+ radius_packet radius_recv_pack;
+ radius_packet *packet = &radius_send_pack;
+ radius_packet *receivepacket = &radius_recv_pack;
+ char *radius_buffer = (char *) &radius_send_pack;
+ char *receive_buffer = (char *) &radius_recv_pack;
+ int32 service = pg_hton32(RADIUS_AUTHENTICATE_ONLY);
+ uint8 *cryptvector;
+ int encryptedpasswordlen;
+ uint8 encryptedpassword[RADIUS_MAX_PASSWORD_LENGTH];
+ uint8 *md5trailer;
+ int packetlength;
+ pgsocket sock;
+
+#ifdef HAVE_IPV6
+ struct sockaddr_in6 localaddr;
+ struct sockaddr_in6 remoteaddr;
+#else
+ struct sockaddr_in localaddr;
+ struct sockaddr_in remoteaddr;
+#endif
+ struct addrinfo hint;
+ struct addrinfo *serveraddrs;
+ int port;
+ ACCEPT_TYPE_ARG3 addrsize;
+ fd_set fdset;
+ struct timeval endtime;
+ int i,
+ j,
+ r;
+
+ /* Assign default values */
+ if (portstr == NULL)
+ portstr = "1812";
+ if (identifier == NULL)
+ identifier = "postgresql";
+
+ MemSet(&hint, 0, sizeof(hint));
+ hint.ai_socktype = SOCK_DGRAM;
+ hint.ai_family = AF_UNSPEC;
+ port = atoi(portstr);
+
+ r = pg_getaddrinfo_all(server, portstr, &hint, &serveraddrs);
+ if (r || !serveraddrs)
+ {
+ ereport(LOG,
+ (errmsg("could not translate RADIUS server name \"%s\" to address: %s",
+ server, gai_strerror(r))));
+ if (serveraddrs)
+ pg_freeaddrinfo_all(hint.ai_family, serveraddrs);
+ return STATUS_ERROR;
+ }
+ /* XXX: add support for multiple returned addresses? */
+
+ /* Construct RADIUS packet */
+ packet->code = RADIUS_ACCESS_REQUEST;
+ packet->length = RADIUS_HEADER_LENGTH;
+ if (!pg_strong_random(packet->vector, RADIUS_VECTOR_LENGTH))
+ {
+ ereport(LOG,
+ (errmsg("could not generate random encryption vector")));
+ pg_freeaddrinfo_all(hint.ai_family, serveraddrs);
+ return STATUS_ERROR;
+ }
+ packet->id = packet->vector[0];
+ radius_add_attribute(packet, RADIUS_SERVICE_TYPE, (const unsigned char *) &service, sizeof(service));
+ radius_add_attribute(packet, RADIUS_USER_NAME, (const unsigned char *) user_name, strlen(user_name));
+ radius_add_attribute(packet, RADIUS_NAS_IDENTIFIER, (const unsigned char *) identifier, strlen(identifier));
+
+ /*
+ * RADIUS password attributes are calculated as: e[0] = p[0] XOR
+ * MD5(secret + Request Authenticator) for the first group of 16 octets,
+ * and then: e[i] = p[i] XOR MD5(secret + e[i-1]) for the following ones
+ * (if necessary)
+ */
+ encryptedpasswordlen = ((strlen(passwd) + RADIUS_VECTOR_LENGTH - 1) / RADIUS_VECTOR_LENGTH) * RADIUS_VECTOR_LENGTH;
+ cryptvector = palloc(strlen(secret) + RADIUS_VECTOR_LENGTH);
+ memcpy(cryptvector, secret, strlen(secret));
+
+ /* for the first iteration, we use the Request Authenticator vector */
+ md5trailer = packet->vector;
+ for (i = 0; i < encryptedpasswordlen; i += RADIUS_VECTOR_LENGTH)
+ {
+ memcpy(cryptvector + strlen(secret), md5trailer, RADIUS_VECTOR_LENGTH);
+
+ /*
+ * .. and for subsequent iterations the result of the previous XOR
+ * (calculated below)
+ */
+ md5trailer = encryptedpassword + i;
+
+ if (!pg_md5_binary(cryptvector, strlen(secret) + RADIUS_VECTOR_LENGTH, encryptedpassword + i))
+ {
+ ereport(LOG,
+ (errmsg("could not perform MD5 encryption of password")));
+ pfree(cryptvector);
+ pg_freeaddrinfo_all(hint.ai_family, serveraddrs);
+ return STATUS_ERROR;
+ }
+
+ for (j = i; j < i + RADIUS_VECTOR_LENGTH; j++)
+ {
+ if (j < strlen(passwd))
+ encryptedpassword[j] = passwd[j] ^ encryptedpassword[j];
+ else
+ encryptedpassword[j] = '\0' ^ encryptedpassword[j];
+ }
+ }
+ pfree(cryptvector);
+
+ radius_add_attribute(packet, RADIUS_PASSWORD, encryptedpassword, encryptedpasswordlen);
+
+ /* Length needs to be in network order on the wire */
+ packetlength = packet->length;
+ packet->length = pg_hton16(packet->length);
+
+ sock = socket(serveraddrs[0].ai_family, SOCK_DGRAM, 0);
+ if (sock == PGINVALID_SOCKET)
+ {
+ ereport(LOG,
+ (errmsg("could not create RADIUS socket: %m")));
+ pg_freeaddrinfo_all(hint.ai_family, serveraddrs);
+ return STATUS_ERROR;
+ }
+
+ memset(&localaddr, 0, sizeof(localaddr));
+#ifdef HAVE_IPV6
+ localaddr.sin6_family = serveraddrs[0].ai_family;
+ localaddr.sin6_addr = in6addr_any;
+ if (localaddr.sin6_family == AF_INET6)
+ addrsize = sizeof(struct sockaddr_in6);
+ else
+ addrsize = sizeof(struct sockaddr_in);
+#else
+ localaddr.sin_family = serveraddrs[0].ai_family;
+ localaddr.sin_addr.s_addr = INADDR_ANY;
+ addrsize = sizeof(struct sockaddr_in);
+#endif
+
+ if (bind(sock, (struct sockaddr *) &localaddr, addrsize))
+ {
+ ereport(LOG,
+ (errmsg("could not bind local RADIUS socket: %m")));
+ closesocket(sock);
+ pg_freeaddrinfo_all(hint.ai_family, serveraddrs);
+ return STATUS_ERROR;
+ }
+
+ if (sendto(sock, radius_buffer, packetlength, 0,
+ serveraddrs[0].ai_addr, serveraddrs[0].ai_addrlen) < 0)
+ {
+ ereport(LOG,
+ (errmsg("could not send RADIUS packet: %m")));
+ closesocket(sock);
+ pg_freeaddrinfo_all(hint.ai_family, serveraddrs);
+ return STATUS_ERROR;
+ }
+
+ /* Don't need the server address anymore */
+ pg_freeaddrinfo_all(hint.ai_family, serveraddrs);
+
+ /*
+ * Figure out at what time we should time out. We can't just use a single
+ * call to select() with a timeout, since somebody can be sending invalid
+ * packets to our port thus causing us to retry in a loop and never time
+ * out.
+ *
+ * XXX: Using WaitLatchOrSocket() and doing a CHECK_FOR_INTERRUPTS() if
+ * the latch was set would improve the responsiveness to
+ * timeouts/cancellations.
+ */
+ gettimeofday(&endtime, NULL);
+ endtime.tv_sec += RADIUS_TIMEOUT;
+
+ while (true)
+ {
+ struct timeval timeout;
+ struct timeval now;
+ int64 timeoutval;
+
+ gettimeofday(&now, NULL);
+ timeoutval = (endtime.tv_sec * 1000000 + endtime.tv_usec) - (now.tv_sec * 1000000 + now.tv_usec);
+ if (timeoutval <= 0)
+ {
+ ereport(LOG,
+ (errmsg("timeout waiting for RADIUS response from %s",
+ server)));
+ closesocket(sock);
+ return STATUS_ERROR;
+ }
+ timeout.tv_sec = timeoutval / 1000000;
+ timeout.tv_usec = timeoutval % 1000000;
+
+ FD_ZERO(&fdset);
+ FD_SET(sock, &fdset);
+
+ r = select(sock + 1, &fdset, NULL, NULL, &timeout);
+ if (r < 0)
+ {
+ if (errno == EINTR)
+ continue;
+
+ /* Anything else is an actual error */
+ ereport(LOG,
+ (errmsg("could not check status on RADIUS socket: %m")));
+ closesocket(sock);
+ return STATUS_ERROR;
+ }
+ if (r == 0)
+ {
+ ereport(LOG,
+ (errmsg("timeout waiting for RADIUS response from %s",
+ server)));
+ closesocket(sock);
+ return STATUS_ERROR;
+ }
+
+ /*
+ * Attempt to read the response packet, and verify the contents.
+ *
+ * Any packet that's not actually a RADIUS packet, or otherwise does
+ * not validate as an explicit reject, is just ignored and we retry
+ * for another packet (until we reach the timeout). This is to avoid
+ * the possibility to denial-of-service the login by flooding the
+ * server with invalid packets on the port that we're expecting the
+ * RADIUS response on.
+ */
+
+ addrsize = sizeof(remoteaddr);
+ packetlength = recvfrom(sock, receive_buffer, RADIUS_BUFFER_SIZE, 0,
+ (struct sockaddr *) &remoteaddr, &addrsize);
+ if (packetlength < 0)
+ {
+ ereport(LOG,
+ (errmsg("could not read RADIUS response: %m")));
+ closesocket(sock);
+ return STATUS_ERROR;
+ }
+
+#ifdef HAVE_IPV6
+ if (remoteaddr.sin6_port != pg_hton16(port))
+#else
+ if (remoteaddr.sin_port != pg_hton16(port))
+#endif
+ {
+#ifdef HAVE_IPV6
+ ereport(LOG,
+ (errmsg("RADIUS response from %s was sent from incorrect port: %d",
+ server, pg_ntoh16(remoteaddr.sin6_port))));
+#else
+ ereport(LOG,
+ (errmsg("RADIUS response from %s was sent from incorrect port: %d",
+ server, pg_ntoh16(remoteaddr.sin_port))));
+#endif
+ continue;
+ }
+
+ if (packetlength < RADIUS_HEADER_LENGTH)
+ {
+ ereport(LOG,
+ (errmsg("RADIUS response from %s too short: %d", server, packetlength)));
+ continue;
+ }
+
+ if (packetlength != pg_ntoh16(receivepacket->length))
+ {
+ ereport(LOG,
+ (errmsg("RADIUS response from %s has corrupt length: %d (actual length %d)",
+ server, pg_ntoh16(receivepacket->length), packetlength)));
+ continue;
+ }
+
+ if (packet->id != receivepacket->id)
+ {
+ ereport(LOG,
+ (errmsg("RADIUS response from %s is to a different request: %d (should be %d)",
+ server, receivepacket->id, packet->id)));
+ continue;
+ }
+
+ /*
+ * Verify the response authenticator, which is calculated as
+ * MD5(Code+ID+Length+RequestAuthenticator+Attributes+Secret)
+ */
+ cryptvector = palloc(packetlength + strlen(secret));
+
+ memcpy(cryptvector, receivepacket, 4); /* code+id+length */
+ memcpy(cryptvector + 4, packet->vector, RADIUS_VECTOR_LENGTH); /* request
+ * authenticator, from
+ * original packet */
+ if (packetlength > RADIUS_HEADER_LENGTH) /* there may be no
+ * attributes at all */
+ memcpy(cryptvector + RADIUS_HEADER_LENGTH, receive_buffer + RADIUS_HEADER_LENGTH, packetlength - RADIUS_HEADER_LENGTH);
+ memcpy(cryptvector + packetlength, secret, strlen(secret));
+
+ if (!pg_md5_binary(cryptvector,
+ packetlength + strlen(secret),
+ encryptedpassword))
+ {
+ ereport(LOG,
+ (errmsg("could not perform MD5 encryption of received packet")));
+ pfree(cryptvector);
+ continue;
+ }
+ pfree(cryptvector);
+
+ if (memcmp(receivepacket->vector, encryptedpassword, RADIUS_VECTOR_LENGTH) != 0)
+ {
+ ereport(LOG,
+ (errmsg("RADIUS response from %s has incorrect MD5 signature",
+ server)));
+ continue;
+ }
+
+ if (receivepacket->code == RADIUS_ACCESS_ACCEPT)
+ {
+ closesocket(sock);
+ return STATUS_OK;
+ }
+ else if (receivepacket->code == RADIUS_ACCESS_REJECT)
+ {
+ closesocket(sock);
+ return STATUS_EOF;
+ }
+ else
+ {
+ ereport(LOG,
+ (errmsg("RADIUS response from %s has invalid code (%d) for user \"%s\"",
+ server, receivepacket->code, user_name)));
+ continue;
+ }
+ } /* while (true) */
+}
diff --git a/src/backend/libpq/be-fsstubs.c b/src/backend/libpq/be-fsstubs.c
new file mode 100644
index 0000000..63eaccc
--- /dev/null
+++ b/src/backend/libpq/be-fsstubs.c
@@ -0,0 +1,860 @@
+/*-------------------------------------------------------------------------
+ *
+ * be-fsstubs.c
+ * Builtin functions for open/close/read/write operations on large objects
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ *
+ * IDENTIFICATION
+ * src/backend/libpq/be-fsstubs.c
+ *
+ * NOTES
+ * This should be moved to a more appropriate place. It is here
+ * for lack of a better place.
+ *
+ * These functions store LargeObjectDesc structs in a private MemoryContext,
+ * which means that large object descriptors hang around until we destroy
+ * the context at transaction end. It'd be possible to prolong the lifetime
+ * of the context so that LO FDs are good across transactions (for example,
+ * we could release the context only if we see that no FDs remain open).
+ * But we'd need additional state in order to do the right thing at the
+ * end of an aborted transaction. FDs opened during an aborted xact would
+ * still need to be closed, since they might not be pointing at valid
+ * relations at all. Locking semantics are also an interesting problem
+ * if LOs stay open across transactions. For now, we'll stick with the
+ * existing documented semantics of LO FDs: they're only good within a
+ * transaction.
+ *
+ * As of PostgreSQL 8.0, much of the angst expressed above is no longer
+ * relevant, and in fact it'd be pretty easy to allow LO FDs to stay
+ * open across transactions. (Snapshot relevancy would still be an issue.)
+ * However backwards compatibility suggests that we should stick to the
+ * status quo.
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+
+#include <fcntl.h>
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "access/xact.h"
+#include "libpq/be-fsstubs.h"
+#include "libpq/libpq-fs.h"
+#include "miscadmin.h"
+#include "storage/fd.h"
+#include "storage/large_object.h"
+#include "utils/acl.h"
+#include "utils/builtins.h"
+#include "utils/memutils.h"
+#include "utils/snapmgr.h"
+
+/* define this to enable debug logging */
+/* #define FSDB 1 */
+/* chunk size for lo_import/lo_export transfers */
+#define BUFSIZE 8192
+
+/*
+ * LO "FD"s are indexes into the cookies array.
+ *
+ * A non-null entry is a pointer to a LargeObjectDesc allocated in the
+ * LO private memory context "fscxt". The cookies array itself is also
+ * dynamically allocated in that context. Its current allocated size is
+ * cookies_size entries, of which any unused entries will be NULL.
+ */
+static LargeObjectDesc **cookies = NULL;
+static int cookies_size = 0;
+
+static bool lo_cleanup_needed = false;
+static MemoryContext fscxt = NULL;
+
+static int newLOfd(void);
+static void closeLOfd(int fd);
+static Oid lo_import_internal(text *filename, Oid lobjOid);
+
+
+/*****************************************************************************
+ * File Interfaces for Large Objects
+ *****************************************************************************/
+
+Datum
+be_lo_open(PG_FUNCTION_ARGS)
+{
+ Oid lobjId = PG_GETARG_OID(0);
+ int32 mode = PG_GETARG_INT32(1);
+ LargeObjectDesc *lobjDesc;
+ int fd;
+
+#ifdef FSDB
+ elog(DEBUG4, "lo_open(%u,%d)", lobjId, mode);
+#endif
+
+ /*
+ * Allocate a large object descriptor first. This will also create
+ * 'fscxt' if this is the first LO opened in this transaction.
+ */
+ fd = newLOfd();
+
+ lobjDesc = inv_open(lobjId, mode, fscxt);
+ lobjDesc->subid = GetCurrentSubTransactionId();
+
+ /*
+ * We must register the snapshot in TopTransaction's resowner so that it
+ * stays alive until the LO is closed rather than until the current portal
+ * shuts down.
+ */
+ if (lobjDesc->snapshot)
+ lobjDesc->snapshot = RegisterSnapshotOnOwner(lobjDesc->snapshot,
+ TopTransactionResourceOwner);
+
+ Assert(cookies[fd] == NULL);
+ cookies[fd] = lobjDesc;
+
+ PG_RETURN_INT32(fd);
+}
+
+Datum
+be_lo_close(PG_FUNCTION_ARGS)
+{
+ int32 fd = PG_GETARG_INT32(0);
+
+ if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL)
+ ereport(ERROR,
+ (errcode(ERRCODE_UNDEFINED_OBJECT),
+ errmsg("invalid large-object descriptor: %d", fd)));
+
+#ifdef FSDB
+ elog(DEBUG4, "lo_close(%d)", fd);
+#endif
+
+ closeLOfd(fd);
+
+ PG_RETURN_INT32(0);
+}
+
+
+/*****************************************************************************
+ * Bare Read/Write operations --- these are not fmgr-callable!
+ *
+ * We assume the large object supports byte oriented reads and seeks so
+ * that our work is easier.
+ *
+ *****************************************************************************/
+
+int
+lo_read(int fd, char *buf, int len)
+{
+ int status;
+ LargeObjectDesc *lobj;
+
+ if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL)
+ ereport(ERROR,
+ (errcode(ERRCODE_UNDEFINED_OBJECT),
+ errmsg("invalid large-object descriptor: %d", fd)));
+ lobj = cookies[fd];
+
+ /*
+ * Check state. inv_read() would throw an error anyway, but we want the
+ * error to be about the FD's state not the underlying privilege; it might
+ * be that the privilege exists but user forgot to ask for read mode.
+ */
+ if ((lobj->flags & IFS_RDLOCK) == 0)
+ ereport(ERROR,
+ (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
+ errmsg("large object descriptor %d was not opened for reading",
+ fd)));
+
+ status = inv_read(lobj, buf, len);
+
+ return status;
+}
+
+int
+lo_write(int fd, const char *buf, int len)
+{
+ int status;
+ LargeObjectDesc *lobj;
+
+ if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL)
+ ereport(ERROR,
+ (errcode(ERRCODE_UNDEFINED_OBJECT),
+ errmsg("invalid large-object descriptor: %d", fd)));
+ lobj = cookies[fd];
+
+ /* see comment in lo_read() */
+ if ((lobj->flags & IFS_WRLOCK) == 0)
+ ereport(ERROR,
+ (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
+ errmsg("large object descriptor %d was not opened for writing",
+ fd)));
+
+ status = inv_write(lobj, buf, len);
+
+ return status;
+}
+
+Datum
+be_lo_lseek(PG_FUNCTION_ARGS)
+{
+ int32 fd = PG_GETARG_INT32(0);
+ int32 offset = PG_GETARG_INT32(1);
+ int32 whence = PG_GETARG_INT32(2);
+ int64 status;
+
+ if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL)
+ ereport(ERROR,
+ (errcode(ERRCODE_UNDEFINED_OBJECT),
+ errmsg("invalid large-object descriptor: %d", fd)));
+
+ status = inv_seek(cookies[fd], offset, whence);
+
+ /* guard against result overflow */
+ if (status != (int32) status)
+ ereport(ERROR,
+ (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
+ errmsg("lo_lseek result out of range for large-object descriptor %d",
+ fd)));
+
+ PG_RETURN_INT32((int32) status);
+}
+
+Datum
+be_lo_lseek64(PG_FUNCTION_ARGS)
+{
+ int32 fd = PG_GETARG_INT32(0);
+ int64 offset = PG_GETARG_INT64(1);
+ int32 whence = PG_GETARG_INT32(2);
+ int64 status;
+
+ if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL)
+ ereport(ERROR,
+ (errcode(ERRCODE_UNDEFINED_OBJECT),
+ errmsg("invalid large-object descriptor: %d", fd)));
+
+ status = inv_seek(cookies[fd], offset, whence);
+
+ PG_RETURN_INT64(status);
+}
+
+Datum
+be_lo_creat(PG_FUNCTION_ARGS)
+{
+ Oid lobjId;
+
+ lo_cleanup_needed = true;
+ lobjId = inv_create(InvalidOid);
+
+ PG_RETURN_OID(lobjId);
+}
+
+Datum
+be_lo_create(PG_FUNCTION_ARGS)
+{
+ Oid lobjId = PG_GETARG_OID(0);
+
+ lo_cleanup_needed = true;
+ lobjId = inv_create(lobjId);
+
+ PG_RETURN_OID(lobjId);
+}
+
+Datum
+be_lo_tell(PG_FUNCTION_ARGS)
+{
+ int32 fd = PG_GETARG_INT32(0);
+ int64 offset;
+
+ if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL)
+ ereport(ERROR,
+ (errcode(ERRCODE_UNDEFINED_OBJECT),
+ errmsg("invalid large-object descriptor: %d", fd)));
+
+ offset = inv_tell(cookies[fd]);
+
+ /* guard against result overflow */
+ if (offset != (int32) offset)
+ ereport(ERROR,
+ (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE),
+ errmsg("lo_tell result out of range for large-object descriptor %d",
+ fd)));
+
+ PG_RETURN_INT32((int32) offset);
+}
+
+Datum
+be_lo_tell64(PG_FUNCTION_ARGS)
+{
+ int32 fd = PG_GETARG_INT32(0);
+ int64 offset;
+
+ if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL)
+ ereport(ERROR,
+ (errcode(ERRCODE_UNDEFINED_OBJECT),
+ errmsg("invalid large-object descriptor: %d", fd)));
+
+ offset = inv_tell(cookies[fd]);
+
+ PG_RETURN_INT64(offset);
+}
+
+Datum
+be_lo_unlink(PG_FUNCTION_ARGS)
+{
+ Oid lobjId = PG_GETARG_OID(0);
+
+ /*
+ * Must be owner of the large object. It would be cleaner to check this
+ * in inv_drop(), but we want to throw the error before not after closing
+ * relevant FDs.
+ */
+ if (!lo_compat_privileges &&
+ !pg_largeobject_ownercheck(lobjId, GetUserId()))
+ ereport(ERROR,
+ (errcode(ERRCODE_INSUFFICIENT_PRIVILEGE),
+ errmsg("must be owner of large object %u", lobjId)));
+
+ /*
+ * If there are any open LO FDs referencing that ID, close 'em.
+ */
+ if (fscxt != NULL)
+ {
+ int i;
+
+ for (i = 0; i < cookies_size; i++)
+ {
+ if (cookies[i] != NULL && cookies[i]->id == lobjId)
+ closeLOfd(i);
+ }
+ }
+
+ /*
+ * inv_drop does not create a need for end-of-transaction cleanup and
+ * hence we don't need to set lo_cleanup_needed.
+ */
+ PG_RETURN_INT32(inv_drop(lobjId));
+}
+
+/*****************************************************************************
+ * Read/Write using bytea
+ *****************************************************************************/
+
+Datum
+be_loread(PG_FUNCTION_ARGS)
+{
+ int32 fd = PG_GETARG_INT32(0);
+ int32 len = PG_GETARG_INT32(1);
+ bytea *retval;
+ int totalread;
+
+ if (len < 0)
+ len = 0;
+
+ retval = (bytea *) palloc(VARHDRSZ + len);
+ totalread = lo_read(fd, VARDATA(retval), len);
+ SET_VARSIZE(retval, totalread + VARHDRSZ);
+
+ PG_RETURN_BYTEA_P(retval);
+}
+
+Datum
+be_lowrite(PG_FUNCTION_ARGS)
+{
+ int32 fd = PG_GETARG_INT32(0);
+ bytea *wbuf = PG_GETARG_BYTEA_PP(1);
+ int bytestowrite;
+ int totalwritten;
+
+ bytestowrite = VARSIZE_ANY_EXHDR(wbuf);
+ totalwritten = lo_write(fd, VARDATA_ANY(wbuf), bytestowrite);
+ PG_RETURN_INT32(totalwritten);
+}
+
+/*****************************************************************************
+ * Import/Export of Large Object
+ *****************************************************************************/
+
+/*
+ * lo_import -
+ * imports a file as an (inversion) large object.
+ */
+Datum
+be_lo_import(PG_FUNCTION_ARGS)
+{
+ text *filename = PG_GETARG_TEXT_PP(0);
+
+ PG_RETURN_OID(lo_import_internal(filename, InvalidOid));
+}
+
+/*
+ * lo_import_with_oid -
+ * imports a file as an (inversion) large object specifying oid.
+ */
+Datum
+be_lo_import_with_oid(PG_FUNCTION_ARGS)
+{
+ text *filename = PG_GETARG_TEXT_PP(0);
+ Oid oid = PG_GETARG_OID(1);
+
+ PG_RETURN_OID(lo_import_internal(filename, oid));
+}
+
+static Oid
+lo_import_internal(text *filename, Oid lobjOid)
+{
+ int fd;
+ int nbytes,
+ tmp PG_USED_FOR_ASSERTS_ONLY;
+ char buf[BUFSIZE];
+ char fnamebuf[MAXPGPATH];
+ LargeObjectDesc *lobj;
+ Oid oid;
+
+ /*
+ * open the file to be read in
+ */
+ text_to_cstring_buffer(filename, fnamebuf, sizeof(fnamebuf));
+ fd = OpenTransientFile(fnamebuf, O_RDONLY | PG_BINARY);
+ if (fd < 0)
+ ereport(ERROR,
+ (errcode_for_file_access(),
+ errmsg("could not open server file \"%s\": %m",
+ fnamebuf)));
+
+ /*
+ * create an inversion object
+ */
+ lo_cleanup_needed = true;
+ oid = inv_create(lobjOid);
+
+ /*
+ * read in from the filesystem and write to the inversion object
+ */
+ lobj = inv_open(oid, INV_WRITE, CurrentMemoryContext);
+
+ while ((nbytes = read(fd, buf, BUFSIZE)) > 0)
+ {
+ tmp = inv_write(lobj, buf, nbytes);
+ Assert(tmp == nbytes);
+ }
+
+ if (nbytes < 0)
+ ereport(ERROR,
+ (errcode_for_file_access(),
+ errmsg("could not read server file \"%s\": %m",
+ fnamebuf)));
+
+ inv_close(lobj);
+
+ if (CloseTransientFile(fd) != 0)
+ ereport(ERROR,
+ (errcode_for_file_access(),
+ errmsg("could not close file \"%s\": %m",
+ fnamebuf)));
+
+ return oid;
+}
+
+/*
+ * lo_export -
+ * exports an (inversion) large object.
+ */
+Datum
+be_lo_export(PG_FUNCTION_ARGS)
+{
+ Oid lobjId = PG_GETARG_OID(0);
+ text *filename = PG_GETARG_TEXT_PP(1);
+ int fd;
+ int nbytes,
+ tmp;
+ char buf[BUFSIZE];
+ char fnamebuf[MAXPGPATH];
+ LargeObjectDesc *lobj;
+ mode_t oumask;
+
+ /*
+ * open the inversion object (no need to test for failure)
+ */
+ lo_cleanup_needed = true;
+ lobj = inv_open(lobjId, INV_READ, CurrentMemoryContext);
+
+ /*
+ * open the file to be written to
+ *
+ * Note: we reduce backend's normal 077 umask to the slightly friendlier
+ * 022. This code used to drop it all the way to 0, but creating
+ * world-writable export files doesn't seem wise.
+ */
+ text_to_cstring_buffer(filename, fnamebuf, sizeof(fnamebuf));
+ oumask = umask(S_IWGRP | S_IWOTH);
+ PG_TRY();
+ {
+ fd = OpenTransientFilePerm(fnamebuf, O_CREAT | O_WRONLY | O_TRUNC | PG_BINARY,
+ S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH);
+ }
+ PG_FINALLY();
+ {
+ umask(oumask);
+ }
+ PG_END_TRY();
+ if (fd < 0)
+ ereport(ERROR,
+ (errcode_for_file_access(),
+ errmsg("could not create server file \"%s\": %m",
+ fnamebuf)));
+
+ /*
+ * read in from the inversion file and write to the filesystem
+ */
+ while ((nbytes = inv_read(lobj, buf, BUFSIZE)) > 0)
+ {
+ tmp = write(fd, buf, nbytes);
+ if (tmp != nbytes)
+ ereport(ERROR,
+ (errcode_for_file_access(),
+ errmsg("could not write server file \"%s\": %m",
+ fnamebuf)));
+ }
+
+ if (CloseTransientFile(fd) != 0)
+ ereport(ERROR,
+ (errcode_for_file_access(),
+ errmsg("could not close file \"%s\": %m",
+ fnamebuf)));
+
+ inv_close(lobj);
+
+ PG_RETURN_INT32(1);
+}
+
+/*
+ * lo_truncate -
+ * truncate a large object to a specified length
+ */
+static void
+lo_truncate_internal(int32 fd, int64 len)
+{
+ LargeObjectDesc *lobj;
+
+ if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL)
+ ereport(ERROR,
+ (errcode(ERRCODE_UNDEFINED_OBJECT),
+ errmsg("invalid large-object descriptor: %d", fd)));
+ lobj = cookies[fd];
+
+ /* see comment in lo_read() */
+ if ((lobj->flags & IFS_WRLOCK) == 0)
+ ereport(ERROR,
+ (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE),
+ errmsg("large object descriptor %d was not opened for writing",
+ fd)));
+
+ inv_truncate(lobj, len);
+}
+
+Datum
+be_lo_truncate(PG_FUNCTION_ARGS)
+{
+ int32 fd = PG_GETARG_INT32(0);
+ int32 len = PG_GETARG_INT32(1);
+
+ lo_truncate_internal(fd, len);
+ PG_RETURN_INT32(0);
+}
+
+Datum
+be_lo_truncate64(PG_FUNCTION_ARGS)
+{
+ int32 fd = PG_GETARG_INT32(0);
+ int64 len = PG_GETARG_INT64(1);
+
+ lo_truncate_internal(fd, len);
+ PG_RETURN_INT32(0);
+}
+
+/*
+ * AtEOXact_LargeObject -
+ * prepares large objects for transaction commit
+ */
+void
+AtEOXact_LargeObject(bool isCommit)
+{
+ int i;
+
+ if (!lo_cleanup_needed)
+ return; /* no LO operations in this xact */
+
+ /*
+ * Close LO fds and clear cookies array so that LO fds are no longer good.
+ * The memory context and resource owner holding them are going away at
+ * the end-of-transaction anyway, but on commit, we need to close them to
+ * avoid warnings about leaked resources at commit. On abort we can skip
+ * this step.
+ */
+ if (isCommit)
+ {
+ for (i = 0; i < cookies_size; i++)
+ {
+ if (cookies[i] != NULL)
+ closeLOfd(i);
+ }
+ }
+
+ /* Needn't actually pfree since we're about to zap context */
+ cookies = NULL;
+ cookies_size = 0;
+
+ /* Release the LO memory context to prevent permanent memory leaks. */
+ if (fscxt)
+ MemoryContextDelete(fscxt);
+ fscxt = NULL;
+
+ /* Give inv_api.c a chance to clean up, too */
+ close_lo_relation(isCommit);
+
+ lo_cleanup_needed = false;
+}
+
+/*
+ * AtEOSubXact_LargeObject
+ * Take care of large objects at subtransaction commit/abort
+ *
+ * Reassign LOs created/opened during a committing subtransaction
+ * to the parent subtransaction. On abort, just close them.
+ */
+void
+AtEOSubXact_LargeObject(bool isCommit, SubTransactionId mySubid,
+ SubTransactionId parentSubid)
+{
+ int i;
+
+ if (fscxt == NULL) /* no LO operations in this xact */
+ return;
+
+ for (i = 0; i < cookies_size; i++)
+ {
+ LargeObjectDesc *lo = cookies[i];
+
+ if (lo != NULL && lo->subid == mySubid)
+ {
+ if (isCommit)
+ lo->subid = parentSubid;
+ else
+ closeLOfd(i);
+ }
+ }
+}
+
+/*****************************************************************************
+ * Support routines for this file
+ *****************************************************************************/
+
+static int
+newLOfd(void)
+{
+ int i,
+ newsize;
+
+ lo_cleanup_needed = true;
+ if (fscxt == NULL)
+ fscxt = AllocSetContextCreate(TopMemoryContext,
+ "Filesystem",
+ ALLOCSET_DEFAULT_SIZES);
+
+ /* Try to find a free slot */
+ for (i = 0; i < cookies_size; i++)
+ {
+ if (cookies[i] == NULL)
+ return i;
+ }
+
+ /* No free slot, so make the array bigger */
+ if (cookies_size <= 0)
+ {
+ /* First time through, arbitrarily make 64-element array */
+ i = 0;
+ newsize = 64;
+ cookies = (LargeObjectDesc **)
+ MemoryContextAllocZero(fscxt, newsize * sizeof(LargeObjectDesc *));
+ cookies_size = newsize;
+ }
+ else
+ {
+ /* Double size of array */
+ i = cookies_size;
+ newsize = cookies_size * 2;
+ cookies = (LargeObjectDesc **)
+ repalloc(cookies, newsize * sizeof(LargeObjectDesc *));
+ MemSet(cookies + cookies_size, 0,
+ (newsize - cookies_size) * sizeof(LargeObjectDesc *));
+ cookies_size = newsize;
+ }
+
+ return i;
+}
+
+static void
+closeLOfd(int fd)
+{
+ LargeObjectDesc *lobj;
+
+ /*
+ * Make sure we do not try to free twice if this errors out for some
+ * reason. Better a leak than a crash.
+ */
+ lobj = cookies[fd];
+ cookies[fd] = NULL;
+
+ if (lobj->snapshot)
+ UnregisterSnapshotFromOwner(lobj->snapshot,
+ TopTransactionResourceOwner);
+ inv_close(lobj);
+}
+
+/*****************************************************************************
+ * Wrappers oriented toward SQL callers
+ *****************************************************************************/
+
+/*
+ * Read [offset, offset+nbytes) within LO; when nbytes is -1, read to end.
+ */
+static bytea *
+lo_get_fragment_internal(Oid loOid, int64 offset, int32 nbytes)
+{
+ LargeObjectDesc *loDesc;
+ int64 loSize;
+ int64 result_length;
+ int total_read PG_USED_FOR_ASSERTS_ONLY;
+ bytea *result = NULL;
+
+ lo_cleanup_needed = true;
+ loDesc = inv_open(loOid, INV_READ, CurrentMemoryContext);
+
+ /*
+ * Compute number of bytes we'll actually read, accommodating nbytes == -1
+ * and reads beyond the end of the LO.
+ */
+ loSize = inv_seek(loDesc, 0, SEEK_END);
+ if (loSize > offset)
+ {
+ if (nbytes >= 0 && nbytes <= loSize - offset)
+ result_length = nbytes; /* request is wholly inside LO */
+ else
+ result_length = loSize - offset; /* adjust to end of LO */
+ }
+ else
+ result_length = 0; /* request is wholly outside LO */
+
+ /*
+ * A result_length calculated from loSize may not fit in a size_t. Check
+ * that the size will satisfy this and subsequently-enforced size limits.
+ */
+ if (result_length > MaxAllocSize - VARHDRSZ)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED),
+ errmsg("large object read request is too large")));
+
+ result = (bytea *) palloc(VARHDRSZ + result_length);
+
+ inv_seek(loDesc, offset, SEEK_SET);
+ total_read = inv_read(loDesc, VARDATA(result), result_length);
+ Assert(total_read == result_length);
+ SET_VARSIZE(result, result_length + VARHDRSZ);
+
+ inv_close(loDesc);
+
+ return result;
+}
+
+/*
+ * Read entire LO
+ */
+Datum
+be_lo_get(PG_FUNCTION_ARGS)
+{
+ Oid loOid = PG_GETARG_OID(0);
+ bytea *result;
+
+ result = lo_get_fragment_internal(loOid, 0, -1);
+
+ PG_RETURN_BYTEA_P(result);
+}
+
+/*
+ * Read range within LO
+ */
+Datum
+be_lo_get_fragment(PG_FUNCTION_ARGS)
+{
+ Oid loOid = PG_GETARG_OID(0);
+ int64 offset = PG_GETARG_INT64(1);
+ int32 nbytes = PG_GETARG_INT32(2);
+ bytea *result;
+
+ if (nbytes < 0)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+ errmsg("requested length cannot be negative")));
+
+ result = lo_get_fragment_internal(loOid, offset, nbytes);
+
+ PG_RETURN_BYTEA_P(result);
+}
+
+/*
+ * Create LO with initial contents given by a bytea argument
+ */
+Datum
+be_lo_from_bytea(PG_FUNCTION_ARGS)
+{
+ Oid loOid = PG_GETARG_OID(0);
+ bytea *str = PG_GETARG_BYTEA_PP(1);
+ LargeObjectDesc *loDesc;
+ int written PG_USED_FOR_ASSERTS_ONLY;
+
+ lo_cleanup_needed = true;
+ loOid = inv_create(loOid);
+ loDesc = inv_open(loOid, INV_WRITE, CurrentMemoryContext);
+ written = inv_write(loDesc, VARDATA_ANY(str), VARSIZE_ANY_EXHDR(str));
+ Assert(written == VARSIZE_ANY_EXHDR(str));
+ inv_close(loDesc);
+
+ PG_RETURN_OID(loOid);
+}
+
+/*
+ * Update range within LO
+ */
+Datum
+be_lo_put(PG_FUNCTION_ARGS)
+{
+ Oid loOid = PG_GETARG_OID(0);
+ int64 offset = PG_GETARG_INT64(1);
+ bytea *str = PG_GETARG_BYTEA_PP(2);
+ LargeObjectDesc *loDesc;
+ int written PG_USED_FOR_ASSERTS_ONLY;
+
+ lo_cleanup_needed = true;
+ loDesc = inv_open(loOid, INV_WRITE, CurrentMemoryContext);
+
+ /* Permission check */
+ if (!lo_compat_privileges &&
+ pg_largeobject_aclcheck_snapshot(loDesc->id,
+ GetUserId(),
+ ACL_UPDATE,
+ loDesc->snapshot) != ACLCHECK_OK)
+ ereport(ERROR,
+ (errcode(ERRCODE_INSUFFICIENT_PRIVILEGE),
+ errmsg("permission denied for large object %u",
+ loDesc->id)));
+
+ inv_seek(loDesc, offset, SEEK_SET);
+ written = inv_write(loDesc, VARDATA_ANY(str), VARSIZE_ANY_EXHDR(str));
+ Assert(written == VARSIZE_ANY_EXHDR(str));
+ inv_close(loDesc);
+
+ PG_RETURN_VOID();
+}
diff --git a/src/backend/libpq/be-gssapi-common.c b/src/backend/libpq/be-gssapi-common.c
new file mode 100644
index 0000000..38f58de
--- /dev/null
+++ b/src/backend/libpq/be-gssapi-common.c
@@ -0,0 +1,94 @@
+/*-------------------------------------------------------------------------
+ *
+ * be-gssapi-common.c
+ * Common code for GSSAPI authentication and encryption
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * IDENTIFICATION
+ * src/backend/libpq/be-gssapi-common.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+
+#include "libpq/be-gssapi-common.h"
+
+/*
+ * Fetch all errors of a specific type and append to "s" (buffer of size len).
+ * If we obtain more than one string, separate them with spaces.
+ * Call once for GSS_CODE and once for MECH_CODE.
+ */
+static void
+pg_GSS_error_int(char *s, size_t len, OM_uint32 stat, int type)
+{
+ gss_buffer_desc gmsg;
+ size_t i = 0;
+ OM_uint32 lmin_s,
+ msg_ctx = 0;
+
+ do
+ {
+ if (gss_display_status(&lmin_s, stat, type, GSS_C_NO_OID,
+ &msg_ctx, &gmsg) != GSS_S_COMPLETE)
+ break;
+ if (i > 0)
+ {
+ if (i < len)
+ s[i] = ' ';
+ i++;
+ }
+ if (i < len)
+ memcpy(s + i, gmsg.value, Min(len - i, gmsg.length));
+ i += gmsg.length;
+ gss_release_buffer(&lmin_s, &gmsg);
+ }
+ while (msg_ctx);
+
+ /* add nul termination */
+ if (i < len)
+ s[i] = '\0';
+ else
+ {
+ elog(COMMERROR, "incomplete GSS error report");
+ s[len - 1] = '\0';
+ }
+}
+
+/*
+ * Report the GSSAPI error described by maj_stat/min_stat.
+ *
+ * errmsg should be an already-translated primary error message.
+ * The GSSAPI info is appended as errdetail.
+ *
+ * The error is always reported with elevel COMMERROR; we daren't try to
+ * send it to the client, as that'd likely lead to infinite recursion
+ * when elog.c tries to write to the client.
+ *
+ * To avoid memory allocation, total error size is capped (at 128 bytes for
+ * each of major and minor). No known mechanisms will produce error messages
+ * beyond this cap.
+ */
+void
+pg_GSS_error(const char *errmsg,
+ OM_uint32 maj_stat, OM_uint32 min_stat)
+{
+ char msg_major[128],
+ msg_minor[128];
+
+ /* Fetch major status message */
+ pg_GSS_error_int(msg_major, sizeof(msg_major), maj_stat, GSS_C_GSS_CODE);
+
+ /* Fetch mechanism minor status message */
+ pg_GSS_error_int(msg_minor, sizeof(msg_minor), min_stat, GSS_C_MECH_CODE);
+
+ /*
+ * errmsg_internal, since translation of the first part must be done
+ * before calling this function anyway.
+ */
+ ereport(COMMERROR,
+ (errmsg_internal("%s", errmsg),
+ errdetail_internal("%s: %s", msg_major, msg_minor)));
+}
diff --git a/src/backend/libpq/be-secure-common.c b/src/backend/libpq/be-secure-common.c
new file mode 100644
index 0000000..7d082d7
--- /dev/null
+++ b/src/backend/libpq/be-secure-common.c
@@ -0,0 +1,195 @@
+/*-------------------------------------------------------------------------
+ *
+ * be-secure-common.c
+ *
+ * common implementation-independent SSL support code
+ *
+ * While be-secure.c contains the interfaces that the rest of the
+ * communications code calls, this file contains support routines that are
+ * used by the library-specific implementations such as be-secure-openssl.c.
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * IDENTIFICATION
+ * src/backend/libpq/be-secure-common.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+
+#include <sys/stat.h>
+#include <unistd.h>
+
+#include "common/string.h"
+#include "libpq/libpq.h"
+#include "storage/fd.h"
+
+/*
+ * Run ssl_passphrase_command
+ *
+ * prompt will be substituted for %p. is_server_start determines the loglevel
+ * of error messages.
+ *
+ * The result will be put in buffer buf, which is of size size. The return
+ * value is the length of the actual result.
+ */
+int
+run_ssl_passphrase_command(const char *prompt, bool is_server_start, char *buf, int size)
+{
+ int loglevel = is_server_start ? ERROR : LOG;
+ StringInfoData command;
+ char *p;
+ FILE *fh;
+ int pclose_rc;
+ size_t len = 0;
+
+ Assert(prompt);
+ Assert(size > 0);
+ buf[0] = '\0';
+
+ initStringInfo(&command);
+
+ for (p = ssl_passphrase_command; *p; p++)
+ {
+ if (p[0] == '%')
+ {
+ switch (p[1])
+ {
+ case 'p':
+ appendStringInfoString(&command, prompt);
+ p++;
+ break;
+ case '%':
+ appendStringInfoChar(&command, '%');
+ p++;
+ break;
+ default:
+ appendStringInfoChar(&command, p[0]);
+ }
+ }
+ else
+ appendStringInfoChar(&command, p[0]);
+ }
+
+ fh = OpenPipeStream(command.data, "r");
+ if (fh == NULL)
+ {
+ ereport(loglevel,
+ (errcode_for_file_access(),
+ errmsg("could not execute command \"%s\": %m",
+ command.data)));
+ goto error;
+ }
+
+ if (!fgets(buf, size, fh))
+ {
+ if (ferror(fh))
+ {
+ explicit_bzero(buf, size);
+ ereport(loglevel,
+ (errcode_for_file_access(),
+ errmsg("could not read from command \"%s\": %m",
+ command.data)));
+ goto error;
+ }
+ }
+
+ pclose_rc = ClosePipeStream(fh);
+ if (pclose_rc == -1)
+ {
+ explicit_bzero(buf, size);
+ ereport(loglevel,
+ (errcode_for_file_access(),
+ errmsg("could not close pipe to external command: %m")));
+ goto error;
+ }
+ else if (pclose_rc != 0)
+ {
+ explicit_bzero(buf, size);
+ ereport(loglevel,
+ (errcode_for_file_access(),
+ errmsg("command \"%s\" failed",
+ command.data),
+ errdetail_internal("%s", wait_result_to_str(pclose_rc))));
+ goto error;
+ }
+
+ /* strip trailing newline and carriage return */
+ len = pg_strip_crlf(buf);
+
+error:
+ pfree(command.data);
+ return len;
+}
+
+
+/*
+ * Check permissions for SSL key files.
+ */
+bool
+check_ssl_key_file_permissions(const char *ssl_key_file, bool isServerStart)
+{
+ int loglevel = isServerStart ? FATAL : LOG;
+ struct stat buf;
+
+ if (stat(ssl_key_file, &buf) != 0)
+ {
+ ereport(loglevel,
+ (errcode_for_file_access(),
+ errmsg("could not access private key file \"%s\": %m",
+ ssl_key_file)));
+ return false;
+ }
+
+ /* Key file must be a regular file */
+ if (!S_ISREG(buf.st_mode))
+ {
+ ereport(loglevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("private key file \"%s\" is not a regular file",
+ ssl_key_file)));
+ return false;
+ }
+
+ /*
+ * Refuse to load key files owned by users other than us or root, and
+ * require no public access to the key file. If the file is owned by us,
+ * require mode 0600 or less. If owned by root, require 0640 or less to
+ * allow read access through either our gid or a supplementary gid that
+ * allows us to read system-wide certificates.
+ *
+ * Note that roughly similar checks are performed in
+ * src/interfaces/libpq/fe-secure-openssl.c so any changes here may need
+ * to be made there as well. The environment is different though; this
+ * code can assume that we're not running as root.
+ *
+ * Ideally we would do similar permissions checks on Windows, but it is
+ * not clear how that would work since Unix-style permissions may not be
+ * available.
+ */
+#if !defined(WIN32) && !defined(__CYGWIN__)
+ if (buf.st_uid != geteuid() && buf.st_uid != 0)
+ {
+ ereport(loglevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("private key file \"%s\" must be owned by the database user or root",
+ ssl_key_file)));
+ return false;
+ }
+
+ if ((buf.st_uid == geteuid() && buf.st_mode & (S_IRWXG | S_IRWXO)) ||
+ (buf.st_uid == 0 && buf.st_mode & (S_IWGRP | S_IXGRP | S_IRWXO)))
+ {
+ ereport(loglevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("private key file \"%s\" has group or world access",
+ ssl_key_file),
+ errdetail("File must have permissions u=rw (0600) or less if owned by the database user, or permissions u=rw,g=r (0640) or less if owned by root.")));
+ return false;
+ }
+#endif
+
+ return true;
+}
diff --git a/src/backend/libpq/be-secure-gssapi.c b/src/backend/libpq/be-secure-gssapi.c
new file mode 100644
index 0000000..316ca65
--- /dev/null
+++ b/src/backend/libpq/be-secure-gssapi.c
@@ -0,0 +1,733 @@
+/*-------------------------------------------------------------------------
+ *
+ * be-secure-gssapi.c
+ * GSSAPI encryption support
+ *
+ * Portions Copyright (c) 2018-2021, PostgreSQL Global Development Group
+ *
+ * IDENTIFICATION
+ * src/backend/libpq/be-secure-gssapi.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+
+#include <unistd.h>
+
+#include "libpq/auth.h"
+#include "libpq/be-gssapi-common.h"
+#include "libpq/libpq.h"
+#include "libpq/pqformat.h"
+#include "miscadmin.h"
+#include "pgstat.h"
+#include "utils/memutils.h"
+
+
+/*
+ * Handle the encryption/decryption of data using GSSAPI.
+ *
+ * In the encrypted data stream on the wire, we break up the data
+ * into packets where each packet starts with a uint32-size length
+ * word (in network byte order), then encrypted data of that length
+ * immediately following. Decryption yields the same data stream
+ * that would appear when not using encryption.
+ *
+ * Encrypted data typically ends up being larger than the same data
+ * unencrypted, so we use fixed-size buffers for handling the
+ * encryption/decryption which are larger than PQComm's buffer will
+ * typically be to minimize the times where we have to make multiple
+ * packets (and therefore multiple recv/send calls for a single
+ * read/write call to us).
+ *
+ * NOTE: The client and server have to agree on the max packet size,
+ * because we have to pass an entire packet to GSSAPI at a time and we
+ * don't want the other side to send arbitrarily huge packets as we
+ * would have to allocate memory for them to then pass them to GSSAPI.
+ *
+ * Therefore, these two #define's are effectively part of the protocol
+ * spec and can't ever be changed.
+ */
+#define PQ_GSS_SEND_BUFFER_SIZE 16384
+#define PQ_GSS_RECV_BUFFER_SIZE 16384
+
+/*
+ * Since we manage at most one GSS-encrypted connection per backend,
+ * we can just keep all this state in static variables. The char *
+ * variables point to buffers that are allocated once and re-used.
+ */
+static char *PqGSSSendBuffer; /* Encrypted data waiting to be sent */
+static int PqGSSSendLength; /* End of data available in PqGSSSendBuffer */
+static int PqGSSSendNext; /* Next index to send a byte from
+ * PqGSSSendBuffer */
+static int PqGSSSendConsumed; /* Number of *unencrypted* bytes consumed for
+ * current contents of PqGSSSendBuffer */
+
+static char *PqGSSRecvBuffer; /* Received, encrypted data */
+static int PqGSSRecvLength; /* End of data available in PqGSSRecvBuffer */
+
+static char *PqGSSResultBuffer; /* Decryption of data in gss_RecvBuffer */
+static int PqGSSResultLength; /* End of data available in PqGSSResultBuffer */
+static int PqGSSResultNext; /* Next index to read a byte from
+ * PqGSSResultBuffer */
+
+static uint32 PqGSSMaxPktSize; /* Maximum size we can encrypt and fit the
+ * results into our output buffer */
+
+
+/*
+ * Attempt to write len bytes of data from ptr to a GSSAPI-encrypted connection.
+ *
+ * The connection must be already set up for GSSAPI encryption (i.e., GSSAPI
+ * transport negotiation is complete).
+ *
+ * On success, returns the number of data bytes consumed (possibly less than
+ * len). On failure, returns -1 with errno set appropriately. For retryable
+ * errors, caller should call again (passing the same data) once the socket
+ * is ready.
+ *
+ * Dealing with fatal errors here is a bit tricky: we can't invoke elog(FATAL)
+ * since it would try to write to the client, probably resulting in infinite
+ * recursion. Instead, use elog(COMMERROR) to log extra info about the
+ * failure if necessary, and then return an errno indicating connection loss.
+ */
+ssize_t
+be_gssapi_write(Port *port, void *ptr, size_t len)
+{
+ OM_uint32 major,
+ minor;
+ gss_buffer_desc input,
+ output;
+ size_t bytes_sent = 0;
+ size_t bytes_to_encrypt;
+ size_t bytes_encrypted;
+ gss_ctx_id_t gctx = port->gss->ctx;
+
+ /*
+ * When we get a failure, we must not tell the caller we have successfully
+ * transmitted everything, else it won't retry. Hence a "success"
+ * (positive) return value must only count source bytes corresponding to
+ * fully-transmitted encrypted packets. The amount of source data
+ * corresponding to the current partly-transmitted packet is remembered in
+ * PqGSSSendConsumed. On a retry, the caller *must* be sending that data
+ * again, so if it offers a len less than that, something is wrong.
+ */
+ if (len < PqGSSSendConsumed)
+ {
+ elog(COMMERROR, "GSSAPI caller failed to retransmit all data needing to be retried");
+ errno = ECONNRESET;
+ return -1;
+ }
+ /* Discount whatever source data we already encrypted. */
+ bytes_to_encrypt = len - PqGSSSendConsumed;
+ bytes_encrypted = PqGSSSendConsumed;
+
+ /*
+ * Loop through encrypting data and sending it out until it's all done or
+ * secure_raw_write() complains (which would likely mean that the socket
+ * is non-blocking and the requested send() would block, or there was some
+ * kind of actual error).
+ */
+ while (bytes_to_encrypt || PqGSSSendLength)
+ {
+ int conf_state = 0;
+ uint32 netlen;
+
+ /*
+ * Check if we have data in the encrypted output buffer that needs to
+ * be sent (possibly left over from a previous call), and if so, try
+ * to send it. If we aren't able to, return that fact back up to the
+ * caller.
+ */
+ if (PqGSSSendLength)
+ {
+ ssize_t ret;
+ ssize_t amount = PqGSSSendLength - PqGSSSendNext;
+
+ ret = secure_raw_write(port, PqGSSSendBuffer + PqGSSSendNext, amount);
+ if (ret <= 0)
+ {
+ /*
+ * Report any previously-sent data; if there was none, reflect
+ * the secure_raw_write result up to our caller. When there
+ * was some, we're effectively assuming that any interesting
+ * failure condition will recur on the next try.
+ */
+ if (bytes_sent)
+ return bytes_sent;
+ return ret;
+ }
+
+ /*
+ * Check if this was a partial write, and if so, move forward that
+ * far in our buffer and try again.
+ */
+ if (ret != amount)
+ {
+ PqGSSSendNext += ret;
+ continue;
+ }
+
+ /* We've successfully sent whatever data was in that packet. */
+ bytes_sent += PqGSSSendConsumed;
+
+ /* All encrypted data was sent, our buffer is empty now. */
+ PqGSSSendLength = PqGSSSendNext = PqGSSSendConsumed = 0;
+ }
+
+ /*
+ * Check if there are any bytes left to encrypt. If not, we're done.
+ */
+ if (!bytes_to_encrypt)
+ break;
+
+ /*
+ * Check how much we are being asked to send, if it's too much, then
+ * we will have to loop and possibly be called multiple times to get
+ * through all the data.
+ */
+ if (bytes_to_encrypt > PqGSSMaxPktSize)
+ input.length = PqGSSMaxPktSize;
+ else
+ input.length = bytes_to_encrypt;
+
+ input.value = (char *) ptr + bytes_encrypted;
+
+ output.value = NULL;
+ output.length = 0;
+
+ /* Create the next encrypted packet */
+ major = gss_wrap(&minor, gctx, 1, GSS_C_QOP_DEFAULT,
+ &input, &conf_state, &output);
+ if (major != GSS_S_COMPLETE)
+ {
+ pg_GSS_error(_("GSSAPI wrap error"), major, minor);
+ errno = ECONNRESET;
+ return -1;
+ }
+ if (conf_state == 0)
+ {
+ ereport(COMMERROR,
+ (errmsg("outgoing GSSAPI message would not use confidentiality")));
+ errno = ECONNRESET;
+ return -1;
+ }
+ if (output.length > PQ_GSS_SEND_BUFFER_SIZE - sizeof(uint32))
+ {
+ ereport(COMMERROR,
+ (errmsg("server tried to send oversize GSSAPI packet (%zu > %zu)",
+ (size_t) output.length,
+ PQ_GSS_SEND_BUFFER_SIZE - sizeof(uint32))));
+ errno = ECONNRESET;
+ return -1;
+ }
+
+ bytes_encrypted += input.length;
+ bytes_to_encrypt -= input.length;
+ PqGSSSendConsumed += input.length;
+
+ /* 4 network-order bytes of length, then payload */
+ netlen = pg_hton32(output.length);
+ memcpy(PqGSSSendBuffer + PqGSSSendLength, &netlen, sizeof(uint32));
+ PqGSSSendLength += sizeof(uint32);
+
+ memcpy(PqGSSSendBuffer + PqGSSSendLength, output.value, output.length);
+ PqGSSSendLength += output.length;
+
+ /* Release buffer storage allocated by GSSAPI */
+ gss_release_buffer(&minor, &output);
+ }
+
+ /* If we get here, our counters should all match up. */
+ Assert(bytes_sent == len);
+ Assert(bytes_sent == bytes_encrypted);
+
+ return bytes_sent;
+}
+
+/*
+ * Read up to len bytes of data into ptr from a GSSAPI-encrypted connection.
+ *
+ * The connection must be already set up for GSSAPI encryption (i.e., GSSAPI
+ * transport negotiation is complete).
+ *
+ * Returns the number of data bytes read, or on failure, returns -1
+ * with errno set appropriately. For retryable errors, caller should call
+ * again once the socket is ready.
+ *
+ * We treat fatal errors the same as in be_gssapi_write(), even though the
+ * argument about infinite recursion doesn't apply here.
+ */
+ssize_t
+be_gssapi_read(Port *port, void *ptr, size_t len)
+{
+ OM_uint32 major,
+ minor;
+ gss_buffer_desc input,
+ output;
+ ssize_t ret;
+ size_t bytes_returned = 0;
+ gss_ctx_id_t gctx = port->gss->ctx;
+
+ /*
+ * The plan here is to read one incoming encrypted packet into
+ * PqGSSRecvBuffer, decrypt it into PqGSSResultBuffer, and then dole out
+ * data from there to the caller. When we exhaust the current input
+ * packet, read another.
+ */
+ while (bytes_returned < len)
+ {
+ int conf_state = 0;
+
+ /* Check if we have data in our buffer that we can return immediately */
+ if (PqGSSResultNext < PqGSSResultLength)
+ {
+ size_t bytes_in_buffer = PqGSSResultLength - PqGSSResultNext;
+ size_t bytes_to_copy = Min(bytes_in_buffer, len - bytes_returned);
+
+ /*
+ * Copy the data from our result buffer into the caller's buffer,
+ * at the point where we last left off filling their buffer.
+ */
+ memcpy((char *) ptr + bytes_returned, PqGSSResultBuffer + PqGSSResultNext, bytes_to_copy);
+ PqGSSResultNext += bytes_to_copy;
+ bytes_returned += bytes_to_copy;
+
+ /*
+ * At this point, we've either filled the caller's buffer or
+ * emptied our result buffer. Either way, return to caller. In
+ * the second case, we could try to read another encrypted packet,
+ * but the odds are good that there isn't one available. (If this
+ * isn't true, we chose too small a max packet size.) In any
+ * case, there's no harm letting the caller process the data we've
+ * already returned.
+ */
+ break;
+ }
+
+ /* Result buffer is empty, so reset buffer pointers */
+ PqGSSResultLength = PqGSSResultNext = 0;
+
+ /*
+ * Because we chose above to return immediately as soon as we emit
+ * some data, bytes_returned must be zero at this point. Therefore
+ * the failure exits below can just return -1 without worrying about
+ * whether we already emitted some data.
+ */
+ Assert(bytes_returned == 0);
+
+ /*
+ * At this point, our result buffer is empty with more bytes being
+ * requested to be read. We are now ready to load the next packet and
+ * decrypt it (entirely) into our result buffer.
+ */
+
+ /* Collect the length if we haven't already */
+ if (PqGSSRecvLength < sizeof(uint32))
+ {
+ ret = secure_raw_read(port, PqGSSRecvBuffer + PqGSSRecvLength,
+ sizeof(uint32) - PqGSSRecvLength);
+
+ /* If ret <= 0, secure_raw_read already set the correct errno */
+ if (ret <= 0)
+ return ret;
+
+ PqGSSRecvLength += ret;
+
+ /* If we still haven't got the length, return to the caller */
+ if (PqGSSRecvLength < sizeof(uint32))
+ {
+ errno = EWOULDBLOCK;
+ return -1;
+ }
+ }
+
+ /* Decode the packet length and check for overlength packet */
+ input.length = pg_ntoh32(*(uint32 *) PqGSSRecvBuffer);
+
+ if (input.length > PQ_GSS_RECV_BUFFER_SIZE - sizeof(uint32))
+ {
+ ereport(COMMERROR,
+ (errmsg("oversize GSSAPI packet sent by the client (%zu > %zu)",
+ (size_t) input.length,
+ PQ_GSS_RECV_BUFFER_SIZE - sizeof(uint32))));
+ errno = ECONNRESET;
+ return -1;
+ }
+
+ /*
+ * Read as much of the packet as we are able to on this call into
+ * wherever we left off from the last time we were called.
+ */
+ ret = secure_raw_read(port, PqGSSRecvBuffer + PqGSSRecvLength,
+ input.length - (PqGSSRecvLength - sizeof(uint32)));
+ /* If ret <= 0, secure_raw_read already set the correct errno */
+ if (ret <= 0)
+ return ret;
+
+ PqGSSRecvLength += ret;
+
+ /* If we don't yet have the whole packet, return to the caller */
+ if (PqGSSRecvLength - sizeof(uint32) < input.length)
+ {
+ errno = EWOULDBLOCK;
+ return -1;
+ }
+
+ /*
+ * We now have the full packet and we can perform the decryption and
+ * refill our result buffer, then loop back up to pass data back to
+ * the caller.
+ */
+ output.value = NULL;
+ output.length = 0;
+ input.value = PqGSSRecvBuffer + sizeof(uint32);
+
+ major = gss_unwrap(&minor, gctx, &input, &output, &conf_state, NULL);
+ if (major != GSS_S_COMPLETE)
+ {
+ pg_GSS_error(_("GSSAPI unwrap error"), major, minor);
+ errno = ECONNRESET;
+ return -1;
+ }
+ if (conf_state == 0)
+ {
+ ereport(COMMERROR,
+ (errmsg("incoming GSSAPI message did not use confidentiality")));
+ errno = ECONNRESET;
+ return -1;
+ }
+
+ memcpy(PqGSSResultBuffer, output.value, output.length);
+ PqGSSResultLength = output.length;
+
+ /* Our receive buffer is now empty, reset it */
+ PqGSSRecvLength = 0;
+
+ /* Release buffer storage allocated by GSSAPI */
+ gss_release_buffer(&minor, &output);
+ }
+
+ return bytes_returned;
+}
+
+/*
+ * Read the specified number of bytes off the wire, waiting using
+ * WaitLatchOrSocket if we would block.
+ *
+ * Results are read into PqGSSRecvBuffer.
+ *
+ * Will always return either -1, to indicate a permanent error, or len.
+ */
+static ssize_t
+read_or_wait(Port *port, ssize_t len)
+{
+ ssize_t ret;
+
+ /*
+ * Keep going until we either read in everything we were asked to, or we
+ * error out.
+ */
+ while (PqGSSRecvLength < len)
+ {
+ ret = secure_raw_read(port, PqGSSRecvBuffer + PqGSSRecvLength, len - PqGSSRecvLength);
+
+ /*
+ * If we got back an error and it wasn't just
+ * EWOULDBLOCK/EAGAIN/EINTR, then give up.
+ */
+ if (ret < 0 &&
+ !(errno == EWOULDBLOCK || errno == EAGAIN || errno == EINTR))
+ return -1;
+
+ /*
+ * Ok, we got back either a positive value, zero, or a negative result
+ * indicating we should retry.
+ *
+ * If it was zero or negative, then we wait on the socket to be
+ * readable again.
+ */
+ if (ret <= 0)
+ {
+ WaitLatchOrSocket(MyLatch,
+ WL_SOCKET_READABLE | WL_EXIT_ON_PM_DEATH,
+ port->sock, 0, WAIT_EVENT_GSS_OPEN_SERVER);
+
+ /*
+ * If we got back zero bytes, and then waited on the socket to be
+ * readable and got back zero bytes on a second read, then this is
+ * EOF and the client hung up on us.
+ *
+ * If we did get data here, then we can just fall through and
+ * handle it just as if we got data the first time.
+ *
+ * Otherwise loop back to the top and try again.
+ */
+ if (ret == 0)
+ {
+ ret = secure_raw_read(port, PqGSSRecvBuffer + PqGSSRecvLength, len - PqGSSRecvLength);
+ if (ret == 0)
+ return -1;
+ }
+ if (ret < 0)
+ continue;
+ }
+
+ PqGSSRecvLength += ret;
+ }
+
+ return len;
+}
+
+/*
+ * Start up a GSSAPI-encrypted connection. This performs GSSAPI
+ * authentication; after this function completes, it is safe to call
+ * be_gssapi_read and be_gssapi_write. Returns -1 and logs on failure;
+ * otherwise, returns 0 and marks the connection as ready for GSSAPI
+ * encryption.
+ *
+ * Note that unlike the be_gssapi_read/be_gssapi_write functions, this
+ * function WILL block on the socket to be ready for read/write (using
+ * WaitLatchOrSocket) as appropriate while establishing the GSSAPI
+ * session.
+ */
+ssize_t
+secure_open_gssapi(Port *port)
+{
+ bool complete_next = false;
+ OM_uint32 major,
+ minor;
+
+ /*
+ * Allocate subsidiary Port data for GSSAPI operations.
+ */
+ port->gss = (pg_gssinfo *)
+ MemoryContextAllocZero(TopMemoryContext, sizeof(pg_gssinfo));
+
+ /*
+ * Allocate buffers and initialize state variables. By malloc'ing the
+ * buffers at this point, we avoid wasting static data space in processes
+ * that will never use them, and we ensure that the buffers are
+ * sufficiently aligned for the length-word accesses that we do in some
+ * places in this file.
+ */
+ PqGSSSendBuffer = malloc(PQ_GSS_SEND_BUFFER_SIZE);
+ PqGSSRecvBuffer = malloc(PQ_GSS_RECV_BUFFER_SIZE);
+ PqGSSResultBuffer = malloc(PQ_GSS_RECV_BUFFER_SIZE);
+ if (!PqGSSSendBuffer || !PqGSSRecvBuffer || !PqGSSResultBuffer)
+ ereport(FATAL,
+ (errcode(ERRCODE_OUT_OF_MEMORY),
+ errmsg("out of memory")));
+ PqGSSSendLength = PqGSSSendNext = PqGSSSendConsumed = 0;
+ PqGSSRecvLength = PqGSSResultLength = PqGSSResultNext = 0;
+
+ /*
+ * Use the configured keytab, if there is one. Unfortunately, Heimdal
+ * doesn't support the cred store extensions, so use the env var.
+ */
+ if (pg_krb_server_keyfile != NULL && pg_krb_server_keyfile[0] != '\0')
+ {
+ if (setenv("KRB5_KTNAME", pg_krb_server_keyfile, 1) != 0)
+ {
+ /* The only likely failure cause is OOM, so use that errcode */
+ ereport(FATAL,
+ (errcode(ERRCODE_OUT_OF_MEMORY),
+ errmsg("could not set environment: %m")));
+ }
+ }
+
+ while (true)
+ {
+ ssize_t ret;
+ gss_buffer_desc input,
+ output = GSS_C_EMPTY_BUFFER;
+
+ /*
+ * The client always sends first, so try to go ahead and read the
+ * length and wait on the socket to be readable again if that fails.
+ */
+ ret = read_or_wait(port, sizeof(uint32));
+ if (ret < 0)
+ return ret;
+
+ /*
+ * Get the length for this packet from the length header.
+ */
+ input.length = pg_ntoh32(*(uint32 *) PqGSSRecvBuffer);
+
+ /* Done with the length, reset our buffer */
+ PqGSSRecvLength = 0;
+
+ /*
+ * During initialization, packets are always fully consumed and
+ * shouldn't ever be over PQ_GSS_RECV_BUFFER_SIZE in length.
+ *
+ * Verify on our side that the client doesn't do something funny.
+ */
+ if (input.length > PQ_GSS_RECV_BUFFER_SIZE)
+ {
+ ereport(COMMERROR,
+ (errmsg("oversize GSSAPI packet sent by the client (%zu > %d)",
+ (size_t) input.length,
+ PQ_GSS_RECV_BUFFER_SIZE)));
+ return -1;
+ }
+
+ /*
+ * Get the rest of the packet so we can pass it to GSSAPI to accept
+ * the context.
+ */
+ ret = read_or_wait(port, input.length);
+ if (ret < 0)
+ return ret;
+
+ input.value = PqGSSRecvBuffer;
+
+ /* Process incoming data. (The client sends first.) */
+ major = gss_accept_sec_context(&minor, &port->gss->ctx,
+ GSS_C_NO_CREDENTIAL, &input,
+ GSS_C_NO_CHANNEL_BINDINGS,
+ &port->gss->name, NULL, &output, NULL,
+ NULL, NULL);
+ if (GSS_ERROR(major))
+ {
+ pg_GSS_error(_("could not accept GSSAPI security context"),
+ major, minor);
+ gss_release_buffer(&minor, &output);
+ return -1;
+ }
+ else if (!(major & GSS_S_CONTINUE_NEEDED))
+ {
+ /*
+ * rfc2744 technically permits context negotiation to be complete
+ * both with and without a packet to be sent.
+ */
+ complete_next = true;
+ }
+
+ /* Done handling the incoming packet, reset our buffer */
+ PqGSSRecvLength = 0;
+
+ /*
+ * Check if we have data to send and, if we do, make sure to send it
+ * all
+ */
+ if (output.length > 0)
+ {
+ uint32 netlen = pg_hton32(output.length);
+
+ if (output.length > PQ_GSS_SEND_BUFFER_SIZE - sizeof(uint32))
+ {
+ ereport(COMMERROR,
+ (errmsg("server tried to send oversize GSSAPI packet (%zu > %zu)",
+ (size_t) output.length,
+ PQ_GSS_SEND_BUFFER_SIZE - sizeof(uint32))));
+ gss_release_buffer(&minor, &output);
+ return -1;
+ }
+
+ memcpy(PqGSSSendBuffer, (char *) &netlen, sizeof(uint32));
+ PqGSSSendLength += sizeof(uint32);
+
+ memcpy(PqGSSSendBuffer + PqGSSSendLength, output.value, output.length);
+ PqGSSSendLength += output.length;
+
+ /* we don't bother with PqGSSSendConsumed here */
+
+ while (PqGSSSendNext < PqGSSSendLength)
+ {
+ ret = secure_raw_write(port, PqGSSSendBuffer + PqGSSSendNext,
+ PqGSSSendLength - PqGSSSendNext);
+
+ /*
+ * If we got back an error and it wasn't just
+ * EWOULDBLOCK/EAGAIN/EINTR, then give up.
+ */
+ if (ret < 0 &&
+ !(errno == EWOULDBLOCK || errno == EAGAIN || errno == EINTR))
+ {
+ gss_release_buffer(&minor, &output);
+ return -1;
+ }
+
+ /* Wait and retry if we couldn't write yet */
+ if (ret <= 0)
+ {
+ WaitLatchOrSocket(MyLatch,
+ WL_SOCKET_WRITEABLE | WL_EXIT_ON_PM_DEATH,
+ port->sock, 0, WAIT_EVENT_GSS_OPEN_SERVER);
+ continue;
+ }
+
+ PqGSSSendNext += ret;
+ }
+
+ /* Done sending the packet, reset our buffer */
+ PqGSSSendLength = PqGSSSendNext = 0;
+
+ gss_release_buffer(&minor, &output);
+ }
+
+ /*
+ * If we got back that the connection is finished being set up, now
+ * that we've sent the last packet, exit our loop.
+ */
+ if (complete_next)
+ break;
+ }
+
+ /*
+ * Determine the max packet size which will fit in our buffer, after
+ * accounting for the length. be_gssapi_write will need this.
+ */
+ major = gss_wrap_size_limit(&minor, port->gss->ctx, 1, GSS_C_QOP_DEFAULT,
+ PQ_GSS_SEND_BUFFER_SIZE - sizeof(uint32),
+ &PqGSSMaxPktSize);
+
+ if (GSS_ERROR(major))
+ {
+ pg_GSS_error(_("GSSAPI size check error"), major, minor);
+ return -1;
+ }
+
+ port->gss->enc = true;
+
+ return 0;
+}
+
+/*
+ * Return if GSSAPI authentication was used on this connection.
+ */
+bool
+be_gssapi_get_auth(Port *port)
+{
+ if (!port || !port->gss)
+ return false;
+
+ return port->gss->auth;
+}
+
+/*
+ * Return if GSSAPI encryption is enabled and being used on this connection.
+ */
+bool
+be_gssapi_get_enc(Port *port)
+{
+ if (!port || !port->gss)
+ return false;
+
+ return port->gss->enc;
+}
+
+/*
+ * Return the GSSAPI principal used for authentication on this connection
+ * (NULL if we did not perform GSSAPI authentication).
+ */
+const char *
+be_gssapi_get_princ(Port *port)
+{
+ if (!port || !port->gss)
+ return NULL;
+
+ return port->gss->princ;
+}
diff --git a/src/backend/libpq/be-secure-openssl.c b/src/backend/libpq/be-secure-openssl.c
new file mode 100644
index 0000000..e3b02b1
--- /dev/null
+++ b/src/backend/libpq/be-secure-openssl.c
@@ -0,0 +1,1526 @@
+/*-------------------------------------------------------------------------
+ *
+ * be-secure-openssl.c
+ * functions for OpenSSL support in the backend.
+ *
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ *
+ * IDENTIFICATION
+ * src/backend/libpq/be-secure-openssl.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+
+#include <sys/stat.h>
+#include <signal.h>
+#include <fcntl.h>
+#include <ctype.h>
+#include <sys/socket.h>
+#include <unistd.h>
+#include <netdb.h>
+#include <netinet/in.h>
+#ifdef HAVE_NETINET_TCP_H
+#include <netinet/tcp.h>
+#include <arpa/inet.h>
+#endif
+
+#include <openssl/ssl.h>
+#include <openssl/dh.h>
+#include <openssl/conf.h>
+#ifndef OPENSSL_NO_ECDH
+#include <openssl/ec.h>
+#endif
+
+#include "common/openssl.h"
+#include "libpq/libpq.h"
+#include "miscadmin.h"
+#include "pgstat.h"
+#include "storage/fd.h"
+#include "storage/latch.h"
+#include "tcop/tcopprot.h"
+#include "utils/memutils.h"
+
+/* default init hook can be overridden by a shared library */
+static void default_openssl_tls_init(SSL_CTX *context, bool isServerStart);
+openssl_tls_init_hook_typ openssl_tls_init_hook = default_openssl_tls_init;
+
+static int my_sock_read(BIO *h, char *buf, int size);
+static int my_sock_write(BIO *h, const char *buf, int size);
+static BIO_METHOD *my_BIO_s_socket(void);
+static int my_SSL_set_fd(Port *port, int fd);
+
+static DH *load_dh_file(char *filename, bool isServerStart);
+static DH *load_dh_buffer(const char *, size_t);
+static int ssl_external_passwd_cb(char *buf, int size, int rwflag, void *userdata);
+static int dummy_ssl_passwd_cb(char *buf, int size, int rwflag, void *userdata);
+static int verify_cb(int, X509_STORE_CTX *);
+static void info_cb(const SSL *ssl, int type, int args);
+static bool initialize_dh(SSL_CTX *context, bool isServerStart);
+static bool initialize_ecdh(SSL_CTX *context, bool isServerStart);
+static const char *SSLerrmessage(unsigned long ecode);
+
+static char *X509_NAME_to_cstring(X509_NAME *name);
+
+static SSL_CTX *SSL_context = NULL;
+static bool SSL_initialized = false;
+static bool dummy_ssl_passwd_cb_called = false;
+static bool ssl_is_server_start;
+
+static int ssl_protocol_version_to_openssl(int v);
+static const char *ssl_protocol_version_to_string(int v);
+
+/* ------------------------------------------------------------ */
+/* Public interface */
+/* ------------------------------------------------------------ */
+
+int
+be_tls_init(bool isServerStart)
+{
+ SSL_CTX *context;
+ int ssl_ver_min = -1;
+ int ssl_ver_max = -1;
+
+ /* This stuff need be done only once. */
+ if (!SSL_initialized)
+ {
+#ifdef HAVE_OPENSSL_INIT_SSL
+ OPENSSL_init_ssl(OPENSSL_INIT_LOAD_CONFIG, NULL);
+#else
+ OPENSSL_config(NULL);
+ SSL_library_init();
+ SSL_load_error_strings();
+#endif
+ SSL_initialized = true;
+ }
+
+ /*
+ * Create a new SSL context into which we'll load all the configuration
+ * settings. If we fail partway through, we can avoid memory leakage by
+ * freeing this context; we don't install it as active until the end.
+ *
+ * We use SSLv23_method() because it can negotiate use of the highest
+ * mutually supported protocol version, while alternatives like
+ * TLSv1_2_method() permit only one specific version. Note that we don't
+ * actually allow SSL v2 or v3, only TLS protocols (see below).
+ */
+ context = SSL_CTX_new(SSLv23_method());
+ if (!context)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errmsg("could not create SSL context: %s",
+ SSLerrmessage(ERR_get_error()))));
+ goto error;
+ }
+
+ /*
+ * Disable OpenSSL's moving-write-buffer sanity check, because it causes
+ * unnecessary failures in nonblocking send cases.
+ */
+ SSL_CTX_set_mode(context, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
+
+ /*
+ * Call init hook (usually to set password callback)
+ */
+ (*openssl_tls_init_hook) (context, isServerStart);
+
+ /* used by the callback */
+ ssl_is_server_start = isServerStart;
+
+ /*
+ * Load and verify server's certificate and private key
+ */
+ if (SSL_CTX_use_certificate_chain_file(context, ssl_cert_file) != 1)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("could not load server certificate file \"%s\": %s",
+ ssl_cert_file, SSLerrmessage(ERR_get_error()))));
+ goto error;
+ }
+
+ if (!check_ssl_key_file_permissions(ssl_key_file, isServerStart))
+ goto error;
+
+ /*
+ * OK, try to load the private key file.
+ */
+ dummy_ssl_passwd_cb_called = false;
+
+ if (SSL_CTX_use_PrivateKey_file(context,
+ ssl_key_file,
+ SSL_FILETYPE_PEM) != 1)
+ {
+ if (dummy_ssl_passwd_cb_called)
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("private key file \"%s\" cannot be reloaded because it requires a passphrase",
+ ssl_key_file)));
+ else
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("could not load private key file \"%s\": %s",
+ ssl_key_file, SSLerrmessage(ERR_get_error()))));
+ goto error;
+ }
+
+ if (SSL_CTX_check_private_key(context) != 1)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("check of private key failed: %s",
+ SSLerrmessage(ERR_get_error()))));
+ goto error;
+ }
+
+ if (ssl_min_protocol_version)
+ {
+ ssl_ver_min = ssl_protocol_version_to_openssl(ssl_min_protocol_version);
+
+ if (ssl_ver_min == -1)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ /*- translator: first %s is a GUC option name, second %s is its value */
+ (errmsg("\"%s\" setting \"%s\" not supported by this build",
+ "ssl_min_protocol_version",
+ GetConfigOption("ssl_min_protocol_version",
+ false, false))));
+ goto error;
+ }
+
+ if (!SSL_CTX_set_min_proto_version(context, ssl_ver_min))
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errmsg("could not set minimum SSL protocol version")));
+ goto error;
+ }
+ }
+
+ if (ssl_max_protocol_version)
+ {
+ ssl_ver_max = ssl_protocol_version_to_openssl(ssl_max_protocol_version);
+
+ if (ssl_ver_max == -1)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ /*- translator: first %s is a GUC option name, second %s is its value */
+ (errmsg("\"%s\" setting \"%s\" not supported by this build",
+ "ssl_max_protocol_version",
+ GetConfigOption("ssl_max_protocol_version",
+ false, false))));
+ goto error;
+ }
+
+ if (!SSL_CTX_set_max_proto_version(context, ssl_ver_max))
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errmsg("could not set maximum SSL protocol version")));
+ goto error;
+ }
+ }
+
+ /* Check compatibility of min/max protocols */
+ if (ssl_min_protocol_version &&
+ ssl_max_protocol_version)
+ {
+ /*
+ * No need to check for invalid values (-1) for each protocol number
+ * as the code above would have already generated an error.
+ */
+ if (ssl_ver_min > ssl_ver_max)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errmsg("could not set SSL protocol version range"),
+ errdetail("\"%s\" cannot be higher than \"%s\"",
+ "ssl_min_protocol_version",
+ "ssl_max_protocol_version")));
+ goto error;
+ }
+ }
+
+ /* disallow SSL session tickets */
+ SSL_CTX_set_options(context, SSL_OP_NO_TICKET);
+
+ /* disallow SSL session caching, too */
+ SSL_CTX_set_session_cache_mode(context, SSL_SESS_CACHE_OFF);
+
+ /* disallow SSL compression */
+ SSL_CTX_set_options(context, SSL_OP_NO_COMPRESSION);
+
+#ifdef SSL_OP_NO_RENEGOTIATION
+
+ /*
+ * Disallow SSL renegotiation, option available since 1.1.0h. This
+ * concerns only TLSv1.2 and older protocol versions, as TLSv1.3 has no
+ * support for renegotiation.
+ */
+ SSL_CTX_set_options(context, SSL_OP_NO_RENEGOTIATION);
+#endif
+
+ /* set up ephemeral DH and ECDH keys */
+ if (!initialize_dh(context, isServerStart))
+ goto error;
+ if (!initialize_ecdh(context, isServerStart))
+ goto error;
+
+ /* set up the allowed cipher list */
+ if (SSL_CTX_set_cipher_list(context, SSLCipherSuites) != 1)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("could not set the cipher list (no valid ciphers available)")));
+ goto error;
+ }
+
+ /* Let server choose order */
+ if (SSLPreferServerCiphers)
+ SSL_CTX_set_options(context, SSL_OP_CIPHER_SERVER_PREFERENCE);
+
+ /*
+ * Load CA store, so we can verify client certificates if needed.
+ */
+ if (ssl_ca_file[0])
+ {
+ STACK_OF(X509_NAME) * root_cert_list;
+
+ if (SSL_CTX_load_verify_locations(context, ssl_ca_file, NULL) != 1 ||
+ (root_cert_list = SSL_load_client_CA_file(ssl_ca_file)) == NULL)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("could not load root certificate file \"%s\": %s",
+ ssl_ca_file, SSLerrmessage(ERR_get_error()))));
+ goto error;
+ }
+
+ /*
+ * Tell OpenSSL to send the list of root certs we trust to clients in
+ * CertificateRequests. This lets a client with a keystore select the
+ * appropriate client certificate to send to us. Also, this ensures
+ * that the SSL context will "own" the root_cert_list and remember to
+ * free it when no longer needed.
+ */
+ SSL_CTX_set_client_CA_list(context, root_cert_list);
+
+ /*
+ * Always ask for SSL client cert, but don't fail if it's not
+ * presented. We might fail such connections later, depending on what
+ * we find in pg_hba.conf.
+ */
+ SSL_CTX_set_verify(context,
+ (SSL_VERIFY_PEER |
+ SSL_VERIFY_CLIENT_ONCE),
+ verify_cb);
+ }
+
+ /*----------
+ * Load the Certificate Revocation List (CRL).
+ * http://searchsecurity.techtarget.com/sDefinition/0,,sid14_gci803160,00.html
+ *----------
+ */
+ if (ssl_crl_file[0] || ssl_crl_dir[0])
+ {
+ X509_STORE *cvstore = SSL_CTX_get_cert_store(context);
+
+ if (cvstore)
+ {
+ /* Set the flags to check against the complete CRL chain */
+ if (X509_STORE_load_locations(cvstore,
+ ssl_crl_file[0] ? ssl_crl_file : NULL,
+ ssl_crl_dir[0] ? ssl_crl_dir : NULL)
+ == 1)
+ {
+ X509_STORE_set_flags(cvstore,
+ X509_V_FLAG_CRL_CHECK | X509_V_FLAG_CRL_CHECK_ALL);
+ }
+ else if (ssl_crl_dir[0] == 0)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("could not load SSL certificate revocation list file \"%s\": %s",
+ ssl_crl_file, SSLerrmessage(ERR_get_error()))));
+ goto error;
+ }
+ else if (ssl_crl_file[0] == 0)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("could not load SSL certificate revocation list directory \"%s\": %s",
+ ssl_crl_dir, SSLerrmessage(ERR_get_error()))));
+ goto error;
+ }
+ else
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("could not load SSL certificate revocation list file \"%s\" or directory \"%s\": %s",
+ ssl_crl_file, ssl_crl_dir,
+ SSLerrmessage(ERR_get_error()))));
+ goto error;
+ }
+ }
+ }
+
+ /*
+ * Success! Replace any existing SSL_context.
+ */
+ if (SSL_context)
+ SSL_CTX_free(SSL_context);
+
+ SSL_context = context;
+
+ /*
+ * Set flag to remember whether CA store has been loaded into SSL_context.
+ */
+ if (ssl_ca_file[0])
+ ssl_loaded_verify_locations = true;
+ else
+ ssl_loaded_verify_locations = false;
+
+ return 0;
+
+ /* Clean up by releasing working context. */
+error:
+ if (context)
+ SSL_CTX_free(context);
+ return -1;
+}
+
+void
+be_tls_destroy(void)
+{
+ if (SSL_context)
+ SSL_CTX_free(SSL_context);
+ SSL_context = NULL;
+ ssl_loaded_verify_locations = false;
+}
+
+int
+be_tls_open_server(Port *port)
+{
+ int r;
+ int err;
+ int waitfor;
+ unsigned long ecode;
+ bool give_proto_hint;
+
+ Assert(!port->ssl);
+ Assert(!port->peer);
+
+ if (!SSL_context)
+ {
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("could not initialize SSL connection: SSL context not set up")));
+ return -1;
+ }
+
+ /* set up debugging/info callback */
+ SSL_CTX_set_info_callback(SSL_context, info_cb);
+
+ if (!(port->ssl = SSL_new(SSL_context)))
+ {
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("could not initialize SSL connection: %s",
+ SSLerrmessage(ERR_get_error()))));
+ return -1;
+ }
+ if (!my_SSL_set_fd(port, port->sock))
+ {
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("could not set SSL socket: %s",
+ SSLerrmessage(ERR_get_error()))));
+ return -1;
+ }
+ port->ssl_in_use = true;
+
+aloop:
+
+ /*
+ * Prepare to call SSL_get_error() by clearing thread's OpenSSL error
+ * queue. In general, the current thread's error queue must be empty
+ * before the TLS/SSL I/O operation is attempted, or SSL_get_error() will
+ * not work reliably. An extension may have failed to clear the
+ * per-thread error queue following another call to an OpenSSL I/O
+ * routine.
+ */
+ ERR_clear_error();
+ r = SSL_accept(port->ssl);
+ if (r <= 0)
+ {
+ err = SSL_get_error(port->ssl, r);
+
+ /*
+ * Other clients of OpenSSL in the backend may fail to call
+ * ERR_get_error(), but we always do, so as to not cause problems for
+ * OpenSSL clients that don't call ERR_clear_error() defensively. Be
+ * sure that this happens by calling now. SSL_get_error() relies on
+ * the OpenSSL per-thread error queue being intact, so this is the
+ * earliest possible point ERR_get_error() may be called.
+ */
+ ecode = ERR_get_error();
+ switch (err)
+ {
+ case SSL_ERROR_WANT_READ:
+ case SSL_ERROR_WANT_WRITE:
+ /* not allowed during connection establishment */
+ Assert(!port->noblock);
+
+ /*
+ * No need to care about timeouts/interrupts here. At this
+ * point authentication_timeout still employs
+ * StartupPacketTimeoutHandler() which directly exits.
+ */
+ if (err == SSL_ERROR_WANT_READ)
+ waitfor = WL_SOCKET_READABLE | WL_EXIT_ON_PM_DEATH;
+ else
+ waitfor = WL_SOCKET_WRITEABLE | WL_EXIT_ON_PM_DEATH;
+
+ (void) WaitLatchOrSocket(MyLatch, waitfor, port->sock, 0,
+ WAIT_EVENT_SSL_OPEN_SERVER);
+ goto aloop;
+ case SSL_ERROR_SYSCALL:
+ if (r < 0)
+ ereport(COMMERROR,
+ (errcode_for_socket_access(),
+ errmsg("could not accept SSL connection: %m")));
+ else
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("could not accept SSL connection: EOF detected")));
+ break;
+ case SSL_ERROR_SSL:
+ switch (ERR_GET_REASON(ecode))
+ {
+ /*
+ * UNSUPPORTED_PROTOCOL, WRONG_VERSION_NUMBER, and
+ * TLSV1_ALERT_PROTOCOL_VERSION have been observed
+ * when trying to communicate with an old OpenSSL
+ * library, or when the client and server specify
+ * disjoint protocol ranges. NO_PROTOCOLS_AVAILABLE
+ * occurs if there's a local misconfiguration (which
+ * can happen despite our checks, if openssl.cnf
+ * injects a limit we didn't account for). It's not
+ * very clear what would make OpenSSL return the other
+ * codes listed here, but a hint about protocol
+ * versions seems like it's appropriate for all.
+ */
+ case SSL_R_NO_PROTOCOLS_AVAILABLE:
+ case SSL_R_UNSUPPORTED_PROTOCOL:
+ case SSL_R_BAD_PROTOCOL_VERSION_NUMBER:
+ case SSL_R_UNKNOWN_PROTOCOL:
+ case SSL_R_UNKNOWN_SSL_VERSION:
+ case SSL_R_UNSUPPORTED_SSL_VERSION:
+ case SSL_R_WRONG_SSL_VERSION:
+ case SSL_R_WRONG_VERSION_NUMBER:
+ case SSL_R_TLSV1_ALERT_PROTOCOL_VERSION:
+#ifdef SSL_R_VERSION_TOO_HIGH
+ case SSL_R_VERSION_TOO_HIGH:
+ case SSL_R_VERSION_TOO_LOW:
+#endif
+ give_proto_hint = true;
+ break;
+ default:
+ give_proto_hint = false;
+ break;
+ }
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("could not accept SSL connection: %s",
+ SSLerrmessage(ecode)),
+ give_proto_hint ?
+ errhint("This may indicate that the client does not support any SSL protocol version between %s and %s.",
+ ssl_min_protocol_version ?
+ ssl_protocol_version_to_string(ssl_min_protocol_version) :
+ MIN_OPENSSL_TLS_VERSION,
+ ssl_max_protocol_version ?
+ ssl_protocol_version_to_string(ssl_max_protocol_version) :
+ MAX_OPENSSL_TLS_VERSION) : 0));
+ break;
+ case SSL_ERROR_ZERO_RETURN:
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("could not accept SSL connection: EOF detected")));
+ break;
+ default:
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("unrecognized SSL error code: %d",
+ err)));
+ break;
+ }
+ return -1;
+ }
+
+ /* Get client certificate, if available. */
+ port->peer = SSL_get_peer_certificate(port->ssl);
+
+ /* and extract the Common Name and Distinguished Name from it. */
+ port->peer_cn = NULL;
+ port->peer_dn = NULL;
+ port->peer_cert_valid = false;
+ if (port->peer != NULL)
+ {
+ int len;
+ X509_NAME *x509name = X509_get_subject_name(port->peer);
+ char *peer_dn;
+ BIO *bio = NULL;
+ BUF_MEM *bio_buf = NULL;
+
+ len = X509_NAME_get_text_by_NID(x509name, NID_commonName, NULL, 0);
+ if (len != -1)
+ {
+ char *peer_cn;
+
+ peer_cn = MemoryContextAlloc(TopMemoryContext, len + 1);
+ r = X509_NAME_get_text_by_NID(x509name, NID_commonName, peer_cn,
+ len + 1);
+ peer_cn[len] = '\0';
+ if (r != len)
+ {
+ /* shouldn't happen */
+ pfree(peer_cn);
+ return -1;
+ }
+
+ /*
+ * Reject embedded NULLs in certificate common name to prevent
+ * attacks like CVE-2009-4034.
+ */
+ if (len != strlen(peer_cn))
+ {
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("SSL certificate's common name contains embedded null")));
+ pfree(peer_cn);
+ return -1;
+ }
+
+ port->peer_cn = peer_cn;
+ }
+
+ bio = BIO_new(BIO_s_mem());
+ if (!bio)
+ {
+ pfree(port->peer_cn);
+ port->peer_cn = NULL;
+ return -1;
+ }
+
+ /*
+ * RFC2253 is the closest thing to an accepted standard format for
+ * DNs. We have documented how to produce this format from a
+ * certificate. It uses commas instead of slashes for delimiters,
+ * which make regular expression matching a bit easier. Also note that
+ * it prints the Subject fields in reverse order.
+ */
+ X509_NAME_print_ex(bio, x509name, 0, XN_FLAG_RFC2253);
+ if (BIO_get_mem_ptr(bio, &bio_buf) <= 0)
+ {
+ BIO_free(bio);
+ pfree(port->peer_cn);
+ port->peer_cn = NULL;
+ return -1;
+ }
+ peer_dn = MemoryContextAlloc(TopMemoryContext, bio_buf->length + 1);
+ memcpy(peer_dn, bio_buf->data, bio_buf->length);
+ len = bio_buf->length;
+ BIO_free(bio);
+ peer_dn[len] = '\0';
+ if (len != strlen(peer_dn))
+ {
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("SSL certificate's distinguished name contains embedded null")));
+ pfree(peer_dn);
+ pfree(port->peer_cn);
+ port->peer_cn = NULL;
+ return -1;
+ }
+
+ port->peer_dn = peer_dn;
+
+ port->peer_cert_valid = true;
+ }
+
+ return 0;
+}
+
+void
+be_tls_close(Port *port)
+{
+ if (port->ssl)
+ {
+ SSL_shutdown(port->ssl);
+ SSL_free(port->ssl);
+ port->ssl = NULL;
+ port->ssl_in_use = false;
+ }
+
+ if (port->peer)
+ {
+ X509_free(port->peer);
+ port->peer = NULL;
+ }
+
+ if (port->peer_cn)
+ {
+ pfree(port->peer_cn);
+ port->peer_cn = NULL;
+ }
+
+ if (port->peer_dn)
+ {
+ pfree(port->peer_dn);
+ port->peer_dn = NULL;
+ }
+}
+
+ssize_t
+be_tls_read(Port *port, void *ptr, size_t len, int *waitfor)
+{
+ ssize_t n;
+ int err;
+ unsigned long ecode;
+
+ errno = 0;
+ ERR_clear_error();
+ n = SSL_read(port->ssl, ptr, len);
+ err = SSL_get_error(port->ssl, n);
+ ecode = (err != SSL_ERROR_NONE || n < 0) ? ERR_get_error() : 0;
+ switch (err)
+ {
+ case SSL_ERROR_NONE:
+ /* a-ok */
+ break;
+ case SSL_ERROR_WANT_READ:
+ *waitfor = WL_SOCKET_READABLE;
+ errno = EWOULDBLOCK;
+ n = -1;
+ break;
+ case SSL_ERROR_WANT_WRITE:
+ *waitfor = WL_SOCKET_WRITEABLE;
+ errno = EWOULDBLOCK;
+ n = -1;
+ break;
+ case SSL_ERROR_SYSCALL:
+ /* leave it to caller to ereport the value of errno */
+ if (n != -1)
+ {
+ errno = ECONNRESET;
+ n = -1;
+ }
+ break;
+ case SSL_ERROR_SSL:
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("SSL error: %s", SSLerrmessage(ecode))));
+ errno = ECONNRESET;
+ n = -1;
+ break;
+ case SSL_ERROR_ZERO_RETURN:
+ /* connection was cleanly shut down by peer */
+ n = 0;
+ break;
+ default:
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("unrecognized SSL error code: %d",
+ err)));
+ errno = ECONNRESET;
+ n = -1;
+ break;
+ }
+
+ return n;
+}
+
+ssize_t
+be_tls_write(Port *port, void *ptr, size_t len, int *waitfor)
+{
+ ssize_t n;
+ int err;
+ unsigned long ecode;
+
+ errno = 0;
+ ERR_clear_error();
+ n = SSL_write(port->ssl, ptr, len);
+ err = SSL_get_error(port->ssl, n);
+ ecode = (err != SSL_ERROR_NONE || n < 0) ? ERR_get_error() : 0;
+ switch (err)
+ {
+ case SSL_ERROR_NONE:
+ /* a-ok */
+ break;
+ case SSL_ERROR_WANT_READ:
+ *waitfor = WL_SOCKET_READABLE;
+ errno = EWOULDBLOCK;
+ n = -1;
+ break;
+ case SSL_ERROR_WANT_WRITE:
+ *waitfor = WL_SOCKET_WRITEABLE;
+ errno = EWOULDBLOCK;
+ n = -1;
+ break;
+ case SSL_ERROR_SYSCALL:
+ /* leave it to caller to ereport the value of errno */
+ if (n != -1)
+ {
+ errno = ECONNRESET;
+ n = -1;
+ }
+ break;
+ case SSL_ERROR_SSL:
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("SSL error: %s", SSLerrmessage(ecode))));
+ errno = ECONNRESET;
+ n = -1;
+ break;
+ case SSL_ERROR_ZERO_RETURN:
+
+ /*
+ * the SSL connection was closed, leave it to the caller to
+ * ereport it
+ */
+ errno = ECONNRESET;
+ n = -1;
+ break;
+ default:
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("unrecognized SSL error code: %d",
+ err)));
+ errno = ECONNRESET;
+ n = -1;
+ break;
+ }
+
+ return n;
+}
+
+/* ------------------------------------------------------------ */
+/* Internal functions */
+/* ------------------------------------------------------------ */
+
+/*
+ * Private substitute BIO: this does the sending and receiving using send() and
+ * recv() instead. This is so that we can enable and disable interrupts
+ * just while calling recv(). We cannot have interrupts occurring while
+ * the bulk of OpenSSL runs, because it uses malloc() and possibly other
+ * non-reentrant libc facilities. We also need to call send() and recv()
+ * directly so it gets passed through the socket/signals layer on Win32.
+ *
+ * These functions are closely modelled on the standard socket BIO in OpenSSL;
+ * see sock_read() and sock_write() in OpenSSL's crypto/bio/bss_sock.c.
+ * XXX OpenSSL 1.0.1e considers many more errcodes than just EINTR as reasons
+ * to retry; do we need to adopt their logic for that?
+ */
+
+#ifndef HAVE_BIO_GET_DATA
+#define BIO_get_data(bio) (bio->ptr)
+#define BIO_set_data(bio, data) (bio->ptr = data)
+#endif
+
+static BIO_METHOD *my_bio_methods = NULL;
+
+static int
+my_sock_read(BIO *h, char *buf, int size)
+{
+ int res = 0;
+
+ if (buf != NULL)
+ {
+ res = secure_raw_read(((Port *) BIO_get_data(h)), buf, size);
+ BIO_clear_retry_flags(h);
+ if (res <= 0)
+ {
+ /* If we were interrupted, tell caller to retry */
+ if (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)
+ {
+ BIO_set_retry_read(h);
+ }
+ }
+ }
+
+ return res;
+}
+
+static int
+my_sock_write(BIO *h, const char *buf, int size)
+{
+ int res = 0;
+
+ res = secure_raw_write(((Port *) BIO_get_data(h)), buf, size);
+ BIO_clear_retry_flags(h);
+ if (res <= 0)
+ {
+ /* If we were interrupted, tell caller to retry */
+ if (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)
+ {
+ BIO_set_retry_write(h);
+ }
+ }
+
+ return res;
+}
+
+static BIO_METHOD *
+my_BIO_s_socket(void)
+{
+ if (!my_bio_methods)
+ {
+ BIO_METHOD *biom = (BIO_METHOD *) BIO_s_socket();
+#ifdef HAVE_BIO_METH_NEW
+ int my_bio_index;
+
+ my_bio_index = BIO_get_new_index();
+ if (my_bio_index == -1)
+ return NULL;
+ my_bio_index |= (BIO_TYPE_DESCRIPTOR | BIO_TYPE_SOURCE_SINK);
+ my_bio_methods = BIO_meth_new(my_bio_index, "PostgreSQL backend socket");
+ if (!my_bio_methods)
+ return NULL;
+ if (!BIO_meth_set_write(my_bio_methods, my_sock_write) ||
+ !BIO_meth_set_read(my_bio_methods, my_sock_read) ||
+ !BIO_meth_set_gets(my_bio_methods, BIO_meth_get_gets(biom)) ||
+ !BIO_meth_set_puts(my_bio_methods, BIO_meth_get_puts(biom)) ||
+ !BIO_meth_set_ctrl(my_bio_methods, BIO_meth_get_ctrl(biom)) ||
+ !BIO_meth_set_create(my_bio_methods, BIO_meth_get_create(biom)) ||
+ !BIO_meth_set_destroy(my_bio_methods, BIO_meth_get_destroy(biom)) ||
+ !BIO_meth_set_callback_ctrl(my_bio_methods, BIO_meth_get_callback_ctrl(biom)))
+ {
+ BIO_meth_free(my_bio_methods);
+ my_bio_methods = NULL;
+ return NULL;
+ }
+#else
+ my_bio_methods = malloc(sizeof(BIO_METHOD));
+ if (!my_bio_methods)
+ return NULL;
+ memcpy(my_bio_methods, biom, sizeof(BIO_METHOD));
+ my_bio_methods->bread = my_sock_read;
+ my_bio_methods->bwrite = my_sock_write;
+#endif
+ }
+ return my_bio_methods;
+}
+
+/* This should exactly match OpenSSL's SSL_set_fd except for using my BIO */
+static int
+my_SSL_set_fd(Port *port, int fd)
+{
+ int ret = 0;
+ BIO *bio;
+ BIO_METHOD *bio_method;
+
+ bio_method = my_BIO_s_socket();
+ if (bio_method == NULL)
+ {
+ SSLerr(SSL_F_SSL_SET_FD, ERR_R_BUF_LIB);
+ goto err;
+ }
+ bio = BIO_new(bio_method);
+
+ if (bio == NULL)
+ {
+ SSLerr(SSL_F_SSL_SET_FD, ERR_R_BUF_LIB);
+ goto err;
+ }
+ BIO_set_data(bio, port);
+
+ BIO_set_fd(bio, fd, BIO_NOCLOSE);
+ SSL_set_bio(port->ssl, bio, bio);
+ ret = 1;
+err:
+ return ret;
+}
+
+/*
+ * Load precomputed DH parameters.
+ *
+ * To prevent "downgrade" attacks, we perform a number of checks
+ * to verify that the DBA-generated DH parameters file contains
+ * what we expect it to contain.
+ */
+static DH *
+load_dh_file(char *filename, bool isServerStart)
+{
+ FILE *fp;
+ DH *dh = NULL;
+ int codes;
+
+ /* attempt to open file. It's not an error if it doesn't exist. */
+ if ((fp = AllocateFile(filename, "r")) == NULL)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode_for_file_access(),
+ errmsg("could not open DH parameters file \"%s\": %m",
+ filename)));
+ return NULL;
+ }
+
+ dh = PEM_read_DHparams(fp, NULL, NULL, NULL);
+ FreeFile(fp);
+
+ if (dh == NULL)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("could not load DH parameters file: %s",
+ SSLerrmessage(ERR_get_error()))));
+ return NULL;
+ }
+
+ /* make sure the DH parameters are usable */
+ if (DH_check(dh, &codes) == 0)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("invalid DH parameters: %s",
+ SSLerrmessage(ERR_get_error()))));
+ DH_free(dh);
+ return NULL;
+ }
+ if (codes & DH_CHECK_P_NOT_PRIME)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("invalid DH parameters: p is not prime")));
+ DH_free(dh);
+ return NULL;
+ }
+ if ((codes & DH_NOT_SUITABLE_GENERATOR) &&
+ (codes & DH_CHECK_P_NOT_SAFE_PRIME))
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("invalid DH parameters: neither suitable generator or safe prime")));
+ DH_free(dh);
+ return NULL;
+ }
+
+ return dh;
+}
+
+/*
+ * Load hardcoded DH parameters.
+ *
+ * If DH parameters cannot be loaded from a specified file, we can load
+ * the hardcoded DH parameters supplied with the backend to prevent
+ * problems.
+ */
+static DH *
+load_dh_buffer(const char *buffer, size_t len)
+{
+ BIO *bio;
+ DH *dh = NULL;
+
+ bio = BIO_new_mem_buf(unconstify(char *, buffer), len);
+ if (bio == NULL)
+ return NULL;
+ dh = PEM_read_bio_DHparams(bio, NULL, NULL, NULL);
+ if (dh == NULL)
+ ereport(DEBUG2,
+ (errmsg_internal("DH load buffer: %s",
+ SSLerrmessage(ERR_get_error()))));
+ BIO_free(bio);
+
+ return dh;
+}
+
+/*
+ * Passphrase collection callback using ssl_passphrase_command
+ */
+static int
+ssl_external_passwd_cb(char *buf, int size, int rwflag, void *userdata)
+{
+ /* same prompt as OpenSSL uses internally */
+ const char *prompt = "Enter PEM pass phrase:";
+
+ Assert(rwflag == 0);
+
+ return run_ssl_passphrase_command(prompt, ssl_is_server_start, buf, size);
+}
+
+/*
+ * Dummy passphrase callback
+ *
+ * If OpenSSL is told to use a passphrase-protected server key, by default
+ * it will issue a prompt on /dev/tty and try to read a key from there.
+ * That's no good during a postmaster SIGHUP cycle, not to mention SSL context
+ * reload in an EXEC_BACKEND postmaster child. So override it with this dummy
+ * function that just returns an empty passphrase, guaranteeing failure.
+ */
+static int
+dummy_ssl_passwd_cb(char *buf, int size, int rwflag, void *userdata)
+{
+ /* Set flag to change the error message we'll report */
+ dummy_ssl_passwd_cb_called = true;
+ /* And return empty string */
+ Assert(size > 0);
+ buf[0] = '\0';
+ return 0;
+}
+
+/*
+ * Certificate verification callback
+ *
+ * This callback allows us to log intermediate problems during
+ * verification, but for now we'll see if the final error message
+ * contains enough information.
+ *
+ * This callback also allows us to override the default acceptance
+ * criteria (e.g., accepting self-signed or expired certs), but
+ * for now we accept the default checks.
+ */
+static int
+verify_cb(int ok, X509_STORE_CTX *ctx)
+{
+ return ok;
+}
+
+/*
+ * This callback is used to copy SSL information messages
+ * into the PostgreSQL log.
+ */
+static void
+info_cb(const SSL *ssl, int type, int args)
+{
+ const char *desc;
+
+ desc = SSL_state_string_long(ssl);
+
+ switch (type)
+ {
+ case SSL_CB_HANDSHAKE_START:
+ ereport(DEBUG4,
+ (errmsg_internal("SSL: handshake start: \"%s\"", desc)));
+ break;
+ case SSL_CB_HANDSHAKE_DONE:
+ ereport(DEBUG4,
+ (errmsg_internal("SSL: handshake done: \"%s\"", desc)));
+ break;
+ case SSL_CB_ACCEPT_LOOP:
+ ereport(DEBUG4,
+ (errmsg_internal("SSL: accept loop: \"%s\"", desc)));
+ break;
+ case SSL_CB_ACCEPT_EXIT:
+ ereport(DEBUG4,
+ (errmsg_internal("SSL: accept exit (%d): \"%s\"", args, desc)));
+ break;
+ case SSL_CB_CONNECT_LOOP:
+ ereport(DEBUG4,
+ (errmsg_internal("SSL: connect loop: \"%s\"", desc)));
+ break;
+ case SSL_CB_CONNECT_EXIT:
+ ereport(DEBUG4,
+ (errmsg_internal("SSL: connect exit (%d): \"%s\"", args, desc)));
+ break;
+ case SSL_CB_READ_ALERT:
+ ereport(DEBUG4,
+ (errmsg_internal("SSL: read alert (0x%04x): \"%s\"", args, desc)));
+ break;
+ case SSL_CB_WRITE_ALERT:
+ ereport(DEBUG4,
+ (errmsg_internal("SSL: write alert (0x%04x): \"%s\"", args, desc)));
+ break;
+ }
+}
+
+/*
+ * Set DH parameters for generating ephemeral DH keys. The
+ * DH parameters can take a long time to compute, so they must be
+ * precomputed.
+ *
+ * Since few sites will bother to create a parameter file, we also
+ * provide a fallback to the parameters provided by the OpenSSL
+ * project.
+ *
+ * These values can be static (once loaded or computed) since the
+ * OpenSSL library can efficiently generate random keys from the
+ * information provided.
+ */
+static bool
+initialize_dh(SSL_CTX *context, bool isServerStart)
+{
+ DH *dh = NULL;
+
+ SSL_CTX_set_options(context, SSL_OP_SINGLE_DH_USE);
+
+ if (ssl_dh_params_file[0])
+ dh = load_dh_file(ssl_dh_params_file, isServerStart);
+ if (!dh)
+ dh = load_dh_buffer(FILE_DH2048, sizeof(FILE_DH2048));
+ if (!dh)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("DH: could not load DH parameters")));
+ return false;
+ }
+
+ if (SSL_CTX_set_tmp_dh(context, dh) != 1)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("DH: could not set DH parameters: %s",
+ SSLerrmessage(ERR_get_error()))));
+ DH_free(dh);
+ return false;
+ }
+
+ DH_free(dh);
+ return true;
+}
+
+/*
+ * Set ECDH parameters for generating ephemeral Elliptic Curve DH
+ * keys. This is much simpler than the DH parameters, as we just
+ * need to provide the name of the curve to OpenSSL.
+ */
+static bool
+initialize_ecdh(SSL_CTX *context, bool isServerStart)
+{
+#ifndef OPENSSL_NO_ECDH
+ EC_KEY *ecdh;
+ int nid;
+
+ nid = OBJ_sn2nid(SSLECDHCurve);
+ if (!nid)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("ECDH: unrecognized curve name: %s", SSLECDHCurve)));
+ return false;
+ }
+
+ ecdh = EC_KEY_new_by_curve_name(nid);
+ if (!ecdh)
+ {
+ ereport(isServerStart ? FATAL : LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("ECDH: could not create key")));
+ return false;
+ }
+
+ SSL_CTX_set_options(context, SSL_OP_SINGLE_ECDH_USE);
+ SSL_CTX_set_tmp_ecdh(context, ecdh);
+ EC_KEY_free(ecdh);
+#endif
+
+ return true;
+}
+
+/*
+ * Obtain reason string for passed SSL errcode
+ *
+ * ERR_get_error() is used by caller to get errcode to pass here.
+ *
+ * Some caution is needed here since ERR_reason_error_string will
+ * return NULL if it doesn't recognize the error code. We don't
+ * want to return NULL ever.
+ */
+static const char *
+SSLerrmessage(unsigned long ecode)
+{
+ const char *errreason;
+ static char errbuf[36];
+
+ if (ecode == 0)
+ return _("no SSL error reported");
+ errreason = ERR_reason_error_string(ecode);
+ if (errreason != NULL)
+ return errreason;
+ snprintf(errbuf, sizeof(errbuf), _("SSL error code %lu"), ecode);
+ return errbuf;
+}
+
+int
+be_tls_get_cipher_bits(Port *port)
+{
+ int bits;
+
+ if (port->ssl)
+ {
+ SSL_get_cipher_bits(port->ssl, &bits);
+ return bits;
+ }
+ else
+ return 0;
+}
+
+const char *
+be_tls_get_version(Port *port)
+{
+ if (port->ssl)
+ return SSL_get_version(port->ssl);
+ else
+ return NULL;
+}
+
+const char *
+be_tls_get_cipher(Port *port)
+{
+ if (port->ssl)
+ return SSL_get_cipher(port->ssl);
+ else
+ return NULL;
+}
+
+void
+be_tls_get_peer_subject_name(Port *port, char *ptr, size_t len)
+{
+ if (port->peer)
+ strlcpy(ptr, X509_NAME_to_cstring(X509_get_subject_name(port->peer)), len);
+ else
+ ptr[0] = '\0';
+}
+
+void
+be_tls_get_peer_issuer_name(Port *port, char *ptr, size_t len)
+{
+ if (port->peer)
+ strlcpy(ptr, X509_NAME_to_cstring(X509_get_issuer_name(port->peer)), len);
+ else
+ ptr[0] = '\0';
+}
+
+void
+be_tls_get_peer_serial(Port *port, char *ptr, size_t len)
+{
+ if (port->peer)
+ {
+ ASN1_INTEGER *serial;
+ BIGNUM *b;
+ char *decimal;
+
+ serial = X509_get_serialNumber(port->peer);
+ b = ASN1_INTEGER_to_BN(serial, NULL);
+ decimal = BN_bn2dec(b);
+
+ BN_free(b);
+ strlcpy(ptr, decimal, len);
+ OPENSSL_free(decimal);
+ }
+ else
+ ptr[0] = '\0';
+}
+
+#ifdef HAVE_X509_GET_SIGNATURE_NID
+char *
+be_tls_get_certificate_hash(Port *port, size_t *len)
+{
+ X509 *server_cert;
+ char *cert_hash;
+ const EVP_MD *algo_type = NULL;
+ unsigned char hash[EVP_MAX_MD_SIZE]; /* size for SHA-512 */
+ unsigned int hash_size;
+ int algo_nid;
+
+ *len = 0;
+ server_cert = SSL_get_certificate(port->ssl);
+ if (server_cert == NULL)
+ return NULL;
+
+ /*
+ * Get the signature algorithm of the certificate to determine the hash
+ * algorithm to use for the result.
+ */
+ if (!OBJ_find_sigid_algs(X509_get_signature_nid(server_cert),
+ &algo_nid, NULL))
+ elog(ERROR, "could not determine server certificate signature algorithm");
+
+ /*
+ * The TLS server's certificate bytes need to be hashed with SHA-256 if
+ * its signature algorithm is MD5 or SHA-1 as per RFC 5929
+ * (https://tools.ietf.org/html/rfc5929#section-4.1). If something else
+ * is used, the same hash as the signature algorithm is used.
+ */
+ switch (algo_nid)
+ {
+ case NID_md5:
+ case NID_sha1:
+ algo_type = EVP_sha256();
+ break;
+ default:
+ algo_type = EVP_get_digestbynid(algo_nid);
+ if (algo_type == NULL)
+ elog(ERROR, "could not find digest for NID %s",
+ OBJ_nid2sn(algo_nid));
+ break;
+ }
+
+ /* generate and save the certificate hash */
+ if (!X509_digest(server_cert, algo_type, hash, &hash_size))
+ elog(ERROR, "could not generate server certificate hash");
+
+ cert_hash = palloc(hash_size);
+ memcpy(cert_hash, hash, hash_size);
+ *len = hash_size;
+
+ return cert_hash;
+}
+#endif
+
+/*
+ * Convert an X509 subject name to a cstring.
+ *
+ */
+static char *
+X509_NAME_to_cstring(X509_NAME *name)
+{
+ BIO *membuf = BIO_new(BIO_s_mem());
+ int i,
+ nid,
+ count = X509_NAME_entry_count(name);
+ X509_NAME_ENTRY *e;
+ ASN1_STRING *v;
+ const char *field_name;
+ size_t size;
+ char nullterm;
+ char *sp;
+ char *dp;
+ char *result;
+
+ if (membuf == NULL)
+ ereport(ERROR,
+ (errcode(ERRCODE_OUT_OF_MEMORY),
+ errmsg("could not create BIO")));
+
+ (void) BIO_set_close(membuf, BIO_CLOSE);
+ for (i = 0; i < count; i++)
+ {
+ e = X509_NAME_get_entry(name, i);
+ nid = OBJ_obj2nid(X509_NAME_ENTRY_get_object(e));
+ if (nid == NID_undef)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+ errmsg("could not get NID for ASN1_OBJECT object")));
+ v = X509_NAME_ENTRY_get_data(e);
+ field_name = OBJ_nid2sn(nid);
+ if (field_name == NULL)
+ field_name = OBJ_nid2ln(nid);
+ if (field_name == NULL)
+ ereport(ERROR,
+ (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
+ errmsg("could not convert NID %d to an ASN1_OBJECT structure", nid)));
+ BIO_printf(membuf, "/%s=", field_name);
+ ASN1_STRING_print_ex(membuf, v,
+ ((ASN1_STRFLGS_RFC2253 & ~ASN1_STRFLGS_ESC_MSB)
+ | ASN1_STRFLGS_UTF8_CONVERT));
+ }
+
+ /* ensure null termination of the BIO's content */
+ nullterm = '\0';
+ BIO_write(membuf, &nullterm, 1);
+ size = BIO_get_mem_data(membuf, &sp);
+ dp = pg_any_to_server(sp, size - 1, PG_UTF8);
+
+ result = pstrdup(dp);
+ if (dp != sp)
+ pfree(dp);
+ if (BIO_free(membuf) != 1)
+ elog(ERROR, "could not free OpenSSL BIO structure");
+
+ return result;
+}
+
+/*
+ * Convert TLS protocol version GUC enum to OpenSSL values
+ *
+ * This is a straightforward one-to-one mapping, but doing it this way makes
+ * guc.c independent of OpenSSL availability and version.
+ *
+ * If a version is passed that is not supported by the current OpenSSL
+ * version, then we return -1. If a nonnegative value is returned,
+ * subsequent code can assume it's working with a supported version.
+ *
+ * Note: this is rather similar to libpq's routine in fe-secure-openssl.c,
+ * so make sure to update both routines if changing this one.
+ */
+static int
+ssl_protocol_version_to_openssl(int v)
+{
+ switch (v)
+ {
+ case PG_TLS_ANY:
+ return 0;
+ case PG_TLS1_VERSION:
+ return TLS1_VERSION;
+ case PG_TLS1_1_VERSION:
+#ifdef TLS1_1_VERSION
+ return TLS1_1_VERSION;
+#else
+ break;
+#endif
+ case PG_TLS1_2_VERSION:
+#ifdef TLS1_2_VERSION
+ return TLS1_2_VERSION;
+#else
+ break;
+#endif
+ case PG_TLS1_3_VERSION:
+#ifdef TLS1_3_VERSION
+ return TLS1_3_VERSION;
+#else
+ break;
+#endif
+ }
+
+ return -1;
+}
+
+/*
+ * Likewise provide a mapping to strings.
+ */
+static const char *
+ssl_protocol_version_to_string(int v)
+{
+ switch (v)
+ {
+ case PG_TLS_ANY:
+ return "any";
+ case PG_TLS1_VERSION:
+ return "TLSv1";
+ case PG_TLS1_1_VERSION:
+ return "TLSv1.1";
+ case PG_TLS1_2_VERSION:
+ return "TLSv1.2";
+ case PG_TLS1_3_VERSION:
+ return "TLSv1.3";
+ }
+
+ return "(unrecognized)";
+}
+
+
+static void
+default_openssl_tls_init(SSL_CTX *context, bool isServerStart)
+{
+ if (isServerStart)
+ {
+ if (ssl_passphrase_command[0])
+ SSL_CTX_set_default_passwd_cb(context, ssl_external_passwd_cb);
+ }
+ else
+ {
+ if (ssl_passphrase_command[0] && ssl_passphrase_command_supports_reload)
+ SSL_CTX_set_default_passwd_cb(context, ssl_external_passwd_cb);
+ else
+
+ /*
+ * If reloading and no external command is configured, override
+ * OpenSSL's default handling of passphrase-protected files,
+ * because we don't want to prompt for a passphrase in an
+ * already-running server.
+ */
+ SSL_CTX_set_default_passwd_cb(context, dummy_ssl_passwd_cb);
+ }
+}
diff --git a/src/backend/libpq/be-secure.c b/src/backend/libpq/be-secure.c
new file mode 100644
index 0000000..8ef0832
--- /dev/null
+++ b/src/backend/libpq/be-secure.c
@@ -0,0 +1,345 @@
+/*-------------------------------------------------------------------------
+ *
+ * be-secure.c
+ * functions related to setting up a secure connection to the frontend.
+ * Secure connections are expected to provide confidentiality,
+ * message integrity and endpoint authentication.
+ *
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ *
+ * IDENTIFICATION
+ * src/backend/libpq/be-secure.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+
+#include <signal.h>
+#include <fcntl.h>
+#include <ctype.h>
+#include <sys/socket.h>
+#include <netdb.h>
+#include <netinet/in.h>
+#ifdef HAVE_NETINET_TCP_H
+#include <netinet/tcp.h>
+#include <arpa/inet.h>
+#endif
+
+#include "libpq/libpq.h"
+#include "miscadmin.h"
+#include "pgstat.h"
+#include "storage/ipc.h"
+#include "storage/proc.h"
+#include "tcop/tcopprot.h"
+#include "utils/memutils.h"
+
+char *ssl_library;
+char *ssl_cert_file;
+char *ssl_key_file;
+char *ssl_ca_file;
+char *ssl_crl_file;
+char *ssl_crl_dir;
+char *ssl_dh_params_file;
+char *ssl_passphrase_command;
+bool ssl_passphrase_command_supports_reload;
+
+#ifdef USE_SSL
+bool ssl_loaded_verify_locations = false;
+#endif
+
+/* GUC variable controlling SSL cipher list */
+char *SSLCipherSuites = NULL;
+
+/* GUC variable for default ECHD curve. */
+char *SSLECDHCurve;
+
+/* GUC variable: if false, prefer client ciphers */
+bool SSLPreferServerCiphers;
+
+int ssl_min_protocol_version;
+int ssl_max_protocol_version;
+
+/* ------------------------------------------------------------ */
+/* Procedures common to all secure sessions */
+/* ------------------------------------------------------------ */
+
+/*
+ * Initialize global context.
+ *
+ * If isServerStart is true, report any errors as FATAL (so we don't return).
+ * Otherwise, log errors at LOG level and return -1 to indicate trouble,
+ * preserving the old SSL state if any. Returns 0 if OK.
+ */
+int
+secure_initialize(bool isServerStart)
+{
+#ifdef USE_SSL
+ return be_tls_init(isServerStart);
+#else
+ return 0;
+#endif
+}
+
+/*
+ * Destroy global context, if any.
+ */
+void
+secure_destroy(void)
+{
+#ifdef USE_SSL
+ be_tls_destroy();
+#endif
+}
+
+/*
+ * Indicate if we have loaded the root CA store to verify certificates
+ */
+bool
+secure_loaded_verify_locations(void)
+{
+#ifdef USE_SSL
+ return ssl_loaded_verify_locations;
+#else
+ return false;
+#endif
+}
+
+/*
+ * Attempt to negotiate secure session.
+ */
+int
+secure_open_server(Port *port)
+{
+ int r = 0;
+
+#ifdef USE_SSL
+ r = be_tls_open_server(port);
+
+ ereport(DEBUG2,
+ (errmsg_internal("SSL connection from DN:\"%s\" CN:\"%s\"",
+ port->peer_dn ? port->peer_dn : "(anonymous)",
+ port->peer_cn ? port->peer_cn : "(anonymous)")));
+#endif
+
+ return r;
+}
+
+/*
+ * Close secure session.
+ */
+void
+secure_close(Port *port)
+{
+#ifdef USE_SSL
+ if (port->ssl_in_use)
+ be_tls_close(port);
+#endif
+}
+
+/*
+ * Read data from a secure connection.
+ */
+ssize_t
+secure_read(Port *port, void *ptr, size_t len)
+{
+ ssize_t n;
+ int waitfor;
+
+ /* Deal with any already-pending interrupt condition. */
+ ProcessClientReadInterrupt(false);
+
+retry:
+#ifdef USE_SSL
+ waitfor = 0;
+ if (port->ssl_in_use)
+ {
+ n = be_tls_read(port, ptr, len, &waitfor);
+ }
+ else
+#endif
+#ifdef ENABLE_GSS
+ if (port->gss && port->gss->enc)
+ {
+ n = be_gssapi_read(port, ptr, len);
+ waitfor = WL_SOCKET_READABLE;
+ }
+ else
+#endif
+ {
+ n = secure_raw_read(port, ptr, len);
+ waitfor = WL_SOCKET_READABLE;
+ }
+
+ /* In blocking mode, wait until the socket is ready */
+ if (n < 0 && !port->noblock && (errno == EWOULDBLOCK || errno == EAGAIN))
+ {
+ WaitEvent event;
+
+ Assert(waitfor);
+
+ ModifyWaitEvent(FeBeWaitSet, FeBeWaitSetSocketPos, waitfor, NULL);
+
+ WaitEventSetWait(FeBeWaitSet, -1 /* no timeout */ , &event, 1,
+ WAIT_EVENT_CLIENT_READ);
+
+ /*
+ * If the postmaster has died, it's not safe to continue running,
+ * because it is the postmaster's job to kill us if some other backend
+ * exits uncleanly. Moreover, we won't run very well in this state;
+ * helper processes like walwriter and the bgwriter will exit, so
+ * performance may be poor. Finally, if we don't exit, pg_ctl will be
+ * unable to restart the postmaster without manual intervention, so no
+ * new connections can be accepted. Exiting clears the deck for a
+ * postmaster restart.
+ *
+ * (Note that we only make this check when we would otherwise sleep on
+ * our latch. We might still continue running for a while if the
+ * postmaster is killed in mid-query, or even through multiple queries
+ * if we never have to wait for read. We don't want to burn too many
+ * cycles checking for this very rare condition, and this should cause
+ * us to exit quickly in most cases.)
+ */
+ if (event.events & WL_POSTMASTER_DEATH)
+ ereport(FATAL,
+ (errcode(ERRCODE_ADMIN_SHUTDOWN),
+ errmsg("terminating connection due to unexpected postmaster exit")));
+
+ /* Handle interrupt. */
+ if (event.events & WL_LATCH_SET)
+ {
+ ResetLatch(MyLatch);
+ ProcessClientReadInterrupt(true);
+
+ /*
+ * We'll retry the read. Most likely it will return immediately
+ * because there's still no data available, and we'll wait for the
+ * socket to become ready again.
+ */
+ }
+ goto retry;
+ }
+
+ /*
+ * Process interrupts that happened during a successful (or non-blocking,
+ * or hard-failed) read.
+ */
+ ProcessClientReadInterrupt(false);
+
+ return n;
+}
+
+ssize_t
+secure_raw_read(Port *port, void *ptr, size_t len)
+{
+ ssize_t n;
+
+ /*
+ * Try to read from the socket without blocking. If it succeeds we're
+ * done, otherwise we'll wait for the socket using the latch mechanism.
+ */
+#ifdef WIN32
+ pgwin32_noblock = true;
+#endif
+ n = recv(port->sock, ptr, len, 0);
+#ifdef WIN32
+ pgwin32_noblock = false;
+#endif
+
+ return n;
+}
+
+
+/*
+ * Write data to a secure connection.
+ */
+ssize_t
+secure_write(Port *port, void *ptr, size_t len)
+{
+ ssize_t n;
+ int waitfor;
+
+ /* Deal with any already-pending interrupt condition. */
+ ProcessClientWriteInterrupt(false);
+
+retry:
+ waitfor = 0;
+#ifdef USE_SSL
+ if (port->ssl_in_use)
+ {
+ n = be_tls_write(port, ptr, len, &waitfor);
+ }
+ else
+#endif
+#ifdef ENABLE_GSS
+ if (port->gss && port->gss->enc)
+ {
+ n = be_gssapi_write(port, ptr, len);
+ waitfor = WL_SOCKET_WRITEABLE;
+ }
+ else
+#endif
+ {
+ n = secure_raw_write(port, ptr, len);
+ waitfor = WL_SOCKET_WRITEABLE;
+ }
+
+ if (n < 0 && !port->noblock && (errno == EWOULDBLOCK || errno == EAGAIN))
+ {
+ WaitEvent event;
+
+ Assert(waitfor);
+
+ ModifyWaitEvent(FeBeWaitSet, FeBeWaitSetSocketPos, waitfor, NULL);
+
+ WaitEventSetWait(FeBeWaitSet, -1 /* no timeout */ , &event, 1,
+ WAIT_EVENT_CLIENT_WRITE);
+
+ /* See comments in secure_read. */
+ if (event.events & WL_POSTMASTER_DEATH)
+ ereport(FATAL,
+ (errcode(ERRCODE_ADMIN_SHUTDOWN),
+ errmsg("terminating connection due to unexpected postmaster exit")));
+
+ /* Handle interrupt. */
+ if (event.events & WL_LATCH_SET)
+ {
+ ResetLatch(MyLatch);
+ ProcessClientWriteInterrupt(true);
+
+ /*
+ * We'll retry the write. Most likely it will return immediately
+ * because there's still no buffer space available, and we'll wait
+ * for the socket to become ready again.
+ */
+ }
+ goto retry;
+ }
+
+ /*
+ * Process interrupts that happened during a successful (or non-blocking,
+ * or hard-failed) write.
+ */
+ ProcessClientWriteInterrupt(false);
+
+ return n;
+}
+
+ssize_t
+secure_raw_write(Port *port, const void *ptr, size_t len)
+{
+ ssize_t n;
+
+#ifdef WIN32
+ pgwin32_noblock = true;
+#endif
+ n = send(port->sock, ptr, len, 0);
+#ifdef WIN32
+ pgwin32_noblock = false;
+#endif
+
+ return n;
+}
diff --git a/src/backend/libpq/crypt.c b/src/backend/libpq/crypt.c
new file mode 100644
index 0000000..3fcad99
--- /dev/null
+++ b/src/backend/libpq/crypt.c
@@ -0,0 +1,290 @@
+/*-------------------------------------------------------------------------
+ *
+ * crypt.c
+ * Functions for dealing with encrypted passwords stored in
+ * pg_authid.rolpassword.
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/backend/libpq/crypt.c
+ *
+ *-------------------------------------------------------------------------
+ */
+#include "postgres.h"
+
+#include <unistd.h>
+
+#include "catalog/pg_authid.h"
+#include "common/md5.h"
+#include "common/scram-common.h"
+#include "libpq/crypt.h"
+#include "libpq/scram.h"
+#include "miscadmin.h"
+#include "utils/builtins.h"
+#include "utils/syscache.h"
+#include "utils/timestamp.h"
+
+
+/*
+ * Fetch stored password for a user, for authentication.
+ *
+ * On error, returns NULL, and stores a palloc'd string describing the reason,
+ * for the postmaster log, in *logdetail. The error reason should *not* be
+ * sent to the client, to avoid giving away user information!
+ */
+char *
+get_role_password(const char *role, char **logdetail)
+{
+ TimestampTz vuntil = 0;
+ HeapTuple roleTup;
+ Datum datum;
+ bool isnull;
+ char *shadow_pass;
+
+ /* Get role info from pg_authid */
+ roleTup = SearchSysCache1(AUTHNAME, PointerGetDatum(role));
+ if (!HeapTupleIsValid(roleTup))
+ {
+ *logdetail = psprintf(_("Role \"%s\" does not exist."),
+ role);
+ return NULL; /* no such user */
+ }
+
+ datum = SysCacheGetAttr(AUTHNAME, roleTup,
+ Anum_pg_authid_rolpassword, &isnull);
+ if (isnull)
+ {
+ ReleaseSysCache(roleTup);
+ *logdetail = psprintf(_("User \"%s\" has no password assigned."),
+ role);
+ return NULL; /* user has no password */
+ }
+ shadow_pass = TextDatumGetCString(datum);
+
+ datum = SysCacheGetAttr(AUTHNAME, roleTup,
+ Anum_pg_authid_rolvaliduntil, &isnull);
+ if (!isnull)
+ vuntil = DatumGetTimestampTz(datum);
+
+ ReleaseSysCache(roleTup);
+
+ /*
+ * Password OK, but check to be sure we are not past rolvaliduntil
+ */
+ if (!isnull && vuntil < GetCurrentTimestamp())
+ {
+ *logdetail = psprintf(_("User \"%s\" has an expired password."),
+ role);
+ return NULL;
+ }
+
+ return shadow_pass;
+}
+
+/*
+ * What kind of a password type is 'shadow_pass'?
+ */
+PasswordType
+get_password_type(const char *shadow_pass)
+{
+ char *encoded_salt;
+ int iterations;
+ uint8 stored_key[SCRAM_KEY_LEN];
+ uint8 server_key[SCRAM_KEY_LEN];
+
+ if (strncmp(shadow_pass, "md5", 3) == 0 &&
+ strlen(shadow_pass) == MD5_PASSWD_LEN &&
+ strspn(shadow_pass + 3, MD5_PASSWD_CHARSET) == MD5_PASSWD_LEN - 3)
+ return PASSWORD_TYPE_MD5;
+ if (parse_scram_secret(shadow_pass, &iterations, &encoded_salt,
+ stored_key, server_key))
+ return PASSWORD_TYPE_SCRAM_SHA_256;
+ return PASSWORD_TYPE_PLAINTEXT;
+}
+
+/*
+ * Given a user-supplied password, convert it into a secret of
+ * 'target_type' kind.
+ *
+ * If the password is already in encrypted form, we cannot reverse the
+ * hash, so it is stored as it is regardless of the requested type.
+ */
+char *
+encrypt_password(PasswordType target_type, const char *role,
+ const char *password)
+{
+ PasswordType guessed_type = get_password_type(password);
+ char *encrypted_password;
+
+ if (guessed_type != PASSWORD_TYPE_PLAINTEXT)
+ {
+ /*
+ * Cannot convert an already-encrypted password from one format to
+ * another, so return it as it is.
+ */
+ return pstrdup(password);
+ }
+
+ switch (target_type)
+ {
+ case PASSWORD_TYPE_MD5:
+ encrypted_password = palloc(MD5_PASSWD_LEN + 1);
+
+ if (!pg_md5_encrypt(password, role, strlen(role),
+ encrypted_password))
+ elog(ERROR, "password encryption failed");
+ return encrypted_password;
+
+ case PASSWORD_TYPE_SCRAM_SHA_256:
+ return pg_be_scram_build_secret(password);
+
+ case PASSWORD_TYPE_PLAINTEXT:
+ elog(ERROR, "cannot encrypt password with 'plaintext'");
+ }
+
+ /*
+ * This shouldn't happen, because the above switch statements should
+ * handle every combination of source and target password types.
+ */
+ elog(ERROR, "cannot encrypt password to requested type");
+ return NULL; /* keep compiler quiet */
+}
+
+/*
+ * Check MD5 authentication response, and return STATUS_OK or STATUS_ERROR.
+ *
+ * 'shadow_pass' is the user's correct password or password hash, as stored
+ * in pg_authid.rolpassword.
+ * 'client_pass' is the response given by the remote user to the MD5 challenge.
+ * 'md5_salt' is the salt used in the MD5 authentication challenge.
+ *
+ * In the error case, optionally store a palloc'd string at *logdetail
+ * that will be sent to the postmaster log (but not the client).
+ */
+int
+md5_crypt_verify(const char *role, const char *shadow_pass,
+ const char *client_pass,
+ const char *md5_salt, int md5_salt_len,
+ char **logdetail)
+{
+ int retval;
+ char crypt_pwd[MD5_PASSWD_LEN + 1];
+
+ Assert(md5_salt_len > 0);
+
+ if (get_password_type(shadow_pass) != PASSWORD_TYPE_MD5)
+ {
+ /* incompatible password hash format. */
+ *logdetail = psprintf(_("User \"%s\" has a password that cannot be used with MD5 authentication."),
+ role);
+ return STATUS_ERROR;
+ }
+
+ /*
+ * Compute the correct answer for the MD5 challenge.
+ *
+ * We do not bother setting logdetail for any pg_md5_encrypt failure
+ * below: the only possible error is out-of-memory, which is unlikely, and
+ * if it did happen adding a psprintf call would only make things worse.
+ */
+ /* stored password already encrypted, only do salt */
+ if (!pg_md5_encrypt(shadow_pass + strlen("md5"),
+ md5_salt, md5_salt_len,
+ crypt_pwd))
+ {
+ return STATUS_ERROR;
+ }
+
+ if (strcmp(client_pass, crypt_pwd) == 0)
+ retval = STATUS_OK;
+ else
+ {
+ *logdetail = psprintf(_("Password does not match for user \"%s\"."),
+ role);
+ retval = STATUS_ERROR;
+ }
+
+ return retval;
+}
+
+/*
+ * Check given password for given user, and return STATUS_OK or STATUS_ERROR.
+ *
+ * 'shadow_pass' is the user's correct password hash, as stored in
+ * pg_authid.rolpassword.
+ * 'client_pass' is the password given by the remote user.
+ *
+ * In the error case, optionally store a palloc'd string at *logdetail
+ * that will be sent to the postmaster log (but not the client).
+ */
+int
+plain_crypt_verify(const char *role, const char *shadow_pass,
+ const char *client_pass,
+ char **logdetail)
+{
+ char crypt_client_pass[MD5_PASSWD_LEN + 1];
+
+ /*
+ * Client sent password in plaintext. If we have an MD5 hash stored, hash
+ * the password the client sent, and compare the hashes. Otherwise
+ * compare the plaintext passwords directly.
+ */
+ switch (get_password_type(shadow_pass))
+ {
+ case PASSWORD_TYPE_SCRAM_SHA_256:
+ if (scram_verify_plain_password(role,
+ client_pass,
+ shadow_pass))
+ {
+ return STATUS_OK;
+ }
+ else
+ {
+ *logdetail = psprintf(_("Password does not match for user \"%s\"."),
+ role);
+ return STATUS_ERROR;
+ }
+ break;
+
+ case PASSWORD_TYPE_MD5:
+ if (!pg_md5_encrypt(client_pass,
+ role,
+ strlen(role),
+ crypt_client_pass))
+ {
+ /*
+ * We do not bother setting logdetail for pg_md5_encrypt
+ * failure: the only possible error is out-of-memory, which is
+ * unlikely, and if it did happen adding a psprintf call would
+ * only make things worse.
+ */
+ return STATUS_ERROR;
+ }
+ if (strcmp(crypt_client_pass, shadow_pass) == 0)
+ return STATUS_OK;
+ else
+ {
+ *logdetail = psprintf(_("Password does not match for user \"%s\"."),
+ role);
+ return STATUS_ERROR;
+ }
+ break;
+
+ case PASSWORD_TYPE_PLAINTEXT:
+
+ /*
+ * We never store passwords in plaintext, so this shouldn't
+ * happen.
+ */
+ break;
+ }
+
+ /*
+ * This shouldn't happen. Plain "password" authentication is possible
+ * with any kind of stored password hash.
+ */
+ *logdetail = psprintf(_("Password of user \"%s\" is in unrecognized format."),
+ role);
+ return STATUS_ERROR;
+}
diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c
new file mode 100644
index 0000000..64e59d4
--- /dev/null
+++ b/src/backend/libpq/hba.c
@@ -0,0 +1,3166 @@
+/*-------------------------------------------------------------------------
+ *
+ * hba.c
+ * Routines to handle host based authentication (that's the scheme
+ * wherein you authenticate a user by seeing what IP address the system
+ * says he comes from and choosing authentication method based on it).
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ *
+ * IDENTIFICATION
+ * src/backend/libpq/hba.c
+ *
+ *-------------------------------------------------------------------------
+ */
+#include "postgres.h"
+
+#include <ctype.h>
+#include <pwd.h>
+#include <fcntl.h>
+#include <sys/param.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <arpa/inet.h>
+#include <unistd.h>
+
+#include "access/htup_details.h"
+#include "catalog/pg_collation.h"
+#include "catalog/pg_type.h"
+#include "common/ip.h"
+#include "common/string.h"
+#include "funcapi.h"
+#include "libpq/ifaddr.h"
+#include "libpq/libpq.h"
+#include "miscadmin.h"
+#include "postmaster/postmaster.h"
+#include "regex/regex.h"
+#include "replication/walsender.h"
+#include "storage/fd.h"
+#include "utils/acl.h"
+#include "utils/builtins.h"
+#include "utils/guc.h"
+#include "utils/lsyscache.h"
+#include "utils/memutils.h"
+#include "utils/varlena.h"
+
+#ifdef USE_LDAP
+#ifdef WIN32
+#include <winldap.h>
+#else
+#include <ldap.h>
+#endif
+#endif
+
+
+#define MAX_TOKEN 256
+
+/* callback data for check_network_callback */
+typedef struct check_network_data
+{
+ IPCompareMethod method; /* test method */
+ SockAddr *raddr; /* client's actual address */
+ bool result; /* set to true if match */
+} check_network_data;
+
+
+#define token_is_keyword(t, k) (!t->quoted && strcmp(t->string, k) == 0)
+#define token_matches(t, k) (strcmp(t->string, k) == 0)
+
+/*
+ * A single string token lexed from a config file, together with whether
+ * the token had been quoted.
+ */
+typedef struct HbaToken
+{
+ char *string;
+ bool quoted;
+} HbaToken;
+
+/*
+ * TokenizedLine represents one line lexed from a config file.
+ * Each item in the "fields" list is a sub-list of HbaTokens.
+ * We don't emit a TokenizedLine for empty or all-comment lines,
+ * so "fields" is never NIL (nor are any of its sub-lists).
+ * Exception: if an error occurs during tokenization, we might
+ * have fields == NIL, in which case err_msg != NULL.
+ */
+typedef struct TokenizedLine
+{
+ List *fields; /* List of lists of HbaTokens */
+ int line_num; /* Line number */
+ char *raw_line; /* Raw line text */
+ char *err_msg; /* Error message if any */
+} TokenizedLine;
+
+/*
+ * pre-parsed content of HBA config file: list of HbaLine structs.
+ * parsed_hba_context is the memory context where it lives.
+ */
+static List *parsed_hba_lines = NIL;
+static MemoryContext parsed_hba_context = NULL;
+
+/*
+ * pre-parsed content of ident mapping file: list of IdentLine structs.
+ * parsed_ident_context is the memory context where it lives.
+ *
+ * NOTE: the IdentLine structs can contain pre-compiled regular expressions
+ * that live outside the memory context. Before destroying or resetting the
+ * memory context, they need to be explicitly free'd.
+ */
+static List *parsed_ident_lines = NIL;
+static MemoryContext parsed_ident_context = NULL;
+
+/*
+ * The following character array represents the names of the authentication
+ * methods that are supported by PostgreSQL.
+ *
+ * Note: keep this in sync with the UserAuth enum in hba.h.
+ */
+static const char *const UserAuthName[] =
+{
+ "reject",
+ "implicit reject", /* Not a user-visible option */
+ "trust",
+ "ident",
+ "password",
+ "md5",
+ "scram-sha-256",
+ "gss",
+ "sspi",
+ "pam",
+ "bsd",
+ "ldap",
+ "cert",
+ "radius",
+ "peer"
+};
+
+
+static MemoryContext tokenize_file(const char *filename, FILE *file,
+ List **tok_lines, int elevel);
+static List *tokenize_inc_file(List *tokens, const char *outer_filename,
+ const char *inc_filename, int elevel, char **err_msg);
+static bool parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
+ int elevel, char **err_msg);
+static ArrayType *gethba_options(HbaLine *hba);
+static void fill_hba_line(Tuplestorestate *tuple_store, TupleDesc tupdesc,
+ int lineno, HbaLine *hba, const char *err_msg);
+static void fill_hba_view(Tuplestorestate *tuple_store, TupleDesc tupdesc);
+
+
+/*
+ * isblank() exists in the ISO C99 spec, but it's not very portable yet,
+ * so provide our own version.
+ */
+bool
+pg_isblank(const char c)
+{
+ return c == ' ' || c == '\t' || c == '\r';
+}
+
+
+/*
+ * Grab one token out of the string pointed to by *lineptr.
+ *
+ * Tokens are strings of non-blank characters bounded by blank characters,
+ * commas, beginning of line, and end of line. Blank means space or tab.
+ *
+ * Tokens can be delimited by double quotes (this allows the inclusion of
+ * blanks or '#', but not newlines). As in SQL, write two double-quotes
+ * to represent a double quote.
+ *
+ * Comments (started by an unquoted '#') are skipped, i.e. the remainder
+ * of the line is ignored.
+ *
+ * (Note that line continuation processing happens before tokenization.
+ * Thus, if a continuation occurs within quoted text or a comment, the
+ * quoted text or comment is considered to continue to the next line.)
+ *
+ * The token, if any, is returned at *buf (a buffer of size bufsz), and
+ * *lineptr is advanced past the token.
+ *
+ * Also, we set *initial_quote to indicate whether there was quoting before
+ * the first character. (We use that to prevent "@x" from being treated
+ * as a file inclusion request. Note that @"x" should be so treated;
+ * we want to allow that to support embedded spaces in file paths.)
+ *
+ * We set *terminating_comma to indicate whether the token is terminated by a
+ * comma (which is not returned).
+ *
+ * In event of an error, log a message at ereport level elevel, and also
+ * set *err_msg to a string describing the error. Currently the only
+ * possible error is token too long for buf.
+ *
+ * If successful: store null-terminated token at *buf and return true.
+ * If no more tokens on line: set *buf = '\0' and return false.
+ * If error: fill buf with truncated or misformatted token and return false.
+ */
+static bool
+next_token(char **lineptr, char *buf, int bufsz,
+ bool *initial_quote, bool *terminating_comma,
+ int elevel, char **err_msg)
+{
+ int c;
+ char *start_buf = buf;
+ char *end_buf = buf + (bufsz - 1);
+ bool in_quote = false;
+ bool was_quote = false;
+ bool saw_quote = false;
+
+ Assert(end_buf > start_buf);
+
+ *initial_quote = false;
+ *terminating_comma = false;
+
+ /* Move over any whitespace and commas preceding the next token */
+ while ((c = (*(*lineptr)++)) != '\0' && (pg_isblank(c) || c == ','))
+ ;
+
+ /*
+ * Build a token in buf of next characters up to EOL, unquoted comma, or
+ * unquoted whitespace.
+ */
+ while (c != '\0' &&
+ (!pg_isblank(c) || in_quote))
+ {
+ /* skip comments to EOL */
+ if (c == '#' && !in_quote)
+ {
+ while ((c = (*(*lineptr)++)) != '\0')
+ ;
+ break;
+ }
+
+ if (buf >= end_buf)
+ {
+ *buf = '\0';
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("authentication file token too long, skipping: \"%s\"",
+ start_buf)));
+ *err_msg = "authentication file token too long";
+ /* Discard remainder of line */
+ while ((c = (*(*lineptr)++)) != '\0')
+ ;
+ /* Un-eat the '\0', in case we're called again */
+ (*lineptr)--;
+ return false;
+ }
+
+ /* we do not pass back a terminating comma in the token */
+ if (c == ',' && !in_quote)
+ {
+ *terminating_comma = true;
+ break;
+ }
+
+ if (c != '"' || was_quote)
+ *buf++ = c;
+
+ /* Literal double-quote is two double-quotes */
+ if (in_quote && c == '"')
+ was_quote = !was_quote;
+ else
+ was_quote = false;
+
+ if (c == '"')
+ {
+ in_quote = !in_quote;
+ saw_quote = true;
+ if (buf == start_buf)
+ *initial_quote = true;
+ }
+
+ c = *(*lineptr)++;
+ }
+
+ /*
+ * Un-eat the char right after the token (critical in case it is '\0',
+ * else next call will read past end of string).
+ */
+ (*lineptr)--;
+
+ *buf = '\0';
+
+ return (saw_quote || buf > start_buf);
+}
+
+/*
+ * Construct a palloc'd HbaToken struct, copying the given string.
+ */
+static HbaToken *
+make_hba_token(const char *token, bool quoted)
+{
+ HbaToken *hbatoken;
+ int toklen;
+
+ toklen = strlen(token);
+ /* we copy string into same palloc block as the struct */
+ hbatoken = (HbaToken *) palloc(sizeof(HbaToken) + toklen + 1);
+ hbatoken->string = (char *) hbatoken + sizeof(HbaToken);
+ hbatoken->quoted = quoted;
+ memcpy(hbatoken->string, token, toklen + 1);
+
+ return hbatoken;
+}
+
+/*
+ * Copy a HbaToken struct into freshly palloc'd memory.
+ */
+static HbaToken *
+copy_hba_token(HbaToken *in)
+{
+ HbaToken *out = make_hba_token(in->string, in->quoted);
+
+ return out;
+}
+
+
+/*
+ * Tokenize one HBA field from a line, handling file inclusion and comma lists.
+ *
+ * filename: current file's pathname (needed to resolve relative pathnames)
+ * *lineptr: current line pointer, which will be advanced past field
+ *
+ * In event of an error, log a message at ereport level elevel, and also
+ * set *err_msg to a string describing the error. Note that the result
+ * may be non-NIL anyway, so *err_msg must be tested to determine whether
+ * there was an error.
+ *
+ * The result is a List of HbaToken structs, one for each token in the field,
+ * or NIL if we reached EOL.
+ */
+static List *
+next_field_expand(const char *filename, char **lineptr,
+ int elevel, char **err_msg)
+{
+ char buf[MAX_TOKEN];
+ bool trailing_comma;
+ bool initial_quote;
+ List *tokens = NIL;
+
+ do
+ {
+ if (!next_token(lineptr, buf, sizeof(buf),
+ &initial_quote, &trailing_comma,
+ elevel, err_msg))
+ break;
+
+ /* Is this referencing a file? */
+ if (!initial_quote && buf[0] == '@' && buf[1] != '\0')
+ tokens = tokenize_inc_file(tokens, filename, buf + 1,
+ elevel, err_msg);
+ else
+ tokens = lappend(tokens, make_hba_token(buf, initial_quote));
+ } while (trailing_comma && (*err_msg == NULL));
+
+ return tokens;
+}
+
+/*
+ * tokenize_inc_file
+ * Expand a file included from another file into an hba "field"
+ *
+ * Opens and tokenises a file included from another HBA config file with @,
+ * and returns all values found therein as a flat list of HbaTokens. If a
+ * @-token is found, recursively expand it. The newly read tokens are
+ * appended to "tokens" (so that foo,bar,@baz does what you expect).
+ * All new tokens are allocated in caller's memory context.
+ *
+ * In event of an error, log a message at ereport level elevel, and also
+ * set *err_msg to a string describing the error. Note that the result
+ * may be non-NIL anyway, so *err_msg must be tested to determine whether
+ * there was an error.
+ */
+static List *
+tokenize_inc_file(List *tokens,
+ const char *outer_filename,
+ const char *inc_filename,
+ int elevel,
+ char **err_msg)
+{
+ char *inc_fullname;
+ FILE *inc_file;
+ List *inc_lines;
+ ListCell *inc_line;
+ MemoryContext linecxt;
+
+ if (is_absolute_path(inc_filename))
+ {
+ /* absolute path is taken as-is */
+ inc_fullname = pstrdup(inc_filename);
+ }
+ else
+ {
+ /* relative path is relative to dir of calling file */
+ inc_fullname = (char *) palloc(strlen(outer_filename) + 1 +
+ strlen(inc_filename) + 1);
+ strcpy(inc_fullname, outer_filename);
+ get_parent_directory(inc_fullname);
+ join_path_components(inc_fullname, inc_fullname, inc_filename);
+ canonicalize_path(inc_fullname);
+ }
+
+ inc_file = AllocateFile(inc_fullname, "r");
+ if (inc_file == NULL)
+ {
+ int save_errno = errno;
+
+ ereport(elevel,
+ (errcode_for_file_access(),
+ errmsg("could not open secondary authentication file \"@%s\" as \"%s\": %m",
+ inc_filename, inc_fullname)));
+ *err_msg = psprintf("could not open secondary authentication file \"@%s\" as \"%s\": %s",
+ inc_filename, inc_fullname, strerror(save_errno));
+ pfree(inc_fullname);
+ return tokens;
+ }
+
+ /* There is possible recursion here if the file contains @ */
+ linecxt = tokenize_file(inc_fullname, inc_file, &inc_lines, elevel);
+
+ FreeFile(inc_file);
+ pfree(inc_fullname);
+
+ /* Copy all tokens found in the file and append to the tokens list */
+ foreach(inc_line, inc_lines)
+ {
+ TokenizedLine *tok_line = (TokenizedLine *) lfirst(inc_line);
+ ListCell *inc_field;
+
+ /* If any line has an error, propagate that up to caller */
+ if (tok_line->err_msg)
+ {
+ *err_msg = pstrdup(tok_line->err_msg);
+ break;
+ }
+
+ foreach(inc_field, tok_line->fields)
+ {
+ List *inc_tokens = lfirst(inc_field);
+ ListCell *inc_token;
+
+ foreach(inc_token, inc_tokens)
+ {
+ HbaToken *token = lfirst(inc_token);
+
+ tokens = lappend(tokens, copy_hba_token(token));
+ }
+ }
+ }
+
+ MemoryContextDelete(linecxt);
+ return tokens;
+}
+
+/*
+ * Tokenize the given file.
+ *
+ * The output is a list of TokenizedLine structs; see struct definition above.
+ *
+ * filename: the absolute path to the target file
+ * file: the already-opened target file
+ * tok_lines: receives output list
+ * elevel: message logging level
+ *
+ * Errors are reported by logging messages at ereport level elevel and by
+ * adding TokenizedLine structs containing non-null err_msg fields to the
+ * output list.
+ *
+ * Return value is a memory context which contains all memory allocated by
+ * this function (it's a child of caller's context).
+ */
+static MemoryContext
+tokenize_file(const char *filename, FILE *file, List **tok_lines, int elevel)
+{
+ int line_number = 1;
+ StringInfoData buf;
+ MemoryContext linecxt;
+ MemoryContext oldcxt;
+
+ linecxt = AllocSetContextCreate(CurrentMemoryContext,
+ "tokenize_file",
+ ALLOCSET_SMALL_SIZES);
+ oldcxt = MemoryContextSwitchTo(linecxt);
+
+ initStringInfo(&buf);
+
+ *tok_lines = NIL;
+
+ while (!feof(file) && !ferror(file))
+ {
+ char *lineptr;
+ List *current_line = NIL;
+ char *err_msg = NULL;
+ int last_backslash_buflen = 0;
+ int continuations = 0;
+
+ /* Collect the next input line, handling backslash continuations */
+ resetStringInfo(&buf);
+
+ while (pg_get_line_append(file, &buf))
+ {
+ /* Strip trailing newline, including \r in case we're on Windows */
+ buf.len = pg_strip_crlf(buf.data);
+
+ /*
+ * Check for backslash continuation. The backslash must be after
+ * the last place we found a continuation, else two backslashes
+ * followed by two \n's would behave surprisingly.
+ */
+ if (buf.len > last_backslash_buflen &&
+ buf.data[buf.len - 1] == '\\')
+ {
+ /* Continuation, so strip it and keep reading */
+ buf.data[--buf.len] = '\0';
+ last_backslash_buflen = buf.len;
+ continuations++;
+ continue;
+ }
+
+ /* Nope, so we have the whole line */
+ break;
+ }
+
+ if (ferror(file))
+ {
+ /* I/O error! */
+ int save_errno = errno;
+
+ ereport(elevel,
+ (errcode_for_file_access(),
+ errmsg("could not read file \"%s\": %m", filename)));
+ err_msg = psprintf("could not read file \"%s\": %s",
+ filename, strerror(save_errno));
+ break;
+ }
+
+ /* Parse fields */
+ lineptr = buf.data;
+ while (*lineptr && err_msg == NULL)
+ {
+ List *current_field;
+
+ current_field = next_field_expand(filename, &lineptr,
+ elevel, &err_msg);
+ /* add field to line, unless we are at EOL or comment start */
+ if (current_field != NIL)
+ current_line = lappend(current_line, current_field);
+ }
+
+ /* Reached EOL; emit line to TokenizedLine list unless it's boring */
+ if (current_line != NIL || err_msg != NULL)
+ {
+ TokenizedLine *tok_line;
+
+ tok_line = (TokenizedLine *) palloc(sizeof(TokenizedLine));
+ tok_line->fields = current_line;
+ tok_line->line_num = line_number;
+ tok_line->raw_line = pstrdup(buf.data);
+ tok_line->err_msg = err_msg;
+ *tok_lines = lappend(*tok_lines, tok_line);
+ }
+
+ line_number += continuations + 1;
+ }
+
+ MemoryContextSwitchTo(oldcxt);
+
+ return linecxt;
+}
+
+
+/*
+ * Does user belong to role?
+ *
+ * userid is the OID of the role given as the attempted login identifier.
+ * We check to see if it is a member of the specified role name.
+ */
+static bool
+is_member(Oid userid, const char *role)
+{
+ Oid roleid;
+
+ if (!OidIsValid(userid))
+ return false; /* if user not exist, say "no" */
+
+ roleid = get_role_oid(role, true);
+
+ if (!OidIsValid(roleid))
+ return false; /* if target role not exist, say "no" */
+
+ /*
+ * See if user is directly or indirectly a member of role. For this
+ * purpose, a superuser is not considered to be automatically a member of
+ * the role, so group auth only applies to explicit membership.
+ */
+ return is_member_of_role_nosuper(userid, roleid);
+}
+
+/*
+ * Check HbaToken list for a match to role, allowing group names.
+ */
+static bool
+check_role(const char *role, Oid roleid, List *tokens)
+{
+ ListCell *cell;
+ HbaToken *tok;
+
+ foreach(cell, tokens)
+ {
+ tok = lfirst(cell);
+ if (!tok->quoted && tok->string[0] == '+')
+ {
+ if (is_member(roleid, tok->string + 1))
+ return true;
+ }
+ else if (token_matches(tok, role) ||
+ token_is_keyword(tok, "all"))
+ return true;
+ }
+ return false;
+}
+
+/*
+ * Check to see if db/role combination matches HbaToken list.
+ */
+static bool
+check_db(const char *dbname, const char *role, Oid roleid, List *tokens)
+{
+ ListCell *cell;
+ HbaToken *tok;
+
+ foreach(cell, tokens)
+ {
+ tok = lfirst(cell);
+ if (am_walsender && !am_db_walsender)
+ {
+ /*
+ * physical replication walsender connections can only match
+ * replication keyword
+ */
+ if (token_is_keyword(tok, "replication"))
+ return true;
+ }
+ else if (token_is_keyword(tok, "all"))
+ return true;
+ else if (token_is_keyword(tok, "sameuser"))
+ {
+ if (strcmp(dbname, role) == 0)
+ return true;
+ }
+ else if (token_is_keyword(tok, "samegroup") ||
+ token_is_keyword(tok, "samerole"))
+ {
+ if (is_member(roleid, dbname))
+ return true;
+ }
+ else if (token_is_keyword(tok, "replication"))
+ continue; /* never match this if not walsender */
+ else if (token_matches(tok, dbname))
+ return true;
+ }
+ return false;
+}
+
+static bool
+ipv4eq(struct sockaddr_in *a, struct sockaddr_in *b)
+{
+ return (a->sin_addr.s_addr == b->sin_addr.s_addr);
+}
+
+#ifdef HAVE_IPV6
+
+static bool
+ipv6eq(struct sockaddr_in6 *a, struct sockaddr_in6 *b)
+{
+ int i;
+
+ for (i = 0; i < 16; i++)
+ if (a->sin6_addr.s6_addr[i] != b->sin6_addr.s6_addr[i])
+ return false;
+
+ return true;
+}
+#endif /* HAVE_IPV6 */
+
+/*
+ * Check whether host name matches pattern.
+ */
+static bool
+hostname_match(const char *pattern, const char *actual_hostname)
+{
+ if (pattern[0] == '.') /* suffix match */
+ {
+ size_t plen = strlen(pattern);
+ size_t hlen = strlen(actual_hostname);
+
+ if (hlen < plen)
+ return false;
+
+ return (pg_strcasecmp(pattern, actual_hostname + (hlen - plen)) == 0);
+ }
+ else
+ return (pg_strcasecmp(pattern, actual_hostname) == 0);
+}
+
+/*
+ * Check to see if a connecting IP matches a given host name.
+ */
+static bool
+check_hostname(hbaPort *port, const char *hostname)
+{
+ struct addrinfo *gai_result,
+ *gai;
+ int ret;
+ bool found;
+
+ /* Quick out if remote host name already known bad */
+ if (port->remote_hostname_resolv < 0)
+ return false;
+
+ /* Lookup remote host name if not already done */
+ if (!port->remote_hostname)
+ {
+ char remote_hostname[NI_MAXHOST];
+
+ ret = pg_getnameinfo_all(&port->raddr.addr, port->raddr.salen,
+ remote_hostname, sizeof(remote_hostname),
+ NULL, 0,
+ NI_NAMEREQD);
+ if (ret != 0)
+ {
+ /* remember failure; don't complain in the postmaster log yet */
+ port->remote_hostname_resolv = -2;
+ port->remote_hostname_errcode = ret;
+ return false;
+ }
+
+ port->remote_hostname = pstrdup(remote_hostname);
+ }
+
+ /* Now see if remote host name matches this pg_hba line */
+ if (!hostname_match(hostname, port->remote_hostname))
+ return false;
+
+ /* If we already verified the forward lookup, we're done */
+ if (port->remote_hostname_resolv == +1)
+ return true;
+
+ /* Lookup IP from host name and check against original IP */
+ ret = getaddrinfo(port->remote_hostname, NULL, NULL, &gai_result);
+ if (ret != 0)
+ {
+ /* remember failure; don't complain in the postmaster log yet */
+ port->remote_hostname_resolv = -2;
+ port->remote_hostname_errcode = ret;
+ return false;
+ }
+
+ found = false;
+ for (gai = gai_result; gai; gai = gai->ai_next)
+ {
+ if (gai->ai_addr->sa_family == port->raddr.addr.ss_family)
+ {
+ if (gai->ai_addr->sa_family == AF_INET)
+ {
+ if (ipv4eq((struct sockaddr_in *) gai->ai_addr,
+ (struct sockaddr_in *) &port->raddr.addr))
+ {
+ found = true;
+ break;
+ }
+ }
+#ifdef HAVE_IPV6
+ else if (gai->ai_addr->sa_family == AF_INET6)
+ {
+ if (ipv6eq((struct sockaddr_in6 *) gai->ai_addr,
+ (struct sockaddr_in6 *) &port->raddr.addr))
+ {
+ found = true;
+ break;
+ }
+ }
+#endif
+ }
+ }
+
+ if (gai_result)
+ freeaddrinfo(gai_result);
+
+ if (!found)
+ elog(DEBUG2, "pg_hba.conf host name \"%s\" rejected because address resolution did not return a match with IP address of client",
+ hostname);
+
+ port->remote_hostname_resolv = found ? +1 : -1;
+
+ return found;
+}
+
+/*
+ * Check to see if a connecting IP matches the given address and netmask.
+ */
+static bool
+check_ip(SockAddr *raddr, struct sockaddr *addr, struct sockaddr *mask)
+{
+ if (raddr->addr.ss_family == addr->sa_family &&
+ pg_range_sockaddr(&raddr->addr,
+ (struct sockaddr_storage *) addr,
+ (struct sockaddr_storage *) mask))
+ return true;
+ return false;
+}
+
+/*
+ * pg_foreach_ifaddr callback: does client addr match this machine interface?
+ */
+static void
+check_network_callback(struct sockaddr *addr, struct sockaddr *netmask,
+ void *cb_data)
+{
+ check_network_data *cn = (check_network_data *) cb_data;
+ struct sockaddr_storage mask;
+
+ /* Already found a match? */
+ if (cn->result)
+ return;
+
+ if (cn->method == ipCmpSameHost)
+ {
+ /* Make an all-ones netmask of appropriate length for family */
+ pg_sockaddr_cidr_mask(&mask, NULL, addr->sa_family);
+ cn->result = check_ip(cn->raddr, addr, (struct sockaddr *) &mask);
+ }
+ else
+ {
+ /* Use the netmask of the interface itself */
+ cn->result = check_ip(cn->raddr, addr, netmask);
+ }
+}
+
+/*
+ * Use pg_foreach_ifaddr to check a samehost or samenet match
+ */
+static bool
+check_same_host_or_net(SockAddr *raddr, IPCompareMethod method)
+{
+ check_network_data cn;
+
+ cn.method = method;
+ cn.raddr = raddr;
+ cn.result = false;
+
+ errno = 0;
+ if (pg_foreach_ifaddr(check_network_callback, &cn) < 0)
+ {
+ ereport(LOG,
+ (errmsg("error enumerating network interfaces: %m")));
+ return false;
+ }
+
+ return cn.result;
+}
+
+
+/*
+ * Macros used to check and report on invalid configuration options.
+ * On error: log a message at level elevel, set *err_msg, and exit the function.
+ * These macros are not as general-purpose as they look, because they know
+ * what the calling function's error-exit value is.
+ *
+ * INVALID_AUTH_OPTION = reports when an option is specified for a method where it's
+ * not supported.
+ * REQUIRE_AUTH_OPTION = same as INVALID_AUTH_OPTION, except it also checks if the
+ * method is actually the one specified. Used as a shortcut when
+ * the option is only valid for one authentication method.
+ * MANDATORY_AUTH_ARG = check if a required option is set for an authentication method,
+ * reporting error if it's not.
+ */
+#define INVALID_AUTH_OPTION(optname, validmethods) \
+do { \
+ ereport(elevel, \
+ (errcode(ERRCODE_CONFIG_FILE_ERROR), \
+ /* translator: the second %s is a list of auth methods */ \
+ errmsg("authentication option \"%s\" is only valid for authentication methods %s", \
+ optname, _(validmethods)), \
+ errcontext("line %d of configuration file \"%s\"", \
+ line_num, HbaFileName))); \
+ *err_msg = psprintf("authentication option \"%s\" is only valid for authentication methods %s", \
+ optname, validmethods); \
+ return false; \
+} while (0)
+
+#define REQUIRE_AUTH_OPTION(methodval, optname, validmethods) \
+do { \
+ if (hbaline->auth_method != methodval) \
+ INVALID_AUTH_OPTION(optname, validmethods); \
+} while (0)
+
+#define MANDATORY_AUTH_ARG(argvar, argname, authname) \
+do { \
+ if (argvar == NULL) { \
+ ereport(elevel, \
+ (errcode(ERRCODE_CONFIG_FILE_ERROR), \
+ errmsg("authentication method \"%s\" requires argument \"%s\" to be set", \
+ authname, argname), \
+ errcontext("line %d of configuration file \"%s\"", \
+ line_num, HbaFileName))); \
+ *err_msg = psprintf("authentication method \"%s\" requires argument \"%s\" to be set", \
+ authname, argname); \
+ return NULL; \
+ } \
+} while (0)
+
+/*
+ * Macros for handling pg_ident problems.
+ * Much as above, but currently the message level is hardwired as LOG
+ * and there is no provision for an err_msg string.
+ *
+ * IDENT_FIELD_ABSENT:
+ * Log a message and exit the function if the given ident field ListCell is
+ * not populated.
+ *
+ * IDENT_MULTI_VALUE:
+ * Log a message and exit the function if the given ident token List has more
+ * than one element.
+ */
+#define IDENT_FIELD_ABSENT(field) \
+do { \
+ if (!field) { \
+ ereport(LOG, \
+ (errcode(ERRCODE_CONFIG_FILE_ERROR), \
+ errmsg("missing entry in file \"%s\" at end of line %d", \
+ IdentFileName, line_num))); \
+ return NULL; \
+ } \
+} while (0)
+
+#define IDENT_MULTI_VALUE(tokens) \
+do { \
+ if (tokens->length > 1) { \
+ ereport(LOG, \
+ (errcode(ERRCODE_CONFIG_FILE_ERROR), \
+ errmsg("multiple values in ident field"), \
+ errcontext("line %d of configuration file \"%s\"", \
+ line_num, IdentFileName))); \
+ return NULL; \
+ } \
+} while (0)
+
+
+/*
+ * Parse one tokenised line from the hba config file and store the result in a
+ * HbaLine structure.
+ *
+ * If parsing fails, log a message at ereport level elevel, store an error
+ * string in tok_line->err_msg, and return NULL. (Some non-error conditions
+ * can also result in such messages.)
+ *
+ * Note: this function leaks memory when an error occurs. Caller is expected
+ * to have set a memory context that will be reset if this function returns
+ * NULL.
+ */
+static HbaLine *
+parse_hba_line(TokenizedLine *tok_line, int elevel)
+{
+ int line_num = tok_line->line_num;
+ char **err_msg = &tok_line->err_msg;
+ char *str;
+ struct addrinfo *gai_result;
+ struct addrinfo hints;
+ int ret;
+ char *cidr_slash;
+ char *unsupauth;
+ ListCell *field;
+ List *tokens;
+ ListCell *tokencell;
+ HbaToken *token;
+ HbaLine *parsedline;
+
+ parsedline = palloc0(sizeof(HbaLine));
+ parsedline->linenumber = line_num;
+ parsedline->rawline = pstrdup(tok_line->raw_line);
+
+ /* Check the record type. */
+ Assert(tok_line->fields != NIL);
+ field = list_head(tok_line->fields);
+ tokens = lfirst(field);
+ if (tokens->length > 1)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("multiple values specified for connection type"),
+ errhint("Specify exactly one connection type per line."),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "multiple values specified for connection type";
+ return NULL;
+ }
+ token = linitial(tokens);
+ if (strcmp(token->string, "local") == 0)
+ {
+#ifdef HAVE_UNIX_SOCKETS
+ parsedline->conntype = ctLocal;
+#else
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("local connections are not supported by this build"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "local connections are not supported by this build";
+ return NULL;
+#endif
+ }
+ else if (strcmp(token->string, "host") == 0 ||
+ strcmp(token->string, "hostssl") == 0 ||
+ strcmp(token->string, "hostnossl") == 0 ||
+ strcmp(token->string, "hostgssenc") == 0 ||
+ strcmp(token->string, "hostnogssenc") == 0)
+ {
+
+ if (token->string[4] == 's') /* "hostssl" */
+ {
+ parsedline->conntype = ctHostSSL;
+ /* Log a warning if SSL support is not active */
+#ifdef USE_SSL
+ if (!EnableSSL)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("hostssl record cannot match because SSL is disabled"),
+ errhint("Set ssl = on in postgresql.conf."),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "hostssl record cannot match because SSL is disabled";
+ }
+#else
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("hostssl record cannot match because SSL is not supported by this build"),
+ errhint("Compile with --with-ssl to use SSL connections."),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "hostssl record cannot match because SSL is not supported by this build";
+#endif
+ }
+ else if (token->string[4] == 'g') /* "hostgssenc" */
+ {
+ parsedline->conntype = ctHostGSS;
+#ifndef ENABLE_GSS
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("hostgssenc record cannot match because GSSAPI is not supported by this build"),
+ errhint("Compile with --with-gssapi to use GSSAPI connections."),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "hostgssenc record cannot match because GSSAPI is not supported by this build";
+#endif
+ }
+ else if (token->string[4] == 'n' && token->string[6] == 's')
+ parsedline->conntype = ctHostNoSSL;
+ else if (token->string[4] == 'n' && token->string[6] == 'g')
+ parsedline->conntype = ctHostNoGSS;
+ else
+ {
+ /* "host" */
+ parsedline->conntype = ctHost;
+ }
+ } /* record type */
+ else
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("invalid connection type \"%s\"",
+ token->string),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("invalid connection type \"%s\"", token->string);
+ return NULL;
+ }
+
+ /* Get the databases. */
+ field = lnext(tok_line->fields, field);
+ if (!field)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("end-of-line before database specification"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "end-of-line before database specification";
+ return NULL;
+ }
+ parsedline->databases = NIL;
+ tokens = lfirst(field);
+ foreach(tokencell, tokens)
+ {
+ parsedline->databases = lappend(parsedline->databases,
+ copy_hba_token(lfirst(tokencell)));
+ }
+
+ /* Get the roles. */
+ field = lnext(tok_line->fields, field);
+ if (!field)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("end-of-line before role specification"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "end-of-line before role specification";
+ return NULL;
+ }
+ parsedline->roles = NIL;
+ tokens = lfirst(field);
+ foreach(tokencell, tokens)
+ {
+ parsedline->roles = lappend(parsedline->roles,
+ copy_hba_token(lfirst(tokencell)));
+ }
+
+ if (parsedline->conntype != ctLocal)
+ {
+ /* Read the IP address field. (with or without CIDR netmask) */
+ field = lnext(tok_line->fields, field);
+ if (!field)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("end-of-line before IP address specification"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "end-of-line before IP address specification";
+ return NULL;
+ }
+ tokens = lfirst(field);
+ if (tokens->length > 1)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("multiple values specified for host address"),
+ errhint("Specify one address range per line."),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "multiple values specified for host address";
+ return NULL;
+ }
+ token = linitial(tokens);
+
+ if (token_is_keyword(token, "all"))
+ {
+ parsedline->ip_cmp_method = ipCmpAll;
+ }
+ else if (token_is_keyword(token, "samehost"))
+ {
+ /* Any IP on this host is allowed to connect */
+ parsedline->ip_cmp_method = ipCmpSameHost;
+ }
+ else if (token_is_keyword(token, "samenet"))
+ {
+ /* Any IP on the host's subnets is allowed to connect */
+ parsedline->ip_cmp_method = ipCmpSameNet;
+ }
+ else
+ {
+ /* IP and netmask are specified */
+ parsedline->ip_cmp_method = ipCmpMask;
+
+ /* need a modifiable copy of token */
+ str = pstrdup(token->string);
+
+ /* Check if it has a CIDR suffix and if so isolate it */
+ cidr_slash = strchr(str, '/');
+ if (cidr_slash)
+ *cidr_slash = '\0';
+
+ /* Get the IP address either way */
+ hints.ai_flags = AI_NUMERICHOST;
+ hints.ai_family = AF_UNSPEC;
+ hints.ai_socktype = 0;
+ hints.ai_protocol = 0;
+ hints.ai_addrlen = 0;
+ hints.ai_canonname = NULL;
+ hints.ai_addr = NULL;
+ hints.ai_next = NULL;
+
+ ret = pg_getaddrinfo_all(str, NULL, &hints, &gai_result);
+ if (ret == 0 && gai_result)
+ {
+ memcpy(&parsedline->addr, gai_result->ai_addr,
+ gai_result->ai_addrlen);
+ parsedline->addrlen = gai_result->ai_addrlen;
+ }
+ else if (ret == EAI_NONAME)
+ parsedline->hostname = str;
+ else
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("invalid IP address \"%s\": %s",
+ str, gai_strerror(ret)),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("invalid IP address \"%s\": %s",
+ str, gai_strerror(ret));
+ if (gai_result)
+ pg_freeaddrinfo_all(hints.ai_family, gai_result);
+ return NULL;
+ }
+
+ pg_freeaddrinfo_all(hints.ai_family, gai_result);
+
+ /* Get the netmask */
+ if (cidr_slash)
+ {
+ if (parsedline->hostname)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("specifying both host name and CIDR mask is invalid: \"%s\"",
+ token->string),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("specifying both host name and CIDR mask is invalid: \"%s\"",
+ token->string);
+ return NULL;
+ }
+
+ if (pg_sockaddr_cidr_mask(&parsedline->mask, cidr_slash + 1,
+ parsedline->addr.ss_family) < 0)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("invalid CIDR mask in address \"%s\"",
+ token->string),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("invalid CIDR mask in address \"%s\"",
+ token->string);
+ return NULL;
+ }
+ parsedline->masklen = parsedline->addrlen;
+ pfree(str);
+ }
+ else if (!parsedline->hostname)
+ {
+ /* Read the mask field. */
+ pfree(str);
+ field = lnext(tok_line->fields, field);
+ if (!field)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("end-of-line before netmask specification"),
+ errhint("Specify an address range in CIDR notation, or provide a separate netmask."),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "end-of-line before netmask specification";
+ return NULL;
+ }
+ tokens = lfirst(field);
+ if (tokens->length > 1)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("multiple values specified for netmask"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "multiple values specified for netmask";
+ return NULL;
+ }
+ token = linitial(tokens);
+
+ ret = pg_getaddrinfo_all(token->string, NULL,
+ &hints, &gai_result);
+ if (ret || !gai_result)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("invalid IP mask \"%s\": %s",
+ token->string, gai_strerror(ret)),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("invalid IP mask \"%s\": %s",
+ token->string, gai_strerror(ret));
+ if (gai_result)
+ pg_freeaddrinfo_all(hints.ai_family, gai_result);
+ return NULL;
+ }
+
+ memcpy(&parsedline->mask, gai_result->ai_addr,
+ gai_result->ai_addrlen);
+ parsedline->masklen = gai_result->ai_addrlen;
+ pg_freeaddrinfo_all(hints.ai_family, gai_result);
+
+ if (parsedline->addr.ss_family != parsedline->mask.ss_family)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("IP address and mask do not match"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "IP address and mask do not match";
+ return NULL;
+ }
+ }
+ }
+ } /* != ctLocal */
+
+ /* Get the authentication method */
+ field = lnext(tok_line->fields, field);
+ if (!field)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("end-of-line before authentication method"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "end-of-line before authentication method";
+ return NULL;
+ }
+ tokens = lfirst(field);
+ if (tokens->length > 1)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("multiple values specified for authentication type"),
+ errhint("Specify exactly one authentication type per line."),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "multiple values specified for authentication type";
+ return NULL;
+ }
+ token = linitial(tokens);
+
+ unsupauth = NULL;
+ if (strcmp(token->string, "trust") == 0)
+ parsedline->auth_method = uaTrust;
+ else if (strcmp(token->string, "ident") == 0)
+ parsedline->auth_method = uaIdent;
+ else if (strcmp(token->string, "peer") == 0)
+ parsedline->auth_method = uaPeer;
+ else if (strcmp(token->string, "password") == 0)
+ parsedline->auth_method = uaPassword;
+ else if (strcmp(token->string, "gss") == 0)
+#ifdef ENABLE_GSS
+ parsedline->auth_method = uaGSS;
+#else
+ unsupauth = "gss";
+#endif
+ else if (strcmp(token->string, "sspi") == 0)
+#ifdef ENABLE_SSPI
+ parsedline->auth_method = uaSSPI;
+#else
+ unsupauth = "sspi";
+#endif
+ else if (strcmp(token->string, "reject") == 0)
+ parsedline->auth_method = uaReject;
+ else if (strcmp(token->string, "md5") == 0)
+ {
+ if (Db_user_namespace)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("MD5 authentication is not supported when \"db_user_namespace\" is enabled"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "MD5 authentication is not supported when \"db_user_namespace\" is enabled";
+ return NULL;
+ }
+ parsedline->auth_method = uaMD5;
+ }
+ else if (strcmp(token->string, "scram-sha-256") == 0)
+ parsedline->auth_method = uaSCRAM;
+ else if (strcmp(token->string, "pam") == 0)
+#ifdef USE_PAM
+ parsedline->auth_method = uaPAM;
+#else
+ unsupauth = "pam";
+#endif
+ else if (strcmp(token->string, "bsd") == 0)
+#ifdef USE_BSD_AUTH
+ parsedline->auth_method = uaBSD;
+#else
+ unsupauth = "bsd";
+#endif
+ else if (strcmp(token->string, "ldap") == 0)
+#ifdef USE_LDAP
+ parsedline->auth_method = uaLDAP;
+#else
+ unsupauth = "ldap";
+#endif
+ else if (strcmp(token->string, "cert") == 0)
+#ifdef USE_SSL
+ parsedline->auth_method = uaCert;
+#else
+ unsupauth = "cert";
+#endif
+ else if (strcmp(token->string, "radius") == 0)
+ parsedline->auth_method = uaRADIUS;
+ else
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("invalid authentication method \"%s\"",
+ token->string),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("invalid authentication method \"%s\"",
+ token->string);
+ return NULL;
+ }
+
+ if (unsupauth)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("invalid authentication method \"%s\": not supported by this build",
+ token->string),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("invalid authentication method \"%s\": not supported by this build",
+ token->string);
+ return NULL;
+ }
+
+ /*
+ * XXX: When using ident on local connections, change it to peer, for
+ * backwards compatibility.
+ */
+ if (parsedline->conntype == ctLocal &&
+ parsedline->auth_method == uaIdent)
+ parsedline->auth_method = uaPeer;
+
+ /* Invalid authentication combinations */
+ if (parsedline->conntype == ctLocal &&
+ parsedline->auth_method == uaGSS)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("gssapi authentication is not supported on local sockets"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "gssapi authentication is not supported on local sockets";
+ return NULL;
+ }
+
+ if (parsedline->conntype != ctLocal &&
+ parsedline->auth_method == uaPeer)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("peer authentication is only supported on local sockets"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "peer authentication is only supported on local sockets";
+ return NULL;
+ }
+
+ /*
+ * SSPI authentication can never be enabled on ctLocal connections,
+ * because it's only supported on Windows, where ctLocal isn't supported.
+ */
+
+
+ if (parsedline->conntype != ctHostSSL &&
+ parsedline->auth_method == uaCert)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("cert authentication is only supported on hostssl connections"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "cert authentication is only supported on hostssl connections";
+ return NULL;
+ }
+
+ /*
+ * For GSS and SSPI, set the default value of include_realm to true.
+ * Having include_realm set to false is dangerous in multi-realm
+ * situations and is generally considered bad practice. We keep the
+ * capability around for backwards compatibility, but we might want to
+ * remove it at some point in the future. Users who still need to strip
+ * the realm off would be better served by using an appropriate regex in a
+ * pg_ident.conf mapping.
+ */
+ if (parsedline->auth_method == uaGSS ||
+ parsedline->auth_method == uaSSPI)
+ parsedline->include_realm = true;
+
+ /*
+ * For SSPI, include_realm defaults to the SAM-compatible domain (aka
+ * NetBIOS name) and user names instead of the Kerberos principal name for
+ * compatibility.
+ */
+ if (parsedline->auth_method == uaSSPI)
+ {
+ parsedline->compat_realm = true;
+ parsedline->upn_username = false;
+ }
+
+ /* Parse remaining arguments */
+ while ((field = lnext(tok_line->fields, field)) != NULL)
+ {
+ tokens = lfirst(field);
+ foreach(tokencell, tokens)
+ {
+ char *val;
+
+ token = lfirst(tokencell);
+
+ str = pstrdup(token->string);
+ val = strchr(str, '=');
+ if (val == NULL)
+ {
+ /*
+ * Got something that's not a name=value pair.
+ */
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("authentication option not in name=value format: %s", token->string),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("authentication option not in name=value format: %s",
+ token->string);
+ return NULL;
+ }
+
+ *val++ = '\0'; /* str now holds "name", val holds "value" */
+ if (!parse_hba_auth_opt(str, val, parsedline, elevel, err_msg))
+ /* parse_hba_auth_opt already logged the error message */
+ return NULL;
+ pfree(str);
+ }
+ }
+
+ /*
+ * Check if the selected authentication method has any mandatory arguments
+ * that are not set.
+ */
+ if (parsedline->auth_method == uaLDAP)
+ {
+#ifndef HAVE_LDAP_INITIALIZE
+ /* Not mandatory for OpenLDAP, because it can use DNS SRV records */
+ MANDATORY_AUTH_ARG(parsedline->ldapserver, "ldapserver", "ldap");
+#endif
+
+ /*
+ * LDAP can operate in two modes: either with a direct bind, using
+ * ldapprefix and ldapsuffix, or using a search+bind, using
+ * ldapbasedn, ldapbinddn, ldapbindpasswd and one of
+ * ldapsearchattribute or ldapsearchfilter. Disallow mixing these
+ * parameters.
+ */
+ if (parsedline->ldapprefix || parsedline->ldapsuffix)
+ {
+ if (parsedline->ldapbasedn ||
+ parsedline->ldapbinddn ||
+ parsedline->ldapbindpasswd ||
+ parsedline->ldapsearchattribute ||
+ parsedline->ldapsearchfilter)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("cannot use ldapbasedn, ldapbinddn, ldapbindpasswd, ldapsearchattribute, ldapsearchfilter, or ldapurl together with ldapprefix"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "cannot use ldapbasedn, ldapbinddn, ldapbindpasswd, ldapsearchattribute, ldapsearchfilter, or ldapurl together with ldapprefix";
+ return NULL;
+ }
+ }
+ else if (!parsedline->ldapbasedn)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("authentication method \"ldap\" requires argument \"ldapbasedn\", \"ldapprefix\", or \"ldapsuffix\" to be set"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "authentication method \"ldap\" requires argument \"ldapbasedn\", \"ldapprefix\", or \"ldapsuffix\" to be set";
+ return NULL;
+ }
+
+ /*
+ * When using search+bind, you can either use a simple attribute
+ * (defaulting to "uid") or a fully custom search filter. You can't
+ * do both.
+ */
+ if (parsedline->ldapsearchattribute && parsedline->ldapsearchfilter)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("cannot use ldapsearchattribute together with ldapsearchfilter"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "cannot use ldapsearchattribute together with ldapsearchfilter";
+ return NULL;
+ }
+ }
+
+ if (parsedline->auth_method == uaRADIUS)
+ {
+ MANDATORY_AUTH_ARG(parsedline->radiusservers, "radiusservers", "radius");
+ MANDATORY_AUTH_ARG(parsedline->radiussecrets, "radiussecrets", "radius");
+
+ if (list_length(parsedline->radiusservers) < 1)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("list of RADIUS servers cannot be empty"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "list of RADIUS servers cannot be empty";
+ return NULL;
+ }
+
+ if (list_length(parsedline->radiussecrets) < 1)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("list of RADIUS secrets cannot be empty"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "list of RADIUS secrets cannot be empty";
+ return NULL;
+ }
+
+ /*
+ * Verify length of option lists - each can be 0 (except for secrets,
+ * but that's already checked above), 1 (use the same value
+ * everywhere) or the same as the number of servers.
+ */
+ if (!(list_length(parsedline->radiussecrets) == 1 ||
+ list_length(parsedline->radiussecrets) == list_length(parsedline->radiusservers)))
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("the number of RADIUS secrets (%d) must be 1 or the same as the number of RADIUS servers (%d)",
+ list_length(parsedline->radiussecrets),
+ list_length(parsedline->radiusservers)),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("the number of RADIUS secrets (%d) must be 1 or the same as the number of RADIUS servers (%d)",
+ list_length(parsedline->radiussecrets),
+ list_length(parsedline->radiusservers));
+ return NULL;
+ }
+ if (!(list_length(parsedline->radiusports) == 0 ||
+ list_length(parsedline->radiusports) == 1 ||
+ list_length(parsedline->radiusports) == list_length(parsedline->radiusservers)))
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("the number of RADIUS ports (%d) must be 1 or the same as the number of RADIUS servers (%d)",
+ list_length(parsedline->radiusports),
+ list_length(parsedline->radiusservers)),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("the number of RADIUS ports (%d) must be 1 or the same as the number of RADIUS servers (%d)",
+ list_length(parsedline->radiusports),
+ list_length(parsedline->radiusservers));
+ return NULL;
+ }
+ if (!(list_length(parsedline->radiusidentifiers) == 0 ||
+ list_length(parsedline->radiusidentifiers) == 1 ||
+ list_length(parsedline->radiusidentifiers) == list_length(parsedline->radiusservers)))
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("the number of RADIUS identifiers (%d) must be 1 or the same as the number of RADIUS servers (%d)",
+ list_length(parsedline->radiusidentifiers),
+ list_length(parsedline->radiusservers)),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("the number of RADIUS identifiers (%d) must be 1 or the same as the number of RADIUS servers (%d)",
+ list_length(parsedline->radiusidentifiers),
+ list_length(parsedline->radiusservers));
+ return NULL;
+ }
+ }
+
+ /*
+ * Enforce any parameters implied by other settings.
+ */
+ if (parsedline->auth_method == uaCert)
+ {
+ /*
+ * For auth method cert, client certificate validation is mandatory, and it implies
+ * the level of verify-full.
+ */
+ parsedline->clientcert = clientCertFull;
+ }
+
+ return parsedline;
+}
+
+
+/*
+ * Parse one name-value pair as an authentication option into the given
+ * HbaLine. Return true if we successfully parse the option, false if we
+ * encounter an error. In the event of an error, also log a message at
+ * ereport level elevel, and store a message string into *err_msg.
+ */
+static bool
+parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline,
+ int elevel, char **err_msg)
+{
+ int line_num = hbaline->linenumber;
+
+#ifdef USE_LDAP
+ hbaline->ldapscope = LDAP_SCOPE_SUBTREE;
+#endif
+
+ if (strcmp(name, "map") == 0)
+ {
+ if (hbaline->auth_method != uaIdent &&
+ hbaline->auth_method != uaPeer &&
+ hbaline->auth_method != uaGSS &&
+ hbaline->auth_method != uaSSPI &&
+ hbaline->auth_method != uaCert)
+ INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, and cert"));
+ hbaline->usermap = pstrdup(val);
+ }
+ else if (strcmp(name, "clientcert") == 0)
+ {
+ if (hbaline->conntype != ctHostSSL)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("clientcert can only be configured for \"hostssl\" rows"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "clientcert can only be configured for \"hostssl\" rows";
+ return false;
+ }
+
+ if (strcmp(val, "verify-full") == 0)
+ {
+ hbaline->clientcert = clientCertFull;
+ }
+ else if (strcmp(val, "verify-ca") == 0)
+ {
+ if (hbaline->auth_method == uaCert)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("clientcert only accepts \"verify-full\" when using \"cert\" authentication"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "clientcert can only be set to \"verify-full\" when using \"cert\" authentication";
+ return false;
+ }
+
+ hbaline->clientcert = clientCertCA;
+ }
+ else
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("invalid value for clientcert: \"%s\"", val),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ return false;
+ }
+ }
+ else if (strcmp(name, "clientname") == 0)
+ {
+ if (hbaline->conntype != ctHostSSL)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("clientname can only be configured for \"hostssl\" rows"),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = "clientname can only be configured for \"hostssl\" rows";
+ return false;
+ }
+
+ if (strcmp(val, "CN") == 0)
+ {
+ hbaline->clientcertname = clientCertCN;
+ }
+ else if (strcmp(val, "DN") == 0)
+ {
+ hbaline->clientcertname = clientCertDN;
+ }
+ else
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("invalid value for clientname: \"%s\"", val),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ return false;
+ }
+ }
+ else if (strcmp(name, "pamservice") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaPAM, "pamservice", "pam");
+ hbaline->pamservice = pstrdup(val);
+ }
+ else if (strcmp(name, "pam_use_hostname") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaPAM, "pam_use_hostname", "pam");
+ if (strcmp(val, "1") == 0)
+ hbaline->pam_use_hostname = true;
+ else
+ hbaline->pam_use_hostname = false;
+
+ }
+ else if (strcmp(name, "ldapurl") == 0)
+ {
+#ifdef LDAP_API_FEATURE_X_OPENLDAP
+ LDAPURLDesc *urldata;
+ int rc;
+#endif
+
+ REQUIRE_AUTH_OPTION(uaLDAP, "ldapurl", "ldap");
+#ifdef LDAP_API_FEATURE_X_OPENLDAP
+ rc = ldap_url_parse(val, &urldata);
+ if (rc != LDAP_SUCCESS)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("could not parse LDAP URL \"%s\": %s", val, ldap_err2string(rc))));
+ *err_msg = psprintf("could not parse LDAP URL \"%s\": %s",
+ val, ldap_err2string(rc));
+ return false;
+ }
+
+ if (strcmp(urldata->lud_scheme, "ldap") != 0 &&
+ strcmp(urldata->lud_scheme, "ldaps") != 0)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("unsupported LDAP URL scheme: %s", urldata->lud_scheme)));
+ *err_msg = psprintf("unsupported LDAP URL scheme: %s",
+ urldata->lud_scheme);
+ ldap_free_urldesc(urldata);
+ return false;
+ }
+
+ if (urldata->lud_scheme)
+ hbaline->ldapscheme = pstrdup(urldata->lud_scheme);
+ if (urldata->lud_host)
+ hbaline->ldapserver = pstrdup(urldata->lud_host);
+ hbaline->ldapport = urldata->lud_port;
+ if (urldata->lud_dn)
+ hbaline->ldapbasedn = pstrdup(urldata->lud_dn);
+
+ if (urldata->lud_attrs)
+ hbaline->ldapsearchattribute = pstrdup(urldata->lud_attrs[0]); /* only use first one */
+ hbaline->ldapscope = urldata->lud_scope;
+ if (urldata->lud_filter)
+ hbaline->ldapsearchfilter = pstrdup(urldata->lud_filter);
+ ldap_free_urldesc(urldata);
+#else /* not OpenLDAP */
+ ereport(elevel,
+ (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+ errmsg("LDAP URLs not supported on this platform")));
+ *err_msg = "LDAP URLs not supported on this platform";
+#endif /* not OpenLDAP */
+ }
+ else if (strcmp(name, "ldaptls") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaLDAP, "ldaptls", "ldap");
+ if (strcmp(val, "1") == 0)
+ hbaline->ldaptls = true;
+ else
+ hbaline->ldaptls = false;
+ }
+ else if (strcmp(name, "ldapscheme") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaLDAP, "ldapscheme", "ldap");
+ if (strcmp(val, "ldap") != 0 && strcmp(val, "ldaps") != 0)
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("invalid ldapscheme value: \"%s\"", val),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ hbaline->ldapscheme = pstrdup(val);
+ }
+ else if (strcmp(name, "ldapserver") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaLDAP, "ldapserver", "ldap");
+ hbaline->ldapserver = pstrdup(val);
+ }
+ else if (strcmp(name, "ldapport") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaLDAP, "ldapport", "ldap");
+ hbaline->ldapport = atoi(val);
+ if (hbaline->ldapport == 0)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("invalid LDAP port number: \"%s\"", val),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("invalid LDAP port number: \"%s\"", val);
+ return false;
+ }
+ }
+ else if (strcmp(name, "ldapbinddn") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaLDAP, "ldapbinddn", "ldap");
+ hbaline->ldapbinddn = pstrdup(val);
+ }
+ else if (strcmp(name, "ldapbindpasswd") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaLDAP, "ldapbindpasswd", "ldap");
+ hbaline->ldapbindpasswd = pstrdup(val);
+ }
+ else if (strcmp(name, "ldapsearchattribute") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaLDAP, "ldapsearchattribute", "ldap");
+ hbaline->ldapsearchattribute = pstrdup(val);
+ }
+ else if (strcmp(name, "ldapsearchfilter") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaLDAP, "ldapsearchfilter", "ldap");
+ hbaline->ldapsearchfilter = pstrdup(val);
+ }
+ else if (strcmp(name, "ldapbasedn") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaLDAP, "ldapbasedn", "ldap");
+ hbaline->ldapbasedn = pstrdup(val);
+ }
+ else if (strcmp(name, "ldapprefix") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaLDAP, "ldapprefix", "ldap");
+ hbaline->ldapprefix = pstrdup(val);
+ }
+ else if (strcmp(name, "ldapsuffix") == 0)
+ {
+ REQUIRE_AUTH_OPTION(uaLDAP, "ldapsuffix", "ldap");
+ hbaline->ldapsuffix = pstrdup(val);
+ }
+ else if (strcmp(name, "krb_realm") == 0)
+ {
+ if (hbaline->auth_method != uaGSS &&
+ hbaline->auth_method != uaSSPI)
+ INVALID_AUTH_OPTION("krb_realm", gettext_noop("gssapi and sspi"));
+ hbaline->krb_realm = pstrdup(val);
+ }
+ else if (strcmp(name, "include_realm") == 0)
+ {
+ if (hbaline->auth_method != uaGSS &&
+ hbaline->auth_method != uaSSPI)
+ INVALID_AUTH_OPTION("include_realm", gettext_noop("gssapi and sspi"));
+ if (strcmp(val, "1") == 0)
+ hbaline->include_realm = true;
+ else
+ hbaline->include_realm = false;
+ }
+ else if (strcmp(name, "compat_realm") == 0)
+ {
+ if (hbaline->auth_method != uaSSPI)
+ INVALID_AUTH_OPTION("compat_realm", gettext_noop("sspi"));
+ if (strcmp(val, "1") == 0)
+ hbaline->compat_realm = true;
+ else
+ hbaline->compat_realm = false;
+ }
+ else if (strcmp(name, "upn_username") == 0)
+ {
+ if (hbaline->auth_method != uaSSPI)
+ INVALID_AUTH_OPTION("upn_username", gettext_noop("sspi"));
+ if (strcmp(val, "1") == 0)
+ hbaline->upn_username = true;
+ else
+ hbaline->upn_username = false;
+ }
+ else if (strcmp(name, "radiusservers") == 0)
+ {
+ struct addrinfo *gai_result;
+ struct addrinfo hints;
+ int ret;
+ List *parsed_servers;
+ ListCell *l;
+ char *dupval = pstrdup(val);
+
+ REQUIRE_AUTH_OPTION(uaRADIUS, "radiusservers", "radius");
+
+ if (!SplitGUCList(dupval, ',', &parsed_servers))
+ {
+ /* syntax error in list */
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("could not parse RADIUS server list \"%s\"",
+ val),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ return false;
+ }
+
+ /* For each entry in the list, translate it */
+ foreach(l, parsed_servers)
+ {
+ MemSet(&hints, 0, sizeof(hints));
+ hints.ai_socktype = SOCK_DGRAM;
+ hints.ai_family = AF_UNSPEC;
+
+ ret = pg_getaddrinfo_all((char *) lfirst(l), NULL, &hints, &gai_result);
+ if (ret || !gai_result)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("could not translate RADIUS server name \"%s\" to address: %s",
+ (char *) lfirst(l), gai_strerror(ret)),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ if (gai_result)
+ pg_freeaddrinfo_all(hints.ai_family, gai_result);
+
+ list_free(parsed_servers);
+ return false;
+ }
+ pg_freeaddrinfo_all(hints.ai_family, gai_result);
+ }
+
+ /* All entries are OK, so store them */
+ hbaline->radiusservers = parsed_servers;
+ hbaline->radiusservers_s = pstrdup(val);
+ }
+ else if (strcmp(name, "radiusports") == 0)
+ {
+ List *parsed_ports;
+ ListCell *l;
+ char *dupval = pstrdup(val);
+
+ REQUIRE_AUTH_OPTION(uaRADIUS, "radiusports", "radius");
+
+ if (!SplitGUCList(dupval, ',', &parsed_ports))
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("could not parse RADIUS port list \"%s\"",
+ val),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("invalid RADIUS port number: \"%s\"", val);
+ return false;
+ }
+
+ foreach(l, parsed_ports)
+ {
+ if (atoi(lfirst(l)) == 0)
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("invalid RADIUS port number: \"%s\"", val),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+
+ return false;
+ }
+ }
+ hbaline->radiusports = parsed_ports;
+ hbaline->radiusports_s = pstrdup(val);
+ }
+ else if (strcmp(name, "radiussecrets") == 0)
+ {
+ List *parsed_secrets;
+ char *dupval = pstrdup(val);
+
+ REQUIRE_AUTH_OPTION(uaRADIUS, "radiussecrets", "radius");
+
+ if (!SplitGUCList(dupval, ',', &parsed_secrets))
+ {
+ /* syntax error in list */
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("could not parse RADIUS secret list \"%s\"",
+ val),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ return false;
+ }
+
+ hbaline->radiussecrets = parsed_secrets;
+ hbaline->radiussecrets_s = pstrdup(val);
+ }
+ else if (strcmp(name, "radiusidentifiers") == 0)
+ {
+ List *parsed_identifiers;
+ char *dupval = pstrdup(val);
+
+ REQUIRE_AUTH_OPTION(uaRADIUS, "radiusidentifiers", "radius");
+
+ if (!SplitGUCList(dupval, ',', &parsed_identifiers))
+ {
+ /* syntax error in list */
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("could not parse RADIUS identifiers list \"%s\"",
+ val),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ return false;
+ }
+
+ hbaline->radiusidentifiers = parsed_identifiers;
+ hbaline->radiusidentifiers_s = pstrdup(val);
+ }
+ else
+ {
+ ereport(elevel,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("unrecognized authentication option name: \"%s\"",
+ name),
+ errcontext("line %d of configuration file \"%s\"",
+ line_num, HbaFileName)));
+ *err_msg = psprintf("unrecognized authentication option name: \"%s\"",
+ name);
+ return false;
+ }
+ return true;
+}
+
+/*
+ * Scan the pre-parsed hba file, looking for a match to the port's connection
+ * request.
+ */
+static void
+check_hba(hbaPort *port)
+{
+ Oid roleid;
+ ListCell *line;
+ HbaLine *hba;
+
+ /* Get the target role's OID. Note we do not error out for bad role. */
+ roleid = get_role_oid(port->user_name, true);
+
+ foreach(line, parsed_hba_lines)
+ {
+ hba = (HbaLine *) lfirst(line);
+
+ /* Check connection type */
+ if (hba->conntype == ctLocal)
+ {
+ if (!IS_AF_UNIX(port->raddr.addr.ss_family))
+ continue;
+ }
+ else
+ {
+ if (IS_AF_UNIX(port->raddr.addr.ss_family))
+ continue;
+
+ /* Check SSL state */
+ if (port->ssl_in_use)
+ {
+ /* Connection is SSL, match both "host" and "hostssl" */
+ if (hba->conntype == ctHostNoSSL)
+ continue;
+ }
+ else
+ {
+ /* Connection is not SSL, match both "host" and "hostnossl" */
+ if (hba->conntype == ctHostSSL)
+ continue;
+ }
+
+ /* Check GSSAPI state */
+#ifdef ENABLE_GSS
+ if (port->gss && port->gss->enc &&
+ hba->conntype == ctHostNoGSS)
+ continue;
+ else if (!(port->gss && port->gss->enc) &&
+ hba->conntype == ctHostGSS)
+ continue;
+#else
+ if (hba->conntype == ctHostGSS)
+ continue;
+#endif
+
+ /* Check IP address */
+ switch (hba->ip_cmp_method)
+ {
+ case ipCmpMask:
+ if (hba->hostname)
+ {
+ if (!check_hostname(port,
+ hba->hostname))
+ continue;
+ }
+ else
+ {
+ if (!check_ip(&port->raddr,
+ (struct sockaddr *) &hba->addr,
+ (struct sockaddr *) &hba->mask))
+ continue;
+ }
+ break;
+ case ipCmpAll:
+ break;
+ case ipCmpSameHost:
+ case ipCmpSameNet:
+ if (!check_same_host_or_net(&port->raddr,
+ hba->ip_cmp_method))
+ continue;
+ break;
+ default:
+ /* shouldn't get here, but deem it no-match if so */
+ continue;
+ }
+ } /* != ctLocal */
+
+ /* Check database and role */
+ if (!check_db(port->database_name, port->user_name, roleid,
+ hba->databases))
+ continue;
+
+ if (!check_role(port->user_name, roleid, hba->roles))
+ continue;
+
+ /* Found a record that matched! */
+ port->hba = hba;
+ return;
+ }
+
+ /* If no matching entry was found, then implicitly reject. */
+ hba = palloc0(sizeof(HbaLine));
+ hba->auth_method = uaImplicitReject;
+ port->hba = hba;
+}
+
+/*
+ * Read the config file and create a List of HbaLine records for the contents.
+ *
+ * The configuration is read into a temporary list, and if any parse error
+ * occurs the old list is kept in place and false is returned. Only if the
+ * whole file parses OK is the list replaced, and the function returns true.
+ *
+ * On a false result, caller will take care of reporting a FATAL error in case
+ * this is the initial startup. If it happens on reload, we just keep running
+ * with the old data.
+ */
+bool
+load_hba(void)
+{
+ FILE *file;
+ List *hba_lines = NIL;
+ ListCell *line;
+ List *new_parsed_lines = NIL;
+ bool ok = true;
+ MemoryContext linecxt;
+ MemoryContext oldcxt;
+ MemoryContext hbacxt;
+
+ file = AllocateFile(HbaFileName, "r");
+ if (file == NULL)
+ {
+ ereport(LOG,
+ (errcode_for_file_access(),
+ errmsg("could not open configuration file \"%s\": %m",
+ HbaFileName)));
+ return false;
+ }
+
+ linecxt = tokenize_file(HbaFileName, file, &hba_lines, LOG);
+ FreeFile(file);
+
+ /* Now parse all the lines */
+ Assert(PostmasterContext);
+ hbacxt = AllocSetContextCreate(PostmasterContext,
+ "hba parser context",
+ ALLOCSET_SMALL_SIZES);
+ oldcxt = MemoryContextSwitchTo(hbacxt);
+ foreach(line, hba_lines)
+ {
+ TokenizedLine *tok_line = (TokenizedLine *) lfirst(line);
+ HbaLine *newline;
+
+ /* don't parse lines that already have errors */
+ if (tok_line->err_msg != NULL)
+ {
+ ok = false;
+ continue;
+ }
+
+ if ((newline = parse_hba_line(tok_line, LOG)) == NULL)
+ {
+ /* Parse error; remember there's trouble */
+ ok = false;
+
+ /*
+ * Keep parsing the rest of the file so we can report errors on
+ * more than the first line. Error has already been logged, no
+ * need for more chatter here.
+ */
+ continue;
+ }
+
+ new_parsed_lines = lappend(new_parsed_lines, newline);
+ }
+
+ /*
+ * A valid HBA file must have at least one entry; else there's no way to
+ * connect to the postmaster. But only complain about this if we didn't
+ * already have parsing errors.
+ */
+ if (ok && new_parsed_lines == NIL)
+ {
+ ereport(LOG,
+ (errcode(ERRCODE_CONFIG_FILE_ERROR),
+ errmsg("configuration file \"%s\" contains no entries",
+ HbaFileName)));
+ ok = false;
+ }
+
+ /* Free tokenizer memory */
+ MemoryContextDelete(linecxt);
+ MemoryContextSwitchTo(oldcxt);
+
+ if (!ok)
+ {
+ /* File contained one or more errors, so bail out */
+ MemoryContextDelete(hbacxt);
+ return false;
+ }
+
+ /* Loaded new file successfully, replace the one we use */
+ if (parsed_hba_context != NULL)
+ MemoryContextDelete(parsed_hba_context);
+ parsed_hba_context = hbacxt;
+ parsed_hba_lines = new_parsed_lines;
+
+ return true;
+}
+
+/*
+ * This macro specifies the maximum number of authentication options
+ * that are possible with any given authentication method that is supported.
+ * Currently LDAP supports 11, and there are 3 that are not dependent on
+ * the auth method here. It may not actually be possible to set all of them
+ * at the same time, but we'll set the macro value high enough to be
+ * conservative and avoid warnings from static analysis tools.
+ */
+#define MAX_HBA_OPTIONS 14
+
+/*
+ * Create a text array listing the options specified in the HBA line.
+ * Return NULL if no options are specified.
+ */
+static ArrayType *
+gethba_options(HbaLine *hba)
+{
+ int noptions;
+ Datum options[MAX_HBA_OPTIONS];
+
+ noptions = 0;
+
+ if (hba->auth_method == uaGSS || hba->auth_method == uaSSPI)
+ {
+ if (hba->include_realm)
+ options[noptions++] =
+ CStringGetTextDatum("include_realm=true");
+
+ if (hba->krb_realm)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("krb_realm=%s", hba->krb_realm));
+ }
+
+ if (hba->usermap)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("map=%s", hba->usermap));
+
+ if (hba->clientcert != clientCertOff)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("clientcert=%s", (hba->clientcert == clientCertCA) ? "verify-ca" : "verify-full"));
+
+ if (hba->pamservice)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("pamservice=%s", hba->pamservice));
+
+ if (hba->auth_method == uaLDAP)
+ {
+ if (hba->ldapserver)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("ldapserver=%s", hba->ldapserver));
+
+ if (hba->ldapport)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("ldapport=%d", hba->ldapport));
+
+ if (hba->ldaptls)
+ options[noptions++] =
+ CStringGetTextDatum("ldaptls=true");
+
+ if (hba->ldapprefix)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("ldapprefix=%s", hba->ldapprefix));
+
+ if (hba->ldapsuffix)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("ldapsuffix=%s", hba->ldapsuffix));
+
+ if (hba->ldapbasedn)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("ldapbasedn=%s", hba->ldapbasedn));
+
+ if (hba->ldapbinddn)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("ldapbinddn=%s", hba->ldapbinddn));
+
+ if (hba->ldapbindpasswd)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("ldapbindpasswd=%s",
+ hba->ldapbindpasswd));
+
+ if (hba->ldapsearchattribute)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("ldapsearchattribute=%s",
+ hba->ldapsearchattribute));
+
+ if (hba->ldapsearchfilter)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("ldapsearchfilter=%s",
+ hba->ldapsearchfilter));
+
+ if (hba->ldapscope)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("ldapscope=%d", hba->ldapscope));
+ }
+
+ if (hba->auth_method == uaRADIUS)
+ {
+ if (hba->radiusservers_s)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("radiusservers=%s", hba->radiusservers_s));
+
+ if (hba->radiussecrets_s)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("radiussecrets=%s", hba->radiussecrets_s));
+
+ if (hba->radiusidentifiers_s)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("radiusidentifiers=%s", hba->radiusidentifiers_s));
+
+ if (hba->radiusports_s)
+ options[noptions++] =
+ CStringGetTextDatum(psprintf("radiusports=%s", hba->radiusports_s));
+ }
+
+ /* If you add more options, consider increasing MAX_HBA_OPTIONS. */
+ Assert(noptions <= MAX_HBA_OPTIONS);
+
+ if (noptions > 0)
+ return construct_array(options, noptions, TEXTOID, -1, false, TYPALIGN_INT);
+ else
+ return NULL;
+}
+
+/* Number of columns in pg_hba_file_rules view */
+#define NUM_PG_HBA_FILE_RULES_ATTS 9
+
+/*
+ * fill_hba_line: build one row of pg_hba_file_rules view, add it to tuplestore
+ *
+ * tuple_store: where to store data
+ * tupdesc: tuple descriptor for the view
+ * lineno: pg_hba.conf line number (must always be valid)
+ * hba: parsed line data (can be NULL, in which case err_msg should be set)
+ * err_msg: error message (NULL if none)
+ *
+ * Note: leaks memory, but we don't care since this is run in a short-lived
+ * memory context.
+ */
+static void
+fill_hba_line(Tuplestorestate *tuple_store, TupleDesc tupdesc,
+ int lineno, HbaLine *hba, const char *err_msg)
+{
+ Datum values[NUM_PG_HBA_FILE_RULES_ATTS];
+ bool nulls[NUM_PG_HBA_FILE_RULES_ATTS];
+ char buffer[NI_MAXHOST];
+ HeapTuple tuple;
+ int index;
+ ListCell *lc;
+ const char *typestr;
+ const char *addrstr;
+ const char *maskstr;
+ ArrayType *options;
+
+ Assert(tupdesc->natts == NUM_PG_HBA_FILE_RULES_ATTS);
+
+ memset(values, 0, sizeof(values));
+ memset(nulls, 0, sizeof(nulls));
+ index = 0;
+
+ /* line_number */
+ values[index++] = Int32GetDatum(lineno);
+
+ if (hba != NULL)
+ {
+ /* type */
+ /* Avoid a default: case so compiler will warn about missing cases */
+ typestr = NULL;
+ switch (hba->conntype)
+ {
+ case ctLocal:
+ typestr = "local";
+ break;
+ case ctHost:
+ typestr = "host";
+ break;
+ case ctHostSSL:
+ typestr = "hostssl";
+ break;
+ case ctHostNoSSL:
+ typestr = "hostnossl";
+ break;
+ case ctHostGSS:
+ typestr = "hostgssenc";
+ break;
+ case ctHostNoGSS:
+ typestr = "hostnogssenc";
+ break;
+ }
+ if (typestr)
+ values[index++] = CStringGetTextDatum(typestr);
+ else
+ nulls[index++] = true;
+
+ /* database */
+ if (hba->databases)
+ {
+ /*
+ * Flatten HbaToken list to string list. It might seem that we
+ * should re-quote any quoted tokens, but that has been rejected
+ * on the grounds that it makes it harder to compare the array
+ * elements to other system catalogs. That makes entries like
+ * "all" or "samerole" formally ambiguous ... but users who name
+ * databases/roles that way are inflicting their own pain.
+ */
+ List *names = NIL;
+
+ foreach(lc, hba->databases)
+ {
+ HbaToken *tok = lfirst(lc);
+
+ names = lappend(names, tok->string);
+ }
+ values[index++] = PointerGetDatum(strlist_to_textarray(names));
+ }
+ else
+ nulls[index++] = true;
+
+ /* user */
+ if (hba->roles)
+ {
+ /* Flatten HbaToken list to string list; see comment above */
+ List *roles = NIL;
+
+ foreach(lc, hba->roles)
+ {
+ HbaToken *tok = lfirst(lc);
+
+ roles = lappend(roles, tok->string);
+ }
+ values[index++] = PointerGetDatum(strlist_to_textarray(roles));
+ }
+ else
+ nulls[index++] = true;
+
+ /* address and netmask */
+ /* Avoid a default: case so compiler will warn about missing cases */
+ addrstr = maskstr = NULL;
+ switch (hba->ip_cmp_method)
+ {
+ case ipCmpMask:
+ if (hba->hostname)
+ {
+ addrstr = hba->hostname;
+ }
+ else
+ {
+ /*
+ * Note: if pg_getnameinfo_all fails, it'll set buffer to
+ * "???", which we want to return.
+ */
+ if (hba->addrlen > 0)
+ {
+ if (pg_getnameinfo_all(&hba->addr, hba->addrlen,
+ buffer, sizeof(buffer),
+ NULL, 0,
+ NI_NUMERICHOST) == 0)
+ clean_ipv6_addr(hba->addr.ss_family, buffer);
+ addrstr = pstrdup(buffer);
+ }
+ if (hba->masklen > 0)
+ {
+ if (pg_getnameinfo_all(&hba->mask, hba->masklen,
+ buffer, sizeof(buffer),
+ NULL, 0,
+ NI_NUMERICHOST) == 0)
+ clean_ipv6_addr(hba->mask.ss_family, buffer);
+ maskstr = pstrdup(buffer);
+ }
+ }
+ break;
+ case ipCmpAll:
+ addrstr = "all";
+ break;
+ case ipCmpSameHost:
+ addrstr = "samehost";
+ break;
+ case ipCmpSameNet:
+ addrstr = "samenet";
+ break;
+ }
+ if (addrstr)
+ values[index++] = CStringGetTextDatum(addrstr);
+ else
+ nulls[index++] = true;
+ if (maskstr)
+ values[index++] = CStringGetTextDatum(maskstr);
+ else
+ nulls[index++] = true;
+
+ /* auth_method */
+ values[index++] = CStringGetTextDatum(hba_authname(hba->auth_method));
+
+ /* options */
+ options = gethba_options(hba);
+ if (options)
+ values[index++] = PointerGetDatum(options);
+ else
+ nulls[index++] = true;
+ }
+ else
+ {
+ /* no parsing result, so set relevant fields to nulls */
+ memset(&nulls[1], true, (NUM_PG_HBA_FILE_RULES_ATTS - 2) * sizeof(bool));
+ }
+
+ /* error */
+ if (err_msg)
+ values[NUM_PG_HBA_FILE_RULES_ATTS - 1] = CStringGetTextDatum(err_msg);
+ else
+ nulls[NUM_PG_HBA_FILE_RULES_ATTS - 1] = true;
+
+ tuple = heap_form_tuple(tupdesc, values, nulls);
+ tuplestore_puttuple(tuple_store, tuple);
+}
+
+/*
+ * Read the pg_hba.conf file and fill the tuplestore with view records.
+ */
+static void
+fill_hba_view(Tuplestorestate *tuple_store, TupleDesc tupdesc)
+{
+ FILE *file;
+ List *hba_lines = NIL;
+ ListCell *line;
+ MemoryContext linecxt;
+ MemoryContext hbacxt;
+ MemoryContext oldcxt;
+
+ /*
+ * In the unlikely event that we can't open pg_hba.conf, we throw an
+ * error, rather than trying to report it via some sort of view entry.
+ * (Most other error conditions should result in a message in a view
+ * entry.)
+ */
+ file = AllocateFile(HbaFileName, "r");
+ if (file == NULL)
+ ereport(ERROR,
+ (errcode_for_file_access(),
+ errmsg("could not open configuration file \"%s\": %m",
+ HbaFileName)));
+
+ linecxt = tokenize_file(HbaFileName, file, &hba_lines, DEBUG3);
+ FreeFile(file);
+
+ /* Now parse all the lines */
+ hbacxt = AllocSetContextCreate(CurrentMemoryContext,
+ "hba parser context",
+ ALLOCSET_SMALL_SIZES);
+ oldcxt = MemoryContextSwitchTo(hbacxt);
+ foreach(line, hba_lines)
+ {
+ TokenizedLine *tok_line = (TokenizedLine *) lfirst(line);
+ HbaLine *hbaline = NULL;
+
+ /* don't parse lines that already have errors */
+ if (tok_line->err_msg == NULL)
+ hbaline = parse_hba_line(tok_line, DEBUG3);
+
+ fill_hba_line(tuple_store, tupdesc, tok_line->line_num,
+ hbaline, tok_line->err_msg);
+ }
+
+ /* Free tokenizer memory */
+ MemoryContextDelete(linecxt);
+ /* Free parse_hba_line memory */
+ MemoryContextSwitchTo(oldcxt);
+ MemoryContextDelete(hbacxt);
+}
+
+/*
+ * SQL-accessible SRF to return all the entries in the pg_hba.conf file.
+ */
+Datum
+pg_hba_file_rules(PG_FUNCTION_ARGS)
+{
+ Tuplestorestate *tuple_store;
+ TupleDesc tupdesc;
+ MemoryContext old_cxt;
+ ReturnSetInfo *rsi;
+
+ /*
+ * We must use the Materialize mode to be safe against HBA file changes
+ * while the cursor is open. It's also more efficient than having to look
+ * up our current position in the parsed list every time.
+ */
+ rsi = (ReturnSetInfo *) fcinfo->resultinfo;
+
+ /* Check to see if caller supports us returning a tuplestore */
+ if (rsi == NULL || !IsA(rsi, ReturnSetInfo))
+ ereport(ERROR,
+ (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+ errmsg("set-valued function called in context that cannot accept a set")));
+ if (!(rsi->allowedModes & SFRM_Materialize))
+ ereport(ERROR,
+ (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
+ errmsg("materialize mode required, but it is not allowed in this context")));
+
+ rsi->returnMode = SFRM_Materialize;
+
+ /* Build a tuple descriptor for our result type */
+ if (get_call_result_type(fcinfo, NULL, &tupdesc) != TYPEFUNC_COMPOSITE)
+ elog(ERROR, "return type must be a row type");
+
+ /* Build tuplestore to hold the result rows */
+ old_cxt = MemoryContextSwitchTo(rsi->econtext->ecxt_per_query_memory);
+
+ tuple_store =
+ tuplestore_begin_heap(rsi->allowedModes & SFRM_Materialize_Random,
+ false, work_mem);
+ rsi->setDesc = tupdesc;
+ rsi->setResult = tuple_store;
+
+ MemoryContextSwitchTo(old_cxt);
+
+ /* Fill the tuplestore */
+ fill_hba_view(tuple_store, tupdesc);
+
+ PG_RETURN_NULL();
+}
+
+
+/*
+ * Parse one tokenised line from the ident config file and store the result in
+ * an IdentLine structure.
+ *
+ * If parsing fails, log a message and return NULL.
+ *
+ * If ident_user is a regular expression (ie. begins with a slash), it is
+ * compiled and stored in IdentLine structure.
+ *
+ * Note: this function leaks memory when an error occurs. Caller is expected
+ * to have set a memory context that will be reset if this function returns
+ * NULL.
+ */
+static IdentLine *
+parse_ident_line(TokenizedLine *tok_line)
+{
+ int line_num = tok_line->line_num;
+ ListCell *field;
+ List *tokens;
+ HbaToken *token;
+ IdentLine *parsedline;
+
+ Assert(tok_line->fields != NIL);
+ field = list_head(tok_line->fields);
+
+ parsedline = palloc0(sizeof(IdentLine));
+ parsedline->linenumber = line_num;
+
+ /* Get the map token (must exist) */
+ tokens = lfirst(field);
+ IDENT_MULTI_VALUE(tokens);
+ token = linitial(tokens);
+ parsedline->usermap = pstrdup(token->string);
+
+ /* Get the ident user token */
+ field = lnext(tok_line->fields, field);
+ IDENT_FIELD_ABSENT(field);
+ tokens = lfirst(field);
+ IDENT_MULTI_VALUE(tokens);
+ token = linitial(tokens);
+ parsedline->ident_user = pstrdup(token->string);
+
+ /* Get the PG rolename token */
+ field = lnext(tok_line->fields, field);
+ IDENT_FIELD_ABSENT(field);
+ tokens = lfirst(field);
+ IDENT_MULTI_VALUE(tokens);
+ token = linitial(tokens);
+ parsedline->pg_role = pstrdup(token->string);
+
+ if (parsedline->ident_user[0] == '/')
+ {
+ /*
+ * When system username starts with a slash, treat it as a regular
+ * expression. Pre-compile it.
+ */
+ int r;
+ pg_wchar *wstr;
+ int wlen;
+
+ wstr = palloc((strlen(parsedline->ident_user + 1) + 1) * sizeof(pg_wchar));
+ wlen = pg_mb2wchar_with_len(parsedline->ident_user + 1,
+ wstr, strlen(parsedline->ident_user + 1));
+
+ r = pg_regcomp(&parsedline->re, wstr, wlen, REG_ADVANCED, C_COLLATION_OID);
+ if (r)
+ {
+ char errstr[100];
+
+ pg_regerror(r, &parsedline->re, errstr, sizeof(errstr));
+ ereport(LOG,
+ (errcode(ERRCODE_INVALID_REGULAR_EXPRESSION),
+ errmsg("invalid regular expression \"%s\": %s",
+ parsedline->ident_user + 1, errstr)));
+
+ pfree(wstr);
+ return NULL;
+ }
+ pfree(wstr);
+ }
+
+ return parsedline;
+}
+
+/*
+ * Process one line from the parsed ident config lines.
+ *
+ * Compare input parsed ident line to the needed map, pg_role and ident_user.
+ * *found_p and *error_p are set according to our results.
+ */
+static void
+check_ident_usermap(IdentLine *identLine, const char *usermap_name,
+ const char *pg_role, const char *ident_user,
+ bool case_insensitive, bool *found_p, bool *error_p)
+{
+ *found_p = false;
+ *error_p = false;
+
+ if (strcmp(identLine->usermap, usermap_name) != 0)
+ /* Line does not match the map name we're looking for, so just abort */
+ return;
+
+ /* Match? */
+ if (identLine->ident_user[0] == '/')
+ {
+ /*
+ * When system username starts with a slash, treat it as a regular
+ * expression. In this case, we process the system username as a
+ * regular expression that returns exactly one match. This is replaced
+ * for \1 in the database username string, if present.
+ */
+ int r;
+ regmatch_t matches[2];
+ pg_wchar *wstr;
+ int wlen;
+ char *ofs;
+ char *regexp_pgrole;
+
+ wstr = palloc((strlen(ident_user) + 1) * sizeof(pg_wchar));
+ wlen = pg_mb2wchar_with_len(ident_user, wstr, strlen(ident_user));
+
+ r = pg_regexec(&identLine->re, wstr, wlen, 0, NULL, 2, matches, 0);
+ if (r)
+ {
+ char errstr[100];
+
+ if (r != REG_NOMATCH)
+ {
+ /* REG_NOMATCH is not an error, everything else is */
+ pg_regerror(r, &identLine->re, errstr, sizeof(errstr));
+ ereport(LOG,
+ (errcode(ERRCODE_INVALID_REGULAR_EXPRESSION),
+ errmsg("regular expression match for \"%s\" failed: %s",
+ identLine->ident_user + 1, errstr)));
+ *error_p = true;
+ }
+
+ pfree(wstr);
+ return;
+ }
+ pfree(wstr);
+
+ if ((ofs = strstr(identLine->pg_role, "\\1")) != NULL)
+ {
+ int offset;
+
+ /* substitution of the first argument requested */
+ if (matches[1].rm_so < 0)
+ {
+ ereport(LOG,
+ (errcode(ERRCODE_INVALID_REGULAR_EXPRESSION),
+ errmsg("regular expression \"%s\" has no subexpressions as requested by backreference in \"%s\"",
+ identLine->ident_user + 1, identLine->pg_role)));
+ *error_p = true;
+ return;
+ }
+
+ /*
+ * length: original length minus length of \1 plus length of match
+ * plus null terminator
+ */
+ regexp_pgrole = palloc0(strlen(identLine->pg_role) - 2 + (matches[1].rm_eo - matches[1].rm_so) + 1);
+ offset = ofs - identLine->pg_role;
+ memcpy(regexp_pgrole, identLine->pg_role, offset);
+ memcpy(regexp_pgrole + offset,
+ ident_user + matches[1].rm_so,
+ matches[1].rm_eo - matches[1].rm_so);
+ strcat(regexp_pgrole, ofs + 2);
+ }
+ else
+ {
+ /* no substitution, so copy the match */
+ regexp_pgrole = pstrdup(identLine->pg_role);
+ }
+
+ /*
+ * now check if the username actually matched what the user is trying
+ * to connect as
+ */
+ if (case_insensitive)
+ {
+ if (pg_strcasecmp(regexp_pgrole, pg_role) == 0)
+ *found_p = true;
+ }
+ else
+ {
+ if (strcmp(regexp_pgrole, pg_role) == 0)
+ *found_p = true;
+ }
+ pfree(regexp_pgrole);
+
+ return;
+ }
+ else
+ {
+ /* Not regular expression, so make complete match */
+ if (case_insensitive)
+ {
+ if (pg_strcasecmp(identLine->pg_role, pg_role) == 0 &&
+ pg_strcasecmp(identLine->ident_user, ident_user) == 0)
+ *found_p = true;
+ }
+ else
+ {
+ if (strcmp(identLine->pg_role, pg_role) == 0 &&
+ strcmp(identLine->ident_user, ident_user) == 0)
+ *found_p = true;
+ }
+ }
+}
+
+
+/*
+ * Scan the (pre-parsed) ident usermap file line by line, looking for a match
+ *
+ * See if the user with ident username "auth_user" is allowed to act
+ * as Postgres user "pg_role" according to usermap "usermap_name".
+ *
+ * Special case: Usermap NULL, equivalent to what was previously called
+ * "sameuser" or "samerole", means don't look in the usermap file.
+ * That's an implied map wherein "pg_role" must be identical to
+ * "auth_user" in order to be authorized.
+ *
+ * Iff authorized, return STATUS_OK, otherwise return STATUS_ERROR.
+ */
+int
+check_usermap(const char *usermap_name,
+ const char *pg_role,
+ const char *auth_user,
+ bool case_insensitive)
+{
+ bool found_entry = false,
+ error = false;
+
+ if (usermap_name == NULL || usermap_name[0] == '\0')
+ {
+ if (case_insensitive)
+ {
+ if (pg_strcasecmp(pg_role, auth_user) == 0)
+ return STATUS_OK;
+ }
+ else
+ {
+ if (strcmp(pg_role, auth_user) == 0)
+ return STATUS_OK;
+ }
+ ereport(LOG,
+ (errmsg("provided user name (%s) and authenticated user name (%s) do not match",
+ pg_role, auth_user)));
+ return STATUS_ERROR;
+ }
+ else
+ {
+ ListCell *line_cell;
+
+ foreach(line_cell, parsed_ident_lines)
+ {
+ check_ident_usermap(lfirst(line_cell), usermap_name,
+ pg_role, auth_user, case_insensitive,
+ &found_entry, &error);
+ if (found_entry || error)
+ break;
+ }
+ }
+ if (!found_entry && !error)
+ {
+ ereport(LOG,
+ (errmsg("no match in usermap \"%s\" for user \"%s\" authenticated as \"%s\"",
+ usermap_name, pg_role, auth_user)));
+ }
+ return found_entry ? STATUS_OK : STATUS_ERROR;
+}
+
+
+/*
+ * Read the ident config file and create a List of IdentLine records for
+ * the contents.
+ *
+ * This works the same as load_hba(), but for the user config file.
+ */
+bool
+load_ident(void)
+{
+ FILE *file;
+ List *ident_lines = NIL;
+ ListCell *line_cell,
+ *parsed_line_cell;
+ List *new_parsed_lines = NIL;
+ bool ok = true;
+ MemoryContext linecxt;
+ MemoryContext oldcxt;
+ MemoryContext ident_context;
+ IdentLine *newline;
+
+ file = AllocateFile(IdentFileName, "r");
+ if (file == NULL)
+ {
+ /* not fatal ... we just won't do any special ident maps */
+ ereport(LOG,
+ (errcode_for_file_access(),
+ errmsg("could not open usermap file \"%s\": %m",
+ IdentFileName)));
+ return false;
+ }
+
+ linecxt = tokenize_file(IdentFileName, file, &ident_lines, LOG);
+ FreeFile(file);
+
+ /* Now parse all the lines */
+ Assert(PostmasterContext);
+ ident_context = AllocSetContextCreate(PostmasterContext,
+ "ident parser context",
+ ALLOCSET_SMALL_SIZES);
+ oldcxt = MemoryContextSwitchTo(ident_context);
+ foreach(line_cell, ident_lines)
+ {
+ TokenizedLine *tok_line = (TokenizedLine *) lfirst(line_cell);
+
+ /* don't parse lines that already have errors */
+ if (tok_line->err_msg != NULL)
+ {
+ ok = false;
+ continue;
+ }
+
+ if ((newline = parse_ident_line(tok_line)) == NULL)
+ {
+ /* Parse error; remember there's trouble */
+ ok = false;
+
+ /*
+ * Keep parsing the rest of the file so we can report errors on
+ * more than the first line. Error has already been logged, no
+ * need for more chatter here.
+ */
+ continue;
+ }
+
+ new_parsed_lines = lappend(new_parsed_lines, newline);
+ }
+
+ /* Free tokenizer memory */
+ MemoryContextDelete(linecxt);
+ MemoryContextSwitchTo(oldcxt);
+
+ if (!ok)
+ {
+ /*
+ * File contained one or more errors, so bail out, first being careful
+ * to clean up whatever we allocated. Most stuff will go away via
+ * MemoryContextDelete, but we have to clean up regexes explicitly.
+ */
+ foreach(parsed_line_cell, new_parsed_lines)
+ {
+ newline = (IdentLine *) lfirst(parsed_line_cell);
+ if (newline->ident_user[0] == '/')
+ pg_regfree(&newline->re);
+ }
+ MemoryContextDelete(ident_context);
+ return false;
+ }
+
+ /* Loaded new file successfully, replace the one we use */
+ if (parsed_ident_lines != NIL)
+ {
+ foreach(parsed_line_cell, parsed_ident_lines)
+ {
+ newline = (IdentLine *) lfirst(parsed_line_cell);
+ if (newline->ident_user[0] == '/')
+ pg_regfree(&newline->re);
+ }
+ }
+ if (parsed_ident_context != NULL)
+ MemoryContextDelete(parsed_ident_context);
+
+ parsed_ident_context = ident_context;
+ parsed_ident_lines = new_parsed_lines;
+
+ return true;
+}
+
+
+
+/*
+ * Determine what authentication method should be used when accessing database
+ * "database" from frontend "raddr", user "user". Return the method and
+ * an optional argument (stored in fields of *port), and STATUS_OK.
+ *
+ * If the file does not contain any entry matching the request, we return
+ * method = uaImplicitReject.
+ */
+void
+hba_getauthmethod(hbaPort *port)
+{
+ check_hba(port);
+}
+
+
+/*
+ * Return the name of the auth method in use ("gss", "md5", "trust", etc.).
+ *
+ * The return value is statically allocated (see the UserAuthName array) and
+ * should not be freed.
+ */
+const char *
+hba_authname(UserAuth auth_method)
+{
+ /*
+ * Make sure UserAuthName[] tracks additions to the UserAuth enum
+ */
+ StaticAssertStmt(lengthof(UserAuthName) == USER_AUTH_LAST + 1,
+ "UserAuthName[] must match the UserAuth enum");
+
+ return UserAuthName[auth_method];
+}
diff --git a/src/backend/libpq/ifaddr.c b/src/backend/libpq/ifaddr.c
new file mode 100644
index 0000000..75760f3
--- /dev/null
+++ b/src/backend/libpq/ifaddr.c
@@ -0,0 +1,594 @@
+/*-------------------------------------------------------------------------
+ *
+ * ifaddr.c
+ * IP netmask calculations, and enumerating network interfaces.
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ *
+ * IDENTIFICATION
+ * src/backend/libpq/ifaddr.c
+ *
+ * This file and the IPV6 implementation were initially provided by
+ * Nigel Kukard <nkukard@lbsd.net>, Linux Based Systems Design
+ * http://www.lbsd.net.
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+
+#include <unistd.h>
+#include <sys/stat.h>
+#include <sys/socket.h>
+#include <netdb.h>
+#include <netinet/in.h>
+#ifdef HAVE_NETINET_TCP_H
+#include <netinet/tcp.h>
+#endif
+#include <sys/file.h>
+
+#include "libpq/ifaddr.h"
+#include "port/pg_bswap.h"
+
+static int range_sockaddr_AF_INET(const struct sockaddr_in *addr,
+ const struct sockaddr_in *netaddr,
+ const struct sockaddr_in *netmask);
+
+#ifdef HAVE_IPV6
+static int range_sockaddr_AF_INET6(const struct sockaddr_in6 *addr,
+ const struct sockaddr_in6 *netaddr,
+ const struct sockaddr_in6 *netmask);
+#endif
+
+
+/*
+ * pg_range_sockaddr - is addr within the subnet specified by netaddr/netmask ?
+ *
+ * Note: caller must already have verified that all three addresses are
+ * in the same address family; and AF_UNIX addresses are not supported.
+ */
+int
+pg_range_sockaddr(const struct sockaddr_storage *addr,
+ const struct sockaddr_storage *netaddr,
+ const struct sockaddr_storage *netmask)
+{
+ if (addr->ss_family == AF_INET)
+ return range_sockaddr_AF_INET((const struct sockaddr_in *) addr,
+ (const struct sockaddr_in *) netaddr,
+ (const struct sockaddr_in *) netmask);
+#ifdef HAVE_IPV6
+ else if (addr->ss_family == AF_INET6)
+ return range_sockaddr_AF_INET6((const struct sockaddr_in6 *) addr,
+ (const struct sockaddr_in6 *) netaddr,
+ (const struct sockaddr_in6 *) netmask);
+#endif
+ else
+ return 0;
+}
+
+static int
+range_sockaddr_AF_INET(const struct sockaddr_in *addr,
+ const struct sockaddr_in *netaddr,
+ const struct sockaddr_in *netmask)
+{
+ if (((addr->sin_addr.s_addr ^ netaddr->sin_addr.s_addr) &
+ netmask->sin_addr.s_addr) == 0)
+ return 1;
+ else
+ return 0;
+}
+
+
+#ifdef HAVE_IPV6
+
+static int
+range_sockaddr_AF_INET6(const struct sockaddr_in6 *addr,
+ const struct sockaddr_in6 *netaddr,
+ const struct sockaddr_in6 *netmask)
+{
+ int i;
+
+ for (i = 0; i < 16; i++)
+ {
+ if (((addr->sin6_addr.s6_addr[i] ^ netaddr->sin6_addr.s6_addr[i]) &
+ netmask->sin6_addr.s6_addr[i]) != 0)
+ return 0;
+ }
+
+ return 1;
+}
+#endif /* HAVE_IPV6 */
+
+/*
+ * pg_sockaddr_cidr_mask - make a network mask of the appropriate family
+ * and required number of significant bits
+ *
+ * numbits can be null, in which case the mask is fully set.
+ *
+ * The resulting mask is placed in *mask, which had better be big enough.
+ *
+ * Return value is 0 if okay, -1 if not.
+ */
+int
+pg_sockaddr_cidr_mask(struct sockaddr_storage *mask, char *numbits, int family)
+{
+ long bits;
+ char *endptr;
+
+ if (numbits == NULL)
+ {
+ bits = (family == AF_INET) ? 32 : 128;
+ }
+ else
+ {
+ bits = strtol(numbits, &endptr, 10);
+ if (*numbits == '\0' || *endptr != '\0')
+ return -1;
+ }
+
+ switch (family)
+ {
+ case AF_INET:
+ {
+ struct sockaddr_in mask4;
+ long maskl;
+
+ if (bits < 0 || bits > 32)
+ return -1;
+ memset(&mask4, 0, sizeof(mask4));
+ /* avoid "x << 32", which is not portable */
+ if (bits > 0)
+ maskl = (0xffffffffUL << (32 - (int) bits))
+ & 0xffffffffUL;
+ else
+ maskl = 0;
+ mask4.sin_addr.s_addr = pg_hton32(maskl);
+ memcpy(mask, &mask4, sizeof(mask4));
+ break;
+ }
+
+#ifdef HAVE_IPV6
+ case AF_INET6:
+ {
+ struct sockaddr_in6 mask6;
+ int i;
+
+ if (bits < 0 || bits > 128)
+ return -1;
+ memset(&mask6, 0, sizeof(mask6));
+ for (i = 0; i < 16; i++)
+ {
+ if (bits <= 0)
+ mask6.sin6_addr.s6_addr[i] = 0;
+ else if (bits >= 8)
+ mask6.sin6_addr.s6_addr[i] = 0xff;
+ else
+ {
+ mask6.sin6_addr.s6_addr[i] =
+ (0xff << (8 - (int) bits)) & 0xff;
+ }
+ bits -= 8;
+ }
+ memcpy(mask, &mask6, sizeof(mask6));
+ break;
+ }
+#endif
+ default:
+ return -1;
+ }
+
+ mask->ss_family = family;
+ return 0;
+}
+
+
+/*
+ * Run the callback function for the addr/mask, after making sure the
+ * mask is sane for the addr.
+ */
+static void
+run_ifaddr_callback(PgIfAddrCallback callback, void *cb_data,
+ struct sockaddr *addr, struct sockaddr *mask)
+{
+ struct sockaddr_storage fullmask;
+
+ if (!addr)
+ return;
+
+ /* Check that the mask is valid */
+ if (mask)
+ {
+ if (mask->sa_family != addr->sa_family)
+ {
+ mask = NULL;
+ }
+ else if (mask->sa_family == AF_INET)
+ {
+ if (((struct sockaddr_in *) mask)->sin_addr.s_addr == INADDR_ANY)
+ mask = NULL;
+ }
+#ifdef HAVE_IPV6
+ else if (mask->sa_family == AF_INET6)
+ {
+ if (IN6_IS_ADDR_UNSPECIFIED(&((struct sockaddr_in6 *) mask)->sin6_addr))
+ mask = NULL;
+ }
+#endif
+ }
+
+ /* If mask is invalid, generate our own fully-set mask */
+ if (!mask)
+ {
+ pg_sockaddr_cidr_mask(&fullmask, NULL, addr->sa_family);
+ mask = (struct sockaddr *) &fullmask;
+ }
+
+ (*callback) (addr, mask, cb_data);
+}
+
+#ifdef WIN32
+
+#include <winsock2.h>
+#include <ws2tcpip.h>
+
+/*
+ * Enumerate the system's network interface addresses and call the callback
+ * for each one. Returns 0 if successful, -1 if trouble.
+ *
+ * This version is for Win32. Uses the Winsock 2 functions (ie: ws2_32.dll)
+ */
+int
+pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data)
+{
+ INTERFACE_INFO *ptr,
+ *ii = NULL;
+ unsigned long length,
+ i;
+ unsigned long n_ii = 0;
+ SOCKET sock;
+ int error;
+
+ sock = WSASocket(AF_INET, SOCK_DGRAM, 0, 0, 0, 0);
+ if (sock == INVALID_SOCKET)
+ return -1;
+
+ while (n_ii < 1024)
+ {
+ n_ii += 64;
+ ptr = realloc(ii, sizeof(INTERFACE_INFO) * n_ii);
+ if (!ptr)
+ {
+ free(ii);
+ closesocket(sock);
+ errno = ENOMEM;
+ return -1;
+ }
+
+ ii = ptr;
+ if (WSAIoctl(sock, SIO_GET_INTERFACE_LIST, 0, 0,
+ ii, n_ii * sizeof(INTERFACE_INFO),
+ &length, 0, 0) == SOCKET_ERROR)
+ {
+ error = WSAGetLastError();
+ if (error == WSAEFAULT || error == WSAENOBUFS)
+ continue; /* need to make the buffer bigger */
+ closesocket(sock);
+ free(ii);
+ return -1;
+ }
+
+ break;
+ }
+
+ for (i = 0; i < length / sizeof(INTERFACE_INFO); ++i)
+ run_ifaddr_callback(callback, cb_data,
+ (struct sockaddr *) &ii[i].iiAddress,
+ (struct sockaddr *) &ii[i].iiNetmask);
+
+ closesocket(sock);
+ free(ii);
+ return 0;
+}
+#elif HAVE_GETIFADDRS /* && !WIN32 */
+
+#ifdef HAVE_IFADDRS_H
+#include <ifaddrs.h>
+#endif
+
+/*
+ * Enumerate the system's network interface addresses and call the callback
+ * for each one. Returns 0 if successful, -1 if trouble.
+ *
+ * This version uses the getifaddrs() interface, which is available on
+ * BSDs, AIX, and modern Linux.
+ */
+int
+pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data)
+{
+ struct ifaddrs *ifa,
+ *l;
+
+ if (getifaddrs(&ifa) < 0)
+ return -1;
+
+ for (l = ifa; l; l = l->ifa_next)
+ run_ifaddr_callback(callback, cb_data,
+ l->ifa_addr, l->ifa_netmask);
+
+ freeifaddrs(ifa);
+ return 0;
+}
+#else /* !HAVE_GETIFADDRS && !WIN32 */
+
+#include <sys/ioctl.h>
+
+#ifdef HAVE_NET_IF_H
+#include <net/if.h>
+#endif
+
+#ifdef HAVE_SYS_SOCKIO_H
+#include <sys/sockio.h>
+#endif
+
+/*
+ * SIOCGIFCONF does not return IPv6 addresses on Solaris
+ * and HP/UX. So we prefer SIOCGLIFCONF if it's available.
+ *
+ * On HP/UX, however, it *only* returns IPv6 addresses,
+ * and the structs are named slightly differently too.
+ * We'd have to do another call with SIOCGIFCONF to get the
+ * IPv4 addresses as well. We don't currently bother, just
+ * fall back to SIOCGIFCONF on HP/UX.
+ */
+
+#if defined(SIOCGLIFCONF) && !defined(__hpux)
+
+/*
+ * Enumerate the system's network interface addresses and call the callback
+ * for each one. Returns 0 if successful, -1 if trouble.
+ *
+ * This version uses ioctl(SIOCGLIFCONF).
+ */
+int
+pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data)
+{
+ struct lifconf lifc;
+ struct lifreq *lifr,
+ lmask;
+ struct sockaddr *addr,
+ *mask;
+ char *ptr,
+ *buffer = NULL;
+ size_t n_buffer = 1024;
+ pgsocket sock,
+ fd;
+
+#ifdef HAVE_IPV6
+ pgsocket sock6;
+#endif
+ int i,
+ total;
+
+ sock = socket(AF_INET, SOCK_DGRAM, 0);
+ if (sock == PGINVALID_SOCKET)
+ return -1;
+
+ while (n_buffer < 1024 * 100)
+ {
+ n_buffer += 1024;
+ ptr = realloc(buffer, n_buffer);
+ if (!ptr)
+ {
+ free(buffer);
+ close(sock);
+ errno = ENOMEM;
+ return -1;
+ }
+
+ memset(&lifc, 0, sizeof(lifc));
+ lifc.lifc_family = AF_UNSPEC;
+ lifc.lifc_buf = buffer = ptr;
+ lifc.lifc_len = n_buffer;
+
+ if (ioctl(sock, SIOCGLIFCONF, &lifc) < 0)
+ {
+ if (errno == EINVAL)
+ continue;
+ free(buffer);
+ close(sock);
+ return -1;
+ }
+
+ /*
+ * Some Unixes try to return as much data as possible, with no
+ * indication of whether enough space allocated. Don't believe we have
+ * it all unless there's lots of slop.
+ */
+ if (lifc.lifc_len < n_buffer - 1024)
+ break;
+ }
+
+#ifdef HAVE_IPV6
+ /* We'll need an IPv6 socket too for the SIOCGLIFNETMASK ioctls */
+ sock6 = socket(AF_INET6, SOCK_DGRAM, 0);
+ if (sock6 == PGINVALID_SOCKET)
+ {
+ free(buffer);
+ close(sock);
+ return -1;
+ }
+#endif
+
+ total = lifc.lifc_len / sizeof(struct lifreq);
+ lifr = lifc.lifc_req;
+ for (i = 0; i < total; ++i)
+ {
+ addr = (struct sockaddr *) &lifr[i].lifr_addr;
+ memcpy(&lmask, &lifr[i], sizeof(struct lifreq));
+#ifdef HAVE_IPV6
+ fd = (addr->sa_family == AF_INET6) ? sock6 : sock;
+#else
+ fd = sock;
+#endif
+ if (ioctl(fd, SIOCGLIFNETMASK, &lmask) < 0)
+ mask = NULL;
+ else
+ mask = (struct sockaddr *) &lmask.lifr_addr;
+ run_ifaddr_callback(callback, cb_data, addr, mask);
+ }
+
+ free(buffer);
+ close(sock);
+#ifdef HAVE_IPV6
+ close(sock6);
+#endif
+ return 0;
+}
+#elif defined(SIOCGIFCONF)
+
+/*
+ * Remaining Unixes use SIOCGIFCONF. Some only return IPv4 information
+ * here, so this is the least preferred method. Note that there is no
+ * standard way to iterate the struct ifreq returned in the array.
+ * On some OSs the structures are padded large enough for any address,
+ * on others you have to calculate the size of the struct ifreq.
+ */
+
+/* Some OSs have _SIZEOF_ADDR_IFREQ, so just use that */
+#ifndef _SIZEOF_ADDR_IFREQ
+
+/* Calculate based on sockaddr.sa_len */
+#ifdef HAVE_STRUCT_SOCKADDR_SA_LEN
+#define _SIZEOF_ADDR_IFREQ(ifr) \
+ ((ifr).ifr_addr.sa_len > sizeof(struct sockaddr) ? \
+ (sizeof(struct ifreq) - sizeof(struct sockaddr) + \
+ (ifr).ifr_addr.sa_len) : sizeof(struct ifreq))
+
+/* Padded ifreq structure, simple */
+#else
+#define _SIZEOF_ADDR_IFREQ(ifr) \
+ sizeof (struct ifreq)
+#endif
+#endif /* !_SIZEOF_ADDR_IFREQ */
+
+/*
+ * Enumerate the system's network interface addresses and call the callback
+ * for each one. Returns 0 if successful, -1 if trouble.
+ *
+ * This version uses ioctl(SIOCGIFCONF).
+ */
+int
+pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data)
+{
+ struct ifconf ifc;
+ struct ifreq *ifr,
+ *end,
+ addr,
+ mask;
+ char *ptr,
+ *buffer = NULL;
+ size_t n_buffer = 1024;
+ pgsocket sock;
+
+ sock = socket(AF_INET, SOCK_DGRAM, 0);
+ if (sock == PGINVALID_SOCKET)
+ return -1;
+
+ while (n_buffer < 1024 * 100)
+ {
+ n_buffer += 1024;
+ ptr = realloc(buffer, n_buffer);
+ if (!ptr)
+ {
+ free(buffer);
+ close(sock);
+ errno = ENOMEM;
+ return -1;
+ }
+
+ memset(&ifc, 0, sizeof(ifc));
+ ifc.ifc_buf = buffer = ptr;
+ ifc.ifc_len = n_buffer;
+
+ if (ioctl(sock, SIOCGIFCONF, &ifc) < 0)
+ {
+ if (errno == EINVAL)
+ continue;
+ free(buffer);
+ close(sock);
+ return -1;
+ }
+
+ /*
+ * Some Unixes try to return as much data as possible, with no
+ * indication of whether enough space allocated. Don't believe we have
+ * it all unless there's lots of slop.
+ */
+ if (ifc.ifc_len < n_buffer - 1024)
+ break;
+ }
+
+ end = (struct ifreq *) (buffer + ifc.ifc_len);
+ for (ifr = ifc.ifc_req; ifr < end;)
+ {
+ memcpy(&addr, ifr, sizeof(addr));
+ memcpy(&mask, ifr, sizeof(mask));
+ if (ioctl(sock, SIOCGIFADDR, &addr, sizeof(addr)) == 0 &&
+ ioctl(sock, SIOCGIFNETMASK, &mask, sizeof(mask)) == 0)
+ run_ifaddr_callback(callback, cb_data,
+ &addr.ifr_addr, &mask.ifr_addr);
+ ifr = (struct ifreq *) ((char *) ifr + _SIZEOF_ADDR_IFREQ(*ifr));
+ }
+
+ free(buffer);
+ close(sock);
+ return 0;
+}
+#else /* !defined(SIOCGIFCONF) */
+
+/*
+ * Enumerate the system's network interface addresses and call the callback
+ * for each one. Returns 0 if successful, -1 if trouble.
+ *
+ * This version is our fallback if there's no known way to get the
+ * interface addresses. Just return the standard loopback addresses.
+ */
+int
+pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data)
+{
+ struct sockaddr_in addr;
+ struct sockaddr_storage mask;
+
+#ifdef HAVE_IPV6
+ struct sockaddr_in6 addr6;
+#endif
+
+ /* addr 127.0.0.1/8 */
+ memset(&addr, 0, sizeof(addr));
+ addr.sin_family = AF_INET;
+ addr.sin_addr.s_addr = pg_ntoh32(0x7f000001);
+ memset(&mask, 0, sizeof(mask));
+ pg_sockaddr_cidr_mask(&mask, "8", AF_INET);
+ run_ifaddr_callback(callback, cb_data,
+ (struct sockaddr *) &addr,
+ (struct sockaddr *) &mask);
+
+#ifdef HAVE_IPV6
+ /* addr ::1/128 */
+ memset(&addr6, 0, sizeof(addr6));
+ addr6.sin6_family = AF_INET6;
+ addr6.sin6_addr.s6_addr[15] = 1;
+ memset(&mask, 0, sizeof(mask));
+ pg_sockaddr_cidr_mask(&mask, "128", AF_INET6);
+ run_ifaddr_callback(callback, cb_data,
+ (struct sockaddr *) &addr6,
+ (struct sockaddr *) &mask);
+#endif
+
+ return 0;
+}
+#endif /* !defined(SIOCGIFCONF) */
+
+#endif /* !HAVE_GETIFADDRS */
diff --git a/src/backend/libpq/pg_hba.conf.sample b/src/backend/libpq/pg_hba.conf.sample
new file mode 100644
index 0000000..5f3f63e
--- /dev/null
+++ b/src/backend/libpq/pg_hba.conf.sample
@@ -0,0 +1,94 @@
+# PostgreSQL Client Authentication Configuration File
+# ===================================================
+#
+# Refer to the "Client Authentication" section in the PostgreSQL
+# documentation for a complete description of this file. A short
+# synopsis follows.
+#
+# This file controls: which hosts are allowed to connect, how clients
+# are authenticated, which PostgreSQL user names they can use, which
+# databases they can access. Records take one of these forms:
+#
+# local DATABASE USER METHOD [OPTIONS]
+# host DATABASE USER ADDRESS METHOD [OPTIONS]
+# hostssl DATABASE USER ADDRESS METHOD [OPTIONS]
+# hostnossl DATABASE USER ADDRESS METHOD [OPTIONS]
+# hostgssenc DATABASE USER ADDRESS METHOD [OPTIONS]
+# hostnogssenc DATABASE USER ADDRESS METHOD [OPTIONS]
+#
+# (The uppercase items must be replaced by actual values.)
+#
+# The first field is the connection type:
+# - "local" is a Unix-domain socket
+# - "host" is a TCP/IP socket (encrypted or not)
+# - "hostssl" is a TCP/IP socket that is SSL-encrypted
+# - "hostnossl" is a TCP/IP socket that is not SSL-encrypted
+# - "hostgssenc" is a TCP/IP socket that is GSSAPI-encrypted
+# - "hostnogssenc" is a TCP/IP socket that is not GSSAPI-encrypted
+#
+# DATABASE can be "all", "sameuser", "samerole", "replication", a
+# database name, or a comma-separated list thereof. The "all"
+# keyword does not match "replication". Access to replication
+# must be enabled in a separate record (see example below).
+#
+# USER can be "all", a user name, a group name prefixed with "+", or a
+# comma-separated list thereof. In both the DATABASE and USER fields
+# you can also write a file name prefixed with "@" to include names
+# from a separate file.
+#
+# ADDRESS specifies the set of hosts the record matches. It can be a
+# host name, or it is made up of an IP address and a CIDR mask that is
+# an integer (between 0 and 32 (IPv4) or 128 (IPv6) inclusive) that
+# specifies the number of significant bits in the mask. A host name
+# that starts with a dot (.) matches a suffix of the actual host name.
+# Alternatively, you can write an IP address and netmask in separate
+# columns to specify the set of hosts. Instead of a CIDR-address, you
+# can write "samehost" to match any of the server's own IP addresses,
+# or "samenet" to match any address in any subnet that the server is
+# directly connected to.
+#
+# METHOD can be "trust", "reject", "md5", "password", "scram-sha-256",
+# "gss", "sspi", "ident", "peer", "pam", "ldap", "radius" or "cert".
+# Note that "password" sends passwords in clear text; "md5" or
+# "scram-sha-256" are preferred since they send encrypted passwords.
+#
+# OPTIONS are a set of options for the authentication in the format
+# NAME=VALUE. The available options depend on the different
+# authentication methods -- refer to the "Client Authentication"
+# section in the documentation for a list of which options are
+# available for which authentication methods.
+#
+# Database and user names containing spaces, commas, quotes and other
+# special characters must be quoted. Quoting one of the keywords
+# "all", "sameuser", "samerole" or "replication" makes the name lose
+# its special character, and just match a database or username with
+# that name.
+#
+# This file is read on server startup and when the server receives a
+# SIGHUP signal. If you edit the file on a running system, you have to
+# SIGHUP the server for the changes to take effect, run "pg_ctl reload",
+# or execute "SELECT pg_reload_conf()".
+#
+# Put your actual configuration here
+# ----------------------------------
+#
+# If you want to allow non-local connections, you need to add more
+# "host" records. In that case you will also need to make PostgreSQL
+# listen on a non-local interface via the listen_addresses
+# configuration parameter, or via the -i or -h command line switches.
+
+@authcomment@
+
+# TYPE DATABASE USER ADDRESS METHOD
+
+@remove-line-for-nolocal@# "local" is for Unix domain socket connections only
+@remove-line-for-nolocal@local all all @authmethodlocal@
+# IPv4 local connections:
+host all all 127.0.0.1/32 @authmethodhost@
+# IPv6 local connections:
+host all all ::1/128 @authmethodhost@
+# Allow replication connections from localhost, by a user with the
+# replication privilege.
+@remove-line-for-nolocal@local replication all @authmethodlocal@
+host replication all 127.0.0.1/32 @authmethodhost@
+host replication all ::1/128 @authmethodhost@
diff --git a/src/backend/libpq/pg_ident.conf.sample b/src/backend/libpq/pg_ident.conf.sample
new file mode 100644
index 0000000..a5870e6
--- /dev/null
+++ b/src/backend/libpq/pg_ident.conf.sample
@@ -0,0 +1,42 @@
+# PostgreSQL User Name Maps
+# =========================
+#
+# Refer to the PostgreSQL documentation, chapter "Client
+# Authentication" for a complete description. A short synopsis
+# follows.
+#
+# This file controls PostgreSQL user name mapping. It maps external
+# user names to their corresponding PostgreSQL user names. Records
+# are of the form:
+#
+# MAPNAME SYSTEM-USERNAME PG-USERNAME
+#
+# (The uppercase quantities must be replaced by actual values.)
+#
+# MAPNAME is the (otherwise freely chosen) map name that was used in
+# pg_hba.conf. SYSTEM-USERNAME is the detected user name of the
+# client. PG-USERNAME is the requested PostgreSQL user name. The
+# existence of a record specifies that SYSTEM-USERNAME may connect as
+# PG-USERNAME.
+#
+# If SYSTEM-USERNAME starts with a slash (/), it will be treated as a
+# regular expression. Optionally this can contain a capture (a
+# parenthesized subexpression). The substring matching the capture
+# will be substituted for \1 (backslash-one) if present in
+# PG-USERNAME.
+#
+# Multiple maps may be specified in this file and used by pg_hba.conf.
+#
+# No map names are defined in the default configuration. If all
+# system user names and PostgreSQL user names are the same, you don't
+# need anything in this file.
+#
+# This file is read on server startup and when the postmaster receives
+# a SIGHUP signal. If you edit the file on a running system, you have
+# to SIGHUP the postmaster for the changes to take effect. You can
+# use "pg_ctl reload" to do that.
+
+# Put your actual configuration here
+# ----------------------------------
+
+# MAPNAME SYSTEM-USERNAME PG-USERNAME
diff --git a/src/backend/libpq/pqcomm.c b/src/backend/libpq/pqcomm.c
new file mode 100644
index 0000000..44782f2
--- /dev/null
+++ b/src/backend/libpq/pqcomm.c
@@ -0,0 +1,1976 @@
+/*-------------------------------------------------------------------------
+ *
+ * pqcomm.c
+ * Communication functions between the Frontend and the Backend
+ *
+ * These routines handle the low-level details of communication between
+ * frontend and backend. They just shove data across the communication
+ * channel, and are ignorant of the semantics of the data.
+ *
+ * To emit an outgoing message, use the routines in pqformat.c to construct
+ * the message in a buffer and then emit it in one call to pq_putmessage.
+ * There are no functions to send raw bytes or partial messages; this
+ * ensures that the channel will not be clogged by an incomplete message if
+ * execution is aborted by ereport(ERROR) partway through the message.
+ *
+ * At one time, libpq was shared between frontend and backend, but now
+ * the backend's "backend/libpq" is quite separate from "interfaces/libpq".
+ * All that remains is similarities of names to trap the unwary...
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/backend/libpq/pqcomm.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+/*------------------------
+ * INTERFACE ROUTINES
+ *
+ * setup/teardown:
+ * StreamServerPort - Open postmaster's server port
+ * StreamConnection - Create new connection with client
+ * StreamClose - Close a client/backend connection
+ * TouchSocketFiles - Protect socket files against /tmp cleaners
+ * pq_init - initialize libpq at backend startup
+ * socket_comm_reset - reset libpq during error recovery
+ * socket_close - shutdown libpq at backend exit
+ *
+ * low-level I/O:
+ * pq_getbytes - get a known number of bytes from connection
+ * pq_getmessage - get a message with length word from connection
+ * pq_getbyte - get next byte from connection
+ * pq_peekbyte - peek at next byte from connection
+ * pq_flush - flush pending output
+ * pq_flush_if_writable - flush pending output if writable without blocking
+ * pq_getbyte_if_available - get a byte if available without blocking
+ *
+ * message-level I/O
+ * pq_putmessage - send a normal message (suppressed in COPY OUT mode)
+ * pq_putmessage_noblock - buffer a normal message (suppressed in COPY OUT)
+ *
+ *------------------------
+ */
+#include "postgres.h"
+
+#ifdef HAVE_POLL_H
+#include <poll.h>
+#endif
+#include <signal.h>
+#include <fcntl.h>
+#include <grp.h>
+#include <unistd.h>
+#include <sys/file.h>
+#include <sys/socket.h>
+#include <sys/stat.h>
+#include <sys/time.h>
+#include <netdb.h>
+#include <netinet/in.h>
+#ifdef HAVE_NETINET_TCP_H
+#include <netinet/tcp.h>
+#endif
+#include <utime.h>
+#ifdef _MSC_VER /* mstcpip.h is missing on mingw */
+#include <mstcpip.h>
+#endif
+
+#include "common/ip.h"
+#include "libpq/libpq.h"
+#include "miscadmin.h"
+#include "port/pg_bswap.h"
+#include "storage/ipc.h"
+#include "utils/guc.h"
+#include "utils/memutils.h"
+
+/*
+ * Cope with the various platform-specific ways to spell TCP keepalive socket
+ * options. This doesn't cover Windows, which as usual does its own thing.
+ */
+#if defined(TCP_KEEPIDLE)
+/* TCP_KEEPIDLE is the name of this option on Linux and *BSD */
+#define PG_TCP_KEEPALIVE_IDLE TCP_KEEPIDLE
+#define PG_TCP_KEEPALIVE_IDLE_STR "TCP_KEEPIDLE"
+#elif defined(TCP_KEEPALIVE_THRESHOLD)
+/* TCP_KEEPALIVE_THRESHOLD is the name of this option on Solaris >= 11 */
+#define PG_TCP_KEEPALIVE_IDLE TCP_KEEPALIVE_THRESHOLD
+#define PG_TCP_KEEPALIVE_IDLE_STR "TCP_KEEPALIVE_THRESHOLD"
+#elif defined(TCP_KEEPALIVE) && defined(__darwin__)
+/* TCP_KEEPALIVE is the name of this option on macOS */
+/* Caution: Solaris has this symbol but it means something different */
+#define PG_TCP_KEEPALIVE_IDLE TCP_KEEPALIVE
+#define PG_TCP_KEEPALIVE_IDLE_STR "TCP_KEEPALIVE"
+#endif
+
+/*
+ * Configuration options
+ */
+int Unix_socket_permissions;
+char *Unix_socket_group;
+
+/* Where the Unix socket files are (list of palloc'd strings) */
+static List *sock_paths = NIL;
+
+/*
+ * Buffers for low-level I/O.
+ *
+ * The receive buffer is fixed size. Send buffer is usually 8k, but can be
+ * enlarged by pq_putmessage_noblock() if the message doesn't fit otherwise.
+ */
+
+#define PQ_SEND_BUFFER_SIZE 8192
+#define PQ_RECV_BUFFER_SIZE 8192
+
+static char *PqSendBuffer;
+static int PqSendBufferSize; /* Size send buffer */
+static int PqSendPointer; /* Next index to store a byte in PqSendBuffer */
+static int PqSendStart; /* Next index to send a byte in PqSendBuffer */
+
+static char PqRecvBuffer[PQ_RECV_BUFFER_SIZE];
+static int PqRecvPointer; /* Next index to read a byte from PqRecvBuffer */
+static int PqRecvLength; /* End of data available in PqRecvBuffer */
+
+/*
+ * Message status
+ */
+static bool PqCommBusy; /* busy sending data to the client */
+static bool PqCommReadingMsg; /* in the middle of reading a message */
+
+
+/* Internal functions */
+static void socket_comm_reset(void);
+static void socket_close(int code, Datum arg);
+static void socket_set_nonblocking(bool nonblocking);
+static int socket_flush(void);
+static int socket_flush_if_writable(void);
+static bool socket_is_send_pending(void);
+static int socket_putmessage(char msgtype, const char *s, size_t len);
+static void socket_putmessage_noblock(char msgtype, const char *s, size_t len);
+static int internal_putbytes(const char *s, size_t len);
+static int internal_flush(void);
+
+#ifdef HAVE_UNIX_SOCKETS
+static int Lock_AF_UNIX(const char *unixSocketDir, const char *unixSocketPath);
+static int Setup_AF_UNIX(const char *sock_path);
+#endif /* HAVE_UNIX_SOCKETS */
+
+static const PQcommMethods PqCommSocketMethods = {
+ socket_comm_reset,
+ socket_flush,
+ socket_flush_if_writable,
+ socket_is_send_pending,
+ socket_putmessage,
+ socket_putmessage_noblock
+};
+
+const PQcommMethods *PqCommMethods = &PqCommSocketMethods;
+
+WaitEventSet *FeBeWaitSet;
+
+
+/* --------------------------------
+ * pq_init - initialize libpq at backend startup
+ * --------------------------------
+ */
+void
+pq_init(void)
+{
+ int socket_pos PG_USED_FOR_ASSERTS_ONLY;
+ int latch_pos PG_USED_FOR_ASSERTS_ONLY;
+
+ /* initialize state variables */
+ PqSendBufferSize = PQ_SEND_BUFFER_SIZE;
+ PqSendBuffer = MemoryContextAlloc(TopMemoryContext, PqSendBufferSize);
+ PqSendPointer = PqSendStart = PqRecvPointer = PqRecvLength = 0;
+ PqCommBusy = false;
+ PqCommReadingMsg = false;
+
+ /* set up process-exit hook to close the socket */
+ on_proc_exit(socket_close, 0);
+
+ /*
+ * In backends (as soon as forked) we operate the underlying socket in
+ * nonblocking mode and use latches to implement blocking semantics if
+ * needed. That allows us to provide safely interruptible reads and
+ * writes.
+ *
+ * Use COMMERROR on failure, because ERROR would try to send the error to
+ * the client, which might require changing the mode again, leading to
+ * infinite recursion.
+ */
+#ifndef WIN32
+ if (!pg_set_noblock(MyProcPort->sock))
+ ereport(COMMERROR,
+ (errmsg("could not set socket to nonblocking mode: %m")));
+#endif
+
+ FeBeWaitSet = CreateWaitEventSet(TopMemoryContext, 3);
+ socket_pos = AddWaitEventToSet(FeBeWaitSet, WL_SOCKET_WRITEABLE,
+ MyProcPort->sock, NULL, NULL);
+ latch_pos = AddWaitEventToSet(FeBeWaitSet, WL_LATCH_SET, PGINVALID_SOCKET,
+ MyLatch, NULL);
+ AddWaitEventToSet(FeBeWaitSet, WL_POSTMASTER_DEATH, PGINVALID_SOCKET,
+ NULL, NULL);
+
+ /*
+ * The event positions match the order we added them, but let's sanity
+ * check them to be sure.
+ */
+ Assert(socket_pos == FeBeWaitSetSocketPos);
+ Assert(latch_pos == FeBeWaitSetLatchPos);
+}
+
+/* --------------------------------
+ * socket_comm_reset - reset libpq during error recovery
+ *
+ * This is called from error recovery at the outer idle loop. It's
+ * just to get us out of trouble if we somehow manage to elog() from
+ * inside a pqcomm.c routine (which ideally will never happen, but...)
+ * --------------------------------
+ */
+static void
+socket_comm_reset(void)
+{
+ /* Do not throw away pending data, but do reset the busy flag */
+ PqCommBusy = false;
+}
+
+/* --------------------------------
+ * socket_close - shutdown libpq at backend exit
+ *
+ * This is the one pg_on_exit_callback in place during BackendInitialize().
+ * That function's unusual signal handling constrains that this callback be
+ * safe to run at any instant.
+ * --------------------------------
+ */
+static void
+socket_close(int code, Datum arg)
+{
+ /* Nothing to do in a standalone backend, where MyProcPort is NULL. */
+ if (MyProcPort != NULL)
+ {
+#ifdef ENABLE_GSS
+ /*
+ * Shutdown GSSAPI layer. This section does nothing when interrupting
+ * BackendInitialize(), because pg_GSS_recvauth() makes first use of
+ * "ctx" and "cred".
+ *
+ * Note that we don't bother to free MyProcPort->gss, since we're
+ * about to exit anyway.
+ */
+ if (MyProcPort->gss)
+ {
+ OM_uint32 min_s;
+
+ if (MyProcPort->gss->ctx != GSS_C_NO_CONTEXT)
+ gss_delete_sec_context(&min_s, &MyProcPort->gss->ctx, NULL);
+
+ if (MyProcPort->gss->cred != GSS_C_NO_CREDENTIAL)
+ gss_release_cred(&min_s, &MyProcPort->gss->cred);
+ }
+#endif /* ENABLE_GSS */
+
+ /*
+ * Cleanly shut down SSL layer. Nowhere else does a postmaster child
+ * call this, so this is safe when interrupting BackendInitialize().
+ */
+ secure_close(MyProcPort);
+
+ /*
+ * Formerly we did an explicit close() here, but it seems better to
+ * leave the socket open until the process dies. This allows clients
+ * to perform a "synchronous close" if they care --- wait till the
+ * transport layer reports connection closure, and you can be sure the
+ * backend has exited.
+ *
+ * We do set sock to PGINVALID_SOCKET to prevent any further I/O,
+ * though.
+ */
+ MyProcPort->sock = PGINVALID_SOCKET;
+ }
+}
+
+
+
+/*
+ * Streams -- wrapper around Unix socket system calls
+ *
+ *
+ * Stream functions are used for vanilla TCP connection protocol.
+ */
+
+
+/*
+ * StreamServerPort -- open a "listening" port to accept connections.
+ *
+ * family should be AF_UNIX or AF_UNSPEC; portNumber is the port number.
+ * For AF_UNIX ports, hostName should be NULL and unixSocketDir must be
+ * specified. For TCP ports, hostName is either NULL for all interfaces or
+ * the interface to listen on, and unixSocketDir is ignored (can be NULL).
+ *
+ * Successfully opened sockets are added to the ListenSocket[] array (of
+ * length MaxListen), at the first position that isn't PGINVALID_SOCKET.
+ *
+ * RETURNS: STATUS_OK or STATUS_ERROR
+ */
+
+int
+StreamServerPort(int family, const char *hostName, unsigned short portNumber,
+ const char *unixSocketDir,
+ pgsocket ListenSocket[], int MaxListen)
+{
+ pgsocket fd;
+ int err;
+ int maxconn;
+ int ret;
+ char portNumberStr[32];
+ const char *familyDesc;
+ char familyDescBuf[64];
+ const char *addrDesc;
+ char addrBuf[NI_MAXHOST];
+ char *service;
+ struct addrinfo *addrs = NULL,
+ *addr;
+ struct addrinfo hint;
+ int listen_index = 0;
+ int added = 0;
+
+#ifdef HAVE_UNIX_SOCKETS
+ char unixSocketPath[MAXPGPATH];
+#endif
+#if !defined(WIN32) || defined(IPV6_V6ONLY)
+ int one = 1;
+#endif
+
+ /* Initialize hint structure */
+ MemSet(&hint, 0, sizeof(hint));
+ hint.ai_family = family;
+ hint.ai_flags = AI_PASSIVE;
+ hint.ai_socktype = SOCK_STREAM;
+
+#ifdef HAVE_UNIX_SOCKETS
+ if (family == AF_UNIX)
+ {
+ /*
+ * Create unixSocketPath from portNumber and unixSocketDir and lock
+ * that file path
+ */
+ UNIXSOCK_PATH(unixSocketPath, portNumber, unixSocketDir);
+ if (strlen(unixSocketPath) >= UNIXSOCK_PATH_BUFLEN)
+ {
+ ereport(LOG,
+ (errmsg("Unix-domain socket path \"%s\" is too long (maximum %d bytes)",
+ unixSocketPath,
+ (int) (UNIXSOCK_PATH_BUFLEN - 1))));
+ return STATUS_ERROR;
+ }
+ if (Lock_AF_UNIX(unixSocketDir, unixSocketPath) != STATUS_OK)
+ return STATUS_ERROR;
+ service = unixSocketPath;
+ }
+ else
+#endif /* HAVE_UNIX_SOCKETS */
+ {
+ snprintf(portNumberStr, sizeof(portNumberStr), "%d", portNumber);
+ service = portNumberStr;
+ }
+
+ ret = pg_getaddrinfo_all(hostName, service, &hint, &addrs);
+ if (ret || !addrs)
+ {
+ if (hostName)
+ ereport(LOG,
+ (errmsg("could not translate host name \"%s\", service \"%s\" to address: %s",
+ hostName, service, gai_strerror(ret))));
+ else
+ ereport(LOG,
+ (errmsg("could not translate service \"%s\" to address: %s",
+ service, gai_strerror(ret))));
+ if (addrs)
+ pg_freeaddrinfo_all(hint.ai_family, addrs);
+ return STATUS_ERROR;
+ }
+
+ for (addr = addrs; addr; addr = addr->ai_next)
+ {
+ if (!IS_AF_UNIX(family) && IS_AF_UNIX(addr->ai_family))
+ {
+ /*
+ * Only set up a unix domain socket when they really asked for it.
+ * The service/port is different in that case.
+ */
+ continue;
+ }
+
+ /* See if there is still room to add 1 more socket. */
+ for (; listen_index < MaxListen; listen_index++)
+ {
+ if (ListenSocket[listen_index] == PGINVALID_SOCKET)
+ break;
+ }
+ if (listen_index >= MaxListen)
+ {
+ ereport(LOG,
+ (errmsg("could not bind to all requested addresses: MAXLISTEN (%d) exceeded",
+ MaxListen)));
+ break;
+ }
+
+ /* set up address family name for log messages */
+ switch (addr->ai_family)
+ {
+ case AF_INET:
+ familyDesc = _("IPv4");
+ break;
+#ifdef HAVE_IPV6
+ case AF_INET6:
+ familyDesc = _("IPv6");
+ break;
+#endif
+#ifdef HAVE_UNIX_SOCKETS
+ case AF_UNIX:
+ familyDesc = _("Unix");
+ break;
+#endif
+ default:
+ snprintf(familyDescBuf, sizeof(familyDescBuf),
+ _("unrecognized address family %d"),
+ addr->ai_family);
+ familyDesc = familyDescBuf;
+ break;
+ }
+
+ /* set up text form of address for log messages */
+#ifdef HAVE_UNIX_SOCKETS
+ if (addr->ai_family == AF_UNIX)
+ addrDesc = unixSocketPath;
+ else
+#endif
+ {
+ pg_getnameinfo_all((const struct sockaddr_storage *) addr->ai_addr,
+ addr->ai_addrlen,
+ addrBuf, sizeof(addrBuf),
+ NULL, 0,
+ NI_NUMERICHOST);
+ addrDesc = addrBuf;
+ }
+
+ if ((fd = socket(addr->ai_family, SOCK_STREAM, 0)) == PGINVALID_SOCKET)
+ {
+ ereport(LOG,
+ (errcode_for_socket_access(),
+ /* translator: first %s is IPv4, IPv6, or Unix */
+ errmsg("could not create %s socket for address \"%s\": %m",
+ familyDesc, addrDesc)));
+ continue;
+ }
+
+#ifndef WIN32
+
+ /*
+ * Without the SO_REUSEADDR flag, a new postmaster can't be started
+ * right away after a stop or crash, giving "address already in use"
+ * error on TCP ports.
+ *
+ * On win32, however, this behavior only happens if the
+ * SO_EXCLUSIVEADDRUSE is set. With SO_REUSEADDR, win32 allows
+ * multiple servers to listen on the same address, resulting in
+ * unpredictable behavior. With no flags at all, win32 behaves as Unix
+ * with SO_REUSEADDR.
+ */
+ if (!IS_AF_UNIX(addr->ai_family))
+ {
+ if ((setsockopt(fd, SOL_SOCKET, SO_REUSEADDR,
+ (char *) &one, sizeof(one))) == -1)
+ {
+ ereport(LOG,
+ (errcode_for_socket_access(),
+ /* translator: third %s is IPv4, IPv6, or Unix */
+ errmsg("%s(%s) failed for %s address \"%s\": %m",
+ "setsockopt", "SO_REUSEADDR",
+ familyDesc, addrDesc)));
+ closesocket(fd);
+ continue;
+ }
+ }
+#endif
+
+#ifdef IPV6_V6ONLY
+ if (addr->ai_family == AF_INET6)
+ {
+ if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY,
+ (char *) &one, sizeof(one)) == -1)
+ {
+ ereport(LOG,
+ (errcode_for_socket_access(),
+ /* translator: third %s is IPv4, IPv6, or Unix */
+ errmsg("%s(%s) failed for %s address \"%s\": %m",
+ "setsockopt", "IPV6_V6ONLY",
+ familyDesc, addrDesc)));
+ closesocket(fd);
+ continue;
+ }
+ }
+#endif
+
+ /*
+ * Note: This might fail on some OS's, like Linux older than
+ * 2.4.21-pre3, that don't have the IPV6_V6ONLY socket option, and map
+ * ipv4 addresses to ipv6. It will show ::ffff:ipv4 for all ipv4
+ * connections.
+ */
+ err = bind(fd, addr->ai_addr, addr->ai_addrlen);
+ if (err < 0)
+ {
+ int saved_errno = errno;
+
+ ereport(LOG,
+ (errcode_for_socket_access(),
+ /* translator: first %s is IPv4, IPv6, or Unix */
+ errmsg("could not bind %s address \"%s\": %m",
+ familyDesc, addrDesc),
+ saved_errno == EADDRINUSE ?
+ (IS_AF_UNIX(addr->ai_family) ?
+ errhint("Is another postmaster already running on port %d?",
+ (int) portNumber) :
+ errhint("Is another postmaster already running on port %d?"
+ " If not, wait a few seconds and retry.",
+ (int) portNumber)) : 0));
+ closesocket(fd);
+ continue;
+ }
+
+#ifdef HAVE_UNIX_SOCKETS
+ if (addr->ai_family == AF_UNIX)
+ {
+ if (Setup_AF_UNIX(service) != STATUS_OK)
+ {
+ closesocket(fd);
+ break;
+ }
+ }
+#endif
+
+ /*
+ * Select appropriate accept-queue length limit. PG_SOMAXCONN is only
+ * intended to provide a clamp on the request on platforms where an
+ * overly large request provokes a kernel error (are there any?).
+ */
+ maxconn = MaxBackends * 2;
+ if (maxconn > PG_SOMAXCONN)
+ maxconn = PG_SOMAXCONN;
+
+ err = listen(fd, maxconn);
+ if (err < 0)
+ {
+ ereport(LOG,
+ (errcode_for_socket_access(),
+ /* translator: first %s is IPv4, IPv6, or Unix */
+ errmsg("could not listen on %s address \"%s\": %m",
+ familyDesc, addrDesc)));
+ closesocket(fd);
+ continue;
+ }
+
+#ifdef HAVE_UNIX_SOCKETS
+ if (addr->ai_family == AF_UNIX)
+ ereport(LOG,
+ (errmsg("listening on Unix socket \"%s\"",
+ addrDesc)));
+ else
+#endif
+ ereport(LOG,
+ /* translator: first %s is IPv4 or IPv6 */
+ (errmsg("listening on %s address \"%s\", port %d",
+ familyDesc, addrDesc, (int) portNumber)));
+
+ ListenSocket[listen_index] = fd;
+ added++;
+ }
+
+ pg_freeaddrinfo_all(hint.ai_family, addrs);
+
+ if (!added)
+ return STATUS_ERROR;
+
+ return STATUS_OK;
+}
+
+
+#ifdef HAVE_UNIX_SOCKETS
+
+/*
+ * Lock_AF_UNIX -- configure unix socket file path
+ */
+static int
+Lock_AF_UNIX(const char *unixSocketDir, const char *unixSocketPath)
+{
+ /* no lock file for abstract sockets */
+ if (unixSocketPath[0] == '@')
+ return STATUS_OK;
+
+ /*
+ * Grab an interlock file associated with the socket file.
+ *
+ * Note: there are two reasons for using a socket lock file, rather than
+ * trying to interlock directly on the socket itself. First, it's a lot
+ * more portable, and second, it lets us remove any pre-existing socket
+ * file without race conditions.
+ */
+ CreateSocketLockFile(unixSocketPath, true, unixSocketDir);
+
+ /*
+ * Once we have the interlock, we can safely delete any pre-existing
+ * socket file to avoid failure at bind() time.
+ */
+ (void) unlink(unixSocketPath);
+
+ /*
+ * Remember socket file pathnames for later maintenance.
+ */
+ sock_paths = lappend(sock_paths, pstrdup(unixSocketPath));
+
+ return STATUS_OK;
+}
+
+
+/*
+ * Setup_AF_UNIX -- configure unix socket permissions
+ */
+static int
+Setup_AF_UNIX(const char *sock_path)
+{
+ /* no file system permissions for abstract sockets */
+ if (sock_path[0] == '@')
+ return STATUS_OK;
+
+ /*
+ * Fix socket ownership/permission if requested. Note we must do this
+ * before we listen() to avoid a window where unwanted connections could
+ * get accepted.
+ */
+ Assert(Unix_socket_group);
+ if (Unix_socket_group[0] != '\0')
+ {
+#ifdef WIN32
+ elog(WARNING, "configuration item unix_socket_group is not supported on this platform");
+#else
+ char *endptr;
+ unsigned long val;
+ gid_t gid;
+
+ val = strtoul(Unix_socket_group, &endptr, 10);
+ if (*endptr == '\0')
+ { /* numeric group id */
+ gid = val;
+ }
+ else
+ { /* convert group name to id */
+ struct group *gr;
+
+ gr = getgrnam(Unix_socket_group);
+ if (!gr)
+ {
+ ereport(LOG,
+ (errmsg("group \"%s\" does not exist",
+ Unix_socket_group)));
+ return STATUS_ERROR;
+ }
+ gid = gr->gr_gid;
+ }
+ if (chown(sock_path, -1, gid) == -1)
+ {
+ ereport(LOG,
+ (errcode_for_file_access(),
+ errmsg("could not set group of file \"%s\": %m",
+ sock_path)));
+ return STATUS_ERROR;
+ }
+#endif
+ }
+
+ if (chmod(sock_path, Unix_socket_permissions) == -1)
+ {
+ ereport(LOG,
+ (errcode_for_file_access(),
+ errmsg("could not set permissions of file \"%s\": %m",
+ sock_path)));
+ return STATUS_ERROR;
+ }
+ return STATUS_OK;
+}
+#endif /* HAVE_UNIX_SOCKETS */
+
+
+/*
+ * StreamConnection -- create a new connection with client using
+ * server port. Set port->sock to the FD of the new connection.
+ *
+ * ASSUME: that this doesn't need to be non-blocking because
+ * the Postmaster uses select() to tell when the socket is ready for
+ * accept().
+ *
+ * RETURNS: STATUS_OK or STATUS_ERROR
+ */
+int
+StreamConnection(pgsocket server_fd, Port *port)
+{
+ /* accept connection and fill in the client (remote) address */
+ port->raddr.salen = sizeof(port->raddr.addr);
+ if ((port->sock = accept(server_fd,
+ (struct sockaddr *) &port->raddr.addr,
+ &port->raddr.salen)) == PGINVALID_SOCKET)
+ {
+ ereport(LOG,
+ (errcode_for_socket_access(),
+ errmsg("could not accept new connection: %m")));
+
+ /*
+ * If accept() fails then postmaster.c will still see the server
+ * socket as read-ready, and will immediately try again. To avoid
+ * uselessly sucking lots of CPU, delay a bit before trying again.
+ * (The most likely reason for failure is being out of kernel file
+ * table slots; we can do little except hope some will get freed up.)
+ */
+ pg_usleep(100000L); /* wait 0.1 sec */
+ return STATUS_ERROR;
+ }
+
+ /* fill in the server (local) address */
+ port->laddr.salen = sizeof(port->laddr.addr);
+ if (getsockname(port->sock,
+ (struct sockaddr *) &port->laddr.addr,
+ &port->laddr.salen) < 0)
+ {
+ ereport(LOG,
+ (errmsg("%s() failed: %m", "getsockname")));
+ return STATUS_ERROR;
+ }
+
+ /* select NODELAY and KEEPALIVE options if it's a TCP connection */
+ if (!IS_AF_UNIX(port->laddr.addr.ss_family))
+ {
+ int on;
+#ifdef WIN32
+ int oldopt;
+ int optlen;
+ int newopt;
+#endif
+
+#ifdef TCP_NODELAY
+ on = 1;
+ if (setsockopt(port->sock, IPPROTO_TCP, TCP_NODELAY,
+ (char *) &on, sizeof(on)) < 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) failed: %m", "setsockopt", "TCP_NODELAY")));
+ return STATUS_ERROR;
+ }
+#endif
+ on = 1;
+ if (setsockopt(port->sock, SOL_SOCKET, SO_KEEPALIVE,
+ (char *) &on, sizeof(on)) < 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) failed: %m", "setsockopt", "SO_KEEPALIVE")));
+ return STATUS_ERROR;
+ }
+
+#ifdef WIN32
+
+ /*
+ * This is a Win32 socket optimization. The OS send buffer should be
+ * large enough to send the whole Postgres send buffer in one go, or
+ * performance suffers. The Postgres send buffer can be enlarged if a
+ * very large message needs to be sent, but we won't attempt to
+ * enlarge the OS buffer if that happens, so somewhat arbitrarily
+ * ensure that the OS buffer is at least PQ_SEND_BUFFER_SIZE * 4.
+ * (That's 32kB with the current default).
+ *
+ * The default OS buffer size used to be 8kB in earlier Windows
+ * versions, but was raised to 64kB in Windows 2012. So it shouldn't
+ * be necessary to change it in later versions anymore. Changing it
+ * unnecessarily can even reduce performance, because setting
+ * SO_SNDBUF in the application disables the "dynamic send buffering"
+ * feature that was introduced in Windows 7. So before fiddling with
+ * SO_SNDBUF, check if the current buffer size is already large enough
+ * and only increase it if necessary.
+ *
+ * See https://support.microsoft.com/kb/823764/EN-US/ and
+ * https://msdn.microsoft.com/en-us/library/bb736549%28v=vs.85%29.aspx
+ */
+ optlen = sizeof(oldopt);
+ if (getsockopt(port->sock, SOL_SOCKET, SO_SNDBUF, (char *) &oldopt,
+ &optlen) < 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) failed: %m", "getsockopt", "SO_SNDBUF")));
+ return STATUS_ERROR;
+ }
+ newopt = PQ_SEND_BUFFER_SIZE * 4;
+ if (oldopt < newopt)
+ {
+ if (setsockopt(port->sock, SOL_SOCKET, SO_SNDBUF, (char *) &newopt,
+ sizeof(newopt)) < 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) failed: %m", "setsockopt", "SO_SNDBUF")));
+ return STATUS_ERROR;
+ }
+ }
+#endif
+
+ /*
+ * Also apply the current keepalive parameters. If we fail to set a
+ * parameter, don't error out, because these aren't universally
+ * supported. (Note: you might think we need to reset the GUC
+ * variables to 0 in such a case, but it's not necessary because the
+ * show hooks for these variables report the truth anyway.)
+ */
+ (void) pq_setkeepalivesidle(tcp_keepalives_idle, port);
+ (void) pq_setkeepalivesinterval(tcp_keepalives_interval, port);
+ (void) pq_setkeepalivescount(tcp_keepalives_count, port);
+ (void) pq_settcpusertimeout(tcp_user_timeout, port);
+ }
+
+ return STATUS_OK;
+}
+
+/*
+ * StreamClose -- close a client/backend connection
+ *
+ * NOTE: this is NOT used to terminate a session; it is just used to release
+ * the file descriptor in a process that should no longer have the socket
+ * open. (For example, the postmaster calls this after passing ownership
+ * of the connection to a child process.) It is expected that someone else
+ * still has the socket open. So, we only want to close the descriptor,
+ * we do NOT want to send anything to the far end.
+ */
+void
+StreamClose(pgsocket sock)
+{
+ closesocket(sock);
+}
+
+/*
+ * TouchSocketFiles -- mark socket files as recently accessed
+ *
+ * This routine should be called every so often to ensure that the socket
+ * files have a recent mod date (ordinary operations on sockets usually won't
+ * change the mod date). That saves them from being removed by
+ * overenthusiastic /tmp-directory-cleaner daemons. (Another reason we should
+ * never have put the socket file in /tmp...)
+ */
+void
+TouchSocketFiles(void)
+{
+ ListCell *l;
+
+ /* Loop through all created sockets... */
+ foreach(l, sock_paths)
+ {
+ char *sock_path = (char *) lfirst(l);
+
+ /* Ignore errors; there's no point in complaining */
+ (void) utime(sock_path, NULL);
+ }
+}
+
+/*
+ * RemoveSocketFiles -- unlink socket files at postmaster shutdown
+ */
+void
+RemoveSocketFiles(void)
+{
+ ListCell *l;
+
+ /* Loop through all created sockets... */
+ foreach(l, sock_paths)
+ {
+ char *sock_path = (char *) lfirst(l);
+
+ /* Ignore any error. */
+ (void) unlink(sock_path);
+ }
+ /* Since we're about to exit, no need to reclaim storage */
+ sock_paths = NIL;
+}
+
+
+/* --------------------------------
+ * Low-level I/O routines begin here.
+ *
+ * These routines communicate with a frontend client across a connection
+ * already established by the preceding routines.
+ * --------------------------------
+ */
+
+/* --------------------------------
+ * socket_set_nonblocking - set socket blocking/non-blocking
+ *
+ * Sets the socket non-blocking if nonblocking is true, or sets it
+ * blocking otherwise.
+ * --------------------------------
+ */
+static void
+socket_set_nonblocking(bool nonblocking)
+{
+ if (MyProcPort == NULL)
+ ereport(ERROR,
+ (errcode(ERRCODE_CONNECTION_DOES_NOT_EXIST),
+ errmsg("there is no client connection")));
+
+ MyProcPort->noblock = nonblocking;
+}
+
+/* --------------------------------
+ * pq_recvbuf - load some bytes into the input buffer
+ *
+ * returns 0 if OK, EOF if trouble
+ * --------------------------------
+ */
+static int
+pq_recvbuf(void)
+{
+ if (PqRecvPointer > 0)
+ {
+ if (PqRecvLength > PqRecvPointer)
+ {
+ /* still some unread data, left-justify it in the buffer */
+ memmove(PqRecvBuffer, PqRecvBuffer + PqRecvPointer,
+ PqRecvLength - PqRecvPointer);
+ PqRecvLength -= PqRecvPointer;
+ PqRecvPointer = 0;
+ }
+ else
+ PqRecvLength = PqRecvPointer = 0;
+ }
+
+ /* Ensure that we're in blocking mode */
+ socket_set_nonblocking(false);
+
+ /* Can fill buffer from PqRecvLength and upwards */
+ for (;;)
+ {
+ int r;
+
+ r = secure_read(MyProcPort, PqRecvBuffer + PqRecvLength,
+ PQ_RECV_BUFFER_SIZE - PqRecvLength);
+
+ if (r < 0)
+ {
+ if (errno == EINTR)
+ continue; /* Ok if interrupted */
+
+ /*
+ * Careful: an ereport() that tries to write to the client would
+ * cause recursion to here, leading to stack overflow and core
+ * dump! This message must go *only* to the postmaster log.
+ */
+ ereport(COMMERROR,
+ (errcode_for_socket_access(),
+ errmsg("could not receive data from client: %m")));
+ return EOF;
+ }
+ if (r == 0)
+ {
+ /*
+ * EOF detected. We used to write a log message here, but it's
+ * better to expect the ultimate caller to do that.
+ */
+ return EOF;
+ }
+ /* r contains number of bytes read, so just incr length */
+ PqRecvLength += r;
+ return 0;
+ }
+}
+
+/* --------------------------------
+ * pq_getbyte - get a single byte from connection, or return EOF
+ * --------------------------------
+ */
+int
+pq_getbyte(void)
+{
+ Assert(PqCommReadingMsg);
+
+ while (PqRecvPointer >= PqRecvLength)
+ {
+ if (pq_recvbuf()) /* If nothing in buffer, then recv some */
+ return EOF; /* Failed to recv data */
+ }
+ return (unsigned char) PqRecvBuffer[PqRecvPointer++];
+}
+
+/* --------------------------------
+ * pq_peekbyte - peek at next byte from connection
+ *
+ * Same as pq_getbyte() except we don't advance the pointer.
+ * --------------------------------
+ */
+int
+pq_peekbyte(void)
+{
+ Assert(PqCommReadingMsg);
+
+ while (PqRecvPointer >= PqRecvLength)
+ {
+ if (pq_recvbuf()) /* If nothing in buffer, then recv some */
+ return EOF; /* Failed to recv data */
+ }
+ return (unsigned char) PqRecvBuffer[PqRecvPointer];
+}
+
+/* --------------------------------
+ * pq_getbyte_if_available - get a single byte from connection,
+ * if available
+ *
+ * The received byte is stored in *c. Returns 1 if a byte was read,
+ * 0 if no data was available, or EOF if trouble.
+ * --------------------------------
+ */
+int
+pq_getbyte_if_available(unsigned char *c)
+{
+ int r;
+
+ Assert(PqCommReadingMsg);
+
+ if (PqRecvPointer < PqRecvLength)
+ {
+ *c = PqRecvBuffer[PqRecvPointer++];
+ return 1;
+ }
+
+ /* Put the socket into non-blocking mode */
+ socket_set_nonblocking(true);
+
+ r = secure_read(MyProcPort, c, 1);
+ if (r < 0)
+ {
+ /*
+ * Ok if no data available without blocking or interrupted (though
+ * EINTR really shouldn't happen with a non-blocking socket). Report
+ * other errors.
+ */
+ if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR)
+ r = 0;
+ else
+ {
+ /*
+ * Careful: an ereport() that tries to write to the client would
+ * cause recursion to here, leading to stack overflow and core
+ * dump! This message must go *only* to the postmaster log.
+ */
+ ereport(COMMERROR,
+ (errcode_for_socket_access(),
+ errmsg("could not receive data from client: %m")));
+ r = EOF;
+ }
+ }
+ else if (r == 0)
+ {
+ /* EOF detected */
+ r = EOF;
+ }
+
+ return r;
+}
+
+/* --------------------------------
+ * pq_getbytes - get a known number of bytes from connection
+ *
+ * returns 0 if OK, EOF if trouble
+ * --------------------------------
+ */
+int
+pq_getbytes(char *s, size_t len)
+{
+ size_t amount;
+
+ Assert(PqCommReadingMsg);
+
+ while (len > 0)
+ {
+ while (PqRecvPointer >= PqRecvLength)
+ {
+ if (pq_recvbuf()) /* If nothing in buffer, then recv some */
+ return EOF; /* Failed to recv data */
+ }
+ amount = PqRecvLength - PqRecvPointer;
+ if (amount > len)
+ amount = len;
+ memcpy(s, PqRecvBuffer + PqRecvPointer, amount);
+ PqRecvPointer += amount;
+ s += amount;
+ len -= amount;
+ }
+ return 0;
+}
+
+/* --------------------------------
+ * pq_discardbytes - throw away a known number of bytes
+ *
+ * same as pq_getbytes except we do not copy the data to anyplace.
+ * this is used for resynchronizing after read errors.
+ *
+ * returns 0 if OK, EOF if trouble
+ * --------------------------------
+ */
+static int
+pq_discardbytes(size_t len)
+{
+ size_t amount;
+
+ Assert(PqCommReadingMsg);
+
+ while (len > 0)
+ {
+ while (PqRecvPointer >= PqRecvLength)
+ {
+ if (pq_recvbuf()) /* If nothing in buffer, then recv some */
+ return EOF; /* Failed to recv data */
+ }
+ amount = PqRecvLength - PqRecvPointer;
+ if (amount > len)
+ amount = len;
+ PqRecvPointer += amount;
+ len -= amount;
+ }
+ return 0;
+}
+
+/* --------------------------------
+ * pq_buffer_has_data - is any buffered data available to read?
+ *
+ * This will *not* attempt to read more data.
+ * --------------------------------
+ */
+bool
+pq_buffer_has_data(void)
+{
+ return (PqRecvPointer < PqRecvLength);
+}
+
+
+/* --------------------------------
+ * pq_startmsgread - begin reading a message from the client.
+ *
+ * This must be called before any of the pq_get* functions.
+ * --------------------------------
+ */
+void
+pq_startmsgread(void)
+{
+ /*
+ * There shouldn't be a read active already, but let's check just to be
+ * sure.
+ */
+ if (PqCommReadingMsg)
+ ereport(FATAL,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("terminating connection because protocol synchronization was lost")));
+
+ PqCommReadingMsg = true;
+}
+
+
+/* --------------------------------
+ * pq_endmsgread - finish reading message.
+ *
+ * This must be called after reading a message with pq_getbytes()
+ * and friends, to indicate that we have read the whole message.
+ * pq_getmessage() does this implicitly.
+ * --------------------------------
+ */
+void
+pq_endmsgread(void)
+{
+ Assert(PqCommReadingMsg);
+
+ PqCommReadingMsg = false;
+}
+
+/* --------------------------------
+ * pq_is_reading_msg - are we currently reading a message?
+ *
+ * This is used in error recovery at the outer idle loop to detect if we have
+ * lost protocol sync, and need to terminate the connection. pq_startmsgread()
+ * will check for that too, but it's nicer to detect it earlier.
+ * --------------------------------
+ */
+bool
+pq_is_reading_msg(void)
+{
+ return PqCommReadingMsg;
+}
+
+/* --------------------------------
+ * pq_getmessage - get a message with length word from connection
+ *
+ * The return value is placed in an expansible StringInfo, which has
+ * already been initialized by the caller.
+ * Only the message body is placed in the StringInfo; the length word
+ * is removed. Also, s->cursor is initialized to zero for convenience
+ * in scanning the message contents.
+ *
+ * maxlen is the upper limit on the length of the
+ * message we are willing to accept. We abort the connection (by
+ * returning EOF) if client tries to send more than that.
+ *
+ * returns 0 if OK, EOF if trouble
+ * --------------------------------
+ */
+int
+pq_getmessage(StringInfo s, int maxlen)
+{
+ int32 len;
+
+ Assert(PqCommReadingMsg);
+
+ resetStringInfo(s);
+
+ /* Read message length word */
+ if (pq_getbytes((char *) &len, 4) == EOF)
+ {
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("unexpected EOF within message length word")));
+ return EOF;
+ }
+
+ len = pg_ntoh32(len);
+
+ if (len < 4 || len > maxlen)
+ {
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("invalid message length")));
+ return EOF;
+ }
+
+ len -= 4; /* discount length itself */
+
+ if (len > 0)
+ {
+ /*
+ * Allocate space for message. If we run out of room (ridiculously
+ * large message), we will elog(ERROR), but we want to discard the
+ * message body so as not to lose communication sync.
+ */
+ PG_TRY();
+ {
+ enlargeStringInfo(s, len);
+ }
+ PG_CATCH();
+ {
+ if (pq_discardbytes(len) == EOF)
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("incomplete message from client")));
+
+ /* we discarded the rest of the message so we're back in sync. */
+ PqCommReadingMsg = false;
+ PG_RE_THROW();
+ }
+ PG_END_TRY();
+
+ /* And grab the message */
+ if (pq_getbytes(s->data, len) == EOF)
+ {
+ ereport(COMMERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("incomplete message from client")));
+ return EOF;
+ }
+ s->len = len;
+ /* Place a trailing null per StringInfo convention */
+ s->data[len] = '\0';
+ }
+
+ /* finished reading the message. */
+ PqCommReadingMsg = false;
+
+ return 0;
+}
+
+
+static int
+internal_putbytes(const char *s, size_t len)
+{
+ size_t amount;
+
+ while (len > 0)
+ {
+ /* If buffer is full, then flush it out */
+ if (PqSendPointer >= PqSendBufferSize)
+ {
+ socket_set_nonblocking(false);
+ if (internal_flush())
+ return EOF;
+ }
+ amount = PqSendBufferSize - PqSendPointer;
+ if (amount > len)
+ amount = len;
+ memcpy(PqSendBuffer + PqSendPointer, s, amount);
+ PqSendPointer += amount;
+ s += amount;
+ len -= amount;
+ }
+ return 0;
+}
+
+/* --------------------------------
+ * socket_flush - flush pending output
+ *
+ * returns 0 if OK, EOF if trouble
+ * --------------------------------
+ */
+static int
+socket_flush(void)
+{
+ int res;
+
+ /* No-op if reentrant call */
+ if (PqCommBusy)
+ return 0;
+ PqCommBusy = true;
+ socket_set_nonblocking(false);
+ res = internal_flush();
+ PqCommBusy = false;
+ return res;
+}
+
+/* --------------------------------
+ * internal_flush - flush pending output
+ *
+ * Returns 0 if OK (meaning everything was sent, or operation would block
+ * and the socket is in non-blocking mode), or EOF if trouble.
+ * --------------------------------
+ */
+static int
+internal_flush(void)
+{
+ static int last_reported_send_errno = 0;
+
+ char *bufptr = PqSendBuffer + PqSendStart;
+ char *bufend = PqSendBuffer + PqSendPointer;
+
+ while (bufptr < bufend)
+ {
+ int r;
+
+ r = secure_write(MyProcPort, bufptr, bufend - bufptr);
+
+ if (r <= 0)
+ {
+ if (errno == EINTR)
+ continue; /* Ok if we were interrupted */
+
+ /*
+ * Ok if no data writable without blocking, and the socket is in
+ * non-blocking mode.
+ */
+ if (errno == EAGAIN ||
+ errno == EWOULDBLOCK)
+ {
+ return 0;
+ }
+
+ /*
+ * Careful: an ereport() that tries to write to the client would
+ * cause recursion to here, leading to stack overflow and core
+ * dump! This message must go *only* to the postmaster log.
+ *
+ * If a client disconnects while we're in the midst of output, we
+ * might write quite a bit of data before we get to a safe query
+ * abort point. So, suppress duplicate log messages.
+ */
+ if (errno != last_reported_send_errno)
+ {
+ last_reported_send_errno = errno;
+ ereport(COMMERROR,
+ (errcode_for_socket_access(),
+ errmsg("could not send data to client: %m")));
+ }
+
+ /*
+ * We drop the buffered data anyway so that processing can
+ * continue, even though we'll probably quit soon. We also set a
+ * flag that'll cause the next CHECK_FOR_INTERRUPTS to terminate
+ * the connection.
+ */
+ PqSendStart = PqSendPointer = 0;
+ ClientConnectionLost = 1;
+ InterruptPending = 1;
+ return EOF;
+ }
+
+ last_reported_send_errno = 0; /* reset after any successful send */
+ bufptr += r;
+ PqSendStart += r;
+ }
+
+ PqSendStart = PqSendPointer = 0;
+ return 0;
+}
+
+/* --------------------------------
+ * pq_flush_if_writable - flush pending output if writable without blocking
+ *
+ * Returns 0 if OK, or EOF if trouble.
+ * --------------------------------
+ */
+static int
+socket_flush_if_writable(void)
+{
+ int res;
+
+ /* Quick exit if nothing to do */
+ if (PqSendPointer == PqSendStart)
+ return 0;
+
+ /* No-op if reentrant call */
+ if (PqCommBusy)
+ return 0;
+
+ /* Temporarily put the socket into non-blocking mode */
+ socket_set_nonblocking(true);
+
+ PqCommBusy = true;
+ res = internal_flush();
+ PqCommBusy = false;
+ return res;
+}
+
+/* --------------------------------
+ * socket_is_send_pending - is there any pending data in the output buffer?
+ * --------------------------------
+ */
+static bool
+socket_is_send_pending(void)
+{
+ return (PqSendStart < PqSendPointer);
+}
+
+/* --------------------------------
+ * Message-level I/O routines begin here.
+ * --------------------------------
+ */
+
+
+/* --------------------------------
+ * socket_putmessage - send a normal message (suppressed in COPY OUT mode)
+ *
+ * msgtype is a message type code to place before the message body.
+ *
+ * len is the length of the message body data at *s. A message length
+ * word (equal to len+4 because it counts itself too) is inserted by this
+ * routine.
+ *
+ * We suppress messages generated while pqcomm.c is busy. This
+ * avoids any possibility of messages being inserted within other
+ * messages. The only known trouble case arises if SIGQUIT occurs
+ * during a pqcomm.c routine --- quickdie() will try to send a warning
+ * message, and the most reasonable approach seems to be to drop it.
+ *
+ * returns 0 if OK, EOF if trouble
+ * --------------------------------
+ */
+static int
+socket_putmessage(char msgtype, const char *s, size_t len)
+{
+ uint32 n32;
+
+ Assert(msgtype != 0);
+
+ if (PqCommBusy)
+ return 0;
+ PqCommBusy = true;
+ if (internal_putbytes(&msgtype, 1))
+ goto fail;
+
+ n32 = pg_hton32((uint32) (len + 4));
+ if (internal_putbytes((char *) &n32, 4))
+ goto fail;
+
+ if (internal_putbytes(s, len))
+ goto fail;
+ PqCommBusy = false;
+ return 0;
+
+fail:
+ PqCommBusy = false;
+ return EOF;
+}
+
+/* --------------------------------
+ * pq_putmessage_noblock - like pq_putmessage, but never blocks
+ *
+ * If the output buffer is too small to hold the message, the buffer
+ * is enlarged.
+ */
+static void
+socket_putmessage_noblock(char msgtype, const char *s, size_t len)
+{
+ int res PG_USED_FOR_ASSERTS_ONLY;
+ int required;
+
+ /*
+ * Ensure we have enough space in the output buffer for the message header
+ * as well as the message itself.
+ */
+ required = PqSendPointer + 1 + 4 + len;
+ if (required > PqSendBufferSize)
+ {
+ PqSendBuffer = repalloc(PqSendBuffer, required);
+ PqSendBufferSize = required;
+ }
+ res = pq_putmessage(msgtype, s, len);
+ Assert(res == 0); /* should not fail when the message fits in
+ * buffer */
+}
+
+/* --------------------------------
+ * pq_putmessage_v2 - send a message in protocol version 2
+ *
+ * msgtype is a message type code to place before the message body.
+ *
+ * We no longer support protocol version 2, but we have kept this
+ * function so that if a client tries to connect with protocol version 2,
+ * as a courtesy we can still send the "unsupported protocol version"
+ * error to the client in the old format.
+ *
+ * Like in pq_putmessage(), we suppress messages generated while
+ * pqcomm.c is busy.
+ *
+ * returns 0 if OK, EOF if trouble
+ * --------------------------------
+ */
+int
+pq_putmessage_v2(char msgtype, const char *s, size_t len)
+{
+ Assert(msgtype != 0);
+
+ if (PqCommBusy)
+ return 0;
+ PqCommBusy = true;
+ if (internal_putbytes(&msgtype, 1))
+ goto fail;
+
+ if (internal_putbytes(s, len))
+ goto fail;
+ PqCommBusy = false;
+ return 0;
+
+fail:
+ PqCommBusy = false;
+ return EOF;
+}
+
+/*
+ * Support for TCP Keepalive parameters
+ */
+
+/*
+ * On Windows, we need to set both idle and interval at the same time.
+ * We also cannot reset them to the default (setting to zero will
+ * actually set them to zero, not default), therefore we fallback to
+ * the out-of-the-box default instead.
+ */
+#if defined(WIN32) && defined(SIO_KEEPALIVE_VALS)
+static int
+pq_setkeepaliveswin32(Port *port, int idle, int interval)
+{
+ struct tcp_keepalive ka;
+ DWORD retsize;
+
+ if (idle <= 0)
+ idle = 2 * 60 * 60; /* default = 2 hours */
+ if (interval <= 0)
+ interval = 1; /* default = 1 second */
+
+ ka.onoff = 1;
+ ka.keepalivetime = idle * 1000;
+ ka.keepaliveinterval = interval * 1000;
+
+ if (WSAIoctl(port->sock,
+ SIO_KEEPALIVE_VALS,
+ (LPVOID) &ka,
+ sizeof(ka),
+ NULL,
+ 0,
+ &retsize,
+ NULL,
+ NULL)
+ != 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) failed: error code %d",
+ "WSAIoctl", "SIO_KEEPALIVE_VALS", WSAGetLastError())));
+ return STATUS_ERROR;
+ }
+ if (port->keepalives_idle != idle)
+ port->keepalives_idle = idle;
+ if (port->keepalives_interval != interval)
+ port->keepalives_interval = interval;
+ return STATUS_OK;
+}
+#endif
+
+int
+pq_getkeepalivesidle(Port *port)
+{
+#if defined(PG_TCP_KEEPALIVE_IDLE) || defined(SIO_KEEPALIVE_VALS)
+ if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family))
+ return 0;
+
+ if (port->keepalives_idle != 0)
+ return port->keepalives_idle;
+
+ if (port->default_keepalives_idle == 0)
+ {
+#ifndef WIN32
+ ACCEPT_TYPE_ARG3 size = sizeof(port->default_keepalives_idle);
+
+ if (getsockopt(port->sock, IPPROTO_TCP, PG_TCP_KEEPALIVE_IDLE,
+ (char *) &port->default_keepalives_idle,
+ &size) < 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) failed: %m", "getsockopt", PG_TCP_KEEPALIVE_IDLE_STR)));
+ port->default_keepalives_idle = -1; /* don't know */
+ }
+#else /* WIN32 */
+ /* We can't get the defaults on Windows, so return "don't know" */
+ port->default_keepalives_idle = -1;
+#endif /* WIN32 */
+ }
+
+ return port->default_keepalives_idle;
+#else
+ return 0;
+#endif
+}
+
+int
+pq_setkeepalivesidle(int idle, Port *port)
+{
+ if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family))
+ return STATUS_OK;
+
+/* check SIO_KEEPALIVE_VALS here, not just WIN32, as some toolchains lack it */
+#if defined(PG_TCP_KEEPALIVE_IDLE) || defined(SIO_KEEPALIVE_VALS)
+ if (idle == port->keepalives_idle)
+ return STATUS_OK;
+
+#ifndef WIN32
+ if (port->default_keepalives_idle <= 0)
+ {
+ if (pq_getkeepalivesidle(port) < 0)
+ {
+ if (idle == 0)
+ return STATUS_OK; /* default is set but unknown */
+ else
+ return STATUS_ERROR;
+ }
+ }
+
+ if (idle == 0)
+ idle = port->default_keepalives_idle;
+
+ if (setsockopt(port->sock, IPPROTO_TCP, PG_TCP_KEEPALIVE_IDLE,
+ (char *) &idle, sizeof(idle)) < 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) failed: %m", "setsockopt", PG_TCP_KEEPALIVE_IDLE_STR)));
+ return STATUS_ERROR;
+ }
+
+ port->keepalives_idle = idle;
+#else /* WIN32 */
+ return pq_setkeepaliveswin32(port, idle, port->keepalives_interval);
+#endif
+#else
+ if (idle != 0)
+ {
+ ereport(LOG,
+ (errmsg("setting the keepalive idle time is not supported")));
+ return STATUS_ERROR;
+ }
+#endif
+
+ return STATUS_OK;
+}
+
+int
+pq_getkeepalivesinterval(Port *port)
+{
+#if defined(TCP_KEEPINTVL) || defined(SIO_KEEPALIVE_VALS)
+ if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family))
+ return 0;
+
+ if (port->keepalives_interval != 0)
+ return port->keepalives_interval;
+
+ if (port->default_keepalives_interval == 0)
+ {
+#ifndef WIN32
+ ACCEPT_TYPE_ARG3 size = sizeof(port->default_keepalives_interval);
+
+ if (getsockopt(port->sock, IPPROTO_TCP, TCP_KEEPINTVL,
+ (char *) &port->default_keepalives_interval,
+ &size) < 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) failed: %m", "getsockopt", "TCP_KEEPINTVL")));
+ port->default_keepalives_interval = -1; /* don't know */
+ }
+#else
+ /* We can't get the defaults on Windows, so return "don't know" */
+ port->default_keepalives_interval = -1;
+#endif /* WIN32 */
+ }
+
+ return port->default_keepalives_interval;
+#else
+ return 0;
+#endif
+}
+
+int
+pq_setkeepalivesinterval(int interval, Port *port)
+{
+ if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family))
+ return STATUS_OK;
+
+#if defined(TCP_KEEPINTVL) || defined(SIO_KEEPALIVE_VALS)
+ if (interval == port->keepalives_interval)
+ return STATUS_OK;
+
+#ifndef WIN32
+ if (port->default_keepalives_interval <= 0)
+ {
+ if (pq_getkeepalivesinterval(port) < 0)
+ {
+ if (interval == 0)
+ return STATUS_OK; /* default is set but unknown */
+ else
+ return STATUS_ERROR;
+ }
+ }
+
+ if (interval == 0)
+ interval = port->default_keepalives_interval;
+
+ if (setsockopt(port->sock, IPPROTO_TCP, TCP_KEEPINTVL,
+ (char *) &interval, sizeof(interval)) < 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) failed: %m", "setsockopt", "TCP_KEEPINTVL")));
+ return STATUS_ERROR;
+ }
+
+ port->keepalives_interval = interval;
+#else /* WIN32 */
+ return pq_setkeepaliveswin32(port, port->keepalives_idle, interval);
+#endif
+#else
+ if (interval != 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) not supported", "setsockopt", "TCP_KEEPINTVL")));
+ return STATUS_ERROR;
+ }
+#endif
+
+ return STATUS_OK;
+}
+
+int
+pq_getkeepalivescount(Port *port)
+{
+#ifdef TCP_KEEPCNT
+ if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family))
+ return 0;
+
+ if (port->keepalives_count != 0)
+ return port->keepalives_count;
+
+ if (port->default_keepalives_count == 0)
+ {
+ ACCEPT_TYPE_ARG3 size = sizeof(port->default_keepalives_count);
+
+ if (getsockopt(port->sock, IPPROTO_TCP, TCP_KEEPCNT,
+ (char *) &port->default_keepalives_count,
+ &size) < 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) failed: %m", "getsockopt", "TCP_KEEPCNT")));
+ port->default_keepalives_count = -1; /* don't know */
+ }
+ }
+
+ return port->default_keepalives_count;
+#else
+ return 0;
+#endif
+}
+
+int
+pq_setkeepalivescount(int count, Port *port)
+{
+ if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family))
+ return STATUS_OK;
+
+#ifdef TCP_KEEPCNT
+ if (count == port->keepalives_count)
+ return STATUS_OK;
+
+ if (port->default_keepalives_count <= 0)
+ {
+ if (pq_getkeepalivescount(port) < 0)
+ {
+ if (count == 0)
+ return STATUS_OK; /* default is set but unknown */
+ else
+ return STATUS_ERROR;
+ }
+ }
+
+ if (count == 0)
+ count = port->default_keepalives_count;
+
+ if (setsockopt(port->sock, IPPROTO_TCP, TCP_KEEPCNT,
+ (char *) &count, sizeof(count)) < 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) failed: %m", "setsockopt", "TCP_KEEPCNT")));
+ return STATUS_ERROR;
+ }
+
+ port->keepalives_count = count;
+#else
+ if (count != 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) not supported", "setsockopt", "TCP_KEEPCNT")));
+ return STATUS_ERROR;
+ }
+#endif
+
+ return STATUS_OK;
+}
+
+int
+pq_gettcpusertimeout(Port *port)
+{
+#ifdef TCP_USER_TIMEOUT
+ if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family))
+ return 0;
+
+ if (port->tcp_user_timeout != 0)
+ return port->tcp_user_timeout;
+
+ if (port->default_tcp_user_timeout == 0)
+ {
+ ACCEPT_TYPE_ARG3 size = sizeof(port->default_tcp_user_timeout);
+
+ if (getsockopt(port->sock, IPPROTO_TCP, TCP_USER_TIMEOUT,
+ (char *) &port->default_tcp_user_timeout,
+ &size) < 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) failed: %m", "getsockopt", "TCP_USER_TIMEOUT")));
+ port->default_tcp_user_timeout = -1; /* don't know */
+ }
+ }
+
+ return port->default_tcp_user_timeout;
+#else
+ return 0;
+#endif
+}
+
+int
+pq_settcpusertimeout(int timeout, Port *port)
+{
+ if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family))
+ return STATUS_OK;
+
+#ifdef TCP_USER_TIMEOUT
+ if (timeout == port->tcp_user_timeout)
+ return STATUS_OK;
+
+ if (port->default_tcp_user_timeout <= 0)
+ {
+ if (pq_gettcpusertimeout(port) < 0)
+ {
+ if (timeout == 0)
+ return STATUS_OK; /* default is set but unknown */
+ else
+ return STATUS_ERROR;
+ }
+ }
+
+ if (timeout == 0)
+ timeout = port->default_tcp_user_timeout;
+
+ if (setsockopt(port->sock, IPPROTO_TCP, TCP_USER_TIMEOUT,
+ (char *) &timeout, sizeof(timeout)) < 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) failed: %m", "setsockopt", "TCP_USER_TIMEOUT")));
+ return STATUS_ERROR;
+ }
+
+ port->tcp_user_timeout = timeout;
+#else
+ if (timeout != 0)
+ {
+ ereport(LOG,
+ (errmsg("%s(%s) not supported", "setsockopt", "TCP_USER_TIMEOUT")));
+ return STATUS_ERROR;
+ }
+#endif
+
+ return STATUS_OK;
+}
+
+/*
+ * Check if the client is still connected.
+ */
+bool
+pq_check_connection(void)
+{
+#if defined(POLLRDHUP)
+ /*
+ * POLLRDHUP is a Linux extension to poll(2) to detect sockets closed by
+ * the other end. We don't have a portable way to do that without
+ * actually trying to read or write data on other systems. We don't want
+ * to read because that would be confused by pipelined queries and COPY
+ * data. Perhaps in future we'll try to write a heartbeat message instead.
+ */
+ struct pollfd pollfd;
+ int rc;
+
+ pollfd.fd = MyProcPort->sock;
+ pollfd.events = POLLOUT | POLLIN | POLLRDHUP;
+ pollfd.revents = 0;
+
+ rc = poll(&pollfd, 1, 0);
+
+ if (rc < 0)
+ {
+ ereport(COMMERROR,
+ (errcode_for_socket_access(),
+ errmsg("could not poll socket: %m")));
+ return false;
+ }
+ else if (rc == 1 && (pollfd.revents & (POLLHUP | POLLRDHUP)))
+ return false;
+#endif
+
+ return true;
+}
diff --git a/src/backend/libpq/pqformat.c b/src/backend/libpq/pqformat.c
new file mode 100644
index 0000000..1999898
--- /dev/null
+++ b/src/backend/libpq/pqformat.c
@@ -0,0 +1,643 @@
+/*-------------------------------------------------------------------------
+ *
+ * pqformat.c
+ * Routines for formatting and parsing frontend/backend messages
+ *
+ * Outgoing messages are built up in a StringInfo buffer (which is expansible)
+ * and then sent in a single call to pq_putmessage. This module provides data
+ * formatting/conversion routines that are needed to produce valid messages.
+ * Note in particular the distinction between "raw data" and "text"; raw data
+ * is message protocol characters and binary values that are not subject to
+ * character set conversion, while text is converted by character encoding
+ * rules.
+ *
+ * Incoming messages are similarly read into a StringInfo buffer, via
+ * pq_getmessage, and then parsed and converted from that using the routines
+ * in this module.
+ *
+ * These same routines support reading and writing of external binary formats
+ * (typsend/typreceive routines). The conversion routines for individual
+ * data types are exactly the same, only initialization and completion
+ * are different.
+ *
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/backend/libpq/pqformat.c
+ *
+ *-------------------------------------------------------------------------
+ */
+/*
+ * INTERFACE ROUTINES
+ * Message assembly and output:
+ * pq_beginmessage - initialize StringInfo buffer
+ * pq_sendbyte - append a raw byte to a StringInfo buffer
+ * pq_sendint - append a binary integer to a StringInfo buffer
+ * pq_sendint64 - append a binary 8-byte int to a StringInfo buffer
+ * pq_sendfloat4 - append a float4 to a StringInfo buffer
+ * pq_sendfloat8 - append a float8 to a StringInfo buffer
+ * pq_sendbytes - append raw data to a StringInfo buffer
+ * pq_sendcountedtext - append a counted text string (with character set conversion)
+ * pq_sendtext - append a text string (with conversion)
+ * pq_sendstring - append a null-terminated text string (with conversion)
+ * pq_send_ascii_string - append a null-terminated text string (without conversion)
+ * pq_endmessage - send the completed message to the frontend
+ * Note: it is also possible to append data to the StringInfo buffer using
+ * the regular StringInfo routines, but this is discouraged since required
+ * character set conversion may not occur.
+ *
+ * typsend support (construct a bytea value containing external binary data):
+ * pq_begintypsend - initialize StringInfo buffer
+ * pq_endtypsend - return the completed string as a "bytea*"
+ *
+ * Special-case message output:
+ * pq_puttextmessage - generate a character set-converted message in one step
+ * pq_putemptymessage - convenience routine for message with empty body
+ *
+ * Message parsing after input:
+ * pq_getmsgbyte - get a raw byte from a message buffer
+ * pq_getmsgint - get a binary integer from a message buffer
+ * pq_getmsgint64 - get a binary 8-byte int from a message buffer
+ * pq_getmsgfloat4 - get a float4 from a message buffer
+ * pq_getmsgfloat8 - get a float8 from a message buffer
+ * pq_getmsgbytes - get raw data from a message buffer
+ * pq_copymsgbytes - copy raw data from a message buffer
+ * pq_getmsgtext - get a counted text string (with conversion)
+ * pq_getmsgstring - get a null-terminated text string (with conversion)
+ * pq_getmsgrawstring - get a null-terminated text string - NO conversion
+ * pq_getmsgend - verify message fully consumed
+ */
+
+#include "postgres.h"
+
+#include <sys/param.h>
+
+#include "libpq/libpq.h"
+#include "libpq/pqformat.h"
+#include "mb/pg_wchar.h"
+#include "port/pg_bswap.h"
+
+
+/* --------------------------------
+ * pq_beginmessage - initialize for sending a message
+ * --------------------------------
+ */
+void
+pq_beginmessage(StringInfo buf, char msgtype)
+{
+ initStringInfo(buf);
+
+ /*
+ * We stash the message type into the buffer's cursor field, expecting
+ * that the pq_sendXXX routines won't touch it. We could alternatively
+ * make it the first byte of the buffer contents, but this seems easier.
+ */
+ buf->cursor = msgtype;
+}
+
+/* --------------------------------
+
+ * pq_beginmessage_reuse - initialize for sending a message, reuse buffer
+ *
+ * This requires the buffer to be allocated in a sufficiently long-lived
+ * memory context.
+ * --------------------------------
+ */
+void
+pq_beginmessage_reuse(StringInfo buf, char msgtype)
+{
+ resetStringInfo(buf);
+
+ /*
+ * We stash the message type into the buffer's cursor field, expecting
+ * that the pq_sendXXX routines won't touch it. We could alternatively
+ * make it the first byte of the buffer contents, but this seems easier.
+ */
+ buf->cursor = msgtype;
+}
+
+/* --------------------------------
+ * pq_sendbytes - append raw data to a StringInfo buffer
+ * --------------------------------
+ */
+void
+pq_sendbytes(StringInfo buf, const char *data, int datalen)
+{
+ /* use variant that maintains a trailing null-byte, out of caution */
+ appendBinaryStringInfo(buf, data, datalen);
+}
+
+/* --------------------------------
+ * pq_sendcountedtext - append a counted text string (with character set conversion)
+ *
+ * The data sent to the frontend by this routine is a 4-byte count field
+ * followed by the string. The count includes itself or not, as per the
+ * countincludesself flag (pre-3.0 protocol requires it to include itself).
+ * The passed text string need not be null-terminated, and the data sent
+ * to the frontend isn't either.
+ * --------------------------------
+ */
+void
+pq_sendcountedtext(StringInfo buf, const char *str, int slen,
+ bool countincludesself)
+{
+ int extra = countincludesself ? 4 : 0;
+ char *p;
+
+ p = pg_server_to_client(str, slen);
+ if (p != str) /* actual conversion has been done? */
+ {
+ slen = strlen(p);
+ pq_sendint32(buf, slen + extra);
+ appendBinaryStringInfoNT(buf, p, slen);
+ pfree(p);
+ }
+ else
+ {
+ pq_sendint32(buf, slen + extra);
+ appendBinaryStringInfoNT(buf, str, slen);
+ }
+}
+
+/* --------------------------------
+ * pq_sendtext - append a text string (with conversion)
+ *
+ * The passed text string need not be null-terminated, and the data sent
+ * to the frontend isn't either. Note that this is not actually useful
+ * for direct frontend transmissions, since there'd be no way for the
+ * frontend to determine the string length. But it is useful for binary
+ * format conversions.
+ * --------------------------------
+ */
+void
+pq_sendtext(StringInfo buf, const char *str, int slen)
+{
+ char *p;
+
+ p = pg_server_to_client(str, slen);
+ if (p != str) /* actual conversion has been done? */
+ {
+ slen = strlen(p);
+ appendBinaryStringInfo(buf, p, slen);
+ pfree(p);
+ }
+ else
+ appendBinaryStringInfo(buf, str, slen);
+}
+
+/* --------------------------------
+ * pq_sendstring - append a null-terminated text string (with conversion)
+ *
+ * NB: passed text string must be null-terminated, and so is the data
+ * sent to the frontend.
+ * --------------------------------
+ */
+void
+pq_sendstring(StringInfo buf, const char *str)
+{
+ int slen = strlen(str);
+ char *p;
+
+ p = pg_server_to_client(str, slen);
+ if (p != str) /* actual conversion has been done? */
+ {
+ slen = strlen(p);
+ appendBinaryStringInfoNT(buf, p, slen + 1);
+ pfree(p);
+ }
+ else
+ appendBinaryStringInfoNT(buf, str, slen + 1);
+}
+
+/* --------------------------------
+ * pq_send_ascii_string - append a null-terminated text string (without conversion)
+ *
+ * This function intentionally bypasses encoding conversion, instead just
+ * silently replacing any non-7-bit-ASCII characters with question marks.
+ * It is used only when we are having trouble sending an error message to
+ * the client with normal localization and encoding conversion. The caller
+ * should already have taken measures to ensure the string is just ASCII;
+ * the extra work here is just to make certain we don't send a badly encoded
+ * string to the client (which might or might not be robust about that).
+ *
+ * NB: passed text string must be null-terminated, and so is the data
+ * sent to the frontend.
+ * --------------------------------
+ */
+void
+pq_send_ascii_string(StringInfo buf, const char *str)
+{
+ while (*str)
+ {
+ char ch = *str++;
+
+ if (IS_HIGHBIT_SET(ch))
+ ch = '?';
+ appendStringInfoCharMacro(buf, ch);
+ }
+ appendStringInfoChar(buf, '\0');
+}
+
+/* --------------------------------
+ * pq_sendfloat4 - append a float4 to a StringInfo buffer
+ *
+ * The point of this routine is to localize knowledge of the external binary
+ * representation of float4, which is a component of several datatypes.
+ *
+ * We currently assume that float4 should be byte-swapped in the same way
+ * as int4. This rule is not perfect but it gives us portability across
+ * most IEEE-float-using architectures.
+ * --------------------------------
+ */
+void
+pq_sendfloat4(StringInfo buf, float4 f)
+{
+ union
+ {
+ float4 f;
+ uint32 i;
+ } swap;
+
+ swap.f = f;
+ pq_sendint32(buf, swap.i);
+}
+
+/* --------------------------------
+ * pq_sendfloat8 - append a float8 to a StringInfo buffer
+ *
+ * The point of this routine is to localize knowledge of the external binary
+ * representation of float8, which is a component of several datatypes.
+ *
+ * We currently assume that float8 should be byte-swapped in the same way
+ * as int8. This rule is not perfect but it gives us portability across
+ * most IEEE-float-using architectures.
+ * --------------------------------
+ */
+void
+pq_sendfloat8(StringInfo buf, float8 f)
+{
+ union
+ {
+ float8 f;
+ int64 i;
+ } swap;
+
+ swap.f = f;
+ pq_sendint64(buf, swap.i);
+}
+
+/* --------------------------------
+ * pq_endmessage - send the completed message to the frontend
+ *
+ * The data buffer is pfree()d, but if the StringInfo was allocated with
+ * makeStringInfo then the caller must still pfree it.
+ * --------------------------------
+ */
+void
+pq_endmessage(StringInfo buf)
+{
+ /* msgtype was saved in cursor field */
+ (void) pq_putmessage(buf->cursor, buf->data, buf->len);
+ /* no need to complain about any failure, since pqcomm.c already did */
+ pfree(buf->data);
+ buf->data = NULL;
+}
+
+/* --------------------------------
+ * pq_endmessage_reuse - send the completed message to the frontend
+ *
+ * The data buffer is *not* freed, allowing to reuse the buffer with
+ * pq_beginmessage_reuse.
+ --------------------------------
+ */
+
+void
+pq_endmessage_reuse(StringInfo buf)
+{
+ /* msgtype was saved in cursor field */
+ (void) pq_putmessage(buf->cursor, buf->data, buf->len);
+}
+
+
+/* --------------------------------
+ * pq_begintypsend - initialize for constructing a bytea result
+ * --------------------------------
+ */
+void
+pq_begintypsend(StringInfo buf)
+{
+ initStringInfo(buf);
+ /* Reserve four bytes for the bytea length word */
+ appendStringInfoCharMacro(buf, '\0');
+ appendStringInfoCharMacro(buf, '\0');
+ appendStringInfoCharMacro(buf, '\0');
+ appendStringInfoCharMacro(buf, '\0');
+}
+
+/* --------------------------------
+ * pq_endtypsend - finish constructing a bytea result
+ *
+ * The data buffer is returned as the palloc'd bytea value. (We expect
+ * that it will be suitably aligned for this because it has been palloc'd.)
+ * We assume the StringInfoData is just a local variable in the caller and
+ * need not be pfree'd.
+ * --------------------------------
+ */
+bytea *
+pq_endtypsend(StringInfo buf)
+{
+ bytea *result = (bytea *) buf->data;
+
+ /* Insert correct length into bytea length word */
+ Assert(buf->len >= VARHDRSZ);
+ SET_VARSIZE(result, buf->len);
+
+ return result;
+}
+
+
+/* --------------------------------
+ * pq_puttextmessage - generate a character set-converted message in one step
+ *
+ * This is the same as the pqcomm.c routine pq_putmessage, except that
+ * the message body is a null-terminated string to which encoding
+ * conversion applies.
+ * --------------------------------
+ */
+void
+pq_puttextmessage(char msgtype, const char *str)
+{
+ int slen = strlen(str);
+ char *p;
+
+ p = pg_server_to_client(str, slen);
+ if (p != str) /* actual conversion has been done? */
+ {
+ (void) pq_putmessage(msgtype, p, strlen(p) + 1);
+ pfree(p);
+ return;
+ }
+ (void) pq_putmessage(msgtype, str, slen + 1);
+}
+
+
+/* --------------------------------
+ * pq_putemptymessage - convenience routine for message with empty body
+ * --------------------------------
+ */
+void
+pq_putemptymessage(char msgtype)
+{
+ (void) pq_putmessage(msgtype, NULL, 0);
+}
+
+
+/* --------------------------------
+ * pq_getmsgbyte - get a raw byte from a message buffer
+ * --------------------------------
+ */
+int
+pq_getmsgbyte(StringInfo msg)
+{
+ if (msg->cursor >= msg->len)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("no data left in message")));
+ return (unsigned char) msg->data[msg->cursor++];
+}
+
+/* --------------------------------
+ * pq_getmsgint - get a binary integer from a message buffer
+ *
+ * Values are treated as unsigned.
+ * --------------------------------
+ */
+unsigned int
+pq_getmsgint(StringInfo msg, int b)
+{
+ unsigned int result;
+ unsigned char n8;
+ uint16 n16;
+ uint32 n32;
+
+ switch (b)
+ {
+ case 1:
+ pq_copymsgbytes(msg, (char *) &n8, 1);
+ result = n8;
+ break;
+ case 2:
+ pq_copymsgbytes(msg, (char *) &n16, 2);
+ result = pg_ntoh16(n16);
+ break;
+ case 4:
+ pq_copymsgbytes(msg, (char *) &n32, 4);
+ result = pg_ntoh32(n32);
+ break;
+ default:
+ elog(ERROR, "unsupported integer size %d", b);
+ result = 0; /* keep compiler quiet */
+ break;
+ }
+ return result;
+}
+
+/* --------------------------------
+ * pq_getmsgint64 - get a binary 8-byte int from a message buffer
+ *
+ * It is tempting to merge this with pq_getmsgint, but we'd have to make the
+ * result int64 for all data widths --- that could be a big performance
+ * hit on machines where int64 isn't efficient.
+ * --------------------------------
+ */
+int64
+pq_getmsgint64(StringInfo msg)
+{
+ uint64 n64;
+
+ pq_copymsgbytes(msg, (char *) &n64, sizeof(n64));
+
+ return pg_ntoh64(n64);
+}
+
+/* --------------------------------
+ * pq_getmsgfloat4 - get a float4 from a message buffer
+ *
+ * See notes for pq_sendfloat4.
+ * --------------------------------
+ */
+float4
+pq_getmsgfloat4(StringInfo msg)
+{
+ union
+ {
+ float4 f;
+ uint32 i;
+ } swap;
+
+ swap.i = pq_getmsgint(msg, 4);
+ return swap.f;
+}
+
+/* --------------------------------
+ * pq_getmsgfloat8 - get a float8 from a message buffer
+ *
+ * See notes for pq_sendfloat8.
+ * --------------------------------
+ */
+float8
+pq_getmsgfloat8(StringInfo msg)
+{
+ union
+ {
+ float8 f;
+ int64 i;
+ } swap;
+
+ swap.i = pq_getmsgint64(msg);
+ return swap.f;
+}
+
+/* --------------------------------
+ * pq_getmsgbytes - get raw data from a message buffer
+ *
+ * Returns a pointer directly into the message buffer; note this
+ * may not have any particular alignment.
+ * --------------------------------
+ */
+const char *
+pq_getmsgbytes(StringInfo msg, int datalen)
+{
+ const char *result;
+
+ if (datalen < 0 || datalen > (msg->len - msg->cursor))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("insufficient data left in message")));
+ result = &msg->data[msg->cursor];
+ msg->cursor += datalen;
+ return result;
+}
+
+/* --------------------------------
+ * pq_copymsgbytes - copy raw data from a message buffer
+ *
+ * Same as above, except data is copied to caller's buffer.
+ * --------------------------------
+ */
+void
+pq_copymsgbytes(StringInfo msg, char *buf, int datalen)
+{
+ if (datalen < 0 || datalen > (msg->len - msg->cursor))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("insufficient data left in message")));
+ memcpy(buf, &msg->data[msg->cursor], datalen);
+ msg->cursor += datalen;
+}
+
+/* --------------------------------
+ * pq_getmsgtext - get a counted text string (with conversion)
+ *
+ * Always returns a pointer to a freshly palloc'd result.
+ * The result has a trailing null, *and* we return its strlen in *nbytes.
+ * --------------------------------
+ */
+char *
+pq_getmsgtext(StringInfo msg, int rawbytes, int *nbytes)
+{
+ char *str;
+ char *p;
+
+ if (rawbytes < 0 || rawbytes > (msg->len - msg->cursor))
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("insufficient data left in message")));
+ str = &msg->data[msg->cursor];
+ msg->cursor += rawbytes;
+
+ p = pg_client_to_server(str, rawbytes);
+ if (p != str) /* actual conversion has been done? */
+ *nbytes = strlen(p);
+ else
+ {
+ p = (char *) palloc(rawbytes + 1);
+ memcpy(p, str, rawbytes);
+ p[rawbytes] = '\0';
+ *nbytes = rawbytes;
+ }
+ return p;
+}
+
+/* --------------------------------
+ * pq_getmsgstring - get a null-terminated text string (with conversion)
+ *
+ * May return a pointer directly into the message buffer, or a pointer
+ * to a palloc'd conversion result.
+ * --------------------------------
+ */
+const char *
+pq_getmsgstring(StringInfo msg)
+{
+ char *str;
+ int slen;
+
+ str = &msg->data[msg->cursor];
+
+ /*
+ * It's safe to use strlen() here because a StringInfo is guaranteed to
+ * have a trailing null byte. But check we found a null inside the
+ * message.
+ */
+ slen = strlen(str);
+ if (msg->cursor + slen >= msg->len)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("invalid string in message")));
+ msg->cursor += slen + 1;
+
+ return pg_client_to_server(str, slen);
+}
+
+/* --------------------------------
+ * pq_getmsgrawstring - get a null-terminated text string - NO conversion
+ *
+ * Returns a pointer directly into the message buffer.
+ * --------------------------------
+ */
+const char *
+pq_getmsgrawstring(StringInfo msg)
+{
+ char *str;
+ int slen;
+
+ str = &msg->data[msg->cursor];
+
+ /*
+ * It's safe to use strlen() here because a StringInfo is guaranteed to
+ * have a trailing null byte. But check we found a null inside the
+ * message.
+ */
+ slen = strlen(str);
+ if (msg->cursor + slen >= msg->len)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("invalid string in message")));
+ msg->cursor += slen + 1;
+
+ return str;
+}
+
+/* --------------------------------
+ * pq_getmsgend - verify message fully consumed
+ * --------------------------------
+ */
+void
+pq_getmsgend(StringInfo msg)
+{
+ if (msg->cursor != msg->len)
+ ereport(ERROR,
+ (errcode(ERRCODE_PROTOCOL_VIOLATION),
+ errmsg("invalid message format")));
+}
diff --git a/src/backend/libpq/pqmq.c b/src/backend/libpq/pqmq.c
new file mode 100644
index 0000000..d1a1f47
--- /dev/null
+++ b/src/backend/libpq/pqmq.c
@@ -0,0 +1,313 @@
+/*-------------------------------------------------------------------------
+ *
+ * pqmq.c
+ * Use the frontend/backend protocol for communication over a shm_mq
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ * src/backend/libpq/pqmq.c
+ *
+ *-------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+
+#include "libpq/libpq.h"
+#include "libpq/pqformat.h"
+#include "libpq/pqmq.h"
+#include "miscadmin.h"
+#include "pgstat.h"
+#include "tcop/tcopprot.h"
+#include "utils/builtins.h"
+
+static shm_mq_handle *pq_mq_handle;
+static bool pq_mq_busy = false;
+static pid_t pq_mq_parallel_leader_pid = 0;
+static pid_t pq_mq_parallel_leader_backend_id = InvalidBackendId;
+
+static void pq_cleanup_redirect_to_shm_mq(dsm_segment *seg, Datum arg);
+static void mq_comm_reset(void);
+static int mq_flush(void);
+static int mq_flush_if_writable(void);
+static bool mq_is_send_pending(void);
+static int mq_putmessage(char msgtype, const char *s, size_t len);
+static void mq_putmessage_noblock(char msgtype, const char *s, size_t len);
+
+static const PQcommMethods PqCommMqMethods = {
+ mq_comm_reset,
+ mq_flush,
+ mq_flush_if_writable,
+ mq_is_send_pending,
+ mq_putmessage,
+ mq_putmessage_noblock
+};
+
+/*
+ * Arrange to redirect frontend/backend protocol messages to a shared-memory
+ * message queue.
+ */
+void
+pq_redirect_to_shm_mq(dsm_segment *seg, shm_mq_handle *mqh)
+{
+ PqCommMethods = &PqCommMqMethods;
+ pq_mq_handle = mqh;
+ whereToSendOutput = DestRemote;
+ FrontendProtocol = PG_PROTOCOL_LATEST;
+ on_dsm_detach(seg, pq_cleanup_redirect_to_shm_mq, (Datum) 0);
+}
+
+/*
+ * When the DSM that contains our shm_mq goes away, we need to stop sending
+ * messages to it.
+ */
+static void
+pq_cleanup_redirect_to_shm_mq(dsm_segment *seg, Datum arg)
+{
+ pq_mq_handle = NULL;
+ whereToSendOutput = DestNone;
+}
+
+/*
+ * Arrange to SendProcSignal() to the parallel leader each time we transmit
+ * message data via the shm_mq.
+ */
+void
+pq_set_parallel_leader(pid_t pid, BackendId backend_id)
+{
+ Assert(PqCommMethods == &PqCommMqMethods);
+ pq_mq_parallel_leader_pid = pid;
+ pq_mq_parallel_leader_backend_id = backend_id;
+}
+
+static void
+mq_comm_reset(void)
+{
+ /* Nothing to do. */
+}
+
+static int
+mq_flush(void)
+{
+ /* Nothing to do. */
+ return 0;
+}
+
+static int
+mq_flush_if_writable(void)
+{
+ /* Nothing to do. */
+ return 0;
+}
+
+static bool
+mq_is_send_pending(void)
+{
+ /* There's never anything pending. */
+ return 0;
+}
+
+/*
+ * Transmit a libpq protocol message to the shared memory message queue
+ * selected via pq_mq_handle. We don't include a length word, because the
+ * receiver will know the length of the message from shm_mq_receive().
+ */
+static int
+mq_putmessage(char msgtype, const char *s, size_t len)
+{
+ shm_mq_iovec iov[2];
+ shm_mq_result result;
+
+ /*
+ * If we're sending a message, and we have to wait because the queue is
+ * full, and then we get interrupted, and that interrupt results in trying
+ * to send another message, we respond by detaching the queue. There's no
+ * way to return to the original context, but even if there were, just
+ * queueing the message would amount to indefinitely postponing the
+ * response to the interrupt. So we do this instead.
+ */
+ if (pq_mq_busy)
+ {
+ if (pq_mq_handle != NULL)
+ shm_mq_detach(pq_mq_handle);
+ pq_mq_handle = NULL;
+ return EOF;
+ }
+
+ /*
+ * If the message queue is already gone, just ignore the message. This
+ * doesn't necessarily indicate a problem; for example, DEBUG messages can
+ * be generated late in the shutdown sequence, after all DSMs have already
+ * been detached.
+ */
+ if (pq_mq_handle == NULL)
+ return 0;
+
+ pq_mq_busy = true;
+
+ iov[0].data = &msgtype;
+ iov[0].len = 1;
+ iov[1].data = s;
+ iov[1].len = len;
+
+ Assert(pq_mq_handle != NULL);
+
+ for (;;)
+ {
+ result = shm_mq_sendv(pq_mq_handle, iov, 2, true);
+
+ if (pq_mq_parallel_leader_pid != 0)
+ SendProcSignal(pq_mq_parallel_leader_pid,
+ PROCSIG_PARALLEL_MESSAGE,
+ pq_mq_parallel_leader_backend_id);
+
+ if (result != SHM_MQ_WOULD_BLOCK)
+ break;
+
+ (void) WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, 0,
+ WAIT_EVENT_MQ_PUT_MESSAGE);
+ ResetLatch(MyLatch);
+ CHECK_FOR_INTERRUPTS();
+ }
+
+ pq_mq_busy = false;
+
+ Assert(result == SHM_MQ_SUCCESS || result == SHM_MQ_DETACHED);
+ if (result != SHM_MQ_SUCCESS)
+ return EOF;
+ return 0;
+}
+
+static void
+mq_putmessage_noblock(char msgtype, const char *s, size_t len)
+{
+ /*
+ * While the shm_mq machinery does support sending a message in
+ * non-blocking mode, there's currently no way to try sending beginning to
+ * send the message that doesn't also commit us to completing the
+ * transmission. This could be improved in the future, but for now we
+ * don't need it.
+ */
+ elog(ERROR, "not currently supported");
+}
+
+/*
+ * Parse an ErrorResponse or NoticeResponse payload and populate an ErrorData
+ * structure with the results.
+ */
+void
+pq_parse_errornotice(StringInfo msg, ErrorData *edata)
+{
+ /* Initialize edata with reasonable defaults. */
+ MemSet(edata, 0, sizeof(ErrorData));
+ edata->elevel = ERROR;
+ edata->assoc_context = CurrentMemoryContext;
+
+ /* Loop over fields and extract each one. */
+ for (;;)
+ {
+ char code = pq_getmsgbyte(msg);
+ const char *value;
+
+ if (code == '\0')
+ {
+ pq_getmsgend(msg);
+ break;
+ }
+ value = pq_getmsgrawstring(msg);
+
+ switch (code)
+ {
+ case PG_DIAG_SEVERITY:
+ /* ignore, trusting we'll get a nonlocalized version */
+ break;
+ case PG_DIAG_SEVERITY_NONLOCALIZED:
+ if (strcmp(value, "DEBUG") == 0)
+ {
+ /*
+ * We can't reconstruct the exact DEBUG level, but
+ * presumably it was >= client_min_messages, so select
+ * DEBUG1 to ensure we'll pass it on to the client.
+ */
+ edata->elevel = DEBUG1;
+ }
+ else if (strcmp(value, "LOG") == 0)
+ {
+ /*
+ * It can't be LOG_SERVER_ONLY, or the worker wouldn't
+ * have sent it to us; so LOG is the correct value.
+ */
+ edata->elevel = LOG;
+ }
+ else if (strcmp(value, "INFO") == 0)
+ edata->elevel = INFO;
+ else if (strcmp(value, "NOTICE") == 0)
+ edata->elevel = NOTICE;
+ else if (strcmp(value, "WARNING") == 0)
+ edata->elevel = WARNING;
+ else if (strcmp(value, "ERROR") == 0)
+ edata->elevel = ERROR;
+ else if (strcmp(value, "FATAL") == 0)
+ edata->elevel = FATAL;
+ else if (strcmp(value, "PANIC") == 0)
+ edata->elevel = PANIC;
+ else
+ elog(ERROR, "unrecognized error severity: \"%s\"", value);
+ break;
+ case PG_DIAG_SQLSTATE:
+ if (strlen(value) != 5)
+ elog(ERROR, "invalid SQLSTATE: \"%s\"", value);
+ edata->sqlerrcode = MAKE_SQLSTATE(value[0], value[1], value[2],
+ value[3], value[4]);
+ break;
+ case PG_DIAG_MESSAGE_PRIMARY:
+ edata->message = pstrdup(value);
+ break;
+ case PG_DIAG_MESSAGE_DETAIL:
+ edata->detail = pstrdup(value);
+ break;
+ case PG_DIAG_MESSAGE_HINT:
+ edata->hint = pstrdup(value);
+ break;
+ case PG_DIAG_STATEMENT_POSITION:
+ edata->cursorpos = pg_strtoint32(value);
+ break;
+ case PG_DIAG_INTERNAL_POSITION:
+ edata->internalpos = pg_strtoint32(value);
+ break;
+ case PG_DIAG_INTERNAL_QUERY:
+ edata->internalquery = pstrdup(value);
+ break;
+ case PG_DIAG_CONTEXT:
+ edata->context = pstrdup(value);
+ break;
+ case PG_DIAG_SCHEMA_NAME:
+ edata->schema_name = pstrdup(value);
+ break;
+ case PG_DIAG_TABLE_NAME:
+ edata->table_name = pstrdup(value);
+ break;
+ case PG_DIAG_COLUMN_NAME:
+ edata->column_name = pstrdup(value);
+ break;
+ case PG_DIAG_DATATYPE_NAME:
+ edata->datatype_name = pstrdup(value);
+ break;
+ case PG_DIAG_CONSTRAINT_NAME:
+ edata->constraint_name = pstrdup(value);
+ break;
+ case PG_DIAG_SOURCE_FILE:
+ edata->filename = pstrdup(value);
+ break;
+ case PG_DIAG_SOURCE_LINE:
+ edata->lineno = pg_strtoint32(value);
+ break;
+ case PG_DIAG_SOURCE_FUNCTION:
+ edata->funcname = pstrdup(value);
+ break;
+ default:
+ elog(ERROR, "unrecognized error field code: %d", (int) code);
+ break;
+ }
+ }
+}
diff --git a/src/backend/libpq/pqsignal.c b/src/backend/libpq/pqsignal.c
new file mode 100644
index 0000000..dedf3a4
--- /dev/null
+++ b/src/backend/libpq/pqsignal.c
@@ -0,0 +1,148 @@
+/*-------------------------------------------------------------------------
+ *
+ * pqsignal.c
+ * Backend signal(2) support (see also src/port/pqsignal.c)
+ *
+ * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group
+ * Portions Copyright (c) 1994, Regents of the University of California
+ *
+ *
+ * IDENTIFICATION
+ * src/backend/libpq/pqsignal.c
+ *
+ * ------------------------------------------------------------------------
+ */
+
+#include "postgres.h"
+
+#include "libpq/pqsignal.h"
+
+
+/* Global variables */
+sigset_t UnBlockSig,
+ BlockSig,
+ StartupBlockSig;
+
+
+/*
+ * Initialize BlockSig, UnBlockSig, and StartupBlockSig.
+ *
+ * BlockSig is the set of signals to block when we are trying to block
+ * signals. This includes all signals we normally expect to get, but NOT
+ * signals that should never be turned off.
+ *
+ * StartupBlockSig is the set of signals to block during startup packet
+ * collection; it's essentially BlockSig minus SIGTERM, SIGQUIT, SIGALRM.
+ *
+ * UnBlockSig is the set of signals to block when we don't want to block
+ * signals.
+ */
+void
+pqinitmask(void)
+{
+ sigemptyset(&UnBlockSig);
+
+ /* Note: InitializeLatchSupport() modifies UnBlockSig. */
+
+ /* First set all signals, then clear some. */
+ sigfillset(&BlockSig);
+ sigfillset(&StartupBlockSig);
+
+ /*
+ * Unmark those signals that should never be blocked. Some of these signal
+ * names don't exist on all platforms. Most do, but might as well ifdef
+ * them all for consistency...
+ */
+#ifdef SIGTRAP
+ sigdelset(&BlockSig, SIGTRAP);
+ sigdelset(&StartupBlockSig, SIGTRAP);
+#endif
+#ifdef SIGABRT
+ sigdelset(&BlockSig, SIGABRT);
+ sigdelset(&StartupBlockSig, SIGABRT);
+#endif
+#ifdef SIGILL
+ sigdelset(&BlockSig, SIGILL);
+ sigdelset(&StartupBlockSig, SIGILL);
+#endif
+#ifdef SIGFPE
+ sigdelset(&BlockSig, SIGFPE);
+ sigdelset(&StartupBlockSig, SIGFPE);
+#endif
+#ifdef SIGSEGV
+ sigdelset(&BlockSig, SIGSEGV);
+ sigdelset(&StartupBlockSig, SIGSEGV);
+#endif
+#ifdef SIGBUS
+ sigdelset(&BlockSig, SIGBUS);
+ sigdelset(&StartupBlockSig, SIGBUS);
+#endif
+#ifdef SIGSYS
+ sigdelset(&BlockSig, SIGSYS);
+ sigdelset(&StartupBlockSig, SIGSYS);
+#endif
+#ifdef SIGCONT
+ sigdelset(&BlockSig, SIGCONT);
+ sigdelset(&StartupBlockSig, SIGCONT);
+#endif
+
+/* Signals unique to startup */
+#ifdef SIGQUIT
+ sigdelset(&StartupBlockSig, SIGQUIT);
+#endif
+#ifdef SIGTERM
+ sigdelset(&StartupBlockSig, SIGTERM);
+#endif
+#ifdef SIGALRM
+ sigdelset(&StartupBlockSig, SIGALRM);
+#endif
+}
+
+/*
+ * Set up a postmaster signal handler for signal "signo"
+ *
+ * Returns the previous handler.
+ *
+ * This is used only in the postmaster, which has its own odd approach to
+ * signal handling. For signals with handlers, we block all signals for the
+ * duration of signal handler execution. We also do not set the SA_RESTART
+ * flag; this should be safe given the tiny range of code in which the
+ * postmaster ever unblocks signals.
+ *
+ * pqinitmask() must have been invoked previously.
+ *
+ * On Windows, this function is just an alias for pqsignal()
+ * (and note that it's calling the code in src/backend/port/win32/signal.c,
+ * not src/port/pqsignal.c). On that platform, the postmaster's signal
+ * handlers still have to block signals for themselves.
+ */
+pqsigfunc
+pqsignal_pm(int signo, pqsigfunc func)
+{
+#ifndef WIN32
+ struct sigaction act,
+ oact;
+
+ act.sa_handler = func;
+ if (func == SIG_IGN || func == SIG_DFL)
+ {
+ /* in these cases, act the same as pqsignal() */
+ sigemptyset(&act.sa_mask);
+ act.sa_flags = SA_RESTART;
+ }
+ else
+ {
+ act.sa_mask = BlockSig;
+ act.sa_flags = 0;
+ }
+#ifdef SA_NOCLDSTOP
+ if (signo == SIGCHLD)
+ act.sa_flags |= SA_NOCLDSTOP;
+#endif
+ if (sigaction(signo, &act, &oact) < 0)
+ return SIG_ERR;
+ return oact.sa_handler;
+#else /* WIN32 */
+ return pqsignal(signo, func);
+#endif
+}