diff options
Diffstat (limited to 'src/backend/libpq')
-rw-r--r-- | src/backend/libpq/Makefile | 39 | ||||
-rw-r--r-- | src/backend/libpq/README.SSL | 82 | ||||
-rw-r--r-- | src/backend/libpq/auth-scram.c | 1445 | ||||
-rw-r--r-- | src/backend/libpq/auth.c | 3492 | ||||
-rw-r--r-- | src/backend/libpq/be-fsstubs.c | 860 | ||||
-rw-r--r-- | src/backend/libpq/be-gssapi-common.c | 94 | ||||
-rw-r--r-- | src/backend/libpq/be-secure-common.c | 195 | ||||
-rw-r--r-- | src/backend/libpq/be-secure-gssapi.c | 733 | ||||
-rw-r--r-- | src/backend/libpq/be-secure-openssl.c | 1526 | ||||
-rw-r--r-- | src/backend/libpq/be-secure.c | 345 | ||||
-rw-r--r-- | src/backend/libpq/crypt.c | 290 | ||||
-rw-r--r-- | src/backend/libpq/hba.c | 3166 | ||||
-rw-r--r-- | src/backend/libpq/ifaddr.c | 594 | ||||
-rw-r--r-- | src/backend/libpq/pg_hba.conf.sample | 94 | ||||
-rw-r--r-- | src/backend/libpq/pg_ident.conf.sample | 42 | ||||
-rw-r--r-- | src/backend/libpq/pqcomm.c | 1976 | ||||
-rw-r--r-- | src/backend/libpq/pqformat.c | 643 | ||||
-rw-r--r-- | src/backend/libpq/pqmq.c | 313 | ||||
-rw-r--r-- | src/backend/libpq/pqsignal.c | 148 |
19 files changed, 16077 insertions, 0 deletions
diff --git a/src/backend/libpq/Makefile b/src/backend/libpq/Makefile new file mode 100644 index 0000000..8d1d16b --- /dev/null +++ b/src/backend/libpq/Makefile @@ -0,0 +1,39 @@ +#------------------------------------------------------------------------- +# +# Makefile-- +# Makefile for libpq subsystem (backend half of libpq interface) +# +# IDENTIFICATION +# src/backend/libpq/Makefile +# +#------------------------------------------------------------------------- + +subdir = src/backend/libpq +top_builddir = ../../.. +include $(top_builddir)/src/Makefile.global + +# be-fsstubs is here for historical reasons, probably belongs elsewhere + +OBJS = \ + auth-scram.o \ + auth.o \ + be-fsstubs.o \ + be-secure-common.o \ + be-secure.o \ + crypt.o \ + hba.o \ + ifaddr.o \ + pqcomm.o \ + pqformat.o \ + pqmq.o \ + pqsignal.o + +ifeq ($(with_ssl),openssl) +OBJS += be-secure-openssl.o +endif + +ifeq ($(with_gssapi),yes) +OBJS += be-gssapi-common.o be-secure-gssapi.o +endif + +include $(top_srcdir)/src/backend/common.mk diff --git a/src/backend/libpq/README.SSL b/src/backend/libpq/README.SSL new file mode 100644 index 0000000..d84a434 --- /dev/null +++ b/src/backend/libpq/README.SSL @@ -0,0 +1,82 @@ +src/backend/libpq/README.SSL + +SSL +=== + +>From the servers perspective: + + + Receives StartupPacket + | + | + (Is SSL_NEGOTIATE_CODE?) ----------- Normal startup + | No + | + | Yes + | + | + (Server compiled with USE_SSL?) ------- Send 'N' + | No | + | | + | Yes Normal startup + | + | + Send 'S' + | + | + Establish SSL + | + | + Normal startup + + + + + +>From the clients perspective (v6.6 client _with_ SSL): + + + Connect + | + | + Send packet with SSL_NEGOTIATE_CODE + | + | + Receive single char ------- 'S' -------- Establish SSL + | | + | '<else>' | + | Normal startup + | + | + Is it 'E' for error ------------------- Retry connection + | Yes without SSL + | No + | + Is it 'N' for normal ------------------- Normal startup + | Yes + | + Fail with unknown + +--------------------------------------------------------------------------- + +Ephemeral DH +============ + +Since the server static private key ($DataDir/server.key) will +normally be stored unencrypted so that the database backend can +restart automatically, it is important that we select an algorithm +that continues to provide confidentiality even if the attacker has the +server's private key. Ephemeral DH (EDH) keys provide this and more +(Perfect Forward Secrecy aka PFS). + +N.B., the static private key should still be protected to the largest +extent possible, to minimize the risk of impersonations. + +Another benefit of EDH is that it allows the backend and clients to +use DSA keys. DSA keys can only provide digital signatures, not +encryption, and are often acceptable in jurisdictions where RSA keys +are unacceptable. + +The downside to EDH is that it makes it impossible to use ssldump(1) +if there's a problem establishing an SSL session. In this case you'll +need to temporarily disable EDH (see initialize_dh()). diff --git a/src/backend/libpq/auth-scram.c b/src/backend/libpq/auth-scram.c new file mode 100644 index 0000000..f9e1026 --- /dev/null +++ b/src/backend/libpq/auth-scram.c @@ -0,0 +1,1445 @@ +/*------------------------------------------------------------------------- + * + * auth-scram.c + * Server-side implementation of the SASL SCRAM-SHA-256 mechanism. + * + * See the following RFCs for more details: + * - RFC 5802: https://tools.ietf.org/html/rfc5802 + * - RFC 5803: https://tools.ietf.org/html/rfc5803 + * - RFC 7677: https://tools.ietf.org/html/rfc7677 + * + * Here are some differences: + * + * - Username from the authentication exchange is not used. The client + * should send an empty string as the username. + * + * - If the password isn't valid UTF-8, or contains characters prohibited + * by the SASLprep profile, we skip the SASLprep pre-processing and use + * the raw bytes in calculating the hash. + * + * - If channel binding is used, the channel binding type is always + * "tls-server-end-point". The spec says the default is "tls-unique" + * (RFC 5802, section 6.1. Default Channel Binding), but there are some + * problems with that. Firstly, not all SSL libraries provide an API to + * get the TLS Finished message, required to use "tls-unique". Secondly, + * "tls-unique" is not specified for TLS v1.3, and as of this writing, + * it's not clear if there will be a replacement. We could support both + * "tls-server-end-point" and "tls-unique", but for our use case, + * "tls-unique" doesn't really have any advantages. The main advantage + * of "tls-unique" would be that it works even if the server doesn't + * have a certificate, but PostgreSQL requires a server certificate + * whenever SSL is used, anyway. + * + * + * The password stored in pg_authid consists of the iteration count, salt, + * StoredKey and ServerKey. + * + * SASLprep usage + * -------------- + * + * One notable difference to the SCRAM specification is that while the + * specification dictates that the password is in UTF-8, and prohibits + * certain characters, we are more lenient. If the password isn't a valid + * UTF-8 string, or contains prohibited characters, the raw bytes are used + * to calculate the hash instead, without SASLprep processing. This is + * because PostgreSQL supports other encodings too, and the encoding being + * used during authentication is undefined (client_encoding isn't set until + * after authentication). In effect, we try to interpret the password as + * UTF-8 and apply SASLprep processing, but if it looks invalid, we assume + * that it's in some other encoding. + * + * In the worst case, we misinterpret a password that's in a different + * encoding as being Unicode, because it happens to consists entirely of + * valid UTF-8 bytes, and we apply Unicode normalization to it. As long + * as we do that consistently, that will not lead to failed logins. + * Fortunately, the UTF-8 byte sequences that are ignored by SASLprep + * don't correspond to any commonly used characters in any of the other + * supported encodings, so it should not lead to any significant loss in + * entropy, even if the normalization is incorrectly applied to a + * non-UTF-8 password. + * + * Error handling + * -------------- + * + * Don't reveal user information to an unauthenticated client. We don't + * want an attacker to be able to probe whether a particular username is + * valid. In SCRAM, the server has to read the salt and iteration count + * from the user's stored secret, and send it to the client. To avoid + * revealing whether a user exists, when the client tries to authenticate + * with a username that doesn't exist, or doesn't have a valid SCRAM + * secret in pg_authid, we create a fake salt and iteration count + * on-the-fly, and proceed with the authentication with that. In the end, + * we'll reject the attempt, as if an incorrect password was given. When + * we are performing a "mock" authentication, the 'doomed' flag in + * scram_state is set. + * + * In the error messages, avoid printing strings from the client, unless + * you check that they are pure ASCII. We don't want an unauthenticated + * attacker to be able to spam the logs with characters that are not valid + * to the encoding being used, whatever that is. We cannot avoid that in + * general, after logging in, but let's do what we can here. + * + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * src/backend/libpq/auth-scram.c + * + *------------------------------------------------------------------------- + */ +#include "postgres.h" + +#include <unistd.h> + +#include "access/xlog.h" +#include "catalog/pg_authid.h" +#include "catalog/pg_control.h" +#include "common/base64.h" +#include "common/hmac.h" +#include "common/saslprep.h" +#include "common/scram-common.h" +#include "common/sha2.h" +#include "libpq/auth.h" +#include "libpq/crypt.h" +#include "libpq/scram.h" +#include "miscadmin.h" +#include "utils/builtins.h" +#include "utils/timestamp.h" + +/* + * Status data for a SCRAM authentication exchange. This should be kept + * internal to this file. + */ +typedef enum +{ + SCRAM_AUTH_INIT, + SCRAM_AUTH_SALT_SENT, + SCRAM_AUTH_FINISHED +} scram_state_enum; + +typedef struct +{ + scram_state_enum state; + + const char *username; /* username from startup packet */ + + Port *port; + bool channel_binding_in_use; + + int iterations; + char *salt; /* base64-encoded */ + uint8 StoredKey[SCRAM_KEY_LEN]; + uint8 ServerKey[SCRAM_KEY_LEN]; + + /* Fields of the first message from client */ + char cbind_flag; + char *client_first_message_bare; + char *client_username; + char *client_nonce; + + /* Fields from the last message from client */ + char *client_final_message_without_proof; + char *client_final_nonce; + char ClientProof[SCRAM_KEY_LEN]; + + /* Fields generated in the server */ + char *server_first_message; + char *server_nonce; + + /* + * If something goes wrong during the authentication, or we are performing + * a "mock" authentication (see comments at top of file), the 'doomed' + * flag is set. A reason for the failure, for the server log, is put in + * 'logdetail'. + */ + bool doomed; + char *logdetail; +} scram_state; + +static void read_client_first_message(scram_state *state, const char *input); +static void read_client_final_message(scram_state *state, const char *input); +static char *build_server_first_message(scram_state *state); +static char *build_server_final_message(scram_state *state); +static bool verify_client_proof(scram_state *state); +static bool verify_final_nonce(scram_state *state); +static void mock_scram_secret(const char *username, int *iterations, + char **salt, uint8 *stored_key, uint8 *server_key); +static bool is_scram_printable(char *p); +static char *sanitize_char(char c); +static char *sanitize_str(const char *s); +static char *scram_mock_salt(const char *username); + +/* + * pg_be_scram_get_mechanisms + * + * Get a list of SASL mechanisms that this module supports. + * + * For the convenience of building the FE/BE packet that lists the + * mechanisms, the names are appended to the given StringInfo buffer, + * separated by '\0' bytes. + */ +void +pg_be_scram_get_mechanisms(Port *port, StringInfo buf) +{ + /* + * Advertise the mechanisms in decreasing order of importance. So the + * channel-binding variants go first, if they are supported. Channel + * binding is only supported with SSL, and only if the SSL implementation + * has a function to get the certificate's hash. + */ +#ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH + if (port->ssl_in_use) + { + appendStringInfoString(buf, SCRAM_SHA_256_PLUS_NAME); + appendStringInfoChar(buf, '\0'); + } +#endif + appendStringInfoString(buf, SCRAM_SHA_256_NAME); + appendStringInfoChar(buf, '\0'); +} + +/* + * pg_be_scram_init + * + * Initialize a new SCRAM authentication exchange status tracker. This + * needs to be called before doing any exchange. It will be filled later + * after the beginning of the exchange with authentication information. + * + * 'selected_mech' identifies the SASL mechanism that the client selected. + * It should be one of the mechanisms that we support, as returned by + * pg_be_scram_get_mechanisms(). + * + * 'shadow_pass' is the role's stored secret, from pg_authid.rolpassword. + * The username was provided by the client in the startup message, and is + * available in port->user_name. If 'shadow_pass' is NULL, we still perform + * an authentication exchange, but it will fail, as if an incorrect password + * was given. + */ +void * +pg_be_scram_init(Port *port, + const char *selected_mech, + const char *shadow_pass) +{ + scram_state *state; + bool got_secret; + + state = (scram_state *) palloc0(sizeof(scram_state)); + state->port = port; + state->state = SCRAM_AUTH_INIT; + + /* + * Parse the selected mechanism. + * + * Note that if we don't support channel binding, either because the SSL + * implementation doesn't support it or we're not using SSL at all, we + * would not have advertised the PLUS variant in the first place. If the + * client nevertheless tries to select it, it's a protocol violation like + * selecting any other SASL mechanism we don't support. + */ +#ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH + if (strcmp(selected_mech, SCRAM_SHA_256_PLUS_NAME) == 0 && port->ssl_in_use) + state->channel_binding_in_use = true; + else +#endif + if (strcmp(selected_mech, SCRAM_SHA_256_NAME) == 0) + state->channel_binding_in_use = false; + else + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("client selected an invalid SASL authentication mechanism"))); + + /* + * Parse the stored secret. + */ + if (shadow_pass) + { + int password_type = get_password_type(shadow_pass); + + if (password_type == PASSWORD_TYPE_SCRAM_SHA_256) + { + if (parse_scram_secret(shadow_pass, &state->iterations, &state->salt, + state->StoredKey, state->ServerKey)) + got_secret = true; + else + { + /* + * The password looked like a SCRAM secret, but could not be + * parsed. + */ + ereport(LOG, + (errmsg("invalid SCRAM secret for user \"%s\"", + state->port->user_name))); + got_secret = false; + } + } + else + { + /* + * The user doesn't have SCRAM secret. (You cannot do SCRAM + * authentication with an MD5 hash.) + */ + state->logdetail = psprintf(_("User \"%s\" does not have a valid SCRAM secret."), + state->port->user_name); + got_secret = false; + } + } + else + { + /* + * The caller requested us to perform a dummy authentication. This is + * considered normal, since the caller requested it, so don't set log + * detail. + */ + got_secret = false; + } + + /* + * If the user did not have a valid SCRAM secret, we still go through the + * motions with a mock one, and fail as if the client supplied an + * incorrect password. This is to avoid revealing information to an + * attacker. + */ + if (!got_secret) + { + mock_scram_secret(state->port->user_name, &state->iterations, + &state->salt, state->StoredKey, state->ServerKey); + state->doomed = true; + } + + return state; +} + +/* + * Continue a SCRAM authentication exchange. + * + * 'input' is the SCRAM payload sent by the client. On the first call, + * 'input' contains the "Initial Client Response" that the client sent as + * part of the SASLInitialResponse message, or NULL if no Initial Client + * Response was given. (The SASL specification distinguishes between an + * empty response and non-existing one.) On subsequent calls, 'input' + * cannot be NULL. For convenience in this function, the caller must + * ensure that there is a null terminator at input[inputlen]. + * + * The next message to send to client is saved in 'output', for a length + * of 'outputlen'. In the case of an error, optionally store a palloc'd + * string at *logdetail that will be sent to the postmaster log (but not + * the client). + */ +int +pg_be_scram_exchange(void *opaq, const char *input, int inputlen, + char **output, int *outputlen, char **logdetail) +{ + scram_state *state = (scram_state *) opaq; + int result; + + *output = NULL; + + /* + * If the client didn't include an "Initial Client Response" in the + * SASLInitialResponse message, send an empty challenge, to which the + * client will respond with the same data that usually comes in the + * Initial Client Response. + */ + if (input == NULL) + { + Assert(state->state == SCRAM_AUTH_INIT); + + *output = pstrdup(""); + *outputlen = 0; + return SASL_EXCHANGE_CONTINUE; + } + + /* + * Check that the input length agrees with the string length of the input. + * We can ignore inputlen after this. + */ + if (inputlen == 0) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("The message is empty."))); + if (inputlen != strlen(input)) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("Message length does not match input length."))); + + switch (state->state) + { + case SCRAM_AUTH_INIT: + + /* + * Initialization phase. Receive the first message from client + * and be sure that it parsed correctly. Then send the challenge + * to the client. + */ + read_client_first_message(state, input); + + /* prepare message to send challenge */ + *output = build_server_first_message(state); + + state->state = SCRAM_AUTH_SALT_SENT; + result = SASL_EXCHANGE_CONTINUE; + break; + + case SCRAM_AUTH_SALT_SENT: + + /* + * Final phase for the server. Receive the response to the + * challenge previously sent, verify, and let the client know that + * everything went well (or not). + */ + read_client_final_message(state, input); + + if (!verify_final_nonce(state)) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("invalid SCRAM response"), + errdetail("Nonce does not match."))); + + /* + * Now check the final nonce and the client proof. + * + * If we performed a "mock" authentication that we knew would fail + * from the get go, this is where we fail. + * + * The SCRAM specification includes an error code, + * "invalid-proof", for authentication failure, but it also allows + * erroring out in an application-specific way. We choose to do + * the latter, so that the error message for invalid password is + * the same for all authentication methods. The caller will call + * ereport(), when we return SASL_EXCHANGE_FAILURE with no output. + * + * NB: the order of these checks is intentional. We calculate the + * client proof even in a mock authentication, even though it's + * bound to fail, to thwart timing attacks to determine if a role + * with the given name exists or not. + */ + if (!verify_client_proof(state) || state->doomed) + { + result = SASL_EXCHANGE_FAILURE; + break; + } + + /* Build final message for client */ + *output = build_server_final_message(state); + + /* Success! */ + result = SASL_EXCHANGE_SUCCESS; + state->state = SCRAM_AUTH_FINISHED; + break; + + default: + elog(ERROR, "invalid SCRAM exchange state"); + result = SASL_EXCHANGE_FAILURE; + } + + if (result == SASL_EXCHANGE_FAILURE && state->logdetail && logdetail) + *logdetail = state->logdetail; + + if (*output) + *outputlen = strlen(*output); + + return result; +} + +/* + * Construct a SCRAM secret, for storing in pg_authid.rolpassword. + * + * The result is palloc'd, so caller is responsible for freeing it. + */ +char * +pg_be_scram_build_secret(const char *password) +{ + char *prep_password; + pg_saslprep_rc rc; + char saltbuf[SCRAM_DEFAULT_SALT_LEN]; + char *result; + + /* + * Normalize the password with SASLprep. If that doesn't work, because + * the password isn't valid UTF-8 or contains prohibited characters, just + * proceed with the original password. (See comments at top of file.) + */ + rc = pg_saslprep(password, &prep_password); + if (rc == SASLPREP_SUCCESS) + password = (const char *) prep_password; + + /* Generate random salt */ + if (!pg_strong_random(saltbuf, SCRAM_DEFAULT_SALT_LEN)) + ereport(ERROR, + (errcode(ERRCODE_INTERNAL_ERROR), + errmsg("could not generate random salt"))); + + result = scram_build_secret(saltbuf, SCRAM_DEFAULT_SALT_LEN, + SCRAM_DEFAULT_ITERATIONS, password); + + if (prep_password) + pfree(prep_password); + + return result; +} + +/* + * Verify a plaintext password against a SCRAM secret. This is used when + * performing plaintext password authentication for a user that has a SCRAM + * secret stored in pg_authid. + */ +bool +scram_verify_plain_password(const char *username, const char *password, + const char *secret) +{ + char *encoded_salt; + char *salt; + int saltlen; + int iterations; + uint8 salted_password[SCRAM_KEY_LEN]; + uint8 stored_key[SCRAM_KEY_LEN]; + uint8 server_key[SCRAM_KEY_LEN]; + uint8 computed_key[SCRAM_KEY_LEN]; + char *prep_password; + pg_saslprep_rc rc; + + if (!parse_scram_secret(secret, &iterations, &encoded_salt, + stored_key, server_key)) + { + /* + * The password looked like a SCRAM secret, but could not be parsed. + */ + ereport(LOG, + (errmsg("invalid SCRAM secret for user \"%s\"", username))); + return false; + } + + saltlen = pg_b64_dec_len(strlen(encoded_salt)); + salt = palloc(saltlen); + saltlen = pg_b64_decode(encoded_salt, strlen(encoded_salt), salt, + saltlen); + if (saltlen < 0) + { + ereport(LOG, + (errmsg("invalid SCRAM secret for user \"%s\"", username))); + return false; + } + + /* Normalize the password */ + rc = pg_saslprep(password, &prep_password); + if (rc == SASLPREP_SUCCESS) + password = prep_password; + + /* Compute Server Key based on the user-supplied plaintext password */ + if (scram_SaltedPassword(password, salt, saltlen, iterations, + salted_password) < 0 || + scram_ServerKey(salted_password, computed_key) < 0) + { + elog(ERROR, "could not compute server key"); + } + + if (prep_password) + pfree(prep_password); + + /* + * Compare the secret's Server Key with the one computed from the + * user-supplied password. + */ + return memcmp(computed_key, server_key, SCRAM_KEY_LEN) == 0; +} + + +/* + * Parse and validate format of given SCRAM secret. + * + * On success, the iteration count, salt, stored key, and server key are + * extracted from the secret, and returned to the caller. For 'stored_key' + * and 'server_key', the caller must pass pre-allocated buffers of size + * SCRAM_KEY_LEN. Salt is returned as a base64-encoded, null-terminated + * string. The buffer for the salt is palloc'd by this function. + * + * Returns true if the SCRAM secret has been parsed, and false otherwise. + */ +bool +parse_scram_secret(const char *secret, int *iterations, char **salt, + uint8 *stored_key, uint8 *server_key) +{ + char *v; + char *p; + char *scheme_str; + char *salt_str; + char *iterations_str; + char *storedkey_str; + char *serverkey_str; + int decoded_len; + char *decoded_salt_buf; + char *decoded_stored_buf; + char *decoded_server_buf; + + /* + * The secret is of form: + * + * SCRAM-SHA-256$<iterations>:<salt>$<storedkey>:<serverkey> + */ + v = pstrdup(secret); + if ((scheme_str = strtok(v, "$")) == NULL) + goto invalid_secret; + if ((iterations_str = strtok(NULL, ":")) == NULL) + goto invalid_secret; + if ((salt_str = strtok(NULL, "$")) == NULL) + goto invalid_secret; + if ((storedkey_str = strtok(NULL, ":")) == NULL) + goto invalid_secret; + if ((serverkey_str = strtok(NULL, "")) == NULL) + goto invalid_secret; + + /* Parse the fields */ + if (strcmp(scheme_str, "SCRAM-SHA-256") != 0) + goto invalid_secret; + + errno = 0; + *iterations = strtol(iterations_str, &p, 10); + if (*p || errno != 0) + goto invalid_secret; + + /* + * Verify that the salt is in Base64-encoded format, by decoding it, + * although we return the encoded version to the caller. + */ + decoded_len = pg_b64_dec_len(strlen(salt_str)); + decoded_salt_buf = palloc(decoded_len); + decoded_len = pg_b64_decode(salt_str, strlen(salt_str), + decoded_salt_buf, decoded_len); + if (decoded_len < 0) + goto invalid_secret; + *salt = pstrdup(salt_str); + + /* + * Decode StoredKey and ServerKey. + */ + decoded_len = pg_b64_dec_len(strlen(storedkey_str)); + decoded_stored_buf = palloc(decoded_len); + decoded_len = pg_b64_decode(storedkey_str, strlen(storedkey_str), + decoded_stored_buf, decoded_len); + if (decoded_len != SCRAM_KEY_LEN) + goto invalid_secret; + memcpy(stored_key, decoded_stored_buf, SCRAM_KEY_LEN); + + decoded_len = pg_b64_dec_len(strlen(serverkey_str)); + decoded_server_buf = palloc(decoded_len); + decoded_len = pg_b64_decode(serverkey_str, strlen(serverkey_str), + decoded_server_buf, decoded_len); + if (decoded_len != SCRAM_KEY_LEN) + goto invalid_secret; + memcpy(server_key, decoded_server_buf, SCRAM_KEY_LEN); + + return true; + +invalid_secret: + *salt = NULL; + return false; +} + +/* + * Generate plausible SCRAM secret parameters for mock authentication. + * + * In a normal authentication, these are extracted from the secret + * stored in the server. This function generates values that look + * realistic, for when there is no stored secret. + * + * Like in parse_scram_secret(), for 'stored_key' and 'server_key', the + * caller must pass pre-allocated buffers of size SCRAM_KEY_LEN, and + * the buffer for the salt is palloc'd by this function. + */ +static void +mock_scram_secret(const char *username, int *iterations, char **salt, + uint8 *stored_key, uint8 *server_key) +{ + char *raw_salt; + char *encoded_salt; + int encoded_len; + + /* + * Generate deterministic salt. + * + * Note that we cannot reveal any information to an attacker here so the + * error messages need to remain generic. This should never fail anyway + * as the salt generated for mock authentication uses the cluster's nonce + * value. + */ + raw_salt = scram_mock_salt(username); + if (raw_salt == NULL) + elog(ERROR, "could not encode salt"); + + encoded_len = pg_b64_enc_len(SCRAM_DEFAULT_SALT_LEN); + /* don't forget the zero-terminator */ + encoded_salt = (char *) palloc(encoded_len + 1); + encoded_len = pg_b64_encode(raw_salt, SCRAM_DEFAULT_SALT_LEN, encoded_salt, + encoded_len); + + if (encoded_len < 0) + elog(ERROR, "could not encode salt"); + encoded_salt[encoded_len] = '\0'; + + *salt = encoded_salt; + *iterations = SCRAM_DEFAULT_ITERATIONS; + + /* StoredKey and ServerKey are not used in a doomed authentication */ + memset(stored_key, 0, SCRAM_KEY_LEN); + memset(server_key, 0, SCRAM_KEY_LEN); +} + +/* + * Read the value in a given SCRAM exchange message for given attribute. + */ +static char * +read_attr_value(char **input, char attr) +{ + char *begin = *input; + char *end; + + if (*begin != attr) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("Expected attribute \"%c\" but found \"%s\".", + attr, sanitize_char(*begin)))); + begin++; + + if (*begin != '=') + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("Expected character \"=\" for attribute \"%c\".", attr))); + begin++; + + end = begin; + while (*end && *end != ',') + end++; + + if (*end) + { + *end = '\0'; + *input = end + 1; + } + else + *input = end; + + return begin; +} + +static bool +is_scram_printable(char *p) +{ + /*------ + * Printable characters, as defined by SCRAM spec: (RFC 5802) + * + * printable = %x21-2B / %x2D-7E + * ;; Printable ASCII except ",". + * ;; Note that any "printable" is also + * ;; a valid "value". + *------ + */ + for (; *p; p++) + { + if (*p < 0x21 || *p > 0x7E || *p == 0x2C /* comma */ ) + return false; + } + return true; +} + +/* + * Convert an arbitrary byte to printable form. For error messages. + * + * If it's a printable ASCII character, print it as a single character. + * otherwise, print it in hex. + * + * The returned pointer points to a static buffer. + */ +static char * +sanitize_char(char c) +{ + static char buf[5]; + + if (c >= 0x21 && c <= 0x7E) + snprintf(buf, sizeof(buf), "'%c'", c); + else + snprintf(buf, sizeof(buf), "0x%02x", (unsigned char) c); + return buf; +} + +/* + * Convert an arbitrary string to printable form, for error messages. + * + * Anything that's not a printable ASCII character is replaced with + * '?', and the string is truncated at 30 characters. + * + * The returned pointer points to a static buffer. + */ +static char * +sanitize_str(const char *s) +{ + static char buf[30 + 1]; + int i; + + for (i = 0; i < sizeof(buf) - 1; i++) + { + char c = s[i]; + + if (c == '\0') + break; + + if (c >= 0x21 && c <= 0x7E) + buf[i] = c; + else + buf[i] = '?'; + } + buf[i] = '\0'; + return buf; +} + +/* + * Read the next attribute and value in a SCRAM exchange message. + * + * The attribute character is set in *attr_p, the attribute value is the + * return value. + */ +static char * +read_any_attr(char **input, char *attr_p) +{ + char *begin = *input; + char *end; + char attr = *begin; + + if (attr == '\0') + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("Attribute expected, but found end of string."))); + + /*------ + * attr-val = ALPHA "=" value + * ;; Generic syntax of any attribute sent + * ;; by server or client + *------ + */ + if (!((attr >= 'A' && attr <= 'Z') || + (attr >= 'a' && attr <= 'z'))) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("Attribute expected, but found invalid character \"%s\".", + sanitize_char(attr)))); + if (attr_p) + *attr_p = attr; + begin++; + + if (*begin != '=') + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("Expected character \"=\" for attribute \"%c\".", attr))); + begin++; + + end = begin; + while (*end && *end != ',') + end++; + + if (*end) + { + *end = '\0'; + *input = end + 1; + } + else + *input = end; + + return begin; +} + +/* + * Read and parse the first message from client in the context of a SCRAM + * authentication exchange message. + * + * At this stage, any errors will be reported directly with ereport(ERROR). + */ +static void +read_client_first_message(scram_state *state, const char *input) +{ + char *p = pstrdup(input); + char *channel_binding_type; + + + /*------ + * The syntax for the client-first-message is: (RFC 5802) + * + * saslname = 1*(value-safe-char / "=2C" / "=3D") + * ;; Conforms to <value>. + * + * authzid = "a=" saslname + * ;; Protocol specific. + * + * cb-name = 1*(ALPHA / DIGIT / "." / "-") + * ;; See RFC 5056, Section 7. + * ;; E.g., "tls-server-end-point" or + * ;; "tls-unique". + * + * gs2-cbind-flag = ("p=" cb-name) / "n" / "y" + * ;; "n" -> client doesn't support channel binding. + * ;; "y" -> client does support channel binding + * ;; but thinks the server does not. + * ;; "p" -> client requires channel binding. + * ;; The selected channel binding follows "p=". + * + * gs2-header = gs2-cbind-flag "," [ authzid ] "," + * ;; GS2 header for SCRAM + * ;; (the actual GS2 header includes an optional + * ;; flag to indicate that the GSS mechanism is not + * ;; "standard", but since SCRAM is "standard", we + * ;; don't include that flag). + * + * username = "n=" saslname + * ;; Usernames are prepared using SASLprep. + * + * reserved-mext = "m=" 1*(value-char) + * ;; Reserved for signaling mandatory extensions. + * ;; The exact syntax will be defined in + * ;; the future. + * + * nonce = "r=" c-nonce [s-nonce] + * ;; Second part provided by server. + * + * c-nonce = printable + * + * client-first-message-bare = + * [reserved-mext ","] + * username "," nonce ["," extensions] + * + * client-first-message = + * gs2-header client-first-message-bare + * + * For example: + * n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL + * + * The "n,," in the beginning means that the client doesn't support + * channel binding, and no authzid is given. "n=user" is the username. + * However, in PostgreSQL the username is sent in the startup packet, and + * the username in the SCRAM exchange is ignored. libpq always sends it + * as an empty string. The last part, "r=fyko+d2lbbFgONRv9qkxdawL" is + * the client nonce. + *------ + */ + + /* + * Read gs2-cbind-flag. (For details see also RFC 5802 Section 6 "Channel + * Binding".) + */ + state->cbind_flag = *p; + switch (*p) + { + case 'n': + + /* + * The client does not support channel binding or has simply + * decided to not use it. In that case just let it go. + */ + if (state->channel_binding_in_use) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data."))); + + p++; + if (*p != ',') + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("Comma expected, but found character \"%s\".", + sanitize_char(*p)))); + p++; + break; + case 'y': + + /* + * The client supports channel binding and thinks that the server + * does not. In this case, the server must fail authentication if + * it supports channel binding. + */ + if (state->channel_binding_in_use) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("The client selected SCRAM-SHA-256-PLUS, but the SCRAM message does not include channel binding data."))); + +#ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH + if (state->port->ssl_in_use) + ereport(ERROR, + (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION), + errmsg("SCRAM channel binding negotiation error"), + errdetail("The client supports SCRAM channel binding but thinks the server does not. " + "However, this server does support channel binding."))); +#endif + p++; + if (*p != ',') + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("Comma expected, but found character \"%s\".", + sanitize_char(*p)))); + p++; + break; + case 'p': + + /* + * The client requires channel binding. Channel binding type + * follows, e.g., "p=tls-server-end-point". + */ + if (!state->channel_binding_in_use) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("The client selected SCRAM-SHA-256 without channel binding, but the SCRAM message includes channel binding data."))); + + channel_binding_type = read_attr_value(&p, 'p'); + + /* + * The only channel binding type we support is + * tls-server-end-point. + */ + if (strcmp(channel_binding_type, "tls-server-end-point") != 0) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("unsupported SCRAM channel-binding type \"%s\"", + sanitize_str(channel_binding_type)))); + break; + default: + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("Unexpected channel-binding flag \"%s\".", + sanitize_char(*p)))); + } + + /* + * Forbid optional authzid (authorization identity). We don't support it. + */ + if (*p == 'a') + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("client uses authorization identity, but it is not supported"))); + if (*p != ',') + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("Unexpected attribute \"%s\" in client-first-message.", + sanitize_char(*p)))); + p++; + + state->client_first_message_bare = pstrdup(p); + + /* + * Any mandatory extensions would go here. We don't support any. + * + * RFC 5802 specifies error code "e=extensions-not-supported" for this, + * but it can only be sent in the server-final message. We prefer to fail + * immediately (which the RFC also allows). + */ + if (*p == 'm') + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("client requires an unsupported SCRAM extension"))); + + /* + * Read username. Note: this is ignored. We use the username from the + * startup message instead, still it is kept around if provided as it + * proves to be useful for debugging purposes. + */ + state->client_username = read_attr_value(&p, 'n'); + + /* read nonce and check that it is made of only printable characters */ + state->client_nonce = read_attr_value(&p, 'r'); + if (!is_scram_printable(state->client_nonce)) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("non-printable characters in SCRAM nonce"))); + + /* + * There can be any number of optional extensions after this. We don't + * support any extensions, so ignore them. + */ + while (*p != '\0') + read_any_attr(&p, NULL); + + /* success! */ +} + +/* + * Verify the final nonce contained in the last message received from + * client in an exchange. + */ +static bool +verify_final_nonce(scram_state *state) +{ + int client_nonce_len = strlen(state->client_nonce); + int server_nonce_len = strlen(state->server_nonce); + int final_nonce_len = strlen(state->client_final_nonce); + + if (final_nonce_len != client_nonce_len + server_nonce_len) + return false; + if (memcmp(state->client_final_nonce, state->client_nonce, client_nonce_len) != 0) + return false; + if (memcmp(state->client_final_nonce + client_nonce_len, state->server_nonce, server_nonce_len) != 0) + return false; + + return true; +} + +/* + * Verify the client proof contained in the last message received from + * client in an exchange. Returns true if the verification is a success, + * or false for a failure. + */ +static bool +verify_client_proof(scram_state *state) +{ + uint8 ClientSignature[SCRAM_KEY_LEN]; + uint8 ClientKey[SCRAM_KEY_LEN]; + uint8 client_StoredKey[SCRAM_KEY_LEN]; + pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256); + int i; + + /* + * Calculate ClientSignature. Note that we don't log directly a failure + * here even when processing the calculations as this could involve a mock + * authentication. + */ + if (pg_hmac_init(ctx, state->StoredKey, SCRAM_KEY_LEN) < 0 || + pg_hmac_update(ctx, + (uint8 *) state->client_first_message_bare, + strlen(state->client_first_message_bare)) < 0 || + pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 || + pg_hmac_update(ctx, + (uint8 *) state->server_first_message, + strlen(state->server_first_message)) < 0 || + pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 || + pg_hmac_update(ctx, + (uint8 *) state->client_final_message_without_proof, + strlen(state->client_final_message_without_proof)) < 0 || + pg_hmac_final(ctx, ClientSignature, sizeof(ClientSignature)) < 0) + { + elog(ERROR, "could not calculate client signature"); + } + + pg_hmac_free(ctx); + + /* Extract the ClientKey that the client calculated from the proof */ + for (i = 0; i < SCRAM_KEY_LEN; i++) + ClientKey[i] = state->ClientProof[i] ^ ClientSignature[i]; + + /* Hash it one more time, and compare with StoredKey */ + if (scram_H(ClientKey, SCRAM_KEY_LEN, client_StoredKey) < 0) + elog(ERROR, "could not hash stored key"); + + if (memcmp(client_StoredKey, state->StoredKey, SCRAM_KEY_LEN) != 0) + return false; + + return true; +} + +/* + * Build the first server-side message sent to the client in a SCRAM + * communication exchange. + */ +static char * +build_server_first_message(scram_state *state) +{ + /*------ + * The syntax for the server-first-message is: (RFC 5802) + * + * server-first-message = + * [reserved-mext ","] nonce "," salt "," + * iteration-count ["," extensions] + * + * nonce = "r=" c-nonce [s-nonce] + * ;; Second part provided by server. + * + * c-nonce = printable + * + * s-nonce = printable + * + * salt = "s=" base64 + * + * iteration-count = "i=" posit-number + * ;; A positive number. + * + * Example: + * + * r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096 + *------ + */ + + /* + * Per the spec, the nonce may consist of any printable ASCII characters. + * For convenience, however, we don't use the whole range available, + * rather, we generate some random bytes, and base64 encode them. + */ + char raw_nonce[SCRAM_RAW_NONCE_LEN]; + int encoded_len; + + if (!pg_strong_random(raw_nonce, SCRAM_RAW_NONCE_LEN)) + ereport(ERROR, + (errcode(ERRCODE_INTERNAL_ERROR), + errmsg("could not generate random nonce"))); + + encoded_len = pg_b64_enc_len(SCRAM_RAW_NONCE_LEN); + /* don't forget the zero-terminator */ + state->server_nonce = palloc(encoded_len + 1); + encoded_len = pg_b64_encode(raw_nonce, SCRAM_RAW_NONCE_LEN, + state->server_nonce, encoded_len); + if (encoded_len < 0) + ereport(ERROR, + (errcode(ERRCODE_INTERNAL_ERROR), + errmsg("could not encode random nonce"))); + state->server_nonce[encoded_len] = '\0'; + + state->server_first_message = + psprintf("r=%s%s,s=%s,i=%u", + state->client_nonce, state->server_nonce, + state->salt, state->iterations); + + return pstrdup(state->server_first_message); +} + + +/* + * Read and parse the final message received from client. + */ +static void +read_client_final_message(scram_state *state, const char *input) +{ + char attr; + char *channel_binding; + char *value; + char *begin, + *proof; + char *p; + char *client_proof; + int client_proof_len; + + begin = p = pstrdup(input); + + /*------ + * The syntax for the server-first-message is: (RFC 5802) + * + * gs2-header = gs2-cbind-flag "," [ authzid ] "," + * ;; GS2 header for SCRAM + * ;; (the actual GS2 header includes an optional + * ;; flag to indicate that the GSS mechanism is not + * ;; "standard", but since SCRAM is "standard", we + * ;; don't include that flag). + * + * cbind-input = gs2-header [ cbind-data ] + * ;; cbind-data MUST be present for + * ;; gs2-cbind-flag of "p" and MUST be absent + * ;; for "y" or "n". + * + * channel-binding = "c=" base64 + * ;; base64 encoding of cbind-input. + * + * proof = "p=" base64 + * + * client-final-message-without-proof = + * channel-binding "," nonce ["," + * extensions] + * + * client-final-message = + * client-final-message-without-proof "," proof + *------ + */ + + /* + * Read channel binding. This repeats the channel-binding flags and is + * then followed by the actual binding data depending on the type. + */ + channel_binding = read_attr_value(&p, 'c'); + if (state->channel_binding_in_use) + { +#ifdef HAVE_BE_TLS_GET_CERTIFICATE_HASH + const char *cbind_data = NULL; + size_t cbind_data_len = 0; + size_t cbind_header_len; + char *cbind_input; + size_t cbind_input_len; + char *b64_message; + int b64_message_len; + + Assert(state->cbind_flag == 'p'); + + /* Fetch hash data of server's SSL certificate */ + cbind_data = be_tls_get_certificate_hash(state->port, + &cbind_data_len); + + /* should not happen */ + if (cbind_data == NULL || cbind_data_len == 0) + elog(ERROR, "could not get server certificate hash"); + + cbind_header_len = strlen("p=tls-server-end-point,,"); /* p=type,, */ + cbind_input_len = cbind_header_len + cbind_data_len; + cbind_input = palloc(cbind_input_len); + snprintf(cbind_input, cbind_input_len, "p=tls-server-end-point,,"); + memcpy(cbind_input + cbind_header_len, cbind_data, cbind_data_len); + + b64_message_len = pg_b64_enc_len(cbind_input_len); + /* don't forget the zero-terminator */ + b64_message = palloc(b64_message_len + 1); + b64_message_len = pg_b64_encode(cbind_input, cbind_input_len, + b64_message, b64_message_len); + if (b64_message_len < 0) + elog(ERROR, "could not encode channel binding data"); + b64_message[b64_message_len] = '\0'; + + /* + * Compare the value sent by the client with the value expected by the + * server. + */ + if (strcmp(channel_binding, b64_message) != 0) + ereport(ERROR, + (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION), + errmsg("SCRAM channel binding check failed"))); +#else + /* shouldn't happen, because we checked this earlier already */ + elog(ERROR, "channel binding not supported by this build"); +#endif + } + else + { + /* + * If we are not using channel binding, the binding data is expected + * to always be "biws", which is "n,," base64-encoded, or "eSws", + * which is "y,,". We also have to check whether the flag is the same + * one that the client originally sent. + */ + if (!(strcmp(channel_binding, "biws") == 0 && state->cbind_flag == 'n') && + !(strcmp(channel_binding, "eSws") == 0 && state->cbind_flag == 'y')) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("unexpected SCRAM channel-binding attribute in client-final-message"))); + } + + state->client_final_nonce = read_attr_value(&p, 'r'); + + /* ignore optional extensions, read until we find "p" attribute */ + do + { + proof = p - 1; + value = read_any_attr(&p, &attr); + } while (attr != 'p'); + + client_proof_len = pg_b64_dec_len(strlen(value)); + client_proof = palloc(client_proof_len); + if (pg_b64_decode(value, strlen(value), client_proof, + client_proof_len) != SCRAM_KEY_LEN) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("Malformed proof in client-final-message."))); + memcpy(state->ClientProof, client_proof, SCRAM_KEY_LEN); + pfree(client_proof); + + if (*p != '\0') + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("malformed SCRAM message"), + errdetail("Garbage found at the end of client-final-message."))); + + state->client_final_message_without_proof = palloc(proof - begin + 1); + memcpy(state->client_final_message_without_proof, input, proof - begin); + state->client_final_message_without_proof[proof - begin] = '\0'; +} + +/* + * Build the final server-side message of an exchange. + */ +static char * +build_server_final_message(scram_state *state) +{ + uint8 ServerSignature[SCRAM_KEY_LEN]; + char *server_signature_base64; + int siglen; + pg_hmac_ctx *ctx = pg_hmac_create(PG_SHA256); + + /* calculate ServerSignature */ + if (pg_hmac_init(ctx, state->ServerKey, SCRAM_KEY_LEN) < 0 || + pg_hmac_update(ctx, + (uint8 *) state->client_first_message_bare, + strlen(state->client_first_message_bare)) < 0 || + pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 || + pg_hmac_update(ctx, + (uint8 *) state->server_first_message, + strlen(state->server_first_message)) < 0 || + pg_hmac_update(ctx, (uint8 *) ",", 1) < 0 || + pg_hmac_update(ctx, + (uint8 *) state->client_final_message_without_proof, + strlen(state->client_final_message_without_proof)) < 0 || + pg_hmac_final(ctx, ServerSignature, sizeof(ServerSignature)) < 0) + { + elog(ERROR, "could not calculate server signature"); + } + + pg_hmac_free(ctx); + + siglen = pg_b64_enc_len(SCRAM_KEY_LEN); + /* don't forget the zero-terminator */ + server_signature_base64 = palloc(siglen + 1); + siglen = pg_b64_encode((const char *) ServerSignature, + SCRAM_KEY_LEN, server_signature_base64, + siglen); + if (siglen < 0) + elog(ERROR, "could not encode server signature"); + server_signature_base64[siglen] = '\0'; + + /*------ + * The syntax for the server-final-message is: (RFC 5802) + * + * verifier = "v=" base64 + * ;; base-64 encoded ServerSignature. + * + * server-final-message = (server-error / verifier) + * ["," extensions] + * + *------ + */ + return psprintf("v=%s", server_signature_base64); +} + + +/* + * Deterministically generate salt for mock authentication, using a SHA256 + * hash based on the username and a cluster-level secret key. Returns a + * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL. + */ +static char * +scram_mock_salt(const char *username) +{ + pg_cryptohash_ctx *ctx; + static uint8 sha_digest[PG_SHA256_DIGEST_LENGTH]; + char *mock_auth_nonce = GetMockAuthenticationNonce(); + + /* + * Generate salt using a SHA256 hash of the username and the cluster's + * mock authentication nonce. (This works as long as the salt length is + * not larger than the SHA256 digest length. If the salt is smaller, the + * caller will just ignore the extra data.) + */ + StaticAssertStmt(PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN, + "salt length greater than SHA256 digest length"); + + ctx = pg_cryptohash_create(PG_SHA256); + if (pg_cryptohash_init(ctx) < 0 || + pg_cryptohash_update(ctx, (uint8 *) username, strlen(username)) < 0 || + pg_cryptohash_update(ctx, (uint8 *) mock_auth_nonce, MOCK_AUTH_NONCE_LEN) < 0 || + pg_cryptohash_final(ctx, sha_digest, sizeof(sha_digest)) < 0) + { + pg_cryptohash_free(ctx); + return NULL; + } + pg_cryptohash_free(ctx); + + return (char *) sha_digest; +} diff --git a/src/backend/libpq/auth.c b/src/backend/libpq/auth.c new file mode 100644 index 0000000..be14f2f --- /dev/null +++ b/src/backend/libpq/auth.c @@ -0,0 +1,3492 @@ +/*------------------------------------------------------------------------- + * + * auth.c + * Routines to handle network authentication + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * + * IDENTIFICATION + * src/backend/libpq/auth.c + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" + +#include <sys/param.h> +#include <sys/socket.h> +#include <netinet/in.h> +#include <unistd.h> +#ifdef HAVE_SYS_SELECT_H +#include <sys/select.h> +#endif + +#include "commands/user.h" +#include "common/ip.h" +#include "common/md5.h" +#include "common/scram-common.h" +#include "libpq/auth.h" +#include "libpq/crypt.h" +#include "libpq/libpq.h" +#include "libpq/pqformat.h" +#include "libpq/scram.h" +#include "miscadmin.h" +#include "port/pg_bswap.h" +#include "postmaster/postmaster.h" +#include "replication/walsender.h" +#include "storage/ipc.h" +#include "utils/guc.h" +#include "utils/memutils.h" +#include "utils/timestamp.h" + +/*---------------------------------------------------------------- + * Global authentication functions + *---------------------------------------------------------------- + */ +static void sendAuthRequest(Port *port, AuthRequest areq, const char *extradata, + int extralen); +static void auth_failed(Port *port, int status, char *logdetail); +static char *recv_password_packet(Port *port); +static void set_authn_id(Port *port, const char *id); + + +/*---------------------------------------------------------------- + * Password-based authentication methods (password, md5, and scram-sha-256) + *---------------------------------------------------------------- + */ +static int CheckPasswordAuth(Port *port, char **logdetail); +static int CheckPWChallengeAuth(Port *port, char **logdetail); + +static int CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail); +static int CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail); + + +/*---------------------------------------------------------------- + * Ident authentication + *---------------------------------------------------------------- + */ +/* Max size of username ident server can return (per RFC 1413) */ +#define IDENT_USERNAME_MAX 512 + +/* Standard TCP port number for Ident service. Assigned by IANA */ +#define IDENT_PORT 113 + +static int ident_inet(hbaPort *port); + + +/*---------------------------------------------------------------- + * Peer authentication + *---------------------------------------------------------------- + */ +static int auth_peer(hbaPort *port); + + +/*---------------------------------------------------------------- + * PAM authentication + *---------------------------------------------------------------- + */ +#ifdef USE_PAM +#ifdef HAVE_PAM_PAM_APPL_H +#include <pam/pam_appl.h> +#endif +#ifdef HAVE_SECURITY_PAM_APPL_H +#include <security/pam_appl.h> +#endif + +#define PGSQL_PAM_SERVICE "postgresql" /* Service name passed to PAM */ + +static int CheckPAMAuth(Port *port, const char *user, const char *password); +static int pam_passwd_conv_proc(int num_msg, const struct pam_message **msg, + struct pam_response **resp, void *appdata_ptr); + +static struct pam_conv pam_passw_conv = { + &pam_passwd_conv_proc, + NULL +}; + +static const char *pam_passwd = NULL; /* Workaround for Solaris 2.6 + * brokenness */ +static Port *pam_port_cludge; /* Workaround for passing "Port *port" into + * pam_passwd_conv_proc */ +static bool pam_no_password; /* For detecting no-password-given */ +#endif /* USE_PAM */ + + +/*---------------------------------------------------------------- + * BSD authentication + *---------------------------------------------------------------- + */ +#ifdef USE_BSD_AUTH +#include <bsd_auth.h> + +static int CheckBSDAuth(Port *port, char *user); +#endif /* USE_BSD_AUTH */ + + +/*---------------------------------------------------------------- + * LDAP authentication + *---------------------------------------------------------------- + */ +#ifdef USE_LDAP +#ifndef WIN32 +/* We use a deprecated function to keep the codepath the same as win32. */ +#define LDAP_DEPRECATED 1 +#include <ldap.h> +#else +#include <winldap.h> + +/* Correct header from the Platform SDK */ +typedef +ULONG (*__ldap_start_tls_sA) (IN PLDAP ExternalHandle, + OUT PULONG ServerReturnValue, + OUT LDAPMessage **result, + IN PLDAPControlA * ServerControls, + IN PLDAPControlA * ClientControls +); +#endif + +static int CheckLDAPAuth(Port *port); + +/* LDAP_OPT_DIAGNOSTIC_MESSAGE is the newer spelling */ +#ifndef LDAP_OPT_DIAGNOSTIC_MESSAGE +#define LDAP_OPT_DIAGNOSTIC_MESSAGE LDAP_OPT_ERROR_STRING +#endif + +#endif /* USE_LDAP */ + +/*---------------------------------------------------------------- + * Cert authentication + *---------------------------------------------------------------- + */ +#ifdef USE_SSL +static int CheckCertAuth(Port *port); +#endif + + +/*---------------------------------------------------------------- + * Kerberos and GSSAPI GUCs + *---------------------------------------------------------------- + */ +char *pg_krb_server_keyfile; +bool pg_krb_caseins_users; + + +/*---------------------------------------------------------------- + * GSSAPI Authentication + *---------------------------------------------------------------- + */ +#ifdef ENABLE_GSS +#include "libpq/be-gssapi-common.h" + +static int pg_GSS_checkauth(Port *port); +static int pg_GSS_recvauth(Port *port); +#endif /* ENABLE_GSS */ + + +/*---------------------------------------------------------------- + * SSPI Authentication + *---------------------------------------------------------------- + */ +#ifdef ENABLE_SSPI +typedef SECURITY_STATUS + (WINAPI * QUERY_SECURITY_CONTEXT_TOKEN_FN) (PCtxtHandle, void **); +static int pg_SSPI_recvauth(Port *port); +static int pg_SSPI_make_upn(char *accountname, + size_t accountnamesize, + char *domainname, + size_t domainnamesize, + bool update_accountname); +#endif + +/*---------------------------------------------------------------- + * RADIUS Authentication + *---------------------------------------------------------------- + */ +static int CheckRADIUSAuth(Port *port); +static int PerformRadiusTransaction(const char *server, const char *secret, const char *portstr, const char *identifier, const char *user_name, const char *passwd); + + +/* + * Maximum accepted size of GSS and SSPI authentication tokens. + * We also use this as a limit on ordinary password packet lengths. + * + * Kerberos tickets are usually quite small, but the TGTs issued by Windows + * domain controllers include an authorization field known as the Privilege + * Attribute Certificate (PAC), which contains the user's Windows permissions + * (group memberships etc.). The PAC is copied into all tickets obtained on + * the basis of this TGT (even those issued by Unix realms which the Windows + * realm trusts), and can be several kB in size. The maximum token size + * accepted by Windows systems is determined by the MaxAuthToken Windows + * registry setting. Microsoft recommends that it is not set higher than + * 65535 bytes, so that seems like a reasonable limit for us as well. + */ +#define PG_MAX_AUTH_TOKEN_LENGTH 65535 + +/* + * Maximum accepted size of SASL messages. + * + * The messages that the server or libpq generate are much smaller than this, + * but have some headroom. + */ +#define PG_MAX_SASL_MESSAGE_LENGTH 1024 + +/*---------------------------------------------------------------- + * Global authentication functions + *---------------------------------------------------------------- + */ + +/* + * This hook allows plugins to get control following client authentication, + * but before the user has been informed about the results. It could be used + * to record login events, insert a delay after failed authentication, etc. + */ +ClientAuthentication_hook_type ClientAuthentication_hook = NULL; + +/* + * Tell the user the authentication failed, but not (much about) why. + * + * There is a tradeoff here between security concerns and making life + * unnecessarily difficult for legitimate users. We would not, for example, + * want to report the password we were expecting to receive... + * But it seems useful to report the username and authorization method + * in use, and these are items that must be presumed known to an attacker + * anyway. + * Note that many sorts of failure report additional information in the + * postmaster log, which we hope is only readable by good guys. In + * particular, if logdetail isn't NULL, we send that string to the log. + */ +static void +auth_failed(Port *port, int status, char *logdetail) +{ + const char *errstr; + char *cdetail; + int errcode_return = ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION; + + /* + * If we failed due to EOF from client, just quit; there's no point in + * trying to send a message to the client, and not much point in logging + * the failure in the postmaster log. (Logging the failure might be + * desirable, were it not for the fact that libpq closes the connection + * unceremoniously if challenged for a password when it hasn't got one to + * send. We'll get a useless log entry for every psql connection under + * password auth, even if it's perfectly successful, if we log STATUS_EOF + * events.) + */ + if (status == STATUS_EOF) + proc_exit(0); + + switch (port->hba->auth_method) + { + case uaReject: + case uaImplicitReject: + errstr = gettext_noop("authentication failed for user \"%s\": host rejected"); + break; + case uaTrust: + errstr = gettext_noop("\"trust\" authentication failed for user \"%s\""); + break; + case uaIdent: + errstr = gettext_noop("Ident authentication failed for user \"%s\""); + break; + case uaPeer: + errstr = gettext_noop("Peer authentication failed for user \"%s\""); + break; + case uaPassword: + case uaMD5: + case uaSCRAM: + errstr = gettext_noop("password authentication failed for user \"%s\""); + /* We use it to indicate if a .pgpass password failed. */ + errcode_return = ERRCODE_INVALID_PASSWORD; + break; + case uaGSS: + errstr = gettext_noop("GSSAPI authentication failed for user \"%s\""); + break; + case uaSSPI: + errstr = gettext_noop("SSPI authentication failed for user \"%s\""); + break; + case uaPAM: + errstr = gettext_noop("PAM authentication failed for user \"%s\""); + break; + case uaBSD: + errstr = gettext_noop("BSD authentication failed for user \"%s\""); + break; + case uaLDAP: + errstr = gettext_noop("LDAP authentication failed for user \"%s\""); + break; + case uaCert: + errstr = gettext_noop("certificate authentication failed for user \"%s\""); + break; + case uaRADIUS: + errstr = gettext_noop("RADIUS authentication failed for user \"%s\""); + break; + default: + errstr = gettext_noop("authentication failed for user \"%s\": invalid authentication method"); + break; + } + + cdetail = psprintf(_("Connection matched pg_hba.conf line %d: \"%s\""), + port->hba->linenumber, port->hba->rawline); + if (logdetail) + logdetail = psprintf("%s\n%s", logdetail, cdetail); + else + logdetail = cdetail; + + ereport(FATAL, + (errcode(errcode_return), + errmsg(errstr, port->user_name), + logdetail ? errdetail_log("%s", logdetail) : 0)); + + /* doesn't return */ +} + + +/* + * Sets the authenticated identity for the current user. The provided string + * will be copied into the TopMemoryContext. The ID will be logged if + * log_connections is enabled. + * + * Auth methods should call this routine exactly once, as soon as the user is + * successfully authenticated, even if they have reasons to know that + * authorization will fail later. + * + * The provided string will be copied into TopMemoryContext, to match the + * lifetime of the Port, so it is safe to pass a string that is managed by an + * external library. + */ +static void +set_authn_id(Port *port, const char *id) +{ + Assert(id); + + if (port->authn_id) + { + /* + * An existing authn_id should never be overwritten; that means two + * authentication providers are fighting (or one is fighting itself). + * Don't leak any authn details to the client, but don't let the + * connection continue, either. + */ + ereport(FATAL, + (errmsg("authentication identifier set more than once"), + errdetail_log("previous identifier: \"%s\"; new identifier: \"%s\"", + port->authn_id, id))); + } + + port->authn_id = MemoryContextStrdup(TopMemoryContext, id); + + if (Log_connections) + { + ereport(LOG, + errmsg("connection authenticated: identity=\"%s\" method=%s " + "(%s:%d)", + port->authn_id, hba_authname(port->hba->auth_method), HbaFileName, + port->hba->linenumber)); + } +} + + +/* + * Client authentication starts here. If there is an error, this + * function does not return and the backend process is terminated. + */ +void +ClientAuthentication(Port *port) +{ + int status = STATUS_ERROR; + char *logdetail = NULL; + + /* + * Get the authentication method to use for this frontend/database + * combination. Note: we do not parse the file at this point; this has + * already been done elsewhere. hba.c dropped an error message into the + * server logfile if parsing the hba config file failed. + */ + hba_getauthmethod(port); + + CHECK_FOR_INTERRUPTS(); + + /* + * This is the first point where we have access to the hba record for the + * current connection, so perform any verifications based on the hba + * options field that should be done *before* the authentication here. + */ + if (port->hba->clientcert != clientCertOff) + { + /* If we haven't loaded a root certificate store, fail */ + if (!secure_loaded_verify_locations()) + ereport(FATAL, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("client certificates can only be checked if a root certificate store is available"))); + + /* + * If we loaded a root certificate store, and if a certificate is + * present on the client, then it has been verified against our root + * certificate store, and the connection would have been aborted + * already if it didn't verify ok. + */ + if (!port->peer_cert_valid) + ereport(FATAL, + (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION), + errmsg("connection requires a valid client certificate"))); + } + + /* + * Now proceed to do the actual authentication check + */ + switch (port->hba->auth_method) + { + case uaReject: + + /* + * An explicit "reject" entry in pg_hba.conf. This report exposes + * the fact that there's an explicit reject entry, which is + * perhaps not so desirable from a security standpoint; but the + * message for an implicit reject could confuse the DBA a lot when + * the true situation is a match to an explicit reject. And we + * don't want to change the message for an implicit reject. As + * noted below, the additional information shown here doesn't + * expose anything not known to an attacker. + */ + { + char hostinfo[NI_MAXHOST]; + const char *encryption_state; + + pg_getnameinfo_all(&port->raddr.addr, port->raddr.salen, + hostinfo, sizeof(hostinfo), + NULL, 0, + NI_NUMERICHOST); + + encryption_state = +#ifdef ENABLE_GSS + (port->gss && port->gss->enc) ? _("GSS encryption") : +#endif +#ifdef USE_SSL + port->ssl_in_use ? _("SSL encryption") : +#endif + _("no encryption"); + + if (am_walsender && !am_db_walsender) + ereport(FATAL, + (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION), + /* translator: last %s describes encryption state */ + errmsg("pg_hba.conf rejects replication connection for host \"%s\", user \"%s\", %s", + hostinfo, port->user_name, + encryption_state))); + else + ereport(FATAL, + (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION), + /* translator: last %s describes encryption state */ + errmsg("pg_hba.conf rejects connection for host \"%s\", user \"%s\", database \"%s\", %s", + hostinfo, port->user_name, + port->database_name, + encryption_state))); + break; + } + + case uaImplicitReject: + + /* + * No matching entry, so tell the user we fell through. + * + * NOTE: the extra info reported here is not a security breach, + * because all that info is known at the frontend and must be + * assumed known to bad guys. We're merely helping out the less + * clueful good guys. + */ + { + char hostinfo[NI_MAXHOST]; + const char *encryption_state; + + pg_getnameinfo_all(&port->raddr.addr, port->raddr.salen, + hostinfo, sizeof(hostinfo), + NULL, 0, + NI_NUMERICHOST); + + encryption_state = +#ifdef ENABLE_GSS + (port->gss && port->gss->enc) ? _("GSS encryption") : +#endif +#ifdef USE_SSL + port->ssl_in_use ? _("SSL encryption") : +#endif + _("no encryption"); + +#define HOSTNAME_LOOKUP_DETAIL(port) \ + (port->remote_hostname ? \ + (port->remote_hostname_resolv == +1 ? \ + errdetail_log("Client IP address resolved to \"%s\", forward lookup matches.", \ + port->remote_hostname) : \ + port->remote_hostname_resolv == 0 ? \ + errdetail_log("Client IP address resolved to \"%s\", forward lookup not checked.", \ + port->remote_hostname) : \ + port->remote_hostname_resolv == -1 ? \ + errdetail_log("Client IP address resolved to \"%s\", forward lookup does not match.", \ + port->remote_hostname) : \ + port->remote_hostname_resolv == -2 ? \ + errdetail_log("Could not translate client host name \"%s\" to IP address: %s.", \ + port->remote_hostname, \ + gai_strerror(port->remote_hostname_errcode)) : \ + 0) \ + : (port->remote_hostname_resolv == -2 ? \ + errdetail_log("Could not resolve client IP address to a host name: %s.", \ + gai_strerror(port->remote_hostname_errcode)) : \ + 0)) + + if (am_walsender && !am_db_walsender) + ereport(FATAL, + (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION), + /* translator: last %s describes encryption state */ + errmsg("no pg_hba.conf entry for replication connection from host \"%s\", user \"%s\", %s", + hostinfo, port->user_name, + encryption_state), + HOSTNAME_LOOKUP_DETAIL(port))); + else + ereport(FATAL, + (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION), + /* translator: last %s describes encryption state */ + errmsg("no pg_hba.conf entry for host \"%s\", user \"%s\", database \"%s\", %s", + hostinfo, port->user_name, + port->database_name, + encryption_state), + HOSTNAME_LOOKUP_DETAIL(port))); + break; + } + + case uaGSS: +#ifdef ENABLE_GSS + /* We might or might not have the gss workspace already */ + if (port->gss == NULL) + port->gss = (pg_gssinfo *) + MemoryContextAllocZero(TopMemoryContext, + sizeof(pg_gssinfo)); + port->gss->auth = true; + + /* + * If GSS state was set up while enabling encryption, we can just + * check the client's principal. Otherwise, ask for it. + */ + if (port->gss->enc) + status = pg_GSS_checkauth(port); + else + { + sendAuthRequest(port, AUTH_REQ_GSS, NULL, 0); + status = pg_GSS_recvauth(port); + } +#else + Assert(false); +#endif + break; + + case uaSSPI: +#ifdef ENABLE_SSPI + if (port->gss == NULL) + port->gss = (pg_gssinfo *) + MemoryContextAllocZero(TopMemoryContext, + sizeof(pg_gssinfo)); + sendAuthRequest(port, AUTH_REQ_SSPI, NULL, 0); + status = pg_SSPI_recvauth(port); +#else + Assert(false); +#endif + break; + + case uaPeer: + status = auth_peer(port); + break; + + case uaIdent: + status = ident_inet(port); + break; + + case uaMD5: + case uaSCRAM: + status = CheckPWChallengeAuth(port, &logdetail); + break; + + case uaPassword: + status = CheckPasswordAuth(port, &logdetail); + break; + + case uaPAM: +#ifdef USE_PAM + status = CheckPAMAuth(port, port->user_name, ""); +#else + Assert(false); +#endif /* USE_PAM */ + break; + + case uaBSD: +#ifdef USE_BSD_AUTH + status = CheckBSDAuth(port, port->user_name); +#else + Assert(false); +#endif /* USE_BSD_AUTH */ + break; + + case uaLDAP: +#ifdef USE_LDAP + status = CheckLDAPAuth(port); +#else + Assert(false); +#endif + break; + case uaRADIUS: + status = CheckRADIUSAuth(port); + break; + case uaCert: + /* uaCert will be treated as if clientcert=verify-full (uaTrust) */ + case uaTrust: + status = STATUS_OK; + break; + } + + if ((status == STATUS_OK && port->hba->clientcert == clientCertFull) + || port->hba->auth_method == uaCert) + { + /* + * Make sure we only check the certificate if we use the cert method + * or verify-full option. + */ +#ifdef USE_SSL + status = CheckCertAuth(port); +#else + Assert(false); +#endif + } + + if (ClientAuthentication_hook) + (*ClientAuthentication_hook) (port, status); + + if (status == STATUS_OK) + sendAuthRequest(port, AUTH_REQ_OK, NULL, 0); + else + auth_failed(port, status, logdetail); +} + + +/* + * Send an authentication request packet to the frontend. + */ +static void +sendAuthRequest(Port *port, AuthRequest areq, const char *extradata, int extralen) +{ + StringInfoData buf; + + CHECK_FOR_INTERRUPTS(); + + pq_beginmessage(&buf, 'R'); + pq_sendint32(&buf, (int32) areq); + if (extralen > 0) + pq_sendbytes(&buf, extradata, extralen); + + pq_endmessage(&buf); + + /* + * Flush message so client will see it, except for AUTH_REQ_OK and + * AUTH_REQ_SASL_FIN, which need not be sent until we are ready for + * queries. + */ + if (areq != AUTH_REQ_OK && areq != AUTH_REQ_SASL_FIN) + pq_flush(); + + CHECK_FOR_INTERRUPTS(); +} + +/* + * Collect password response packet from frontend. + * + * Returns NULL if couldn't get password, else palloc'd string. + */ +static char * +recv_password_packet(Port *port) +{ + StringInfoData buf; + int mtype; + + pq_startmsgread(); + + /* Expect 'p' message type */ + mtype = pq_getbyte(); + if (mtype != 'p') + { + /* + * If the client just disconnects without offering a password, don't + * make a log entry. This is legal per protocol spec and in fact + * commonly done by psql, so complaining just clutters the log. + */ + if (mtype != EOF) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("expected password response, got message type %d", + mtype))); + return NULL; /* EOF or bad message type */ + } + + initStringInfo(&buf); + if (pq_getmessage(&buf, PG_MAX_AUTH_TOKEN_LENGTH)) /* receive password */ + { + /* EOF - pq_getmessage already logged a suitable message */ + pfree(buf.data); + return NULL; + } + + /* + * Apply sanity check: password packet length should agree with length of + * contained string. Note it is safe to use strlen here because + * StringInfo is guaranteed to have an appended '\0'. + */ + if (strlen(buf.data) + 1 != buf.len) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("invalid password packet size"))); + + /* + * Don't allow an empty password. Libpq treats an empty password the same + * as no password at all, and won't even try to authenticate. But other + * clients might, so allowing it would be confusing. + * + * Note that this only catches an empty password sent by the client in + * plaintext. There's also a check in CREATE/ALTER USER that prevents an + * empty string from being stored as a user's password in the first place. + * We rely on that for MD5 and SCRAM authentication, but we still need + * this check here, to prevent an empty password from being used with + * authentication methods that check the password against an external + * system, like PAM, LDAP and RADIUS. + */ + if (buf.len == 1) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PASSWORD), + errmsg("empty password returned by client"))); + + /* Do not echo password to logs, for security. */ + elog(DEBUG5, "received password packet"); + + /* + * Return the received string. Note we do not attempt to do any + * character-set conversion on it; since we don't yet know the client's + * encoding, there wouldn't be much point. + */ + return buf.data; +} + + +/*---------------------------------------------------------------- + * Password-based authentication mechanisms + *---------------------------------------------------------------- + */ + +/* + * Plaintext password authentication. + */ +static int +CheckPasswordAuth(Port *port, char **logdetail) +{ + char *passwd; + int result; + char *shadow_pass; + + sendAuthRequest(port, AUTH_REQ_PASSWORD, NULL, 0); + + passwd = recv_password_packet(port); + if (passwd == NULL) + return STATUS_EOF; /* client wouldn't send password */ + + shadow_pass = get_role_password(port->user_name, logdetail); + if (shadow_pass) + { + result = plain_crypt_verify(port->user_name, shadow_pass, passwd, + logdetail); + } + else + result = STATUS_ERROR; + + if (shadow_pass) + pfree(shadow_pass); + pfree(passwd); + + if (result == STATUS_OK) + set_authn_id(port, port->user_name); + + return result; +} + +/* + * MD5 and SCRAM authentication. + */ +static int +CheckPWChallengeAuth(Port *port, char **logdetail) +{ + int auth_result; + char *shadow_pass; + PasswordType pwtype; + + Assert(port->hba->auth_method == uaSCRAM || + port->hba->auth_method == uaMD5); + + /* First look up the user's password. */ + shadow_pass = get_role_password(port->user_name, logdetail); + + /* + * If the user does not exist, or has no password or it's expired, we + * still go through the motions of authentication, to avoid revealing to + * the client that the user didn't exist. If 'md5' is allowed, we choose + * whether to use 'md5' or 'scram-sha-256' authentication based on current + * password_encryption setting. The idea is that most genuine users + * probably have a password of that type, and if we pretend that this user + * had a password of that type, too, it "blends in" best. + */ + if (!shadow_pass) + pwtype = Password_encryption; + else + pwtype = get_password_type(shadow_pass); + + /* + * If 'md5' authentication is allowed, decide whether to perform 'md5' or + * 'scram-sha-256' authentication based on the type of password the user + * has. If it's an MD5 hash, we must do MD5 authentication, and if it's a + * SCRAM secret, we must do SCRAM authentication. + * + * If MD5 authentication is not allowed, always use SCRAM. If the user + * had an MD5 password, CheckSCRAMAuth() will fail. + */ + if (port->hba->auth_method == uaMD5 && pwtype == PASSWORD_TYPE_MD5) + auth_result = CheckMD5Auth(port, shadow_pass, logdetail); + else + auth_result = CheckSCRAMAuth(port, shadow_pass, logdetail); + + if (shadow_pass) + pfree(shadow_pass); + + /* + * If get_role_password() returned error, return error, even if the + * authentication succeeded. + */ + if (!shadow_pass) + { + Assert(auth_result != STATUS_OK); + return STATUS_ERROR; + } + + if (auth_result == STATUS_OK) + set_authn_id(port, port->user_name); + + return auth_result; +} + +static int +CheckMD5Auth(Port *port, char *shadow_pass, char **logdetail) +{ + char md5Salt[4]; /* Password salt */ + char *passwd; + int result; + + if (Db_user_namespace) + ereport(FATAL, + (errcode(ERRCODE_INVALID_AUTHORIZATION_SPECIFICATION), + errmsg("MD5 authentication is not supported when \"db_user_namespace\" is enabled"))); + + /* include the salt to use for computing the response */ + if (!pg_strong_random(md5Salt, 4)) + { + ereport(LOG, + (errmsg("could not generate random MD5 salt"))); + return STATUS_ERROR; + } + + sendAuthRequest(port, AUTH_REQ_MD5, md5Salt, 4); + + passwd = recv_password_packet(port); + if (passwd == NULL) + return STATUS_EOF; /* client wouldn't send password */ + + if (shadow_pass) + result = md5_crypt_verify(port->user_name, shadow_pass, passwd, + md5Salt, 4, logdetail); + else + result = STATUS_ERROR; + + pfree(passwd); + + return result; +} + +static int +CheckSCRAMAuth(Port *port, char *shadow_pass, char **logdetail) +{ + StringInfoData sasl_mechs; + int mtype; + StringInfoData buf; + void *scram_opaq = NULL; + char *output = NULL; + int outputlen = 0; + const char *input; + int inputlen; + int result; + bool initial; + + /* + * Send the SASL authentication request to user. It includes the list of + * authentication mechanisms that are supported. + */ + initStringInfo(&sasl_mechs); + + pg_be_scram_get_mechanisms(port, &sasl_mechs); + /* Put another '\0' to mark that list is finished. */ + appendStringInfoChar(&sasl_mechs, '\0'); + + sendAuthRequest(port, AUTH_REQ_SASL, sasl_mechs.data, sasl_mechs.len); + pfree(sasl_mechs.data); + + /* + * Loop through SASL message exchange. This exchange can consist of + * multiple messages sent in both directions. First message is always + * from the client. All messages from client to server are password + * packets (type 'p'). + */ + initial = true; + do + { + pq_startmsgread(); + mtype = pq_getbyte(); + if (mtype != 'p') + { + /* Only log error if client didn't disconnect. */ + if (mtype != EOF) + { + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("expected SASL response, got message type %d", + mtype))); + } + else + return STATUS_EOF; + } + + /* Get the actual SASL message */ + initStringInfo(&buf); + if (pq_getmessage(&buf, PG_MAX_SASL_MESSAGE_LENGTH)) + { + /* EOF - pq_getmessage already logged error */ + pfree(buf.data); + return STATUS_ERROR; + } + + elog(DEBUG4, "processing received SASL response of length %d", buf.len); + + /* + * The first SASLInitialResponse message is different from the others. + * It indicates which SASL mechanism the client selected, and contains + * an optional Initial Client Response payload. The subsequent + * SASLResponse messages contain just the SASL payload. + */ + if (initial) + { + const char *selected_mech; + + selected_mech = pq_getmsgrawstring(&buf); + + /* + * Initialize the status tracker for message exchanges. + * + * If the user doesn't exist, or doesn't have a valid password, or + * it's expired, we still go through the motions of SASL + * authentication, but tell the authentication method that the + * authentication is "doomed". That is, it's going to fail, no + * matter what. + * + * This is because we don't want to reveal to an attacker what + * usernames are valid, nor which users have a valid password. + */ + scram_opaq = pg_be_scram_init(port, selected_mech, shadow_pass); + + inputlen = pq_getmsgint(&buf, 4); + if (inputlen == -1) + input = NULL; + else + input = pq_getmsgbytes(&buf, inputlen); + + initial = false; + } + else + { + inputlen = buf.len; + input = pq_getmsgbytes(&buf, buf.len); + } + pq_getmsgend(&buf); + + /* + * The StringInfo guarantees that there's a \0 byte after the + * response. + */ + Assert(input == NULL || input[inputlen] == '\0'); + + /* + * we pass 'logdetail' as NULL when doing a mock authentication, + * because we should already have a better error message in that case + */ + result = pg_be_scram_exchange(scram_opaq, input, inputlen, + &output, &outputlen, + logdetail); + + /* input buffer no longer used */ + pfree(buf.data); + + if (output) + { + /* + * Negotiation generated data to be sent to the client. + */ + elog(DEBUG4, "sending SASL challenge of length %u", outputlen); + + if (result == SASL_EXCHANGE_SUCCESS) + sendAuthRequest(port, AUTH_REQ_SASL_FIN, output, outputlen); + else + sendAuthRequest(port, AUTH_REQ_SASL_CONT, output, outputlen); + + pfree(output); + } + } while (result == SASL_EXCHANGE_CONTINUE); + + /* Oops, Something bad happened */ + if (result != SASL_EXCHANGE_SUCCESS) + { + return STATUS_ERROR; + } + + return STATUS_OK; +} + + +/*---------------------------------------------------------------- + * GSSAPI authentication system + *---------------------------------------------------------------- + */ +#ifdef ENABLE_GSS +static int +pg_GSS_recvauth(Port *port) +{ + OM_uint32 maj_stat, + min_stat, + lmin_s, + gflags; + int mtype; + StringInfoData buf; + gss_buffer_desc gbuf; + + /* + * Use the configured keytab, if there is one. Unfortunately, Heimdal + * doesn't support the cred store extensions, so use the env var. + */ + if (pg_krb_server_keyfile != NULL && pg_krb_server_keyfile[0] != '\0') + { + if (setenv("KRB5_KTNAME", pg_krb_server_keyfile, 1) != 0) + { + /* The only likely failure cause is OOM, so use that errcode */ + ereport(FATAL, + (errcode(ERRCODE_OUT_OF_MEMORY), + errmsg("could not set environment: %m"))); + } + } + + /* + * We accept any service principal that's present in our keytab. This + * increases interoperability between kerberos implementations that see + * for example case sensitivity differently, while not really opening up + * any vector of attack. + */ + port->gss->cred = GSS_C_NO_CREDENTIAL; + + /* + * Initialize sequence with an empty context + */ + port->gss->ctx = GSS_C_NO_CONTEXT; + + /* + * Loop through GSSAPI message exchange. This exchange can consist of + * multiple messages sent in both directions. First message is always from + * the client. All messages from client to server are password packets + * (type 'p'). + */ + do + { + pq_startmsgread(); + + CHECK_FOR_INTERRUPTS(); + + mtype = pq_getbyte(); + if (mtype != 'p') + { + /* Only log error if client didn't disconnect. */ + if (mtype != EOF) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("expected GSS response, got message type %d", + mtype))); + return STATUS_ERROR; + } + + /* Get the actual GSS token */ + initStringInfo(&buf); + if (pq_getmessage(&buf, PG_MAX_AUTH_TOKEN_LENGTH)) + { + /* EOF - pq_getmessage already logged error */ + pfree(buf.data); + return STATUS_ERROR; + } + + /* Map to GSSAPI style buffer */ + gbuf.length = buf.len; + gbuf.value = buf.data; + + elog(DEBUG4, "processing received GSS token of length %u", + (unsigned int) gbuf.length); + + maj_stat = gss_accept_sec_context(&min_stat, + &port->gss->ctx, + port->gss->cred, + &gbuf, + GSS_C_NO_CHANNEL_BINDINGS, + &port->gss->name, + NULL, + &port->gss->outbuf, + &gflags, + NULL, + NULL); + + /* gbuf no longer used */ + pfree(buf.data); + + elog(DEBUG5, "gss_accept_sec_context major: %d, " + "minor: %d, outlen: %u, outflags: %x", + maj_stat, min_stat, + (unsigned int) port->gss->outbuf.length, gflags); + + CHECK_FOR_INTERRUPTS(); + + if (port->gss->outbuf.length != 0) + { + /* + * Negotiation generated data to be sent to the client. + */ + elog(DEBUG4, "sending GSS response token of length %u", + (unsigned int) port->gss->outbuf.length); + + sendAuthRequest(port, AUTH_REQ_GSS_CONT, + port->gss->outbuf.value, port->gss->outbuf.length); + + gss_release_buffer(&lmin_s, &port->gss->outbuf); + } + + if (maj_stat != GSS_S_COMPLETE && maj_stat != GSS_S_CONTINUE_NEEDED) + { + gss_delete_sec_context(&lmin_s, &port->gss->ctx, GSS_C_NO_BUFFER); + pg_GSS_error(_("accepting GSS security context failed"), + maj_stat, min_stat); + return STATUS_ERROR; + } + + if (maj_stat == GSS_S_CONTINUE_NEEDED) + elog(DEBUG4, "GSS continue needed"); + + } while (maj_stat == GSS_S_CONTINUE_NEEDED); + + if (port->gss->cred != GSS_C_NO_CREDENTIAL) + { + /* + * Release service principal credentials + */ + gss_release_cred(&min_stat, &port->gss->cred); + } + return pg_GSS_checkauth(port); +} + +/* + * Check whether the GSSAPI-authenticated user is allowed to connect as the + * claimed username. + */ +static int +pg_GSS_checkauth(Port *port) +{ + int ret; + OM_uint32 maj_stat, + min_stat, + lmin_s; + gss_buffer_desc gbuf; + char *princ; + + /* + * Get the name of the user that authenticated, and compare it to the pg + * username that was specified for the connection. + */ + maj_stat = gss_display_name(&min_stat, port->gss->name, &gbuf, NULL); + if (maj_stat != GSS_S_COMPLETE) + { + pg_GSS_error(_("retrieving GSS user name failed"), + maj_stat, min_stat); + return STATUS_ERROR; + } + + /* + * gbuf.value might not be null-terminated, so turn it into a regular + * null-terminated string. + */ + princ = palloc(gbuf.length + 1); + memcpy(princ, gbuf.value, gbuf.length); + princ[gbuf.length] = '\0'; + gss_release_buffer(&lmin_s, &gbuf); + + /* + * Copy the original name of the authenticated principal into our backend + * memory for display later. + * + * This is also our authenticated identity. Set it now, rather than + * waiting for the usermap check below, because authentication has already + * succeeded and we want the log file to reflect that. + */ + port->gss->princ = MemoryContextStrdup(TopMemoryContext, princ); + set_authn_id(port, princ); + + /* + * Split the username at the realm separator + */ + if (strchr(princ, '@')) + { + char *cp = strchr(princ, '@'); + + /* + * If we are not going to include the realm in the username that is + * passed to the ident map, destructively modify it here to remove the + * realm. Then advance past the separator to check the realm. + */ + if (!port->hba->include_realm) + *cp = '\0'; + cp++; + + if (port->hba->krb_realm != NULL && strlen(port->hba->krb_realm)) + { + /* + * Match the realm part of the name first + */ + if (pg_krb_caseins_users) + ret = pg_strcasecmp(port->hba->krb_realm, cp); + else + ret = strcmp(port->hba->krb_realm, cp); + + if (ret) + { + /* GSS realm does not match */ + elog(DEBUG2, + "GSSAPI realm (%s) and configured realm (%s) don't match", + cp, port->hba->krb_realm); + pfree(princ); + return STATUS_ERROR; + } + } + } + else if (port->hba->krb_realm && strlen(port->hba->krb_realm)) + { + elog(DEBUG2, + "GSSAPI did not return realm but realm matching was requested"); + pfree(princ); + return STATUS_ERROR; + } + + ret = check_usermap(port->hba->usermap, port->user_name, princ, + pg_krb_caseins_users); + + pfree(princ); + + return ret; +} +#endif /* ENABLE_GSS */ + + +/*---------------------------------------------------------------- + * SSPI authentication system + *---------------------------------------------------------------- + */ +#ifdef ENABLE_SSPI + +/* + * Generate an error for SSPI authentication. The caller should apply + * _() to errmsg to make it translatable. + */ +static void +pg_SSPI_error(int severity, const char *errmsg, SECURITY_STATUS r) +{ + char sysmsg[256]; + + if (FormatMessage(FORMAT_MESSAGE_IGNORE_INSERTS | + FORMAT_MESSAGE_FROM_SYSTEM, + NULL, r, 0, + sysmsg, sizeof(sysmsg), NULL) == 0) + ereport(severity, + (errmsg_internal("%s", errmsg), + errdetail_internal("SSPI error %x", (unsigned int) r))); + else + ereport(severity, + (errmsg_internal("%s", errmsg), + errdetail_internal("%s (%x)", sysmsg, (unsigned int) r))); +} + +static int +pg_SSPI_recvauth(Port *port) +{ + int mtype; + StringInfoData buf; + SECURITY_STATUS r; + CredHandle sspicred; + CtxtHandle *sspictx = NULL, + newctx; + TimeStamp expiry; + ULONG contextattr; + SecBufferDesc inbuf; + SecBufferDesc outbuf; + SecBuffer OutBuffers[1]; + SecBuffer InBuffers[1]; + HANDLE token; + TOKEN_USER *tokenuser; + DWORD retlen; + char accountname[MAXPGPATH]; + char domainname[MAXPGPATH]; + DWORD accountnamesize = sizeof(accountname); + DWORD domainnamesize = sizeof(domainname); + SID_NAME_USE accountnameuse; + HMODULE secur32; + char *authn_id; + + QUERY_SECURITY_CONTEXT_TOKEN_FN _QuerySecurityContextToken; + + /* + * Acquire a handle to the server credentials. + */ + r = AcquireCredentialsHandle(NULL, + "negotiate", + SECPKG_CRED_INBOUND, + NULL, + NULL, + NULL, + NULL, + &sspicred, + &expiry); + if (r != SEC_E_OK) + pg_SSPI_error(ERROR, _("could not acquire SSPI credentials"), r); + + /* + * Loop through SSPI message exchange. This exchange can consist of + * multiple messages sent in both directions. First message is always from + * the client. All messages from client to server are password packets + * (type 'p'). + */ + do + { + pq_startmsgread(); + mtype = pq_getbyte(); + if (mtype != 'p') + { + if (sspictx != NULL) + { + DeleteSecurityContext(sspictx); + free(sspictx); + } + FreeCredentialsHandle(&sspicred); + + /* Only log error if client didn't disconnect. */ + if (mtype != EOF) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("expected SSPI response, got message type %d", + mtype))); + return STATUS_ERROR; + } + + /* Get the actual SSPI token */ + initStringInfo(&buf); + if (pq_getmessage(&buf, PG_MAX_AUTH_TOKEN_LENGTH)) + { + /* EOF - pq_getmessage already logged error */ + pfree(buf.data); + if (sspictx != NULL) + { + DeleteSecurityContext(sspictx); + free(sspictx); + } + FreeCredentialsHandle(&sspicred); + return STATUS_ERROR; + } + + /* Map to SSPI style buffer */ + inbuf.ulVersion = SECBUFFER_VERSION; + inbuf.cBuffers = 1; + inbuf.pBuffers = InBuffers; + InBuffers[0].pvBuffer = buf.data; + InBuffers[0].cbBuffer = buf.len; + InBuffers[0].BufferType = SECBUFFER_TOKEN; + + /* Prepare output buffer */ + OutBuffers[0].pvBuffer = NULL; + OutBuffers[0].BufferType = SECBUFFER_TOKEN; + OutBuffers[0].cbBuffer = 0; + outbuf.cBuffers = 1; + outbuf.pBuffers = OutBuffers; + outbuf.ulVersion = SECBUFFER_VERSION; + + elog(DEBUG4, "processing received SSPI token of length %u", + (unsigned int) buf.len); + + r = AcceptSecurityContext(&sspicred, + sspictx, + &inbuf, + ASC_REQ_ALLOCATE_MEMORY, + SECURITY_NETWORK_DREP, + &newctx, + &outbuf, + &contextattr, + NULL); + + /* input buffer no longer used */ + pfree(buf.data); + + if (outbuf.cBuffers > 0 && outbuf.pBuffers[0].cbBuffer > 0) + { + /* + * Negotiation generated data to be sent to the client. + */ + elog(DEBUG4, "sending SSPI response token of length %u", + (unsigned int) outbuf.pBuffers[0].cbBuffer); + + port->gss->outbuf.length = outbuf.pBuffers[0].cbBuffer; + port->gss->outbuf.value = outbuf.pBuffers[0].pvBuffer; + + sendAuthRequest(port, AUTH_REQ_GSS_CONT, + port->gss->outbuf.value, port->gss->outbuf.length); + + FreeContextBuffer(outbuf.pBuffers[0].pvBuffer); + } + + if (r != SEC_E_OK && r != SEC_I_CONTINUE_NEEDED) + { + if (sspictx != NULL) + { + DeleteSecurityContext(sspictx); + free(sspictx); + } + FreeCredentialsHandle(&sspicred); + pg_SSPI_error(ERROR, + _("could not accept SSPI security context"), r); + } + + /* + * Overwrite the current context with the one we just received. If + * sspictx is NULL it was the first loop and we need to allocate a + * buffer for it. On subsequent runs, we can just overwrite the buffer + * contents since the size does not change. + */ + if (sspictx == NULL) + { + sspictx = malloc(sizeof(CtxtHandle)); + if (sspictx == NULL) + ereport(ERROR, + (errmsg("out of memory"))); + } + + memcpy(sspictx, &newctx, sizeof(CtxtHandle)); + + if (r == SEC_I_CONTINUE_NEEDED) + elog(DEBUG4, "SSPI continue needed"); + + } while (r == SEC_I_CONTINUE_NEEDED); + + + /* + * Release service principal credentials + */ + FreeCredentialsHandle(&sspicred); + + + /* + * SEC_E_OK indicates that authentication is now complete. + * + * Get the name of the user that authenticated, and compare it to the pg + * username that was specified for the connection. + * + * MingW is missing the export for QuerySecurityContextToken in the + * secur32 library, so we have to load it dynamically. + */ + + secur32 = LoadLibrary("SECUR32.DLL"); + if (secur32 == NULL) + ereport(ERROR, + (errmsg("could not load library \"%s\": error code %lu", + "SECUR32.DLL", GetLastError()))); + + _QuerySecurityContextToken = (QUERY_SECURITY_CONTEXT_TOKEN_FN) (pg_funcptr_t) + GetProcAddress(secur32, "QuerySecurityContextToken"); + if (_QuerySecurityContextToken == NULL) + { + FreeLibrary(secur32); + ereport(ERROR, + (errmsg_internal("could not locate QuerySecurityContextToken in secur32.dll: error code %lu", + GetLastError()))); + } + + r = (_QuerySecurityContextToken) (sspictx, &token); + if (r != SEC_E_OK) + { + FreeLibrary(secur32); + pg_SSPI_error(ERROR, + _("could not get token from SSPI security context"), r); + } + + FreeLibrary(secur32); + + /* + * No longer need the security context, everything from here on uses the + * token instead. + */ + DeleteSecurityContext(sspictx); + free(sspictx); + + if (!GetTokenInformation(token, TokenUser, NULL, 0, &retlen) && GetLastError() != 122) + ereport(ERROR, + (errmsg_internal("could not get token information buffer size: error code %lu", + GetLastError()))); + + tokenuser = malloc(retlen); + if (tokenuser == NULL) + ereport(ERROR, + (errmsg("out of memory"))); + + if (!GetTokenInformation(token, TokenUser, tokenuser, retlen, &retlen)) + ereport(ERROR, + (errmsg_internal("could not get token information: error code %lu", + GetLastError()))); + + CloseHandle(token); + + if (!LookupAccountSid(NULL, tokenuser->User.Sid, accountname, &accountnamesize, + domainname, &domainnamesize, &accountnameuse)) + ereport(ERROR, + (errmsg_internal("could not look up account SID: error code %lu", + GetLastError()))); + + free(tokenuser); + + if (!port->hba->compat_realm) + { + int status = pg_SSPI_make_upn(accountname, sizeof(accountname), + domainname, sizeof(domainname), + port->hba->upn_username); + + if (status != STATUS_OK) + /* Error already reported from pg_SSPI_make_upn */ + return status; + } + + /* + * We have all of the information necessary to construct the authenticated + * identity. Set it now, rather than waiting for check_usermap below, + * because authentication has already succeeded and we want the log file + * to reflect that. + */ + if (port->hba->compat_realm) + { + /* SAM-compatible format. */ + authn_id = psprintf("%s\\%s", domainname, accountname); + } + else + { + /* Kerberos principal format. */ + authn_id = psprintf("%s@%s", accountname, domainname); + } + + set_authn_id(port, authn_id); + pfree(authn_id); + + /* + * Compare realm/domain if requested. In SSPI, always compare case + * insensitive. + */ + if (port->hba->krb_realm && strlen(port->hba->krb_realm)) + { + if (pg_strcasecmp(port->hba->krb_realm, domainname) != 0) + { + elog(DEBUG2, + "SSPI domain (%s) and configured domain (%s) don't match", + domainname, port->hba->krb_realm); + + return STATUS_ERROR; + } + } + + /* + * We have the username (without domain/realm) in accountname, compare to + * the supplied value. In SSPI, always compare case insensitive. + * + * If set to include realm, append it in <username>@<realm> format. + */ + if (port->hba->include_realm) + { + char *namebuf; + int retval; + + namebuf = psprintf("%s@%s", accountname, domainname); + retval = check_usermap(port->hba->usermap, port->user_name, namebuf, true); + pfree(namebuf); + return retval; + } + else + return check_usermap(port->hba->usermap, port->user_name, accountname, true); +} + +/* + * Replaces the domainname with the Kerberos realm name, + * and optionally the accountname with the Kerberos user name. + */ +static int +pg_SSPI_make_upn(char *accountname, + size_t accountnamesize, + char *domainname, + size_t domainnamesize, + bool update_accountname) +{ + char *samname; + char *upname = NULL; + char *p = NULL; + ULONG upnamesize = 0; + size_t upnamerealmsize; + BOOLEAN res; + + /* + * Build SAM name (DOMAIN\user), then translate to UPN + * (user@kerberos.realm). The realm name is returned in lower case, but + * that is fine because in SSPI auth, string comparisons are always + * case-insensitive. + */ + + samname = psprintf("%s\\%s", domainname, accountname); + res = TranslateName(samname, NameSamCompatible, NameUserPrincipal, + NULL, &upnamesize); + + if ((!res && GetLastError() != ERROR_INSUFFICIENT_BUFFER) + || upnamesize == 0) + { + pfree(samname); + ereport(LOG, + (errcode(ERRCODE_INVALID_ROLE_SPECIFICATION), + errmsg("could not translate name"))); + return STATUS_ERROR; + } + + /* upnamesize includes the terminating NUL. */ + upname = palloc(upnamesize); + + res = TranslateName(samname, NameSamCompatible, NameUserPrincipal, + upname, &upnamesize); + + pfree(samname); + if (res) + p = strchr(upname, '@'); + + if (!res || p == NULL) + { + pfree(upname); + ereport(LOG, + (errcode(ERRCODE_INVALID_ROLE_SPECIFICATION), + errmsg("could not translate name"))); + return STATUS_ERROR; + } + + /* Length of realm name after the '@', including the NUL. */ + upnamerealmsize = upnamesize - (p - upname + 1); + + /* Replace domainname with realm name. */ + if (upnamerealmsize > domainnamesize) + { + pfree(upname); + ereport(LOG, + (errcode(ERRCODE_INVALID_ROLE_SPECIFICATION), + errmsg("realm name too long"))); + return STATUS_ERROR; + } + + /* Length is now safe. */ + strcpy(domainname, p + 1); + + /* Replace account name as well (in case UPN != SAM)? */ + if (update_accountname) + { + if ((p - upname + 1) > accountnamesize) + { + pfree(upname); + ereport(LOG, + (errcode(ERRCODE_INVALID_ROLE_SPECIFICATION), + errmsg("translated account name too long"))); + return STATUS_ERROR; + } + + *p = 0; + strcpy(accountname, upname); + } + + pfree(upname); + return STATUS_OK; +} +#endif /* ENABLE_SSPI */ + + + +/*---------------------------------------------------------------- + * Ident authentication system + *---------------------------------------------------------------- + */ + +/* + * Parse the string "*ident_response" as a response from a query to an Ident + * server. If it's a normal response indicating a user name, return true + * and store the user name at *ident_user. If it's anything else, + * return false. + */ +static bool +interpret_ident_response(const char *ident_response, + char *ident_user) +{ + const char *cursor = ident_response; /* Cursor into *ident_response */ + + /* + * Ident's response, in the telnet tradition, should end in crlf (\r\n). + */ + if (strlen(ident_response) < 2) + return false; + else if (ident_response[strlen(ident_response) - 2] != '\r') + return false; + else + { + while (*cursor != ':' && *cursor != '\r') + cursor++; /* skip port field */ + + if (*cursor != ':') + return false; + else + { + /* We're positioned to colon before response type field */ + char response_type[80]; + int i; /* Index into *response_type */ + + cursor++; /* Go over colon */ + while (pg_isblank(*cursor)) + cursor++; /* skip blanks */ + i = 0; + while (*cursor != ':' && *cursor != '\r' && !pg_isblank(*cursor) && + i < (int) (sizeof(response_type) - 1)) + response_type[i++] = *cursor++; + response_type[i] = '\0'; + while (pg_isblank(*cursor)) + cursor++; /* skip blanks */ + if (strcmp(response_type, "USERID") != 0) + return false; + else + { + /* + * It's a USERID response. Good. "cursor" should be pointing + * to the colon that precedes the operating system type. + */ + if (*cursor != ':') + return false; + else + { + cursor++; /* Go over colon */ + /* Skip over operating system field. */ + while (*cursor != ':' && *cursor != '\r') + cursor++; + if (*cursor != ':') + return false; + else + { + int i; /* Index into *ident_user */ + + cursor++; /* Go over colon */ + while (pg_isblank(*cursor)) + cursor++; /* skip blanks */ + /* Rest of line is user name. Copy it over. */ + i = 0; + while (*cursor != '\r' && i < IDENT_USERNAME_MAX) + ident_user[i++] = *cursor++; + ident_user[i] = '\0'; + return true; + } + } + } + } + } +} + + +/* + * Talk to the ident server on "remote_addr" and find out who + * owns the tcp connection to "local_addr" + * If the username is successfully retrieved, check the usermap. + * + * XXX: Using WaitLatchOrSocket() and doing a CHECK_FOR_INTERRUPTS() if the + * latch was set would improve the responsiveness to timeouts/cancellations. + */ +static int +ident_inet(hbaPort *port) +{ + const SockAddr remote_addr = port->raddr; + const SockAddr local_addr = port->laddr; + char ident_user[IDENT_USERNAME_MAX + 1]; + pgsocket sock_fd = PGINVALID_SOCKET; /* for talking to Ident server */ + int rc; /* Return code from a locally called function */ + bool ident_return; + char remote_addr_s[NI_MAXHOST]; + char remote_port[NI_MAXSERV]; + char local_addr_s[NI_MAXHOST]; + char local_port[NI_MAXSERV]; + char ident_port[NI_MAXSERV]; + char ident_query[80]; + char ident_response[80 + IDENT_USERNAME_MAX]; + struct addrinfo *ident_serv = NULL, + *la = NULL, + hints; + + /* + * Might look a little weird to first convert it to text and then back to + * sockaddr, but it's protocol independent. + */ + pg_getnameinfo_all(&remote_addr.addr, remote_addr.salen, + remote_addr_s, sizeof(remote_addr_s), + remote_port, sizeof(remote_port), + NI_NUMERICHOST | NI_NUMERICSERV); + pg_getnameinfo_all(&local_addr.addr, local_addr.salen, + local_addr_s, sizeof(local_addr_s), + local_port, sizeof(local_port), + NI_NUMERICHOST | NI_NUMERICSERV); + + snprintf(ident_port, sizeof(ident_port), "%d", IDENT_PORT); + hints.ai_flags = AI_NUMERICHOST; + hints.ai_family = remote_addr.addr.ss_family; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + hints.ai_addrlen = 0; + hints.ai_canonname = NULL; + hints.ai_addr = NULL; + hints.ai_next = NULL; + rc = pg_getaddrinfo_all(remote_addr_s, ident_port, &hints, &ident_serv); + if (rc || !ident_serv) + { + /* we don't expect this to happen */ + ident_return = false; + goto ident_inet_done; + } + + hints.ai_flags = AI_NUMERICHOST; + hints.ai_family = local_addr.addr.ss_family; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + hints.ai_addrlen = 0; + hints.ai_canonname = NULL; + hints.ai_addr = NULL; + hints.ai_next = NULL; + rc = pg_getaddrinfo_all(local_addr_s, NULL, &hints, &la); + if (rc || !la) + { + /* we don't expect this to happen */ + ident_return = false; + goto ident_inet_done; + } + + sock_fd = socket(ident_serv->ai_family, ident_serv->ai_socktype, + ident_serv->ai_protocol); + if (sock_fd == PGINVALID_SOCKET) + { + ereport(LOG, + (errcode_for_socket_access(), + errmsg("could not create socket for Ident connection: %m"))); + ident_return = false; + goto ident_inet_done; + } + + /* + * Bind to the address which the client originally contacted, otherwise + * the ident server won't be able to match up the right connection. This + * is necessary if the PostgreSQL server is running on an IP alias. + */ + rc = bind(sock_fd, la->ai_addr, la->ai_addrlen); + if (rc != 0) + { + ereport(LOG, + (errcode_for_socket_access(), + errmsg("could not bind to local address \"%s\": %m", + local_addr_s))); + ident_return = false; + goto ident_inet_done; + } + + rc = connect(sock_fd, ident_serv->ai_addr, + ident_serv->ai_addrlen); + if (rc != 0) + { + ereport(LOG, + (errcode_for_socket_access(), + errmsg("could not connect to Ident server at address \"%s\", port %s: %m", + remote_addr_s, ident_port))); + ident_return = false; + goto ident_inet_done; + } + + /* The query we send to the Ident server */ + snprintf(ident_query, sizeof(ident_query), "%s,%s\r\n", + remote_port, local_port); + + /* loop in case send is interrupted */ + do + { + CHECK_FOR_INTERRUPTS(); + + rc = send(sock_fd, ident_query, strlen(ident_query), 0); + } while (rc < 0 && errno == EINTR); + + if (rc < 0) + { + ereport(LOG, + (errcode_for_socket_access(), + errmsg("could not send query to Ident server at address \"%s\", port %s: %m", + remote_addr_s, ident_port))); + ident_return = false; + goto ident_inet_done; + } + + do + { + CHECK_FOR_INTERRUPTS(); + + rc = recv(sock_fd, ident_response, sizeof(ident_response) - 1, 0); + } while (rc < 0 && errno == EINTR); + + if (rc < 0) + { + ereport(LOG, + (errcode_for_socket_access(), + errmsg("could not receive response from Ident server at address \"%s\", port %s: %m", + remote_addr_s, ident_port))); + ident_return = false; + goto ident_inet_done; + } + + ident_response[rc] = '\0'; + ident_return = interpret_ident_response(ident_response, ident_user); + if (!ident_return) + ereport(LOG, + (errmsg("invalidly formatted response from Ident server: \"%s\"", + ident_response))); + +ident_inet_done: + if (sock_fd != PGINVALID_SOCKET) + closesocket(sock_fd); + if (ident_serv) + pg_freeaddrinfo_all(remote_addr.addr.ss_family, ident_serv); + if (la) + pg_freeaddrinfo_all(local_addr.addr.ss_family, la); + + if (ident_return) + { + /* + * Success! Store the identity, then check the usermap. Note that + * setting the authenticated identity is done before checking the + * usermap, because at this point authentication has succeeded. + */ + set_authn_id(port, ident_user); + return check_usermap(port->hba->usermap, port->user_name, ident_user, false); + } + return STATUS_ERROR; +} + + +/*---------------------------------------------------------------- + * Peer authentication system + *---------------------------------------------------------------- + */ + +/* + * Ask kernel about the credentials of the connecting process, + * determine the symbolic name of the corresponding user, and check + * if valid per the usermap. + * + * Iff authorized, return STATUS_OK, otherwise return STATUS_ERROR. + */ +static int +auth_peer(hbaPort *port) +{ + uid_t uid; + gid_t gid; +#ifndef WIN32 + struct passwd *pw; + int ret; +#endif + + if (getpeereid(port->sock, &uid, &gid) != 0) + { + /* Provide special error message if getpeereid is a stub */ + if (errno == ENOSYS) + ereport(LOG, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("peer authentication is not supported on this platform"))); + else + ereport(LOG, + (errcode_for_socket_access(), + errmsg("could not get peer credentials: %m"))); + return STATUS_ERROR; + } + +#ifndef WIN32 + errno = 0; /* clear errno before call */ + pw = getpwuid(uid); + if (!pw) + { + int save_errno = errno; + + ereport(LOG, + (errmsg("could not look up local user ID %ld: %s", + (long) uid, + save_errno ? strerror(save_errno) : _("user does not exist")))); + return STATUS_ERROR; + } + + /* + * Make a copy of static getpw*() result area; this is our authenticated + * identity. Set it before calling check_usermap, because authentication + * has already succeeded and we want the log file to reflect that. + */ + set_authn_id(port, pw->pw_name); + + ret = check_usermap(port->hba->usermap, port->user_name, port->authn_id, false); + + return ret; +#else + /* should have failed with ENOSYS above */ + Assert(false); + return STATUS_ERROR; +#endif +} + + +/*---------------------------------------------------------------- + * PAM authentication system + *---------------------------------------------------------------- + */ +#ifdef USE_PAM + +/* + * PAM conversation function + */ + +static int +pam_passwd_conv_proc(int num_msg, const struct pam_message **msg, + struct pam_response **resp, void *appdata_ptr) +{ + const char *passwd; + struct pam_response *reply; + int i; + + if (appdata_ptr) + passwd = (char *) appdata_ptr; + else + { + /* + * Workaround for Solaris 2.6 where the PAM library is broken and does + * not pass appdata_ptr to the conversation routine + */ + passwd = pam_passwd; + } + + *resp = NULL; /* in case of error exit */ + + if (num_msg <= 0 || num_msg > PAM_MAX_NUM_MSG) + return PAM_CONV_ERR; + + /* + * Explicitly not using palloc here - PAM will free this memory in + * pam_end() + */ + if ((reply = calloc(num_msg, sizeof(struct pam_response))) == NULL) + { + ereport(LOG, + (errcode(ERRCODE_OUT_OF_MEMORY), + errmsg("out of memory"))); + return PAM_CONV_ERR; + } + + for (i = 0; i < num_msg; i++) + { + switch (msg[i]->msg_style) + { + case PAM_PROMPT_ECHO_OFF: + if (strlen(passwd) == 0) + { + /* + * Password wasn't passed to PAM the first time around - + * let's go ask the client to send a password, which we + * then stuff into PAM. + */ + sendAuthRequest(pam_port_cludge, AUTH_REQ_PASSWORD, NULL, 0); + passwd = recv_password_packet(pam_port_cludge); + if (passwd == NULL) + { + /* + * Client didn't want to send password. We + * intentionally do not log anything about this, + * either here or at higher levels. + */ + pam_no_password = true; + goto fail; + } + } + if ((reply[i].resp = strdup(passwd)) == NULL) + goto fail; + reply[i].resp_retcode = PAM_SUCCESS; + break; + case PAM_ERROR_MSG: + ereport(LOG, + (errmsg("error from underlying PAM layer: %s", + msg[i]->msg))); + /* FALL THROUGH */ + case PAM_TEXT_INFO: + /* we don't bother to log TEXT_INFO messages */ + if ((reply[i].resp = strdup("")) == NULL) + goto fail; + reply[i].resp_retcode = PAM_SUCCESS; + break; + default: + ereport(LOG, + (errmsg("unsupported PAM conversation %d/\"%s\"", + msg[i]->msg_style, + msg[i]->msg ? msg[i]->msg : "(none)"))); + goto fail; + } + } + + *resp = reply; + return PAM_SUCCESS; + +fail: + /* free up whatever we allocated */ + for (i = 0; i < num_msg; i++) + { + if (reply[i].resp != NULL) + free(reply[i].resp); + } + free(reply); + + return PAM_CONV_ERR; +} + + +/* + * Check authentication against PAM. + */ +static int +CheckPAMAuth(Port *port, const char *user, const char *password) +{ + int retval; + pam_handle_t *pamh = NULL; + + /* + * We can't entirely rely on PAM to pass through appdata --- it appears + * not to work on at least Solaris 2.6. So use these ugly static + * variables instead. + */ + pam_passwd = password; + pam_port_cludge = port; + pam_no_password = false; + + /* + * Set the application data portion of the conversation struct. This is + * later used inside the PAM conversation to pass the password to the + * authentication module. + */ + pam_passw_conv.appdata_ptr = unconstify(char *, password); /* from password above, + * not allocated */ + + /* Optionally, one can set the service name in pg_hba.conf */ + if (port->hba->pamservice && port->hba->pamservice[0] != '\0') + retval = pam_start(port->hba->pamservice, "pgsql@", + &pam_passw_conv, &pamh); + else + retval = pam_start(PGSQL_PAM_SERVICE, "pgsql@", + &pam_passw_conv, &pamh); + + if (retval != PAM_SUCCESS) + { + ereport(LOG, + (errmsg("could not create PAM authenticator: %s", + pam_strerror(pamh, retval)))); + pam_passwd = NULL; /* Unset pam_passwd */ + return STATUS_ERROR; + } + + retval = pam_set_item(pamh, PAM_USER, user); + + if (retval != PAM_SUCCESS) + { + ereport(LOG, + (errmsg("pam_set_item(PAM_USER) failed: %s", + pam_strerror(pamh, retval)))); + pam_passwd = NULL; /* Unset pam_passwd */ + return STATUS_ERROR; + } + + if (port->hba->conntype != ctLocal) + { + char hostinfo[NI_MAXHOST]; + int flags; + + if (port->hba->pam_use_hostname) + flags = 0; + else + flags = NI_NUMERICHOST | NI_NUMERICSERV; + + retval = pg_getnameinfo_all(&port->raddr.addr, port->raddr.salen, + hostinfo, sizeof(hostinfo), NULL, 0, + flags); + if (retval != 0) + { + ereport(WARNING, + (errmsg_internal("pg_getnameinfo_all() failed: %s", + gai_strerror(retval)))); + return STATUS_ERROR; + } + + retval = pam_set_item(pamh, PAM_RHOST, hostinfo); + + if (retval != PAM_SUCCESS) + { + ereport(LOG, + (errmsg("pam_set_item(PAM_RHOST) failed: %s", + pam_strerror(pamh, retval)))); + pam_passwd = NULL; + return STATUS_ERROR; + } + } + + retval = pam_set_item(pamh, PAM_CONV, &pam_passw_conv); + + if (retval != PAM_SUCCESS) + { + ereport(LOG, + (errmsg("pam_set_item(PAM_CONV) failed: %s", + pam_strerror(pamh, retval)))); + pam_passwd = NULL; /* Unset pam_passwd */ + return STATUS_ERROR; + } + + retval = pam_authenticate(pamh, 0); + + if (retval != PAM_SUCCESS) + { + /* If pam_passwd_conv_proc saw EOF, don't log anything */ + if (!pam_no_password) + ereport(LOG, + (errmsg("pam_authenticate failed: %s", + pam_strerror(pamh, retval)))); + pam_passwd = NULL; /* Unset pam_passwd */ + return pam_no_password ? STATUS_EOF : STATUS_ERROR; + } + + retval = pam_acct_mgmt(pamh, 0); + + if (retval != PAM_SUCCESS) + { + /* If pam_passwd_conv_proc saw EOF, don't log anything */ + if (!pam_no_password) + ereport(LOG, + (errmsg("pam_acct_mgmt failed: %s", + pam_strerror(pamh, retval)))); + pam_passwd = NULL; /* Unset pam_passwd */ + return pam_no_password ? STATUS_EOF : STATUS_ERROR; + } + + retval = pam_end(pamh, retval); + + if (retval != PAM_SUCCESS) + { + ereport(LOG, + (errmsg("could not release PAM authenticator: %s", + pam_strerror(pamh, retval)))); + } + + pam_passwd = NULL; /* Unset pam_passwd */ + + if (retval == PAM_SUCCESS) + set_authn_id(port, user); + + return (retval == PAM_SUCCESS ? STATUS_OK : STATUS_ERROR); +} +#endif /* USE_PAM */ + + +/*---------------------------------------------------------------- + * BSD authentication system + *---------------------------------------------------------------- + */ +#ifdef USE_BSD_AUTH +static int +CheckBSDAuth(Port *port, char *user) +{ + char *passwd; + int retval; + + /* Send regular password request to client, and get the response */ + sendAuthRequest(port, AUTH_REQ_PASSWORD, NULL, 0); + + passwd = recv_password_packet(port); + if (passwd == NULL) + return STATUS_EOF; + + /* + * Ask the BSD auth system to verify password. Note that auth_userokay + * will overwrite the password string with zeroes, but it's just a + * temporary string so we don't care. + */ + retval = auth_userokay(user, NULL, "auth-postgresql", passwd); + + pfree(passwd); + + if (!retval) + return STATUS_ERROR; + + set_authn_id(port, user); + return STATUS_OK; +} +#endif /* USE_BSD_AUTH */ + + +/*---------------------------------------------------------------- + * LDAP authentication system + *---------------------------------------------------------------- + */ +#ifdef USE_LDAP + +static int errdetail_for_ldap(LDAP *ldap); + +/* + * Initialize a connection to the LDAP server, including setting up + * TLS if requested. + */ +static int +InitializeLDAPConnection(Port *port, LDAP **ldap) +{ + const char *scheme; + int ldapversion = LDAP_VERSION3; + int r; + + scheme = port->hba->ldapscheme; + if (scheme == NULL) + scheme = "ldap"; +#ifdef WIN32 + if (strcmp(scheme, "ldaps") == 0) + *ldap = ldap_sslinit(port->hba->ldapserver, port->hba->ldapport, 1); + else + *ldap = ldap_init(port->hba->ldapserver, port->hba->ldapport); + if (!*ldap) + { + ereport(LOG, + (errmsg("could not initialize LDAP: error code %d", + (int) LdapGetLastError()))); + + return STATUS_ERROR; + } +#else +#ifdef HAVE_LDAP_INITIALIZE + + /* + * OpenLDAP provides a non-standard extension ldap_initialize() that takes + * a list of URIs, allowing us to request "ldaps" instead of "ldap". It + * also provides ldap_domain2hostlist() to find LDAP servers automatically + * using DNS SRV. They were introduced in the same version, so for now we + * don't have an extra configure check for the latter. + */ + { + StringInfoData uris; + char *hostlist = NULL; + char *p; + bool append_port; + + /* We'll build a space-separated scheme://hostname:port list here */ + initStringInfo(&uris); + + /* + * If pg_hba.conf provided no hostnames, we can ask OpenLDAP to try to + * find some by extracting a domain name from the base DN and looking + * up DSN SRV records for _ldap._tcp.<domain>. + */ + if (!port->hba->ldapserver || port->hba->ldapserver[0] == '\0') + { + char *domain; + + /* ou=blah,dc=foo,dc=bar -> foo.bar */ + if (ldap_dn2domain(port->hba->ldapbasedn, &domain)) + { + ereport(LOG, + (errmsg("could not extract domain name from ldapbasedn"))); + return STATUS_ERROR; + } + + /* Look up a list of LDAP server hosts and port numbers */ + if (ldap_domain2hostlist(domain, &hostlist)) + { + ereport(LOG, + (errmsg("LDAP authentication could not find DNS SRV records for \"%s\"", + domain), + (errhint("Set an LDAP server name explicitly.")))); + ldap_memfree(domain); + return STATUS_ERROR; + } + ldap_memfree(domain); + + /* We have a space-separated list of host:port entries */ + p = hostlist; + append_port = false; + } + else + { + /* We have a space-separated list of hosts from pg_hba.conf */ + p = port->hba->ldapserver; + append_port = true; + } + + /* Convert the list of host[:port] entries to full URIs */ + do + { + size_t size; + + /* Find the span of the next entry */ + size = strcspn(p, " "); + + /* Append a space separator if this isn't the first URI */ + if (uris.len > 0) + appendStringInfoChar(&uris, ' '); + + /* Append scheme://host:port */ + appendStringInfoString(&uris, scheme); + appendStringInfoString(&uris, "://"); + appendBinaryStringInfo(&uris, p, size); + if (append_port) + appendStringInfo(&uris, ":%d", port->hba->ldapport); + + /* Step over this entry and any number of trailing spaces */ + p += size; + while (*p == ' ') + ++p; + } while (*p); + + /* Free memory from OpenLDAP if we looked up SRV records */ + if (hostlist) + ldap_memfree(hostlist); + + /* Finally, try to connect using the URI list */ + r = ldap_initialize(ldap, uris.data); + pfree(uris.data); + if (r != LDAP_SUCCESS) + { + ereport(LOG, + (errmsg("could not initialize LDAP: %s", + ldap_err2string(r)))); + + return STATUS_ERROR; + } + } +#else + if (strcmp(scheme, "ldaps") == 0) + { + ereport(LOG, + (errmsg("ldaps not supported with this LDAP library"))); + + return STATUS_ERROR; + } + *ldap = ldap_init(port->hba->ldapserver, port->hba->ldapport); + if (!*ldap) + { + ereport(LOG, + (errmsg("could not initialize LDAP: %m"))); + + return STATUS_ERROR; + } +#endif +#endif + + if ((r = ldap_set_option(*ldap, LDAP_OPT_PROTOCOL_VERSION, &ldapversion)) != LDAP_SUCCESS) + { + ereport(LOG, + (errmsg("could not set LDAP protocol version: %s", + ldap_err2string(r)), + errdetail_for_ldap(*ldap))); + ldap_unbind(*ldap); + return STATUS_ERROR; + } + + if (port->hba->ldaptls) + { +#ifndef WIN32 + if ((r = ldap_start_tls_s(*ldap, NULL, NULL)) != LDAP_SUCCESS) +#else + static __ldap_start_tls_sA _ldap_start_tls_sA = NULL; + + if (_ldap_start_tls_sA == NULL) + { + /* + * Need to load this function dynamically because it may not exist + * on Windows, and causes a load error for the whole exe if + * referenced. + */ + HANDLE ldaphandle; + + ldaphandle = LoadLibrary("WLDAP32.DLL"); + if (ldaphandle == NULL) + { + /* + * should never happen since we import other files from + * wldap32, but check anyway + */ + ereport(LOG, + (errmsg("could not load library \"%s\": error code %lu", + "WLDAP32.DLL", GetLastError()))); + ldap_unbind(*ldap); + return STATUS_ERROR; + } + _ldap_start_tls_sA = (__ldap_start_tls_sA) (pg_funcptr_t) GetProcAddress(ldaphandle, "ldap_start_tls_sA"); + if (_ldap_start_tls_sA == NULL) + { + ereport(LOG, + (errmsg("could not load function _ldap_start_tls_sA in wldap32.dll"), + errdetail("LDAP over SSL is not supported on this platform."))); + ldap_unbind(*ldap); + FreeLibrary(ldaphandle); + return STATUS_ERROR; + } + + /* + * Leak LDAP handle on purpose, because we need the library to + * stay open. This is ok because it will only ever be leaked once + * per process and is automatically cleaned up on process exit. + */ + } + if ((r = _ldap_start_tls_sA(*ldap, NULL, NULL, NULL, NULL)) != LDAP_SUCCESS) +#endif + { + ereport(LOG, + (errmsg("could not start LDAP TLS session: %s", + ldap_err2string(r)), + errdetail_for_ldap(*ldap))); + ldap_unbind(*ldap); + return STATUS_ERROR; + } + } + + return STATUS_OK; +} + +/* Placeholders recognized by FormatSearchFilter. For now just one. */ +#define LPH_USERNAME "$username" +#define LPH_USERNAME_LEN (sizeof(LPH_USERNAME) - 1) + +/* Not all LDAP implementations define this. */ +#ifndef LDAP_NO_ATTRS +#define LDAP_NO_ATTRS "1.1" +#endif + +/* Not all LDAP implementations define this. */ +#ifndef LDAPS_PORT +#define LDAPS_PORT 636 +#endif + +/* + * Return a newly allocated C string copied from "pattern" with all + * occurrences of the placeholder "$username" replaced with "user_name". + */ +static char * +FormatSearchFilter(const char *pattern, const char *user_name) +{ + StringInfoData output; + + initStringInfo(&output); + while (*pattern != '\0') + { + if (strncmp(pattern, LPH_USERNAME, LPH_USERNAME_LEN) == 0) + { + appendStringInfoString(&output, user_name); + pattern += LPH_USERNAME_LEN; + } + else + appendStringInfoChar(&output, *pattern++); + } + + return output.data; +} + +/* + * Perform LDAP authentication + */ +static int +CheckLDAPAuth(Port *port) +{ + char *passwd; + LDAP *ldap; + int r; + char *fulluser; + const char *server_name; + +#ifdef HAVE_LDAP_INITIALIZE + + /* + * For OpenLDAP, allow empty hostname if we have a basedn. We'll look for + * servers with DNS SRV records via OpenLDAP library facilities. + */ + if ((!port->hba->ldapserver || port->hba->ldapserver[0] == '\0') && + (!port->hba->ldapbasedn || port->hba->ldapbasedn[0] == '\0')) + { + ereport(LOG, + (errmsg("LDAP server not specified, and no ldapbasedn"))); + return STATUS_ERROR; + } +#else + if (!port->hba->ldapserver || port->hba->ldapserver[0] == '\0') + { + ereport(LOG, + (errmsg("LDAP server not specified"))); + return STATUS_ERROR; + } +#endif + + /* + * If we're using SRV records, we don't have a server name so we'll just + * show an empty string in error messages. + */ + server_name = port->hba->ldapserver ? port->hba->ldapserver : ""; + + if (port->hba->ldapport == 0) + { + if (port->hba->ldapscheme != NULL && + strcmp(port->hba->ldapscheme, "ldaps") == 0) + port->hba->ldapport = LDAPS_PORT; + else + port->hba->ldapport = LDAP_PORT; + } + + sendAuthRequest(port, AUTH_REQ_PASSWORD, NULL, 0); + + passwd = recv_password_packet(port); + if (passwd == NULL) + return STATUS_EOF; /* client wouldn't send password */ + + if (InitializeLDAPConnection(port, &ldap) == STATUS_ERROR) + { + /* Error message already sent */ + pfree(passwd); + return STATUS_ERROR; + } + + if (port->hba->ldapbasedn) + { + /* + * First perform an LDAP search to find the DN for the user we are + * trying to log in as. + */ + char *filter; + LDAPMessage *search_message; + LDAPMessage *entry; + char *attributes[] = {LDAP_NO_ATTRS, NULL}; + char *dn; + char *c; + int count; + + /* + * Disallow any characters that we would otherwise need to escape, + * since they aren't really reasonable in a username anyway. Allowing + * them would make it possible to inject any kind of custom filters in + * the LDAP filter. + */ + for (c = port->user_name; *c; c++) + { + if (*c == '*' || + *c == '(' || + *c == ')' || + *c == '\\' || + *c == '/') + { + ereport(LOG, + (errmsg("invalid character in user name for LDAP authentication"))); + ldap_unbind(ldap); + pfree(passwd); + return STATUS_ERROR; + } + } + + /* + * Bind with a pre-defined username/password (if available) for + * searching. If none is specified, this turns into an anonymous bind. + */ + r = ldap_simple_bind_s(ldap, + port->hba->ldapbinddn ? port->hba->ldapbinddn : "", + port->hba->ldapbindpasswd ? port->hba->ldapbindpasswd : ""); + if (r != LDAP_SUCCESS) + { + ereport(LOG, + (errmsg("could not perform initial LDAP bind for ldapbinddn \"%s\" on server \"%s\": %s", + port->hba->ldapbinddn ? port->hba->ldapbinddn : "", + server_name, + ldap_err2string(r)), + errdetail_for_ldap(ldap))); + ldap_unbind(ldap); + pfree(passwd); + return STATUS_ERROR; + } + + /* Build a custom filter or a single attribute filter? */ + if (port->hba->ldapsearchfilter) + filter = FormatSearchFilter(port->hba->ldapsearchfilter, port->user_name); + else if (port->hba->ldapsearchattribute) + filter = psprintf("(%s=%s)", port->hba->ldapsearchattribute, port->user_name); + else + filter = psprintf("(uid=%s)", port->user_name); + + r = ldap_search_s(ldap, + port->hba->ldapbasedn, + port->hba->ldapscope, + filter, + attributes, + 0, + &search_message); + + if (r != LDAP_SUCCESS) + { + ereport(LOG, + (errmsg("could not search LDAP for filter \"%s\" on server \"%s\": %s", + filter, server_name, ldap_err2string(r)), + errdetail_for_ldap(ldap))); + ldap_unbind(ldap); + pfree(passwd); + pfree(filter); + return STATUS_ERROR; + } + + count = ldap_count_entries(ldap, search_message); + if (count != 1) + { + if (count == 0) + ereport(LOG, + (errmsg("LDAP user \"%s\" does not exist", port->user_name), + errdetail("LDAP search for filter \"%s\" on server \"%s\" returned no entries.", + filter, server_name))); + else + ereport(LOG, + (errmsg("LDAP user \"%s\" is not unique", port->user_name), + errdetail_plural("LDAP search for filter \"%s\" on server \"%s\" returned %d entry.", + "LDAP search for filter \"%s\" on server \"%s\" returned %d entries.", + count, + filter, server_name, count))); + + ldap_unbind(ldap); + pfree(passwd); + pfree(filter); + ldap_msgfree(search_message); + return STATUS_ERROR; + } + + entry = ldap_first_entry(ldap, search_message); + dn = ldap_get_dn(ldap, entry); + if (dn == NULL) + { + int error; + + (void) ldap_get_option(ldap, LDAP_OPT_ERROR_NUMBER, &error); + ereport(LOG, + (errmsg("could not get dn for the first entry matching \"%s\" on server \"%s\": %s", + filter, server_name, + ldap_err2string(error)), + errdetail_for_ldap(ldap))); + ldap_unbind(ldap); + pfree(passwd); + pfree(filter); + ldap_msgfree(search_message); + return STATUS_ERROR; + } + fulluser = pstrdup(dn); + + pfree(filter); + ldap_memfree(dn); + ldap_msgfree(search_message); + + /* Unbind and disconnect from the LDAP server */ + r = ldap_unbind_s(ldap); + if (r != LDAP_SUCCESS) + { + ereport(LOG, + (errmsg("could not unbind after searching for user \"%s\" on server \"%s\"", + fulluser, server_name))); + pfree(passwd); + pfree(fulluser); + return STATUS_ERROR; + } + + /* + * Need to re-initialize the LDAP connection, so that we can bind to + * it with a different username. + */ + if (InitializeLDAPConnection(port, &ldap) == STATUS_ERROR) + { + pfree(passwd); + pfree(fulluser); + + /* Error message already sent */ + return STATUS_ERROR; + } + } + else + fulluser = psprintf("%s%s%s", + port->hba->ldapprefix ? port->hba->ldapprefix : "", + port->user_name, + port->hba->ldapsuffix ? port->hba->ldapsuffix : ""); + + r = ldap_simple_bind_s(ldap, fulluser, passwd); + + if (r != LDAP_SUCCESS) + { + ereport(LOG, + (errmsg("LDAP login failed for user \"%s\" on server \"%s\": %s", + fulluser, server_name, ldap_err2string(r)), + errdetail_for_ldap(ldap))); + ldap_unbind(ldap); + pfree(passwd); + pfree(fulluser); + return STATUS_ERROR; + } + + /* Save the original bind DN as the authenticated identity. */ + set_authn_id(port, fulluser); + + ldap_unbind(ldap); + pfree(passwd); + pfree(fulluser); + + return STATUS_OK; +} + +/* + * Add a detail error message text to the current error if one can be + * constructed from the LDAP 'diagnostic message'. + */ +static int +errdetail_for_ldap(LDAP *ldap) +{ + char *message; + int rc; + + rc = ldap_get_option(ldap, LDAP_OPT_DIAGNOSTIC_MESSAGE, &message); + if (rc == LDAP_SUCCESS && message != NULL) + { + errdetail("LDAP diagnostics: %s", message); + ldap_memfree(message); + } + + return 0; +} + +#endif /* USE_LDAP */ + + +/*---------------------------------------------------------------- + * SSL client certificate authentication + *---------------------------------------------------------------- + */ +#ifdef USE_SSL +static int +CheckCertAuth(Port *port) +{ + int status_check_usermap = STATUS_ERROR; + char *peer_username = NULL; + + Assert(port->ssl); + + /* select the correct field to compare */ + switch (port->hba->clientcertname) + { + case clientCertDN: + peer_username = port->peer_dn; + break; + case clientCertCN: + peer_username = port->peer_cn; + } + + /* Make sure we have received a username in the certificate */ + if (peer_username == NULL || + strlen(peer_username) <= 0) + { + ereport(LOG, + (errmsg("certificate authentication failed for user \"%s\": client certificate contains no user name", + port->user_name))); + return STATUS_ERROR; + } + + if (port->hba->auth_method == uaCert) + { + /* + * For cert auth, the client's Subject DN is always our authenticated + * identity, even if we're only using its CN for authorization. Set + * it now, rather than waiting for check_usermap() below, because + * authentication has already succeeded and we want the log file to + * reflect that. + */ + if (!port->peer_dn) + { + /* + * This should not happen as both peer_dn and peer_cn should be + * set in this context. + */ + ereport(LOG, + (errmsg("certificate authentication failed for user \"%s\": unable to retrieve subject DN", + port->user_name))); + return STATUS_ERROR; + } + + set_authn_id(port, port->peer_dn); + } + + /* Just pass the certificate cn/dn to the usermap check */ + status_check_usermap = check_usermap(port->hba->usermap, port->user_name, peer_username, false); + if (status_check_usermap != STATUS_OK) + { + /* + * If clientcert=verify-full was specified and the authentication + * method is other than uaCert, log the reason for rejecting the + * authentication. + */ + if (port->hba->clientcert == clientCertFull && port->hba->auth_method != uaCert) + { + switch (port->hba->clientcertname) + { + case clientCertDN: + ereport(LOG, + (errmsg("certificate validation (clientcert=verify-full) failed for user \"%s\": DN mismatch", + port->user_name))); + break; + case clientCertCN: + ereport(LOG, + (errmsg("certificate validation (clientcert=verify-full) failed for user \"%s\": CN mismatch", + port->user_name))); + } + } + } + return status_check_usermap; +} +#endif + + +/*---------------------------------------------------------------- + * RADIUS authentication + *---------------------------------------------------------------- + */ + +/* + * RADIUS authentication is described in RFC2865 (and several others). + */ + +#define RADIUS_VECTOR_LENGTH 16 +#define RADIUS_HEADER_LENGTH 20 +#define RADIUS_MAX_PASSWORD_LENGTH 128 + +/* Maximum size of a RADIUS packet we will create or accept */ +#define RADIUS_BUFFER_SIZE 1024 + +typedef struct +{ + uint8 attribute; + uint8 length; + uint8 data[FLEXIBLE_ARRAY_MEMBER]; +} radius_attribute; + +typedef struct +{ + uint8 code; + uint8 id; + uint16 length; + uint8 vector[RADIUS_VECTOR_LENGTH]; + /* this is a bit longer than strictly necessary: */ + char pad[RADIUS_BUFFER_SIZE - RADIUS_VECTOR_LENGTH]; +} radius_packet; + +/* RADIUS packet types */ +#define RADIUS_ACCESS_REQUEST 1 +#define RADIUS_ACCESS_ACCEPT 2 +#define RADIUS_ACCESS_REJECT 3 + +/* RADIUS attributes */ +#define RADIUS_USER_NAME 1 +#define RADIUS_PASSWORD 2 +#define RADIUS_SERVICE_TYPE 6 +#define RADIUS_NAS_IDENTIFIER 32 + +/* RADIUS service types */ +#define RADIUS_AUTHENTICATE_ONLY 8 + +/* Seconds to wait - XXX: should be in a config variable! */ +#define RADIUS_TIMEOUT 3 + +static void +radius_add_attribute(radius_packet *packet, uint8 type, const unsigned char *data, int len) +{ + radius_attribute *attr; + + if (packet->length + len > RADIUS_BUFFER_SIZE) + { + /* + * With remotely realistic data, this can never happen. But catch it + * just to make sure we don't overrun a buffer. We'll just skip adding + * the broken attribute, which will in the end cause authentication to + * fail. + */ + elog(WARNING, + "adding attribute code %d with length %d to radius packet would create oversize packet, ignoring", + type, len); + return; + } + + attr = (radius_attribute *) ((unsigned char *) packet + packet->length); + attr->attribute = type; + attr->length = len + 2; /* total size includes type and length */ + memcpy(attr->data, data, len); + packet->length += attr->length; +} + +static int +CheckRADIUSAuth(Port *port) +{ + char *passwd; + ListCell *server, + *secrets, + *radiusports, + *identifiers; + + /* Make sure struct alignment is correct */ + Assert(offsetof(radius_packet, vector) == 4); + + /* Verify parameters */ + if (list_length(port->hba->radiusservers) < 1) + { + ereport(LOG, + (errmsg("RADIUS server not specified"))); + return STATUS_ERROR; + } + + if (list_length(port->hba->radiussecrets) < 1) + { + ereport(LOG, + (errmsg("RADIUS secret not specified"))); + return STATUS_ERROR; + } + + /* Send regular password request to client, and get the response */ + sendAuthRequest(port, AUTH_REQ_PASSWORD, NULL, 0); + + passwd = recv_password_packet(port); + if (passwd == NULL) + return STATUS_EOF; /* client wouldn't send password */ + + if (strlen(passwd) > RADIUS_MAX_PASSWORD_LENGTH) + { + ereport(LOG, + (errmsg("RADIUS authentication does not support passwords longer than %d characters", RADIUS_MAX_PASSWORD_LENGTH))); + pfree(passwd); + return STATUS_ERROR; + } + + /* + * Loop over and try each server in order. + */ + secrets = list_head(port->hba->radiussecrets); + radiusports = list_head(port->hba->radiusports); + identifiers = list_head(port->hba->radiusidentifiers); + foreach(server, port->hba->radiusservers) + { + int ret = PerformRadiusTransaction(lfirst(server), + lfirst(secrets), + radiusports ? lfirst(radiusports) : NULL, + identifiers ? lfirst(identifiers) : NULL, + port->user_name, + passwd); + + /*------ + * STATUS_OK = Login OK + * STATUS_ERROR = Login not OK, but try next server + * STATUS_EOF = Login not OK, and don't try next server + *------ + */ + if (ret == STATUS_OK) + { + set_authn_id(port, port->user_name); + + pfree(passwd); + return STATUS_OK; + } + else if (ret == STATUS_EOF) + { + pfree(passwd); + return STATUS_ERROR; + } + + /* + * secret, port and identifiers either have length 0 (use default), + * length 1 (use the same everywhere) or the same length as servers. + * So if the length is >1, we advance one step. In other cases, we + * don't and will then reuse the correct value. + */ + if (list_length(port->hba->radiussecrets) > 1) + secrets = lnext(port->hba->radiussecrets, secrets); + if (list_length(port->hba->radiusports) > 1) + radiusports = lnext(port->hba->radiusports, radiusports); + if (list_length(port->hba->radiusidentifiers) > 1) + identifiers = lnext(port->hba->radiusidentifiers, identifiers); + } + + /* No servers left to try, so give up */ + pfree(passwd); + return STATUS_ERROR; +} + +static int +PerformRadiusTransaction(const char *server, const char *secret, const char *portstr, const char *identifier, const char *user_name, const char *passwd) +{ + radius_packet radius_send_pack; + radius_packet radius_recv_pack; + radius_packet *packet = &radius_send_pack; + radius_packet *receivepacket = &radius_recv_pack; + char *radius_buffer = (char *) &radius_send_pack; + char *receive_buffer = (char *) &radius_recv_pack; + int32 service = pg_hton32(RADIUS_AUTHENTICATE_ONLY); + uint8 *cryptvector; + int encryptedpasswordlen; + uint8 encryptedpassword[RADIUS_MAX_PASSWORD_LENGTH]; + uint8 *md5trailer; + int packetlength; + pgsocket sock; + +#ifdef HAVE_IPV6 + struct sockaddr_in6 localaddr; + struct sockaddr_in6 remoteaddr; +#else + struct sockaddr_in localaddr; + struct sockaddr_in remoteaddr; +#endif + struct addrinfo hint; + struct addrinfo *serveraddrs; + int port; + ACCEPT_TYPE_ARG3 addrsize; + fd_set fdset; + struct timeval endtime; + int i, + j, + r; + + /* Assign default values */ + if (portstr == NULL) + portstr = "1812"; + if (identifier == NULL) + identifier = "postgresql"; + + MemSet(&hint, 0, sizeof(hint)); + hint.ai_socktype = SOCK_DGRAM; + hint.ai_family = AF_UNSPEC; + port = atoi(portstr); + + r = pg_getaddrinfo_all(server, portstr, &hint, &serveraddrs); + if (r || !serveraddrs) + { + ereport(LOG, + (errmsg("could not translate RADIUS server name \"%s\" to address: %s", + server, gai_strerror(r)))); + if (serveraddrs) + pg_freeaddrinfo_all(hint.ai_family, serveraddrs); + return STATUS_ERROR; + } + /* XXX: add support for multiple returned addresses? */ + + /* Construct RADIUS packet */ + packet->code = RADIUS_ACCESS_REQUEST; + packet->length = RADIUS_HEADER_LENGTH; + if (!pg_strong_random(packet->vector, RADIUS_VECTOR_LENGTH)) + { + ereport(LOG, + (errmsg("could not generate random encryption vector"))); + pg_freeaddrinfo_all(hint.ai_family, serveraddrs); + return STATUS_ERROR; + } + packet->id = packet->vector[0]; + radius_add_attribute(packet, RADIUS_SERVICE_TYPE, (const unsigned char *) &service, sizeof(service)); + radius_add_attribute(packet, RADIUS_USER_NAME, (const unsigned char *) user_name, strlen(user_name)); + radius_add_attribute(packet, RADIUS_NAS_IDENTIFIER, (const unsigned char *) identifier, strlen(identifier)); + + /* + * RADIUS password attributes are calculated as: e[0] = p[0] XOR + * MD5(secret + Request Authenticator) for the first group of 16 octets, + * and then: e[i] = p[i] XOR MD5(secret + e[i-1]) for the following ones + * (if necessary) + */ + encryptedpasswordlen = ((strlen(passwd) + RADIUS_VECTOR_LENGTH - 1) / RADIUS_VECTOR_LENGTH) * RADIUS_VECTOR_LENGTH; + cryptvector = palloc(strlen(secret) + RADIUS_VECTOR_LENGTH); + memcpy(cryptvector, secret, strlen(secret)); + + /* for the first iteration, we use the Request Authenticator vector */ + md5trailer = packet->vector; + for (i = 0; i < encryptedpasswordlen; i += RADIUS_VECTOR_LENGTH) + { + memcpy(cryptvector + strlen(secret), md5trailer, RADIUS_VECTOR_LENGTH); + + /* + * .. and for subsequent iterations the result of the previous XOR + * (calculated below) + */ + md5trailer = encryptedpassword + i; + + if (!pg_md5_binary(cryptvector, strlen(secret) + RADIUS_VECTOR_LENGTH, encryptedpassword + i)) + { + ereport(LOG, + (errmsg("could not perform MD5 encryption of password"))); + pfree(cryptvector); + pg_freeaddrinfo_all(hint.ai_family, serveraddrs); + return STATUS_ERROR; + } + + for (j = i; j < i + RADIUS_VECTOR_LENGTH; j++) + { + if (j < strlen(passwd)) + encryptedpassword[j] = passwd[j] ^ encryptedpassword[j]; + else + encryptedpassword[j] = '\0' ^ encryptedpassword[j]; + } + } + pfree(cryptvector); + + radius_add_attribute(packet, RADIUS_PASSWORD, encryptedpassword, encryptedpasswordlen); + + /* Length needs to be in network order on the wire */ + packetlength = packet->length; + packet->length = pg_hton16(packet->length); + + sock = socket(serveraddrs[0].ai_family, SOCK_DGRAM, 0); + if (sock == PGINVALID_SOCKET) + { + ereport(LOG, + (errmsg("could not create RADIUS socket: %m"))); + pg_freeaddrinfo_all(hint.ai_family, serveraddrs); + return STATUS_ERROR; + } + + memset(&localaddr, 0, sizeof(localaddr)); +#ifdef HAVE_IPV6 + localaddr.sin6_family = serveraddrs[0].ai_family; + localaddr.sin6_addr = in6addr_any; + if (localaddr.sin6_family == AF_INET6) + addrsize = sizeof(struct sockaddr_in6); + else + addrsize = sizeof(struct sockaddr_in); +#else + localaddr.sin_family = serveraddrs[0].ai_family; + localaddr.sin_addr.s_addr = INADDR_ANY; + addrsize = sizeof(struct sockaddr_in); +#endif + + if (bind(sock, (struct sockaddr *) &localaddr, addrsize)) + { + ereport(LOG, + (errmsg("could not bind local RADIUS socket: %m"))); + closesocket(sock); + pg_freeaddrinfo_all(hint.ai_family, serveraddrs); + return STATUS_ERROR; + } + + if (sendto(sock, radius_buffer, packetlength, 0, + serveraddrs[0].ai_addr, serveraddrs[0].ai_addrlen) < 0) + { + ereport(LOG, + (errmsg("could not send RADIUS packet: %m"))); + closesocket(sock); + pg_freeaddrinfo_all(hint.ai_family, serveraddrs); + return STATUS_ERROR; + } + + /* Don't need the server address anymore */ + pg_freeaddrinfo_all(hint.ai_family, serveraddrs); + + /* + * Figure out at what time we should time out. We can't just use a single + * call to select() with a timeout, since somebody can be sending invalid + * packets to our port thus causing us to retry in a loop and never time + * out. + * + * XXX: Using WaitLatchOrSocket() and doing a CHECK_FOR_INTERRUPTS() if + * the latch was set would improve the responsiveness to + * timeouts/cancellations. + */ + gettimeofday(&endtime, NULL); + endtime.tv_sec += RADIUS_TIMEOUT; + + while (true) + { + struct timeval timeout; + struct timeval now; + int64 timeoutval; + + gettimeofday(&now, NULL); + timeoutval = (endtime.tv_sec * 1000000 + endtime.tv_usec) - (now.tv_sec * 1000000 + now.tv_usec); + if (timeoutval <= 0) + { + ereport(LOG, + (errmsg("timeout waiting for RADIUS response from %s", + server))); + closesocket(sock); + return STATUS_ERROR; + } + timeout.tv_sec = timeoutval / 1000000; + timeout.tv_usec = timeoutval % 1000000; + + FD_ZERO(&fdset); + FD_SET(sock, &fdset); + + r = select(sock + 1, &fdset, NULL, NULL, &timeout); + if (r < 0) + { + if (errno == EINTR) + continue; + + /* Anything else is an actual error */ + ereport(LOG, + (errmsg("could not check status on RADIUS socket: %m"))); + closesocket(sock); + return STATUS_ERROR; + } + if (r == 0) + { + ereport(LOG, + (errmsg("timeout waiting for RADIUS response from %s", + server))); + closesocket(sock); + return STATUS_ERROR; + } + + /* + * Attempt to read the response packet, and verify the contents. + * + * Any packet that's not actually a RADIUS packet, or otherwise does + * not validate as an explicit reject, is just ignored and we retry + * for another packet (until we reach the timeout). This is to avoid + * the possibility to denial-of-service the login by flooding the + * server with invalid packets on the port that we're expecting the + * RADIUS response on. + */ + + addrsize = sizeof(remoteaddr); + packetlength = recvfrom(sock, receive_buffer, RADIUS_BUFFER_SIZE, 0, + (struct sockaddr *) &remoteaddr, &addrsize); + if (packetlength < 0) + { + ereport(LOG, + (errmsg("could not read RADIUS response: %m"))); + closesocket(sock); + return STATUS_ERROR; + } + +#ifdef HAVE_IPV6 + if (remoteaddr.sin6_port != pg_hton16(port)) +#else + if (remoteaddr.sin_port != pg_hton16(port)) +#endif + { +#ifdef HAVE_IPV6 + ereport(LOG, + (errmsg("RADIUS response from %s was sent from incorrect port: %d", + server, pg_ntoh16(remoteaddr.sin6_port)))); +#else + ereport(LOG, + (errmsg("RADIUS response from %s was sent from incorrect port: %d", + server, pg_ntoh16(remoteaddr.sin_port)))); +#endif + continue; + } + + if (packetlength < RADIUS_HEADER_LENGTH) + { + ereport(LOG, + (errmsg("RADIUS response from %s too short: %d", server, packetlength))); + continue; + } + + if (packetlength != pg_ntoh16(receivepacket->length)) + { + ereport(LOG, + (errmsg("RADIUS response from %s has corrupt length: %d (actual length %d)", + server, pg_ntoh16(receivepacket->length), packetlength))); + continue; + } + + if (packet->id != receivepacket->id) + { + ereport(LOG, + (errmsg("RADIUS response from %s is to a different request: %d (should be %d)", + server, receivepacket->id, packet->id))); + continue; + } + + /* + * Verify the response authenticator, which is calculated as + * MD5(Code+ID+Length+RequestAuthenticator+Attributes+Secret) + */ + cryptvector = palloc(packetlength + strlen(secret)); + + memcpy(cryptvector, receivepacket, 4); /* code+id+length */ + memcpy(cryptvector + 4, packet->vector, RADIUS_VECTOR_LENGTH); /* request + * authenticator, from + * original packet */ + if (packetlength > RADIUS_HEADER_LENGTH) /* there may be no + * attributes at all */ + memcpy(cryptvector + RADIUS_HEADER_LENGTH, receive_buffer + RADIUS_HEADER_LENGTH, packetlength - RADIUS_HEADER_LENGTH); + memcpy(cryptvector + packetlength, secret, strlen(secret)); + + if (!pg_md5_binary(cryptvector, + packetlength + strlen(secret), + encryptedpassword)) + { + ereport(LOG, + (errmsg("could not perform MD5 encryption of received packet"))); + pfree(cryptvector); + continue; + } + pfree(cryptvector); + + if (memcmp(receivepacket->vector, encryptedpassword, RADIUS_VECTOR_LENGTH) != 0) + { + ereport(LOG, + (errmsg("RADIUS response from %s has incorrect MD5 signature", + server))); + continue; + } + + if (receivepacket->code == RADIUS_ACCESS_ACCEPT) + { + closesocket(sock); + return STATUS_OK; + } + else if (receivepacket->code == RADIUS_ACCESS_REJECT) + { + closesocket(sock); + return STATUS_EOF; + } + else + { + ereport(LOG, + (errmsg("RADIUS response from %s has invalid code (%d) for user \"%s\"", + server, receivepacket->code, user_name))); + continue; + } + } /* while (true) */ +} diff --git a/src/backend/libpq/be-fsstubs.c b/src/backend/libpq/be-fsstubs.c new file mode 100644 index 0000000..63eaccc --- /dev/null +++ b/src/backend/libpq/be-fsstubs.c @@ -0,0 +1,860 @@ +/*------------------------------------------------------------------------- + * + * be-fsstubs.c + * Builtin functions for open/close/read/write operations on large objects + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * + * IDENTIFICATION + * src/backend/libpq/be-fsstubs.c + * + * NOTES + * This should be moved to a more appropriate place. It is here + * for lack of a better place. + * + * These functions store LargeObjectDesc structs in a private MemoryContext, + * which means that large object descriptors hang around until we destroy + * the context at transaction end. It'd be possible to prolong the lifetime + * of the context so that LO FDs are good across transactions (for example, + * we could release the context only if we see that no FDs remain open). + * But we'd need additional state in order to do the right thing at the + * end of an aborted transaction. FDs opened during an aborted xact would + * still need to be closed, since they might not be pointing at valid + * relations at all. Locking semantics are also an interesting problem + * if LOs stay open across transactions. For now, we'll stick with the + * existing documented semantics of LO FDs: they're only good within a + * transaction. + * + * As of PostgreSQL 8.0, much of the angst expressed above is no longer + * relevant, and in fact it'd be pretty easy to allow LO FDs to stay + * open across transactions. (Snapshot relevancy would still be an issue.) + * However backwards compatibility suggests that we should stick to the + * status quo. + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" + +#include <fcntl.h> +#include <sys/stat.h> +#include <unistd.h> + +#include "access/xact.h" +#include "libpq/be-fsstubs.h" +#include "libpq/libpq-fs.h" +#include "miscadmin.h" +#include "storage/fd.h" +#include "storage/large_object.h" +#include "utils/acl.h" +#include "utils/builtins.h" +#include "utils/memutils.h" +#include "utils/snapmgr.h" + +/* define this to enable debug logging */ +/* #define FSDB 1 */ +/* chunk size for lo_import/lo_export transfers */ +#define BUFSIZE 8192 + +/* + * LO "FD"s are indexes into the cookies array. + * + * A non-null entry is a pointer to a LargeObjectDesc allocated in the + * LO private memory context "fscxt". The cookies array itself is also + * dynamically allocated in that context. Its current allocated size is + * cookies_size entries, of which any unused entries will be NULL. + */ +static LargeObjectDesc **cookies = NULL; +static int cookies_size = 0; + +static bool lo_cleanup_needed = false; +static MemoryContext fscxt = NULL; + +static int newLOfd(void); +static void closeLOfd(int fd); +static Oid lo_import_internal(text *filename, Oid lobjOid); + + +/***************************************************************************** + * File Interfaces for Large Objects + *****************************************************************************/ + +Datum +be_lo_open(PG_FUNCTION_ARGS) +{ + Oid lobjId = PG_GETARG_OID(0); + int32 mode = PG_GETARG_INT32(1); + LargeObjectDesc *lobjDesc; + int fd; + +#ifdef FSDB + elog(DEBUG4, "lo_open(%u,%d)", lobjId, mode); +#endif + + /* + * Allocate a large object descriptor first. This will also create + * 'fscxt' if this is the first LO opened in this transaction. + */ + fd = newLOfd(); + + lobjDesc = inv_open(lobjId, mode, fscxt); + lobjDesc->subid = GetCurrentSubTransactionId(); + + /* + * We must register the snapshot in TopTransaction's resowner so that it + * stays alive until the LO is closed rather than until the current portal + * shuts down. + */ + if (lobjDesc->snapshot) + lobjDesc->snapshot = RegisterSnapshotOnOwner(lobjDesc->snapshot, + TopTransactionResourceOwner); + + Assert(cookies[fd] == NULL); + cookies[fd] = lobjDesc; + + PG_RETURN_INT32(fd); +} + +Datum +be_lo_close(PG_FUNCTION_ARGS) +{ + int32 fd = PG_GETARG_INT32(0); + + if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL) + ereport(ERROR, + (errcode(ERRCODE_UNDEFINED_OBJECT), + errmsg("invalid large-object descriptor: %d", fd))); + +#ifdef FSDB + elog(DEBUG4, "lo_close(%d)", fd); +#endif + + closeLOfd(fd); + + PG_RETURN_INT32(0); +} + + +/***************************************************************************** + * Bare Read/Write operations --- these are not fmgr-callable! + * + * We assume the large object supports byte oriented reads and seeks so + * that our work is easier. + * + *****************************************************************************/ + +int +lo_read(int fd, char *buf, int len) +{ + int status; + LargeObjectDesc *lobj; + + if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL) + ereport(ERROR, + (errcode(ERRCODE_UNDEFINED_OBJECT), + errmsg("invalid large-object descriptor: %d", fd))); + lobj = cookies[fd]; + + /* + * Check state. inv_read() would throw an error anyway, but we want the + * error to be about the FD's state not the underlying privilege; it might + * be that the privilege exists but user forgot to ask for read mode. + */ + if ((lobj->flags & IFS_RDLOCK) == 0) + ereport(ERROR, + (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("large object descriptor %d was not opened for reading", + fd))); + + status = inv_read(lobj, buf, len); + + return status; +} + +int +lo_write(int fd, const char *buf, int len) +{ + int status; + LargeObjectDesc *lobj; + + if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL) + ereport(ERROR, + (errcode(ERRCODE_UNDEFINED_OBJECT), + errmsg("invalid large-object descriptor: %d", fd))); + lobj = cookies[fd]; + + /* see comment in lo_read() */ + if ((lobj->flags & IFS_WRLOCK) == 0) + ereport(ERROR, + (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("large object descriptor %d was not opened for writing", + fd))); + + status = inv_write(lobj, buf, len); + + return status; +} + +Datum +be_lo_lseek(PG_FUNCTION_ARGS) +{ + int32 fd = PG_GETARG_INT32(0); + int32 offset = PG_GETARG_INT32(1); + int32 whence = PG_GETARG_INT32(2); + int64 status; + + if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL) + ereport(ERROR, + (errcode(ERRCODE_UNDEFINED_OBJECT), + errmsg("invalid large-object descriptor: %d", fd))); + + status = inv_seek(cookies[fd], offset, whence); + + /* guard against result overflow */ + if (status != (int32) status) + ereport(ERROR, + (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), + errmsg("lo_lseek result out of range for large-object descriptor %d", + fd))); + + PG_RETURN_INT32((int32) status); +} + +Datum +be_lo_lseek64(PG_FUNCTION_ARGS) +{ + int32 fd = PG_GETARG_INT32(0); + int64 offset = PG_GETARG_INT64(1); + int32 whence = PG_GETARG_INT32(2); + int64 status; + + if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL) + ereport(ERROR, + (errcode(ERRCODE_UNDEFINED_OBJECT), + errmsg("invalid large-object descriptor: %d", fd))); + + status = inv_seek(cookies[fd], offset, whence); + + PG_RETURN_INT64(status); +} + +Datum +be_lo_creat(PG_FUNCTION_ARGS) +{ + Oid lobjId; + + lo_cleanup_needed = true; + lobjId = inv_create(InvalidOid); + + PG_RETURN_OID(lobjId); +} + +Datum +be_lo_create(PG_FUNCTION_ARGS) +{ + Oid lobjId = PG_GETARG_OID(0); + + lo_cleanup_needed = true; + lobjId = inv_create(lobjId); + + PG_RETURN_OID(lobjId); +} + +Datum +be_lo_tell(PG_FUNCTION_ARGS) +{ + int32 fd = PG_GETARG_INT32(0); + int64 offset; + + if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL) + ereport(ERROR, + (errcode(ERRCODE_UNDEFINED_OBJECT), + errmsg("invalid large-object descriptor: %d", fd))); + + offset = inv_tell(cookies[fd]); + + /* guard against result overflow */ + if (offset != (int32) offset) + ereport(ERROR, + (errcode(ERRCODE_NUMERIC_VALUE_OUT_OF_RANGE), + errmsg("lo_tell result out of range for large-object descriptor %d", + fd))); + + PG_RETURN_INT32((int32) offset); +} + +Datum +be_lo_tell64(PG_FUNCTION_ARGS) +{ + int32 fd = PG_GETARG_INT32(0); + int64 offset; + + if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL) + ereport(ERROR, + (errcode(ERRCODE_UNDEFINED_OBJECT), + errmsg("invalid large-object descriptor: %d", fd))); + + offset = inv_tell(cookies[fd]); + + PG_RETURN_INT64(offset); +} + +Datum +be_lo_unlink(PG_FUNCTION_ARGS) +{ + Oid lobjId = PG_GETARG_OID(0); + + /* + * Must be owner of the large object. It would be cleaner to check this + * in inv_drop(), but we want to throw the error before not after closing + * relevant FDs. + */ + if (!lo_compat_privileges && + !pg_largeobject_ownercheck(lobjId, GetUserId())) + ereport(ERROR, + (errcode(ERRCODE_INSUFFICIENT_PRIVILEGE), + errmsg("must be owner of large object %u", lobjId))); + + /* + * If there are any open LO FDs referencing that ID, close 'em. + */ + if (fscxt != NULL) + { + int i; + + for (i = 0; i < cookies_size; i++) + { + if (cookies[i] != NULL && cookies[i]->id == lobjId) + closeLOfd(i); + } + } + + /* + * inv_drop does not create a need for end-of-transaction cleanup and + * hence we don't need to set lo_cleanup_needed. + */ + PG_RETURN_INT32(inv_drop(lobjId)); +} + +/***************************************************************************** + * Read/Write using bytea + *****************************************************************************/ + +Datum +be_loread(PG_FUNCTION_ARGS) +{ + int32 fd = PG_GETARG_INT32(0); + int32 len = PG_GETARG_INT32(1); + bytea *retval; + int totalread; + + if (len < 0) + len = 0; + + retval = (bytea *) palloc(VARHDRSZ + len); + totalread = lo_read(fd, VARDATA(retval), len); + SET_VARSIZE(retval, totalread + VARHDRSZ); + + PG_RETURN_BYTEA_P(retval); +} + +Datum +be_lowrite(PG_FUNCTION_ARGS) +{ + int32 fd = PG_GETARG_INT32(0); + bytea *wbuf = PG_GETARG_BYTEA_PP(1); + int bytestowrite; + int totalwritten; + + bytestowrite = VARSIZE_ANY_EXHDR(wbuf); + totalwritten = lo_write(fd, VARDATA_ANY(wbuf), bytestowrite); + PG_RETURN_INT32(totalwritten); +} + +/***************************************************************************** + * Import/Export of Large Object + *****************************************************************************/ + +/* + * lo_import - + * imports a file as an (inversion) large object. + */ +Datum +be_lo_import(PG_FUNCTION_ARGS) +{ + text *filename = PG_GETARG_TEXT_PP(0); + + PG_RETURN_OID(lo_import_internal(filename, InvalidOid)); +} + +/* + * lo_import_with_oid - + * imports a file as an (inversion) large object specifying oid. + */ +Datum +be_lo_import_with_oid(PG_FUNCTION_ARGS) +{ + text *filename = PG_GETARG_TEXT_PP(0); + Oid oid = PG_GETARG_OID(1); + + PG_RETURN_OID(lo_import_internal(filename, oid)); +} + +static Oid +lo_import_internal(text *filename, Oid lobjOid) +{ + int fd; + int nbytes, + tmp PG_USED_FOR_ASSERTS_ONLY; + char buf[BUFSIZE]; + char fnamebuf[MAXPGPATH]; + LargeObjectDesc *lobj; + Oid oid; + + /* + * open the file to be read in + */ + text_to_cstring_buffer(filename, fnamebuf, sizeof(fnamebuf)); + fd = OpenTransientFile(fnamebuf, O_RDONLY | PG_BINARY); + if (fd < 0) + ereport(ERROR, + (errcode_for_file_access(), + errmsg("could not open server file \"%s\": %m", + fnamebuf))); + + /* + * create an inversion object + */ + lo_cleanup_needed = true; + oid = inv_create(lobjOid); + + /* + * read in from the filesystem and write to the inversion object + */ + lobj = inv_open(oid, INV_WRITE, CurrentMemoryContext); + + while ((nbytes = read(fd, buf, BUFSIZE)) > 0) + { + tmp = inv_write(lobj, buf, nbytes); + Assert(tmp == nbytes); + } + + if (nbytes < 0) + ereport(ERROR, + (errcode_for_file_access(), + errmsg("could not read server file \"%s\": %m", + fnamebuf))); + + inv_close(lobj); + + if (CloseTransientFile(fd) != 0) + ereport(ERROR, + (errcode_for_file_access(), + errmsg("could not close file \"%s\": %m", + fnamebuf))); + + return oid; +} + +/* + * lo_export - + * exports an (inversion) large object. + */ +Datum +be_lo_export(PG_FUNCTION_ARGS) +{ + Oid lobjId = PG_GETARG_OID(0); + text *filename = PG_GETARG_TEXT_PP(1); + int fd; + int nbytes, + tmp; + char buf[BUFSIZE]; + char fnamebuf[MAXPGPATH]; + LargeObjectDesc *lobj; + mode_t oumask; + + /* + * open the inversion object (no need to test for failure) + */ + lo_cleanup_needed = true; + lobj = inv_open(lobjId, INV_READ, CurrentMemoryContext); + + /* + * open the file to be written to + * + * Note: we reduce backend's normal 077 umask to the slightly friendlier + * 022. This code used to drop it all the way to 0, but creating + * world-writable export files doesn't seem wise. + */ + text_to_cstring_buffer(filename, fnamebuf, sizeof(fnamebuf)); + oumask = umask(S_IWGRP | S_IWOTH); + PG_TRY(); + { + fd = OpenTransientFilePerm(fnamebuf, O_CREAT | O_WRONLY | O_TRUNC | PG_BINARY, + S_IRUSR | S_IWUSR | S_IRGRP | S_IROTH); + } + PG_FINALLY(); + { + umask(oumask); + } + PG_END_TRY(); + if (fd < 0) + ereport(ERROR, + (errcode_for_file_access(), + errmsg("could not create server file \"%s\": %m", + fnamebuf))); + + /* + * read in from the inversion file and write to the filesystem + */ + while ((nbytes = inv_read(lobj, buf, BUFSIZE)) > 0) + { + tmp = write(fd, buf, nbytes); + if (tmp != nbytes) + ereport(ERROR, + (errcode_for_file_access(), + errmsg("could not write server file \"%s\": %m", + fnamebuf))); + } + + if (CloseTransientFile(fd) != 0) + ereport(ERROR, + (errcode_for_file_access(), + errmsg("could not close file \"%s\": %m", + fnamebuf))); + + inv_close(lobj); + + PG_RETURN_INT32(1); +} + +/* + * lo_truncate - + * truncate a large object to a specified length + */ +static void +lo_truncate_internal(int32 fd, int64 len) +{ + LargeObjectDesc *lobj; + + if (fd < 0 || fd >= cookies_size || cookies[fd] == NULL) + ereport(ERROR, + (errcode(ERRCODE_UNDEFINED_OBJECT), + errmsg("invalid large-object descriptor: %d", fd))); + lobj = cookies[fd]; + + /* see comment in lo_read() */ + if ((lobj->flags & IFS_WRLOCK) == 0) + ereport(ERROR, + (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("large object descriptor %d was not opened for writing", + fd))); + + inv_truncate(lobj, len); +} + +Datum +be_lo_truncate(PG_FUNCTION_ARGS) +{ + int32 fd = PG_GETARG_INT32(0); + int32 len = PG_GETARG_INT32(1); + + lo_truncate_internal(fd, len); + PG_RETURN_INT32(0); +} + +Datum +be_lo_truncate64(PG_FUNCTION_ARGS) +{ + int32 fd = PG_GETARG_INT32(0); + int64 len = PG_GETARG_INT64(1); + + lo_truncate_internal(fd, len); + PG_RETURN_INT32(0); +} + +/* + * AtEOXact_LargeObject - + * prepares large objects for transaction commit + */ +void +AtEOXact_LargeObject(bool isCommit) +{ + int i; + + if (!lo_cleanup_needed) + return; /* no LO operations in this xact */ + + /* + * Close LO fds and clear cookies array so that LO fds are no longer good. + * The memory context and resource owner holding them are going away at + * the end-of-transaction anyway, but on commit, we need to close them to + * avoid warnings about leaked resources at commit. On abort we can skip + * this step. + */ + if (isCommit) + { + for (i = 0; i < cookies_size; i++) + { + if (cookies[i] != NULL) + closeLOfd(i); + } + } + + /* Needn't actually pfree since we're about to zap context */ + cookies = NULL; + cookies_size = 0; + + /* Release the LO memory context to prevent permanent memory leaks. */ + if (fscxt) + MemoryContextDelete(fscxt); + fscxt = NULL; + + /* Give inv_api.c a chance to clean up, too */ + close_lo_relation(isCommit); + + lo_cleanup_needed = false; +} + +/* + * AtEOSubXact_LargeObject + * Take care of large objects at subtransaction commit/abort + * + * Reassign LOs created/opened during a committing subtransaction + * to the parent subtransaction. On abort, just close them. + */ +void +AtEOSubXact_LargeObject(bool isCommit, SubTransactionId mySubid, + SubTransactionId parentSubid) +{ + int i; + + if (fscxt == NULL) /* no LO operations in this xact */ + return; + + for (i = 0; i < cookies_size; i++) + { + LargeObjectDesc *lo = cookies[i]; + + if (lo != NULL && lo->subid == mySubid) + { + if (isCommit) + lo->subid = parentSubid; + else + closeLOfd(i); + } + } +} + +/***************************************************************************** + * Support routines for this file + *****************************************************************************/ + +static int +newLOfd(void) +{ + int i, + newsize; + + lo_cleanup_needed = true; + if (fscxt == NULL) + fscxt = AllocSetContextCreate(TopMemoryContext, + "Filesystem", + ALLOCSET_DEFAULT_SIZES); + + /* Try to find a free slot */ + for (i = 0; i < cookies_size; i++) + { + if (cookies[i] == NULL) + return i; + } + + /* No free slot, so make the array bigger */ + if (cookies_size <= 0) + { + /* First time through, arbitrarily make 64-element array */ + i = 0; + newsize = 64; + cookies = (LargeObjectDesc **) + MemoryContextAllocZero(fscxt, newsize * sizeof(LargeObjectDesc *)); + cookies_size = newsize; + } + else + { + /* Double size of array */ + i = cookies_size; + newsize = cookies_size * 2; + cookies = (LargeObjectDesc **) + repalloc(cookies, newsize * sizeof(LargeObjectDesc *)); + MemSet(cookies + cookies_size, 0, + (newsize - cookies_size) * sizeof(LargeObjectDesc *)); + cookies_size = newsize; + } + + return i; +} + +static void +closeLOfd(int fd) +{ + LargeObjectDesc *lobj; + + /* + * Make sure we do not try to free twice if this errors out for some + * reason. Better a leak than a crash. + */ + lobj = cookies[fd]; + cookies[fd] = NULL; + + if (lobj->snapshot) + UnregisterSnapshotFromOwner(lobj->snapshot, + TopTransactionResourceOwner); + inv_close(lobj); +} + +/***************************************************************************** + * Wrappers oriented toward SQL callers + *****************************************************************************/ + +/* + * Read [offset, offset+nbytes) within LO; when nbytes is -1, read to end. + */ +static bytea * +lo_get_fragment_internal(Oid loOid, int64 offset, int32 nbytes) +{ + LargeObjectDesc *loDesc; + int64 loSize; + int64 result_length; + int total_read PG_USED_FOR_ASSERTS_ONLY; + bytea *result = NULL; + + lo_cleanup_needed = true; + loDesc = inv_open(loOid, INV_READ, CurrentMemoryContext); + + /* + * Compute number of bytes we'll actually read, accommodating nbytes == -1 + * and reads beyond the end of the LO. + */ + loSize = inv_seek(loDesc, 0, SEEK_END); + if (loSize > offset) + { + if (nbytes >= 0 && nbytes <= loSize - offset) + result_length = nbytes; /* request is wholly inside LO */ + else + result_length = loSize - offset; /* adjust to end of LO */ + } + else + result_length = 0; /* request is wholly outside LO */ + + /* + * A result_length calculated from loSize may not fit in a size_t. Check + * that the size will satisfy this and subsequently-enforced size limits. + */ + if (result_length > MaxAllocSize - VARHDRSZ) + ereport(ERROR, + (errcode(ERRCODE_PROGRAM_LIMIT_EXCEEDED), + errmsg("large object read request is too large"))); + + result = (bytea *) palloc(VARHDRSZ + result_length); + + inv_seek(loDesc, offset, SEEK_SET); + total_read = inv_read(loDesc, VARDATA(result), result_length); + Assert(total_read == result_length); + SET_VARSIZE(result, result_length + VARHDRSZ); + + inv_close(loDesc); + + return result; +} + +/* + * Read entire LO + */ +Datum +be_lo_get(PG_FUNCTION_ARGS) +{ + Oid loOid = PG_GETARG_OID(0); + bytea *result; + + result = lo_get_fragment_internal(loOid, 0, -1); + + PG_RETURN_BYTEA_P(result); +} + +/* + * Read range within LO + */ +Datum +be_lo_get_fragment(PG_FUNCTION_ARGS) +{ + Oid loOid = PG_GETARG_OID(0); + int64 offset = PG_GETARG_INT64(1); + int32 nbytes = PG_GETARG_INT32(2); + bytea *result; + + if (nbytes < 0) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("requested length cannot be negative"))); + + result = lo_get_fragment_internal(loOid, offset, nbytes); + + PG_RETURN_BYTEA_P(result); +} + +/* + * Create LO with initial contents given by a bytea argument + */ +Datum +be_lo_from_bytea(PG_FUNCTION_ARGS) +{ + Oid loOid = PG_GETARG_OID(0); + bytea *str = PG_GETARG_BYTEA_PP(1); + LargeObjectDesc *loDesc; + int written PG_USED_FOR_ASSERTS_ONLY; + + lo_cleanup_needed = true; + loOid = inv_create(loOid); + loDesc = inv_open(loOid, INV_WRITE, CurrentMemoryContext); + written = inv_write(loDesc, VARDATA_ANY(str), VARSIZE_ANY_EXHDR(str)); + Assert(written == VARSIZE_ANY_EXHDR(str)); + inv_close(loDesc); + + PG_RETURN_OID(loOid); +} + +/* + * Update range within LO + */ +Datum +be_lo_put(PG_FUNCTION_ARGS) +{ + Oid loOid = PG_GETARG_OID(0); + int64 offset = PG_GETARG_INT64(1); + bytea *str = PG_GETARG_BYTEA_PP(2); + LargeObjectDesc *loDesc; + int written PG_USED_FOR_ASSERTS_ONLY; + + lo_cleanup_needed = true; + loDesc = inv_open(loOid, INV_WRITE, CurrentMemoryContext); + + /* Permission check */ + if (!lo_compat_privileges && + pg_largeobject_aclcheck_snapshot(loDesc->id, + GetUserId(), + ACL_UPDATE, + loDesc->snapshot) != ACLCHECK_OK) + ereport(ERROR, + (errcode(ERRCODE_INSUFFICIENT_PRIVILEGE), + errmsg("permission denied for large object %u", + loDesc->id))); + + inv_seek(loDesc, offset, SEEK_SET); + written = inv_write(loDesc, VARDATA_ANY(str), VARSIZE_ANY_EXHDR(str)); + Assert(written == VARSIZE_ANY_EXHDR(str)); + inv_close(loDesc); + + PG_RETURN_VOID(); +} diff --git a/src/backend/libpq/be-gssapi-common.c b/src/backend/libpq/be-gssapi-common.c new file mode 100644 index 0000000..38f58de --- /dev/null +++ b/src/backend/libpq/be-gssapi-common.c @@ -0,0 +1,94 @@ +/*------------------------------------------------------------------------- + * + * be-gssapi-common.c + * Common code for GSSAPI authentication and encryption + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * IDENTIFICATION + * src/backend/libpq/be-gssapi-common.c + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" + +#include "libpq/be-gssapi-common.h" + +/* + * Fetch all errors of a specific type and append to "s" (buffer of size len). + * If we obtain more than one string, separate them with spaces. + * Call once for GSS_CODE and once for MECH_CODE. + */ +static void +pg_GSS_error_int(char *s, size_t len, OM_uint32 stat, int type) +{ + gss_buffer_desc gmsg; + size_t i = 0; + OM_uint32 lmin_s, + msg_ctx = 0; + + do + { + if (gss_display_status(&lmin_s, stat, type, GSS_C_NO_OID, + &msg_ctx, &gmsg) != GSS_S_COMPLETE) + break; + if (i > 0) + { + if (i < len) + s[i] = ' '; + i++; + } + if (i < len) + memcpy(s + i, gmsg.value, Min(len - i, gmsg.length)); + i += gmsg.length; + gss_release_buffer(&lmin_s, &gmsg); + } + while (msg_ctx); + + /* add nul termination */ + if (i < len) + s[i] = '\0'; + else + { + elog(COMMERROR, "incomplete GSS error report"); + s[len - 1] = '\0'; + } +} + +/* + * Report the GSSAPI error described by maj_stat/min_stat. + * + * errmsg should be an already-translated primary error message. + * The GSSAPI info is appended as errdetail. + * + * The error is always reported with elevel COMMERROR; we daren't try to + * send it to the client, as that'd likely lead to infinite recursion + * when elog.c tries to write to the client. + * + * To avoid memory allocation, total error size is capped (at 128 bytes for + * each of major and minor). No known mechanisms will produce error messages + * beyond this cap. + */ +void +pg_GSS_error(const char *errmsg, + OM_uint32 maj_stat, OM_uint32 min_stat) +{ + char msg_major[128], + msg_minor[128]; + + /* Fetch major status message */ + pg_GSS_error_int(msg_major, sizeof(msg_major), maj_stat, GSS_C_GSS_CODE); + + /* Fetch mechanism minor status message */ + pg_GSS_error_int(msg_minor, sizeof(msg_minor), min_stat, GSS_C_MECH_CODE); + + /* + * errmsg_internal, since translation of the first part must be done + * before calling this function anyway. + */ + ereport(COMMERROR, + (errmsg_internal("%s", errmsg), + errdetail_internal("%s: %s", msg_major, msg_minor))); +} diff --git a/src/backend/libpq/be-secure-common.c b/src/backend/libpq/be-secure-common.c new file mode 100644 index 0000000..7d082d7 --- /dev/null +++ b/src/backend/libpq/be-secure-common.c @@ -0,0 +1,195 @@ +/*------------------------------------------------------------------------- + * + * be-secure-common.c + * + * common implementation-independent SSL support code + * + * While be-secure.c contains the interfaces that the rest of the + * communications code calls, this file contains support routines that are + * used by the library-specific implementations such as be-secure-openssl.c. + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * IDENTIFICATION + * src/backend/libpq/be-secure-common.c + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" + +#include <sys/stat.h> +#include <unistd.h> + +#include "common/string.h" +#include "libpq/libpq.h" +#include "storage/fd.h" + +/* + * Run ssl_passphrase_command + * + * prompt will be substituted for %p. is_server_start determines the loglevel + * of error messages. + * + * The result will be put in buffer buf, which is of size size. The return + * value is the length of the actual result. + */ +int +run_ssl_passphrase_command(const char *prompt, bool is_server_start, char *buf, int size) +{ + int loglevel = is_server_start ? ERROR : LOG; + StringInfoData command; + char *p; + FILE *fh; + int pclose_rc; + size_t len = 0; + + Assert(prompt); + Assert(size > 0); + buf[0] = '\0'; + + initStringInfo(&command); + + for (p = ssl_passphrase_command; *p; p++) + { + if (p[0] == '%') + { + switch (p[1]) + { + case 'p': + appendStringInfoString(&command, prompt); + p++; + break; + case '%': + appendStringInfoChar(&command, '%'); + p++; + break; + default: + appendStringInfoChar(&command, p[0]); + } + } + else + appendStringInfoChar(&command, p[0]); + } + + fh = OpenPipeStream(command.data, "r"); + if (fh == NULL) + { + ereport(loglevel, + (errcode_for_file_access(), + errmsg("could not execute command \"%s\": %m", + command.data))); + goto error; + } + + if (!fgets(buf, size, fh)) + { + if (ferror(fh)) + { + explicit_bzero(buf, size); + ereport(loglevel, + (errcode_for_file_access(), + errmsg("could not read from command \"%s\": %m", + command.data))); + goto error; + } + } + + pclose_rc = ClosePipeStream(fh); + if (pclose_rc == -1) + { + explicit_bzero(buf, size); + ereport(loglevel, + (errcode_for_file_access(), + errmsg("could not close pipe to external command: %m"))); + goto error; + } + else if (pclose_rc != 0) + { + explicit_bzero(buf, size); + ereport(loglevel, + (errcode_for_file_access(), + errmsg("command \"%s\" failed", + command.data), + errdetail_internal("%s", wait_result_to_str(pclose_rc)))); + goto error; + } + + /* strip trailing newline and carriage return */ + len = pg_strip_crlf(buf); + +error: + pfree(command.data); + return len; +} + + +/* + * Check permissions for SSL key files. + */ +bool +check_ssl_key_file_permissions(const char *ssl_key_file, bool isServerStart) +{ + int loglevel = isServerStart ? FATAL : LOG; + struct stat buf; + + if (stat(ssl_key_file, &buf) != 0) + { + ereport(loglevel, + (errcode_for_file_access(), + errmsg("could not access private key file \"%s\": %m", + ssl_key_file))); + return false; + } + + /* Key file must be a regular file */ + if (!S_ISREG(buf.st_mode)) + { + ereport(loglevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("private key file \"%s\" is not a regular file", + ssl_key_file))); + return false; + } + + /* + * Refuse to load key files owned by users other than us or root, and + * require no public access to the key file. If the file is owned by us, + * require mode 0600 or less. If owned by root, require 0640 or less to + * allow read access through either our gid or a supplementary gid that + * allows us to read system-wide certificates. + * + * Note that roughly similar checks are performed in + * src/interfaces/libpq/fe-secure-openssl.c so any changes here may need + * to be made there as well. The environment is different though; this + * code can assume that we're not running as root. + * + * Ideally we would do similar permissions checks on Windows, but it is + * not clear how that would work since Unix-style permissions may not be + * available. + */ +#if !defined(WIN32) && !defined(__CYGWIN__) + if (buf.st_uid != geteuid() && buf.st_uid != 0) + { + ereport(loglevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("private key file \"%s\" must be owned by the database user or root", + ssl_key_file))); + return false; + } + + if ((buf.st_uid == geteuid() && buf.st_mode & (S_IRWXG | S_IRWXO)) || + (buf.st_uid == 0 && buf.st_mode & (S_IWGRP | S_IXGRP | S_IRWXO))) + { + ereport(loglevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("private key file \"%s\" has group or world access", + ssl_key_file), + errdetail("File must have permissions u=rw (0600) or less if owned by the database user, or permissions u=rw,g=r (0640) or less if owned by root."))); + return false; + } +#endif + + return true; +} diff --git a/src/backend/libpq/be-secure-gssapi.c b/src/backend/libpq/be-secure-gssapi.c new file mode 100644 index 0000000..316ca65 --- /dev/null +++ b/src/backend/libpq/be-secure-gssapi.c @@ -0,0 +1,733 @@ +/*------------------------------------------------------------------------- + * + * be-secure-gssapi.c + * GSSAPI encryption support + * + * Portions Copyright (c) 2018-2021, PostgreSQL Global Development Group + * + * IDENTIFICATION + * src/backend/libpq/be-secure-gssapi.c + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" + +#include <unistd.h> + +#include "libpq/auth.h" +#include "libpq/be-gssapi-common.h" +#include "libpq/libpq.h" +#include "libpq/pqformat.h" +#include "miscadmin.h" +#include "pgstat.h" +#include "utils/memutils.h" + + +/* + * Handle the encryption/decryption of data using GSSAPI. + * + * In the encrypted data stream on the wire, we break up the data + * into packets where each packet starts with a uint32-size length + * word (in network byte order), then encrypted data of that length + * immediately following. Decryption yields the same data stream + * that would appear when not using encryption. + * + * Encrypted data typically ends up being larger than the same data + * unencrypted, so we use fixed-size buffers for handling the + * encryption/decryption which are larger than PQComm's buffer will + * typically be to minimize the times where we have to make multiple + * packets (and therefore multiple recv/send calls for a single + * read/write call to us). + * + * NOTE: The client and server have to agree on the max packet size, + * because we have to pass an entire packet to GSSAPI at a time and we + * don't want the other side to send arbitrarily huge packets as we + * would have to allocate memory for them to then pass them to GSSAPI. + * + * Therefore, these two #define's are effectively part of the protocol + * spec and can't ever be changed. + */ +#define PQ_GSS_SEND_BUFFER_SIZE 16384 +#define PQ_GSS_RECV_BUFFER_SIZE 16384 + +/* + * Since we manage at most one GSS-encrypted connection per backend, + * we can just keep all this state in static variables. The char * + * variables point to buffers that are allocated once and re-used. + */ +static char *PqGSSSendBuffer; /* Encrypted data waiting to be sent */ +static int PqGSSSendLength; /* End of data available in PqGSSSendBuffer */ +static int PqGSSSendNext; /* Next index to send a byte from + * PqGSSSendBuffer */ +static int PqGSSSendConsumed; /* Number of *unencrypted* bytes consumed for + * current contents of PqGSSSendBuffer */ + +static char *PqGSSRecvBuffer; /* Received, encrypted data */ +static int PqGSSRecvLength; /* End of data available in PqGSSRecvBuffer */ + +static char *PqGSSResultBuffer; /* Decryption of data in gss_RecvBuffer */ +static int PqGSSResultLength; /* End of data available in PqGSSResultBuffer */ +static int PqGSSResultNext; /* Next index to read a byte from + * PqGSSResultBuffer */ + +static uint32 PqGSSMaxPktSize; /* Maximum size we can encrypt and fit the + * results into our output buffer */ + + +/* + * Attempt to write len bytes of data from ptr to a GSSAPI-encrypted connection. + * + * The connection must be already set up for GSSAPI encryption (i.e., GSSAPI + * transport negotiation is complete). + * + * On success, returns the number of data bytes consumed (possibly less than + * len). On failure, returns -1 with errno set appropriately. For retryable + * errors, caller should call again (passing the same data) once the socket + * is ready. + * + * Dealing with fatal errors here is a bit tricky: we can't invoke elog(FATAL) + * since it would try to write to the client, probably resulting in infinite + * recursion. Instead, use elog(COMMERROR) to log extra info about the + * failure if necessary, and then return an errno indicating connection loss. + */ +ssize_t +be_gssapi_write(Port *port, void *ptr, size_t len) +{ + OM_uint32 major, + minor; + gss_buffer_desc input, + output; + size_t bytes_sent = 0; + size_t bytes_to_encrypt; + size_t bytes_encrypted; + gss_ctx_id_t gctx = port->gss->ctx; + + /* + * When we get a failure, we must not tell the caller we have successfully + * transmitted everything, else it won't retry. Hence a "success" + * (positive) return value must only count source bytes corresponding to + * fully-transmitted encrypted packets. The amount of source data + * corresponding to the current partly-transmitted packet is remembered in + * PqGSSSendConsumed. On a retry, the caller *must* be sending that data + * again, so if it offers a len less than that, something is wrong. + */ + if (len < PqGSSSendConsumed) + { + elog(COMMERROR, "GSSAPI caller failed to retransmit all data needing to be retried"); + errno = ECONNRESET; + return -1; + } + /* Discount whatever source data we already encrypted. */ + bytes_to_encrypt = len - PqGSSSendConsumed; + bytes_encrypted = PqGSSSendConsumed; + + /* + * Loop through encrypting data and sending it out until it's all done or + * secure_raw_write() complains (which would likely mean that the socket + * is non-blocking and the requested send() would block, or there was some + * kind of actual error). + */ + while (bytes_to_encrypt || PqGSSSendLength) + { + int conf_state = 0; + uint32 netlen; + + /* + * Check if we have data in the encrypted output buffer that needs to + * be sent (possibly left over from a previous call), and if so, try + * to send it. If we aren't able to, return that fact back up to the + * caller. + */ + if (PqGSSSendLength) + { + ssize_t ret; + ssize_t amount = PqGSSSendLength - PqGSSSendNext; + + ret = secure_raw_write(port, PqGSSSendBuffer + PqGSSSendNext, amount); + if (ret <= 0) + { + /* + * Report any previously-sent data; if there was none, reflect + * the secure_raw_write result up to our caller. When there + * was some, we're effectively assuming that any interesting + * failure condition will recur on the next try. + */ + if (bytes_sent) + return bytes_sent; + return ret; + } + + /* + * Check if this was a partial write, and if so, move forward that + * far in our buffer and try again. + */ + if (ret != amount) + { + PqGSSSendNext += ret; + continue; + } + + /* We've successfully sent whatever data was in that packet. */ + bytes_sent += PqGSSSendConsumed; + + /* All encrypted data was sent, our buffer is empty now. */ + PqGSSSendLength = PqGSSSendNext = PqGSSSendConsumed = 0; + } + + /* + * Check if there are any bytes left to encrypt. If not, we're done. + */ + if (!bytes_to_encrypt) + break; + + /* + * Check how much we are being asked to send, if it's too much, then + * we will have to loop and possibly be called multiple times to get + * through all the data. + */ + if (bytes_to_encrypt > PqGSSMaxPktSize) + input.length = PqGSSMaxPktSize; + else + input.length = bytes_to_encrypt; + + input.value = (char *) ptr + bytes_encrypted; + + output.value = NULL; + output.length = 0; + + /* Create the next encrypted packet */ + major = gss_wrap(&minor, gctx, 1, GSS_C_QOP_DEFAULT, + &input, &conf_state, &output); + if (major != GSS_S_COMPLETE) + { + pg_GSS_error(_("GSSAPI wrap error"), major, minor); + errno = ECONNRESET; + return -1; + } + if (conf_state == 0) + { + ereport(COMMERROR, + (errmsg("outgoing GSSAPI message would not use confidentiality"))); + errno = ECONNRESET; + return -1; + } + if (output.length > PQ_GSS_SEND_BUFFER_SIZE - sizeof(uint32)) + { + ereport(COMMERROR, + (errmsg("server tried to send oversize GSSAPI packet (%zu > %zu)", + (size_t) output.length, + PQ_GSS_SEND_BUFFER_SIZE - sizeof(uint32)))); + errno = ECONNRESET; + return -1; + } + + bytes_encrypted += input.length; + bytes_to_encrypt -= input.length; + PqGSSSendConsumed += input.length; + + /* 4 network-order bytes of length, then payload */ + netlen = pg_hton32(output.length); + memcpy(PqGSSSendBuffer + PqGSSSendLength, &netlen, sizeof(uint32)); + PqGSSSendLength += sizeof(uint32); + + memcpy(PqGSSSendBuffer + PqGSSSendLength, output.value, output.length); + PqGSSSendLength += output.length; + + /* Release buffer storage allocated by GSSAPI */ + gss_release_buffer(&minor, &output); + } + + /* If we get here, our counters should all match up. */ + Assert(bytes_sent == len); + Assert(bytes_sent == bytes_encrypted); + + return bytes_sent; +} + +/* + * Read up to len bytes of data into ptr from a GSSAPI-encrypted connection. + * + * The connection must be already set up for GSSAPI encryption (i.e., GSSAPI + * transport negotiation is complete). + * + * Returns the number of data bytes read, or on failure, returns -1 + * with errno set appropriately. For retryable errors, caller should call + * again once the socket is ready. + * + * We treat fatal errors the same as in be_gssapi_write(), even though the + * argument about infinite recursion doesn't apply here. + */ +ssize_t +be_gssapi_read(Port *port, void *ptr, size_t len) +{ + OM_uint32 major, + minor; + gss_buffer_desc input, + output; + ssize_t ret; + size_t bytes_returned = 0; + gss_ctx_id_t gctx = port->gss->ctx; + + /* + * The plan here is to read one incoming encrypted packet into + * PqGSSRecvBuffer, decrypt it into PqGSSResultBuffer, and then dole out + * data from there to the caller. When we exhaust the current input + * packet, read another. + */ + while (bytes_returned < len) + { + int conf_state = 0; + + /* Check if we have data in our buffer that we can return immediately */ + if (PqGSSResultNext < PqGSSResultLength) + { + size_t bytes_in_buffer = PqGSSResultLength - PqGSSResultNext; + size_t bytes_to_copy = Min(bytes_in_buffer, len - bytes_returned); + + /* + * Copy the data from our result buffer into the caller's buffer, + * at the point where we last left off filling their buffer. + */ + memcpy((char *) ptr + bytes_returned, PqGSSResultBuffer + PqGSSResultNext, bytes_to_copy); + PqGSSResultNext += bytes_to_copy; + bytes_returned += bytes_to_copy; + + /* + * At this point, we've either filled the caller's buffer or + * emptied our result buffer. Either way, return to caller. In + * the second case, we could try to read another encrypted packet, + * but the odds are good that there isn't one available. (If this + * isn't true, we chose too small a max packet size.) In any + * case, there's no harm letting the caller process the data we've + * already returned. + */ + break; + } + + /* Result buffer is empty, so reset buffer pointers */ + PqGSSResultLength = PqGSSResultNext = 0; + + /* + * Because we chose above to return immediately as soon as we emit + * some data, bytes_returned must be zero at this point. Therefore + * the failure exits below can just return -1 without worrying about + * whether we already emitted some data. + */ + Assert(bytes_returned == 0); + + /* + * At this point, our result buffer is empty with more bytes being + * requested to be read. We are now ready to load the next packet and + * decrypt it (entirely) into our result buffer. + */ + + /* Collect the length if we haven't already */ + if (PqGSSRecvLength < sizeof(uint32)) + { + ret = secure_raw_read(port, PqGSSRecvBuffer + PqGSSRecvLength, + sizeof(uint32) - PqGSSRecvLength); + + /* If ret <= 0, secure_raw_read already set the correct errno */ + if (ret <= 0) + return ret; + + PqGSSRecvLength += ret; + + /* If we still haven't got the length, return to the caller */ + if (PqGSSRecvLength < sizeof(uint32)) + { + errno = EWOULDBLOCK; + return -1; + } + } + + /* Decode the packet length and check for overlength packet */ + input.length = pg_ntoh32(*(uint32 *) PqGSSRecvBuffer); + + if (input.length > PQ_GSS_RECV_BUFFER_SIZE - sizeof(uint32)) + { + ereport(COMMERROR, + (errmsg("oversize GSSAPI packet sent by the client (%zu > %zu)", + (size_t) input.length, + PQ_GSS_RECV_BUFFER_SIZE - sizeof(uint32)))); + errno = ECONNRESET; + return -1; + } + + /* + * Read as much of the packet as we are able to on this call into + * wherever we left off from the last time we were called. + */ + ret = secure_raw_read(port, PqGSSRecvBuffer + PqGSSRecvLength, + input.length - (PqGSSRecvLength - sizeof(uint32))); + /* If ret <= 0, secure_raw_read already set the correct errno */ + if (ret <= 0) + return ret; + + PqGSSRecvLength += ret; + + /* If we don't yet have the whole packet, return to the caller */ + if (PqGSSRecvLength - sizeof(uint32) < input.length) + { + errno = EWOULDBLOCK; + return -1; + } + + /* + * We now have the full packet and we can perform the decryption and + * refill our result buffer, then loop back up to pass data back to + * the caller. + */ + output.value = NULL; + output.length = 0; + input.value = PqGSSRecvBuffer + sizeof(uint32); + + major = gss_unwrap(&minor, gctx, &input, &output, &conf_state, NULL); + if (major != GSS_S_COMPLETE) + { + pg_GSS_error(_("GSSAPI unwrap error"), major, minor); + errno = ECONNRESET; + return -1; + } + if (conf_state == 0) + { + ereport(COMMERROR, + (errmsg("incoming GSSAPI message did not use confidentiality"))); + errno = ECONNRESET; + return -1; + } + + memcpy(PqGSSResultBuffer, output.value, output.length); + PqGSSResultLength = output.length; + + /* Our receive buffer is now empty, reset it */ + PqGSSRecvLength = 0; + + /* Release buffer storage allocated by GSSAPI */ + gss_release_buffer(&minor, &output); + } + + return bytes_returned; +} + +/* + * Read the specified number of bytes off the wire, waiting using + * WaitLatchOrSocket if we would block. + * + * Results are read into PqGSSRecvBuffer. + * + * Will always return either -1, to indicate a permanent error, or len. + */ +static ssize_t +read_or_wait(Port *port, ssize_t len) +{ + ssize_t ret; + + /* + * Keep going until we either read in everything we were asked to, or we + * error out. + */ + while (PqGSSRecvLength < len) + { + ret = secure_raw_read(port, PqGSSRecvBuffer + PqGSSRecvLength, len - PqGSSRecvLength); + + /* + * If we got back an error and it wasn't just + * EWOULDBLOCK/EAGAIN/EINTR, then give up. + */ + if (ret < 0 && + !(errno == EWOULDBLOCK || errno == EAGAIN || errno == EINTR)) + return -1; + + /* + * Ok, we got back either a positive value, zero, or a negative result + * indicating we should retry. + * + * If it was zero or negative, then we wait on the socket to be + * readable again. + */ + if (ret <= 0) + { + WaitLatchOrSocket(MyLatch, + WL_SOCKET_READABLE | WL_EXIT_ON_PM_DEATH, + port->sock, 0, WAIT_EVENT_GSS_OPEN_SERVER); + + /* + * If we got back zero bytes, and then waited on the socket to be + * readable and got back zero bytes on a second read, then this is + * EOF and the client hung up on us. + * + * If we did get data here, then we can just fall through and + * handle it just as if we got data the first time. + * + * Otherwise loop back to the top and try again. + */ + if (ret == 0) + { + ret = secure_raw_read(port, PqGSSRecvBuffer + PqGSSRecvLength, len - PqGSSRecvLength); + if (ret == 0) + return -1; + } + if (ret < 0) + continue; + } + + PqGSSRecvLength += ret; + } + + return len; +} + +/* + * Start up a GSSAPI-encrypted connection. This performs GSSAPI + * authentication; after this function completes, it is safe to call + * be_gssapi_read and be_gssapi_write. Returns -1 and logs on failure; + * otherwise, returns 0 and marks the connection as ready for GSSAPI + * encryption. + * + * Note that unlike the be_gssapi_read/be_gssapi_write functions, this + * function WILL block on the socket to be ready for read/write (using + * WaitLatchOrSocket) as appropriate while establishing the GSSAPI + * session. + */ +ssize_t +secure_open_gssapi(Port *port) +{ + bool complete_next = false; + OM_uint32 major, + minor; + + /* + * Allocate subsidiary Port data for GSSAPI operations. + */ + port->gss = (pg_gssinfo *) + MemoryContextAllocZero(TopMemoryContext, sizeof(pg_gssinfo)); + + /* + * Allocate buffers and initialize state variables. By malloc'ing the + * buffers at this point, we avoid wasting static data space in processes + * that will never use them, and we ensure that the buffers are + * sufficiently aligned for the length-word accesses that we do in some + * places in this file. + */ + PqGSSSendBuffer = malloc(PQ_GSS_SEND_BUFFER_SIZE); + PqGSSRecvBuffer = malloc(PQ_GSS_RECV_BUFFER_SIZE); + PqGSSResultBuffer = malloc(PQ_GSS_RECV_BUFFER_SIZE); + if (!PqGSSSendBuffer || !PqGSSRecvBuffer || !PqGSSResultBuffer) + ereport(FATAL, + (errcode(ERRCODE_OUT_OF_MEMORY), + errmsg("out of memory"))); + PqGSSSendLength = PqGSSSendNext = PqGSSSendConsumed = 0; + PqGSSRecvLength = PqGSSResultLength = PqGSSResultNext = 0; + + /* + * Use the configured keytab, if there is one. Unfortunately, Heimdal + * doesn't support the cred store extensions, so use the env var. + */ + if (pg_krb_server_keyfile != NULL && pg_krb_server_keyfile[0] != '\0') + { + if (setenv("KRB5_KTNAME", pg_krb_server_keyfile, 1) != 0) + { + /* The only likely failure cause is OOM, so use that errcode */ + ereport(FATAL, + (errcode(ERRCODE_OUT_OF_MEMORY), + errmsg("could not set environment: %m"))); + } + } + + while (true) + { + ssize_t ret; + gss_buffer_desc input, + output = GSS_C_EMPTY_BUFFER; + + /* + * The client always sends first, so try to go ahead and read the + * length and wait on the socket to be readable again if that fails. + */ + ret = read_or_wait(port, sizeof(uint32)); + if (ret < 0) + return ret; + + /* + * Get the length for this packet from the length header. + */ + input.length = pg_ntoh32(*(uint32 *) PqGSSRecvBuffer); + + /* Done with the length, reset our buffer */ + PqGSSRecvLength = 0; + + /* + * During initialization, packets are always fully consumed and + * shouldn't ever be over PQ_GSS_RECV_BUFFER_SIZE in length. + * + * Verify on our side that the client doesn't do something funny. + */ + if (input.length > PQ_GSS_RECV_BUFFER_SIZE) + { + ereport(COMMERROR, + (errmsg("oversize GSSAPI packet sent by the client (%zu > %d)", + (size_t) input.length, + PQ_GSS_RECV_BUFFER_SIZE))); + return -1; + } + + /* + * Get the rest of the packet so we can pass it to GSSAPI to accept + * the context. + */ + ret = read_or_wait(port, input.length); + if (ret < 0) + return ret; + + input.value = PqGSSRecvBuffer; + + /* Process incoming data. (The client sends first.) */ + major = gss_accept_sec_context(&minor, &port->gss->ctx, + GSS_C_NO_CREDENTIAL, &input, + GSS_C_NO_CHANNEL_BINDINGS, + &port->gss->name, NULL, &output, NULL, + NULL, NULL); + if (GSS_ERROR(major)) + { + pg_GSS_error(_("could not accept GSSAPI security context"), + major, minor); + gss_release_buffer(&minor, &output); + return -1; + } + else if (!(major & GSS_S_CONTINUE_NEEDED)) + { + /* + * rfc2744 technically permits context negotiation to be complete + * both with and without a packet to be sent. + */ + complete_next = true; + } + + /* Done handling the incoming packet, reset our buffer */ + PqGSSRecvLength = 0; + + /* + * Check if we have data to send and, if we do, make sure to send it + * all + */ + if (output.length > 0) + { + uint32 netlen = pg_hton32(output.length); + + if (output.length > PQ_GSS_SEND_BUFFER_SIZE - sizeof(uint32)) + { + ereport(COMMERROR, + (errmsg("server tried to send oversize GSSAPI packet (%zu > %zu)", + (size_t) output.length, + PQ_GSS_SEND_BUFFER_SIZE - sizeof(uint32)))); + gss_release_buffer(&minor, &output); + return -1; + } + + memcpy(PqGSSSendBuffer, (char *) &netlen, sizeof(uint32)); + PqGSSSendLength += sizeof(uint32); + + memcpy(PqGSSSendBuffer + PqGSSSendLength, output.value, output.length); + PqGSSSendLength += output.length; + + /* we don't bother with PqGSSSendConsumed here */ + + while (PqGSSSendNext < PqGSSSendLength) + { + ret = secure_raw_write(port, PqGSSSendBuffer + PqGSSSendNext, + PqGSSSendLength - PqGSSSendNext); + + /* + * If we got back an error and it wasn't just + * EWOULDBLOCK/EAGAIN/EINTR, then give up. + */ + if (ret < 0 && + !(errno == EWOULDBLOCK || errno == EAGAIN || errno == EINTR)) + { + gss_release_buffer(&minor, &output); + return -1; + } + + /* Wait and retry if we couldn't write yet */ + if (ret <= 0) + { + WaitLatchOrSocket(MyLatch, + WL_SOCKET_WRITEABLE | WL_EXIT_ON_PM_DEATH, + port->sock, 0, WAIT_EVENT_GSS_OPEN_SERVER); + continue; + } + + PqGSSSendNext += ret; + } + + /* Done sending the packet, reset our buffer */ + PqGSSSendLength = PqGSSSendNext = 0; + + gss_release_buffer(&minor, &output); + } + + /* + * If we got back that the connection is finished being set up, now + * that we've sent the last packet, exit our loop. + */ + if (complete_next) + break; + } + + /* + * Determine the max packet size which will fit in our buffer, after + * accounting for the length. be_gssapi_write will need this. + */ + major = gss_wrap_size_limit(&minor, port->gss->ctx, 1, GSS_C_QOP_DEFAULT, + PQ_GSS_SEND_BUFFER_SIZE - sizeof(uint32), + &PqGSSMaxPktSize); + + if (GSS_ERROR(major)) + { + pg_GSS_error(_("GSSAPI size check error"), major, minor); + return -1; + } + + port->gss->enc = true; + + return 0; +} + +/* + * Return if GSSAPI authentication was used on this connection. + */ +bool +be_gssapi_get_auth(Port *port) +{ + if (!port || !port->gss) + return false; + + return port->gss->auth; +} + +/* + * Return if GSSAPI encryption is enabled and being used on this connection. + */ +bool +be_gssapi_get_enc(Port *port) +{ + if (!port || !port->gss) + return false; + + return port->gss->enc; +} + +/* + * Return the GSSAPI principal used for authentication on this connection + * (NULL if we did not perform GSSAPI authentication). + */ +const char * +be_gssapi_get_princ(Port *port) +{ + if (!port || !port->gss) + return NULL; + + return port->gss->princ; +} diff --git a/src/backend/libpq/be-secure-openssl.c b/src/backend/libpq/be-secure-openssl.c new file mode 100644 index 0000000..e3b02b1 --- /dev/null +++ b/src/backend/libpq/be-secure-openssl.c @@ -0,0 +1,1526 @@ +/*------------------------------------------------------------------------- + * + * be-secure-openssl.c + * functions for OpenSSL support in the backend. + * + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * + * IDENTIFICATION + * src/backend/libpq/be-secure-openssl.c + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" + +#include <sys/stat.h> +#include <signal.h> +#include <fcntl.h> +#include <ctype.h> +#include <sys/socket.h> +#include <unistd.h> +#include <netdb.h> +#include <netinet/in.h> +#ifdef HAVE_NETINET_TCP_H +#include <netinet/tcp.h> +#include <arpa/inet.h> +#endif + +#include <openssl/ssl.h> +#include <openssl/dh.h> +#include <openssl/conf.h> +#ifndef OPENSSL_NO_ECDH +#include <openssl/ec.h> +#endif + +#include "common/openssl.h" +#include "libpq/libpq.h" +#include "miscadmin.h" +#include "pgstat.h" +#include "storage/fd.h" +#include "storage/latch.h" +#include "tcop/tcopprot.h" +#include "utils/memutils.h" + +/* default init hook can be overridden by a shared library */ +static void default_openssl_tls_init(SSL_CTX *context, bool isServerStart); +openssl_tls_init_hook_typ openssl_tls_init_hook = default_openssl_tls_init; + +static int my_sock_read(BIO *h, char *buf, int size); +static int my_sock_write(BIO *h, const char *buf, int size); +static BIO_METHOD *my_BIO_s_socket(void); +static int my_SSL_set_fd(Port *port, int fd); + +static DH *load_dh_file(char *filename, bool isServerStart); +static DH *load_dh_buffer(const char *, size_t); +static int ssl_external_passwd_cb(char *buf, int size, int rwflag, void *userdata); +static int dummy_ssl_passwd_cb(char *buf, int size, int rwflag, void *userdata); +static int verify_cb(int, X509_STORE_CTX *); +static void info_cb(const SSL *ssl, int type, int args); +static bool initialize_dh(SSL_CTX *context, bool isServerStart); +static bool initialize_ecdh(SSL_CTX *context, bool isServerStart); +static const char *SSLerrmessage(unsigned long ecode); + +static char *X509_NAME_to_cstring(X509_NAME *name); + +static SSL_CTX *SSL_context = NULL; +static bool SSL_initialized = false; +static bool dummy_ssl_passwd_cb_called = false; +static bool ssl_is_server_start; + +static int ssl_protocol_version_to_openssl(int v); +static const char *ssl_protocol_version_to_string(int v); + +/* ------------------------------------------------------------ */ +/* Public interface */ +/* ------------------------------------------------------------ */ + +int +be_tls_init(bool isServerStart) +{ + SSL_CTX *context; + int ssl_ver_min = -1; + int ssl_ver_max = -1; + + /* This stuff need be done only once. */ + if (!SSL_initialized) + { +#ifdef HAVE_OPENSSL_INIT_SSL + OPENSSL_init_ssl(OPENSSL_INIT_LOAD_CONFIG, NULL); +#else + OPENSSL_config(NULL); + SSL_library_init(); + SSL_load_error_strings(); +#endif + SSL_initialized = true; + } + + /* + * Create a new SSL context into which we'll load all the configuration + * settings. If we fail partway through, we can avoid memory leakage by + * freeing this context; we don't install it as active until the end. + * + * We use SSLv23_method() because it can negotiate use of the highest + * mutually supported protocol version, while alternatives like + * TLSv1_2_method() permit only one specific version. Note that we don't + * actually allow SSL v2 or v3, only TLS protocols (see below). + */ + context = SSL_CTX_new(SSLv23_method()); + if (!context) + { + ereport(isServerStart ? FATAL : LOG, + (errmsg("could not create SSL context: %s", + SSLerrmessage(ERR_get_error())))); + goto error; + } + + /* + * Disable OpenSSL's moving-write-buffer sanity check, because it causes + * unnecessary failures in nonblocking send cases. + */ + SSL_CTX_set_mode(context, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); + + /* + * Call init hook (usually to set password callback) + */ + (*openssl_tls_init_hook) (context, isServerStart); + + /* used by the callback */ + ssl_is_server_start = isServerStart; + + /* + * Load and verify server's certificate and private key + */ + if (SSL_CTX_use_certificate_chain_file(context, ssl_cert_file) != 1) + { + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("could not load server certificate file \"%s\": %s", + ssl_cert_file, SSLerrmessage(ERR_get_error())))); + goto error; + } + + if (!check_ssl_key_file_permissions(ssl_key_file, isServerStart)) + goto error; + + /* + * OK, try to load the private key file. + */ + dummy_ssl_passwd_cb_called = false; + + if (SSL_CTX_use_PrivateKey_file(context, + ssl_key_file, + SSL_FILETYPE_PEM) != 1) + { + if (dummy_ssl_passwd_cb_called) + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("private key file \"%s\" cannot be reloaded because it requires a passphrase", + ssl_key_file))); + else + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("could not load private key file \"%s\": %s", + ssl_key_file, SSLerrmessage(ERR_get_error())))); + goto error; + } + + if (SSL_CTX_check_private_key(context) != 1) + { + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("check of private key failed: %s", + SSLerrmessage(ERR_get_error())))); + goto error; + } + + if (ssl_min_protocol_version) + { + ssl_ver_min = ssl_protocol_version_to_openssl(ssl_min_protocol_version); + + if (ssl_ver_min == -1) + { + ereport(isServerStart ? FATAL : LOG, + /*- translator: first %s is a GUC option name, second %s is its value */ + (errmsg("\"%s\" setting \"%s\" not supported by this build", + "ssl_min_protocol_version", + GetConfigOption("ssl_min_protocol_version", + false, false)))); + goto error; + } + + if (!SSL_CTX_set_min_proto_version(context, ssl_ver_min)) + { + ereport(isServerStart ? FATAL : LOG, + (errmsg("could not set minimum SSL protocol version"))); + goto error; + } + } + + if (ssl_max_protocol_version) + { + ssl_ver_max = ssl_protocol_version_to_openssl(ssl_max_protocol_version); + + if (ssl_ver_max == -1) + { + ereport(isServerStart ? FATAL : LOG, + /*- translator: first %s is a GUC option name, second %s is its value */ + (errmsg("\"%s\" setting \"%s\" not supported by this build", + "ssl_max_protocol_version", + GetConfigOption("ssl_max_protocol_version", + false, false)))); + goto error; + } + + if (!SSL_CTX_set_max_proto_version(context, ssl_ver_max)) + { + ereport(isServerStart ? FATAL : LOG, + (errmsg("could not set maximum SSL protocol version"))); + goto error; + } + } + + /* Check compatibility of min/max protocols */ + if (ssl_min_protocol_version && + ssl_max_protocol_version) + { + /* + * No need to check for invalid values (-1) for each protocol number + * as the code above would have already generated an error. + */ + if (ssl_ver_min > ssl_ver_max) + { + ereport(isServerStart ? FATAL : LOG, + (errmsg("could not set SSL protocol version range"), + errdetail("\"%s\" cannot be higher than \"%s\"", + "ssl_min_protocol_version", + "ssl_max_protocol_version"))); + goto error; + } + } + + /* disallow SSL session tickets */ + SSL_CTX_set_options(context, SSL_OP_NO_TICKET); + + /* disallow SSL session caching, too */ + SSL_CTX_set_session_cache_mode(context, SSL_SESS_CACHE_OFF); + + /* disallow SSL compression */ + SSL_CTX_set_options(context, SSL_OP_NO_COMPRESSION); + +#ifdef SSL_OP_NO_RENEGOTIATION + + /* + * Disallow SSL renegotiation, option available since 1.1.0h. This + * concerns only TLSv1.2 and older protocol versions, as TLSv1.3 has no + * support for renegotiation. + */ + SSL_CTX_set_options(context, SSL_OP_NO_RENEGOTIATION); +#endif + + /* set up ephemeral DH and ECDH keys */ + if (!initialize_dh(context, isServerStart)) + goto error; + if (!initialize_ecdh(context, isServerStart)) + goto error; + + /* set up the allowed cipher list */ + if (SSL_CTX_set_cipher_list(context, SSLCipherSuites) != 1) + { + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("could not set the cipher list (no valid ciphers available)"))); + goto error; + } + + /* Let server choose order */ + if (SSLPreferServerCiphers) + SSL_CTX_set_options(context, SSL_OP_CIPHER_SERVER_PREFERENCE); + + /* + * Load CA store, so we can verify client certificates if needed. + */ + if (ssl_ca_file[0]) + { + STACK_OF(X509_NAME) * root_cert_list; + + if (SSL_CTX_load_verify_locations(context, ssl_ca_file, NULL) != 1 || + (root_cert_list = SSL_load_client_CA_file(ssl_ca_file)) == NULL) + { + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("could not load root certificate file \"%s\": %s", + ssl_ca_file, SSLerrmessage(ERR_get_error())))); + goto error; + } + + /* + * Tell OpenSSL to send the list of root certs we trust to clients in + * CertificateRequests. This lets a client with a keystore select the + * appropriate client certificate to send to us. Also, this ensures + * that the SSL context will "own" the root_cert_list and remember to + * free it when no longer needed. + */ + SSL_CTX_set_client_CA_list(context, root_cert_list); + + /* + * Always ask for SSL client cert, but don't fail if it's not + * presented. We might fail such connections later, depending on what + * we find in pg_hba.conf. + */ + SSL_CTX_set_verify(context, + (SSL_VERIFY_PEER | + SSL_VERIFY_CLIENT_ONCE), + verify_cb); + } + + /*---------- + * Load the Certificate Revocation List (CRL). + * http://searchsecurity.techtarget.com/sDefinition/0,,sid14_gci803160,00.html + *---------- + */ + if (ssl_crl_file[0] || ssl_crl_dir[0]) + { + X509_STORE *cvstore = SSL_CTX_get_cert_store(context); + + if (cvstore) + { + /* Set the flags to check against the complete CRL chain */ + if (X509_STORE_load_locations(cvstore, + ssl_crl_file[0] ? ssl_crl_file : NULL, + ssl_crl_dir[0] ? ssl_crl_dir : NULL) + == 1) + { + X509_STORE_set_flags(cvstore, + X509_V_FLAG_CRL_CHECK | X509_V_FLAG_CRL_CHECK_ALL); + } + else if (ssl_crl_dir[0] == 0) + { + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("could not load SSL certificate revocation list file \"%s\": %s", + ssl_crl_file, SSLerrmessage(ERR_get_error())))); + goto error; + } + else if (ssl_crl_file[0] == 0) + { + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("could not load SSL certificate revocation list directory \"%s\": %s", + ssl_crl_dir, SSLerrmessage(ERR_get_error())))); + goto error; + } + else + { + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("could not load SSL certificate revocation list file \"%s\" or directory \"%s\": %s", + ssl_crl_file, ssl_crl_dir, + SSLerrmessage(ERR_get_error())))); + goto error; + } + } + } + + /* + * Success! Replace any existing SSL_context. + */ + if (SSL_context) + SSL_CTX_free(SSL_context); + + SSL_context = context; + + /* + * Set flag to remember whether CA store has been loaded into SSL_context. + */ + if (ssl_ca_file[0]) + ssl_loaded_verify_locations = true; + else + ssl_loaded_verify_locations = false; + + return 0; + + /* Clean up by releasing working context. */ +error: + if (context) + SSL_CTX_free(context); + return -1; +} + +void +be_tls_destroy(void) +{ + if (SSL_context) + SSL_CTX_free(SSL_context); + SSL_context = NULL; + ssl_loaded_verify_locations = false; +} + +int +be_tls_open_server(Port *port) +{ + int r; + int err; + int waitfor; + unsigned long ecode; + bool give_proto_hint; + + Assert(!port->ssl); + Assert(!port->peer); + + if (!SSL_context) + { + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("could not initialize SSL connection: SSL context not set up"))); + return -1; + } + + /* set up debugging/info callback */ + SSL_CTX_set_info_callback(SSL_context, info_cb); + + if (!(port->ssl = SSL_new(SSL_context))) + { + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("could not initialize SSL connection: %s", + SSLerrmessage(ERR_get_error())))); + return -1; + } + if (!my_SSL_set_fd(port, port->sock)) + { + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("could not set SSL socket: %s", + SSLerrmessage(ERR_get_error())))); + return -1; + } + port->ssl_in_use = true; + +aloop: + + /* + * Prepare to call SSL_get_error() by clearing thread's OpenSSL error + * queue. In general, the current thread's error queue must be empty + * before the TLS/SSL I/O operation is attempted, or SSL_get_error() will + * not work reliably. An extension may have failed to clear the + * per-thread error queue following another call to an OpenSSL I/O + * routine. + */ + ERR_clear_error(); + r = SSL_accept(port->ssl); + if (r <= 0) + { + err = SSL_get_error(port->ssl, r); + + /* + * Other clients of OpenSSL in the backend may fail to call + * ERR_get_error(), but we always do, so as to not cause problems for + * OpenSSL clients that don't call ERR_clear_error() defensively. Be + * sure that this happens by calling now. SSL_get_error() relies on + * the OpenSSL per-thread error queue being intact, so this is the + * earliest possible point ERR_get_error() may be called. + */ + ecode = ERR_get_error(); + switch (err) + { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + /* not allowed during connection establishment */ + Assert(!port->noblock); + + /* + * No need to care about timeouts/interrupts here. At this + * point authentication_timeout still employs + * StartupPacketTimeoutHandler() which directly exits. + */ + if (err == SSL_ERROR_WANT_READ) + waitfor = WL_SOCKET_READABLE | WL_EXIT_ON_PM_DEATH; + else + waitfor = WL_SOCKET_WRITEABLE | WL_EXIT_ON_PM_DEATH; + + (void) WaitLatchOrSocket(MyLatch, waitfor, port->sock, 0, + WAIT_EVENT_SSL_OPEN_SERVER); + goto aloop; + case SSL_ERROR_SYSCALL: + if (r < 0) + ereport(COMMERROR, + (errcode_for_socket_access(), + errmsg("could not accept SSL connection: %m"))); + else + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("could not accept SSL connection: EOF detected"))); + break; + case SSL_ERROR_SSL: + switch (ERR_GET_REASON(ecode)) + { + /* + * UNSUPPORTED_PROTOCOL, WRONG_VERSION_NUMBER, and + * TLSV1_ALERT_PROTOCOL_VERSION have been observed + * when trying to communicate with an old OpenSSL + * library, or when the client and server specify + * disjoint protocol ranges. NO_PROTOCOLS_AVAILABLE + * occurs if there's a local misconfiguration (which + * can happen despite our checks, if openssl.cnf + * injects a limit we didn't account for). It's not + * very clear what would make OpenSSL return the other + * codes listed here, but a hint about protocol + * versions seems like it's appropriate for all. + */ + case SSL_R_NO_PROTOCOLS_AVAILABLE: + case SSL_R_UNSUPPORTED_PROTOCOL: + case SSL_R_BAD_PROTOCOL_VERSION_NUMBER: + case SSL_R_UNKNOWN_PROTOCOL: + case SSL_R_UNKNOWN_SSL_VERSION: + case SSL_R_UNSUPPORTED_SSL_VERSION: + case SSL_R_WRONG_SSL_VERSION: + case SSL_R_WRONG_VERSION_NUMBER: + case SSL_R_TLSV1_ALERT_PROTOCOL_VERSION: +#ifdef SSL_R_VERSION_TOO_HIGH + case SSL_R_VERSION_TOO_HIGH: + case SSL_R_VERSION_TOO_LOW: +#endif + give_proto_hint = true; + break; + default: + give_proto_hint = false; + break; + } + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("could not accept SSL connection: %s", + SSLerrmessage(ecode)), + give_proto_hint ? + errhint("This may indicate that the client does not support any SSL protocol version between %s and %s.", + ssl_min_protocol_version ? + ssl_protocol_version_to_string(ssl_min_protocol_version) : + MIN_OPENSSL_TLS_VERSION, + ssl_max_protocol_version ? + ssl_protocol_version_to_string(ssl_max_protocol_version) : + MAX_OPENSSL_TLS_VERSION) : 0)); + break; + case SSL_ERROR_ZERO_RETURN: + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("could not accept SSL connection: EOF detected"))); + break; + default: + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("unrecognized SSL error code: %d", + err))); + break; + } + return -1; + } + + /* Get client certificate, if available. */ + port->peer = SSL_get_peer_certificate(port->ssl); + + /* and extract the Common Name and Distinguished Name from it. */ + port->peer_cn = NULL; + port->peer_dn = NULL; + port->peer_cert_valid = false; + if (port->peer != NULL) + { + int len; + X509_NAME *x509name = X509_get_subject_name(port->peer); + char *peer_dn; + BIO *bio = NULL; + BUF_MEM *bio_buf = NULL; + + len = X509_NAME_get_text_by_NID(x509name, NID_commonName, NULL, 0); + if (len != -1) + { + char *peer_cn; + + peer_cn = MemoryContextAlloc(TopMemoryContext, len + 1); + r = X509_NAME_get_text_by_NID(x509name, NID_commonName, peer_cn, + len + 1); + peer_cn[len] = '\0'; + if (r != len) + { + /* shouldn't happen */ + pfree(peer_cn); + return -1; + } + + /* + * Reject embedded NULLs in certificate common name to prevent + * attacks like CVE-2009-4034. + */ + if (len != strlen(peer_cn)) + { + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("SSL certificate's common name contains embedded null"))); + pfree(peer_cn); + return -1; + } + + port->peer_cn = peer_cn; + } + + bio = BIO_new(BIO_s_mem()); + if (!bio) + { + pfree(port->peer_cn); + port->peer_cn = NULL; + return -1; + } + + /* + * RFC2253 is the closest thing to an accepted standard format for + * DNs. We have documented how to produce this format from a + * certificate. It uses commas instead of slashes for delimiters, + * which make regular expression matching a bit easier. Also note that + * it prints the Subject fields in reverse order. + */ + X509_NAME_print_ex(bio, x509name, 0, XN_FLAG_RFC2253); + if (BIO_get_mem_ptr(bio, &bio_buf) <= 0) + { + BIO_free(bio); + pfree(port->peer_cn); + port->peer_cn = NULL; + return -1; + } + peer_dn = MemoryContextAlloc(TopMemoryContext, bio_buf->length + 1); + memcpy(peer_dn, bio_buf->data, bio_buf->length); + len = bio_buf->length; + BIO_free(bio); + peer_dn[len] = '\0'; + if (len != strlen(peer_dn)) + { + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("SSL certificate's distinguished name contains embedded null"))); + pfree(peer_dn); + pfree(port->peer_cn); + port->peer_cn = NULL; + return -1; + } + + port->peer_dn = peer_dn; + + port->peer_cert_valid = true; + } + + return 0; +} + +void +be_tls_close(Port *port) +{ + if (port->ssl) + { + SSL_shutdown(port->ssl); + SSL_free(port->ssl); + port->ssl = NULL; + port->ssl_in_use = false; + } + + if (port->peer) + { + X509_free(port->peer); + port->peer = NULL; + } + + if (port->peer_cn) + { + pfree(port->peer_cn); + port->peer_cn = NULL; + } + + if (port->peer_dn) + { + pfree(port->peer_dn); + port->peer_dn = NULL; + } +} + +ssize_t +be_tls_read(Port *port, void *ptr, size_t len, int *waitfor) +{ + ssize_t n; + int err; + unsigned long ecode; + + errno = 0; + ERR_clear_error(); + n = SSL_read(port->ssl, ptr, len); + err = SSL_get_error(port->ssl, n); + ecode = (err != SSL_ERROR_NONE || n < 0) ? ERR_get_error() : 0; + switch (err) + { + case SSL_ERROR_NONE: + /* a-ok */ + break; + case SSL_ERROR_WANT_READ: + *waitfor = WL_SOCKET_READABLE; + errno = EWOULDBLOCK; + n = -1; + break; + case SSL_ERROR_WANT_WRITE: + *waitfor = WL_SOCKET_WRITEABLE; + errno = EWOULDBLOCK; + n = -1; + break; + case SSL_ERROR_SYSCALL: + /* leave it to caller to ereport the value of errno */ + if (n != -1) + { + errno = ECONNRESET; + n = -1; + } + break; + case SSL_ERROR_SSL: + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("SSL error: %s", SSLerrmessage(ecode)))); + errno = ECONNRESET; + n = -1; + break; + case SSL_ERROR_ZERO_RETURN: + /* connection was cleanly shut down by peer */ + n = 0; + break; + default: + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("unrecognized SSL error code: %d", + err))); + errno = ECONNRESET; + n = -1; + break; + } + + return n; +} + +ssize_t +be_tls_write(Port *port, void *ptr, size_t len, int *waitfor) +{ + ssize_t n; + int err; + unsigned long ecode; + + errno = 0; + ERR_clear_error(); + n = SSL_write(port->ssl, ptr, len); + err = SSL_get_error(port->ssl, n); + ecode = (err != SSL_ERROR_NONE || n < 0) ? ERR_get_error() : 0; + switch (err) + { + case SSL_ERROR_NONE: + /* a-ok */ + break; + case SSL_ERROR_WANT_READ: + *waitfor = WL_SOCKET_READABLE; + errno = EWOULDBLOCK; + n = -1; + break; + case SSL_ERROR_WANT_WRITE: + *waitfor = WL_SOCKET_WRITEABLE; + errno = EWOULDBLOCK; + n = -1; + break; + case SSL_ERROR_SYSCALL: + /* leave it to caller to ereport the value of errno */ + if (n != -1) + { + errno = ECONNRESET; + n = -1; + } + break; + case SSL_ERROR_SSL: + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("SSL error: %s", SSLerrmessage(ecode)))); + errno = ECONNRESET; + n = -1; + break; + case SSL_ERROR_ZERO_RETURN: + + /* + * the SSL connection was closed, leave it to the caller to + * ereport it + */ + errno = ECONNRESET; + n = -1; + break; + default: + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("unrecognized SSL error code: %d", + err))); + errno = ECONNRESET; + n = -1; + break; + } + + return n; +} + +/* ------------------------------------------------------------ */ +/* Internal functions */ +/* ------------------------------------------------------------ */ + +/* + * Private substitute BIO: this does the sending and receiving using send() and + * recv() instead. This is so that we can enable and disable interrupts + * just while calling recv(). We cannot have interrupts occurring while + * the bulk of OpenSSL runs, because it uses malloc() and possibly other + * non-reentrant libc facilities. We also need to call send() and recv() + * directly so it gets passed through the socket/signals layer on Win32. + * + * These functions are closely modelled on the standard socket BIO in OpenSSL; + * see sock_read() and sock_write() in OpenSSL's crypto/bio/bss_sock.c. + * XXX OpenSSL 1.0.1e considers many more errcodes than just EINTR as reasons + * to retry; do we need to adopt their logic for that? + */ + +#ifndef HAVE_BIO_GET_DATA +#define BIO_get_data(bio) (bio->ptr) +#define BIO_set_data(bio, data) (bio->ptr = data) +#endif + +static BIO_METHOD *my_bio_methods = NULL; + +static int +my_sock_read(BIO *h, char *buf, int size) +{ + int res = 0; + + if (buf != NULL) + { + res = secure_raw_read(((Port *) BIO_get_data(h)), buf, size); + BIO_clear_retry_flags(h); + if (res <= 0) + { + /* If we were interrupted, tell caller to retry */ + if (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN) + { + BIO_set_retry_read(h); + } + } + } + + return res; +} + +static int +my_sock_write(BIO *h, const char *buf, int size) +{ + int res = 0; + + res = secure_raw_write(((Port *) BIO_get_data(h)), buf, size); + BIO_clear_retry_flags(h); + if (res <= 0) + { + /* If we were interrupted, tell caller to retry */ + if (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN) + { + BIO_set_retry_write(h); + } + } + + return res; +} + +static BIO_METHOD * +my_BIO_s_socket(void) +{ + if (!my_bio_methods) + { + BIO_METHOD *biom = (BIO_METHOD *) BIO_s_socket(); +#ifdef HAVE_BIO_METH_NEW + int my_bio_index; + + my_bio_index = BIO_get_new_index(); + if (my_bio_index == -1) + return NULL; + my_bio_index |= (BIO_TYPE_DESCRIPTOR | BIO_TYPE_SOURCE_SINK); + my_bio_methods = BIO_meth_new(my_bio_index, "PostgreSQL backend socket"); + if (!my_bio_methods) + return NULL; + if (!BIO_meth_set_write(my_bio_methods, my_sock_write) || + !BIO_meth_set_read(my_bio_methods, my_sock_read) || + !BIO_meth_set_gets(my_bio_methods, BIO_meth_get_gets(biom)) || + !BIO_meth_set_puts(my_bio_methods, BIO_meth_get_puts(biom)) || + !BIO_meth_set_ctrl(my_bio_methods, BIO_meth_get_ctrl(biom)) || + !BIO_meth_set_create(my_bio_methods, BIO_meth_get_create(biom)) || + !BIO_meth_set_destroy(my_bio_methods, BIO_meth_get_destroy(biom)) || + !BIO_meth_set_callback_ctrl(my_bio_methods, BIO_meth_get_callback_ctrl(biom))) + { + BIO_meth_free(my_bio_methods); + my_bio_methods = NULL; + return NULL; + } +#else + my_bio_methods = malloc(sizeof(BIO_METHOD)); + if (!my_bio_methods) + return NULL; + memcpy(my_bio_methods, biom, sizeof(BIO_METHOD)); + my_bio_methods->bread = my_sock_read; + my_bio_methods->bwrite = my_sock_write; +#endif + } + return my_bio_methods; +} + +/* This should exactly match OpenSSL's SSL_set_fd except for using my BIO */ +static int +my_SSL_set_fd(Port *port, int fd) +{ + int ret = 0; + BIO *bio; + BIO_METHOD *bio_method; + + bio_method = my_BIO_s_socket(); + if (bio_method == NULL) + { + SSLerr(SSL_F_SSL_SET_FD, ERR_R_BUF_LIB); + goto err; + } + bio = BIO_new(bio_method); + + if (bio == NULL) + { + SSLerr(SSL_F_SSL_SET_FD, ERR_R_BUF_LIB); + goto err; + } + BIO_set_data(bio, port); + + BIO_set_fd(bio, fd, BIO_NOCLOSE); + SSL_set_bio(port->ssl, bio, bio); + ret = 1; +err: + return ret; +} + +/* + * Load precomputed DH parameters. + * + * To prevent "downgrade" attacks, we perform a number of checks + * to verify that the DBA-generated DH parameters file contains + * what we expect it to contain. + */ +static DH * +load_dh_file(char *filename, bool isServerStart) +{ + FILE *fp; + DH *dh = NULL; + int codes; + + /* attempt to open file. It's not an error if it doesn't exist. */ + if ((fp = AllocateFile(filename, "r")) == NULL) + { + ereport(isServerStart ? FATAL : LOG, + (errcode_for_file_access(), + errmsg("could not open DH parameters file \"%s\": %m", + filename))); + return NULL; + } + + dh = PEM_read_DHparams(fp, NULL, NULL, NULL); + FreeFile(fp); + + if (dh == NULL) + { + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("could not load DH parameters file: %s", + SSLerrmessage(ERR_get_error())))); + return NULL; + } + + /* make sure the DH parameters are usable */ + if (DH_check(dh, &codes) == 0) + { + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("invalid DH parameters: %s", + SSLerrmessage(ERR_get_error())))); + DH_free(dh); + return NULL; + } + if (codes & DH_CHECK_P_NOT_PRIME) + { + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("invalid DH parameters: p is not prime"))); + DH_free(dh); + return NULL; + } + if ((codes & DH_NOT_SUITABLE_GENERATOR) && + (codes & DH_CHECK_P_NOT_SAFE_PRIME)) + { + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("invalid DH parameters: neither suitable generator or safe prime"))); + DH_free(dh); + return NULL; + } + + return dh; +} + +/* + * Load hardcoded DH parameters. + * + * If DH parameters cannot be loaded from a specified file, we can load + * the hardcoded DH parameters supplied with the backend to prevent + * problems. + */ +static DH * +load_dh_buffer(const char *buffer, size_t len) +{ + BIO *bio; + DH *dh = NULL; + + bio = BIO_new_mem_buf(unconstify(char *, buffer), len); + if (bio == NULL) + return NULL; + dh = PEM_read_bio_DHparams(bio, NULL, NULL, NULL); + if (dh == NULL) + ereport(DEBUG2, + (errmsg_internal("DH load buffer: %s", + SSLerrmessage(ERR_get_error())))); + BIO_free(bio); + + return dh; +} + +/* + * Passphrase collection callback using ssl_passphrase_command + */ +static int +ssl_external_passwd_cb(char *buf, int size, int rwflag, void *userdata) +{ + /* same prompt as OpenSSL uses internally */ + const char *prompt = "Enter PEM pass phrase:"; + + Assert(rwflag == 0); + + return run_ssl_passphrase_command(prompt, ssl_is_server_start, buf, size); +} + +/* + * Dummy passphrase callback + * + * If OpenSSL is told to use a passphrase-protected server key, by default + * it will issue a prompt on /dev/tty and try to read a key from there. + * That's no good during a postmaster SIGHUP cycle, not to mention SSL context + * reload in an EXEC_BACKEND postmaster child. So override it with this dummy + * function that just returns an empty passphrase, guaranteeing failure. + */ +static int +dummy_ssl_passwd_cb(char *buf, int size, int rwflag, void *userdata) +{ + /* Set flag to change the error message we'll report */ + dummy_ssl_passwd_cb_called = true; + /* And return empty string */ + Assert(size > 0); + buf[0] = '\0'; + return 0; +} + +/* + * Certificate verification callback + * + * This callback allows us to log intermediate problems during + * verification, but for now we'll see if the final error message + * contains enough information. + * + * This callback also allows us to override the default acceptance + * criteria (e.g., accepting self-signed or expired certs), but + * for now we accept the default checks. + */ +static int +verify_cb(int ok, X509_STORE_CTX *ctx) +{ + return ok; +} + +/* + * This callback is used to copy SSL information messages + * into the PostgreSQL log. + */ +static void +info_cb(const SSL *ssl, int type, int args) +{ + const char *desc; + + desc = SSL_state_string_long(ssl); + + switch (type) + { + case SSL_CB_HANDSHAKE_START: + ereport(DEBUG4, + (errmsg_internal("SSL: handshake start: \"%s\"", desc))); + break; + case SSL_CB_HANDSHAKE_DONE: + ereport(DEBUG4, + (errmsg_internal("SSL: handshake done: \"%s\"", desc))); + break; + case SSL_CB_ACCEPT_LOOP: + ereport(DEBUG4, + (errmsg_internal("SSL: accept loop: \"%s\"", desc))); + break; + case SSL_CB_ACCEPT_EXIT: + ereport(DEBUG4, + (errmsg_internal("SSL: accept exit (%d): \"%s\"", args, desc))); + break; + case SSL_CB_CONNECT_LOOP: + ereport(DEBUG4, + (errmsg_internal("SSL: connect loop: \"%s\"", desc))); + break; + case SSL_CB_CONNECT_EXIT: + ereport(DEBUG4, + (errmsg_internal("SSL: connect exit (%d): \"%s\"", args, desc))); + break; + case SSL_CB_READ_ALERT: + ereport(DEBUG4, + (errmsg_internal("SSL: read alert (0x%04x): \"%s\"", args, desc))); + break; + case SSL_CB_WRITE_ALERT: + ereport(DEBUG4, + (errmsg_internal("SSL: write alert (0x%04x): \"%s\"", args, desc))); + break; + } +} + +/* + * Set DH parameters for generating ephemeral DH keys. The + * DH parameters can take a long time to compute, so they must be + * precomputed. + * + * Since few sites will bother to create a parameter file, we also + * provide a fallback to the parameters provided by the OpenSSL + * project. + * + * These values can be static (once loaded or computed) since the + * OpenSSL library can efficiently generate random keys from the + * information provided. + */ +static bool +initialize_dh(SSL_CTX *context, bool isServerStart) +{ + DH *dh = NULL; + + SSL_CTX_set_options(context, SSL_OP_SINGLE_DH_USE); + + if (ssl_dh_params_file[0]) + dh = load_dh_file(ssl_dh_params_file, isServerStart); + if (!dh) + dh = load_dh_buffer(FILE_DH2048, sizeof(FILE_DH2048)); + if (!dh) + { + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("DH: could not load DH parameters"))); + return false; + } + + if (SSL_CTX_set_tmp_dh(context, dh) != 1) + { + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("DH: could not set DH parameters: %s", + SSLerrmessage(ERR_get_error())))); + DH_free(dh); + return false; + } + + DH_free(dh); + return true; +} + +/* + * Set ECDH parameters for generating ephemeral Elliptic Curve DH + * keys. This is much simpler than the DH parameters, as we just + * need to provide the name of the curve to OpenSSL. + */ +static bool +initialize_ecdh(SSL_CTX *context, bool isServerStart) +{ +#ifndef OPENSSL_NO_ECDH + EC_KEY *ecdh; + int nid; + + nid = OBJ_sn2nid(SSLECDHCurve); + if (!nid) + { + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("ECDH: unrecognized curve name: %s", SSLECDHCurve))); + return false; + } + + ecdh = EC_KEY_new_by_curve_name(nid); + if (!ecdh) + { + ereport(isServerStart ? FATAL : LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("ECDH: could not create key"))); + return false; + } + + SSL_CTX_set_options(context, SSL_OP_SINGLE_ECDH_USE); + SSL_CTX_set_tmp_ecdh(context, ecdh); + EC_KEY_free(ecdh); +#endif + + return true; +} + +/* + * Obtain reason string for passed SSL errcode + * + * ERR_get_error() is used by caller to get errcode to pass here. + * + * Some caution is needed here since ERR_reason_error_string will + * return NULL if it doesn't recognize the error code. We don't + * want to return NULL ever. + */ +static const char * +SSLerrmessage(unsigned long ecode) +{ + const char *errreason; + static char errbuf[36]; + + if (ecode == 0) + return _("no SSL error reported"); + errreason = ERR_reason_error_string(ecode); + if (errreason != NULL) + return errreason; + snprintf(errbuf, sizeof(errbuf), _("SSL error code %lu"), ecode); + return errbuf; +} + +int +be_tls_get_cipher_bits(Port *port) +{ + int bits; + + if (port->ssl) + { + SSL_get_cipher_bits(port->ssl, &bits); + return bits; + } + else + return 0; +} + +const char * +be_tls_get_version(Port *port) +{ + if (port->ssl) + return SSL_get_version(port->ssl); + else + return NULL; +} + +const char * +be_tls_get_cipher(Port *port) +{ + if (port->ssl) + return SSL_get_cipher(port->ssl); + else + return NULL; +} + +void +be_tls_get_peer_subject_name(Port *port, char *ptr, size_t len) +{ + if (port->peer) + strlcpy(ptr, X509_NAME_to_cstring(X509_get_subject_name(port->peer)), len); + else + ptr[0] = '\0'; +} + +void +be_tls_get_peer_issuer_name(Port *port, char *ptr, size_t len) +{ + if (port->peer) + strlcpy(ptr, X509_NAME_to_cstring(X509_get_issuer_name(port->peer)), len); + else + ptr[0] = '\0'; +} + +void +be_tls_get_peer_serial(Port *port, char *ptr, size_t len) +{ + if (port->peer) + { + ASN1_INTEGER *serial; + BIGNUM *b; + char *decimal; + + serial = X509_get_serialNumber(port->peer); + b = ASN1_INTEGER_to_BN(serial, NULL); + decimal = BN_bn2dec(b); + + BN_free(b); + strlcpy(ptr, decimal, len); + OPENSSL_free(decimal); + } + else + ptr[0] = '\0'; +} + +#ifdef HAVE_X509_GET_SIGNATURE_NID +char * +be_tls_get_certificate_hash(Port *port, size_t *len) +{ + X509 *server_cert; + char *cert_hash; + const EVP_MD *algo_type = NULL; + unsigned char hash[EVP_MAX_MD_SIZE]; /* size for SHA-512 */ + unsigned int hash_size; + int algo_nid; + + *len = 0; + server_cert = SSL_get_certificate(port->ssl); + if (server_cert == NULL) + return NULL; + + /* + * Get the signature algorithm of the certificate to determine the hash + * algorithm to use for the result. + */ + if (!OBJ_find_sigid_algs(X509_get_signature_nid(server_cert), + &algo_nid, NULL)) + elog(ERROR, "could not determine server certificate signature algorithm"); + + /* + * The TLS server's certificate bytes need to be hashed with SHA-256 if + * its signature algorithm is MD5 or SHA-1 as per RFC 5929 + * (https://tools.ietf.org/html/rfc5929#section-4.1). If something else + * is used, the same hash as the signature algorithm is used. + */ + switch (algo_nid) + { + case NID_md5: + case NID_sha1: + algo_type = EVP_sha256(); + break; + default: + algo_type = EVP_get_digestbynid(algo_nid); + if (algo_type == NULL) + elog(ERROR, "could not find digest for NID %s", + OBJ_nid2sn(algo_nid)); + break; + } + + /* generate and save the certificate hash */ + if (!X509_digest(server_cert, algo_type, hash, &hash_size)) + elog(ERROR, "could not generate server certificate hash"); + + cert_hash = palloc(hash_size); + memcpy(cert_hash, hash, hash_size); + *len = hash_size; + + return cert_hash; +} +#endif + +/* + * Convert an X509 subject name to a cstring. + * + */ +static char * +X509_NAME_to_cstring(X509_NAME *name) +{ + BIO *membuf = BIO_new(BIO_s_mem()); + int i, + nid, + count = X509_NAME_entry_count(name); + X509_NAME_ENTRY *e; + ASN1_STRING *v; + const char *field_name; + size_t size; + char nullterm; + char *sp; + char *dp; + char *result; + + if (membuf == NULL) + ereport(ERROR, + (errcode(ERRCODE_OUT_OF_MEMORY), + errmsg("could not create BIO"))); + + (void) BIO_set_close(membuf, BIO_CLOSE); + for (i = 0; i < count; i++) + { + e = X509_NAME_get_entry(name, i); + nid = OBJ_obj2nid(X509_NAME_ENTRY_get_object(e)); + if (nid == NID_undef) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("could not get NID for ASN1_OBJECT object"))); + v = X509_NAME_ENTRY_get_data(e); + field_name = OBJ_nid2sn(nid); + if (field_name == NULL) + field_name = OBJ_nid2ln(nid); + if (field_name == NULL) + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("could not convert NID %d to an ASN1_OBJECT structure", nid))); + BIO_printf(membuf, "/%s=", field_name); + ASN1_STRING_print_ex(membuf, v, + ((ASN1_STRFLGS_RFC2253 & ~ASN1_STRFLGS_ESC_MSB) + | ASN1_STRFLGS_UTF8_CONVERT)); + } + + /* ensure null termination of the BIO's content */ + nullterm = '\0'; + BIO_write(membuf, &nullterm, 1); + size = BIO_get_mem_data(membuf, &sp); + dp = pg_any_to_server(sp, size - 1, PG_UTF8); + + result = pstrdup(dp); + if (dp != sp) + pfree(dp); + if (BIO_free(membuf) != 1) + elog(ERROR, "could not free OpenSSL BIO structure"); + + return result; +} + +/* + * Convert TLS protocol version GUC enum to OpenSSL values + * + * This is a straightforward one-to-one mapping, but doing it this way makes + * guc.c independent of OpenSSL availability and version. + * + * If a version is passed that is not supported by the current OpenSSL + * version, then we return -1. If a nonnegative value is returned, + * subsequent code can assume it's working with a supported version. + * + * Note: this is rather similar to libpq's routine in fe-secure-openssl.c, + * so make sure to update both routines if changing this one. + */ +static int +ssl_protocol_version_to_openssl(int v) +{ + switch (v) + { + case PG_TLS_ANY: + return 0; + case PG_TLS1_VERSION: + return TLS1_VERSION; + case PG_TLS1_1_VERSION: +#ifdef TLS1_1_VERSION + return TLS1_1_VERSION; +#else + break; +#endif + case PG_TLS1_2_VERSION: +#ifdef TLS1_2_VERSION + return TLS1_2_VERSION; +#else + break; +#endif + case PG_TLS1_3_VERSION: +#ifdef TLS1_3_VERSION + return TLS1_3_VERSION; +#else + break; +#endif + } + + return -1; +} + +/* + * Likewise provide a mapping to strings. + */ +static const char * +ssl_protocol_version_to_string(int v) +{ + switch (v) + { + case PG_TLS_ANY: + return "any"; + case PG_TLS1_VERSION: + return "TLSv1"; + case PG_TLS1_1_VERSION: + return "TLSv1.1"; + case PG_TLS1_2_VERSION: + return "TLSv1.2"; + case PG_TLS1_3_VERSION: + return "TLSv1.3"; + } + + return "(unrecognized)"; +} + + +static void +default_openssl_tls_init(SSL_CTX *context, bool isServerStart) +{ + if (isServerStart) + { + if (ssl_passphrase_command[0]) + SSL_CTX_set_default_passwd_cb(context, ssl_external_passwd_cb); + } + else + { + if (ssl_passphrase_command[0] && ssl_passphrase_command_supports_reload) + SSL_CTX_set_default_passwd_cb(context, ssl_external_passwd_cb); + else + + /* + * If reloading and no external command is configured, override + * OpenSSL's default handling of passphrase-protected files, + * because we don't want to prompt for a passphrase in an + * already-running server. + */ + SSL_CTX_set_default_passwd_cb(context, dummy_ssl_passwd_cb); + } +} diff --git a/src/backend/libpq/be-secure.c b/src/backend/libpq/be-secure.c new file mode 100644 index 0000000..8ef0832 --- /dev/null +++ b/src/backend/libpq/be-secure.c @@ -0,0 +1,345 @@ +/*------------------------------------------------------------------------- + * + * be-secure.c + * functions related to setting up a secure connection to the frontend. + * Secure connections are expected to provide confidentiality, + * message integrity and endpoint authentication. + * + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * + * IDENTIFICATION + * src/backend/libpq/be-secure.c + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" + +#include <signal.h> +#include <fcntl.h> +#include <ctype.h> +#include <sys/socket.h> +#include <netdb.h> +#include <netinet/in.h> +#ifdef HAVE_NETINET_TCP_H +#include <netinet/tcp.h> +#include <arpa/inet.h> +#endif + +#include "libpq/libpq.h" +#include "miscadmin.h" +#include "pgstat.h" +#include "storage/ipc.h" +#include "storage/proc.h" +#include "tcop/tcopprot.h" +#include "utils/memutils.h" + +char *ssl_library; +char *ssl_cert_file; +char *ssl_key_file; +char *ssl_ca_file; +char *ssl_crl_file; +char *ssl_crl_dir; +char *ssl_dh_params_file; +char *ssl_passphrase_command; +bool ssl_passphrase_command_supports_reload; + +#ifdef USE_SSL +bool ssl_loaded_verify_locations = false; +#endif + +/* GUC variable controlling SSL cipher list */ +char *SSLCipherSuites = NULL; + +/* GUC variable for default ECHD curve. */ +char *SSLECDHCurve; + +/* GUC variable: if false, prefer client ciphers */ +bool SSLPreferServerCiphers; + +int ssl_min_protocol_version; +int ssl_max_protocol_version; + +/* ------------------------------------------------------------ */ +/* Procedures common to all secure sessions */ +/* ------------------------------------------------------------ */ + +/* + * Initialize global context. + * + * If isServerStart is true, report any errors as FATAL (so we don't return). + * Otherwise, log errors at LOG level and return -1 to indicate trouble, + * preserving the old SSL state if any. Returns 0 if OK. + */ +int +secure_initialize(bool isServerStart) +{ +#ifdef USE_SSL + return be_tls_init(isServerStart); +#else + return 0; +#endif +} + +/* + * Destroy global context, if any. + */ +void +secure_destroy(void) +{ +#ifdef USE_SSL + be_tls_destroy(); +#endif +} + +/* + * Indicate if we have loaded the root CA store to verify certificates + */ +bool +secure_loaded_verify_locations(void) +{ +#ifdef USE_SSL + return ssl_loaded_verify_locations; +#else + return false; +#endif +} + +/* + * Attempt to negotiate secure session. + */ +int +secure_open_server(Port *port) +{ + int r = 0; + +#ifdef USE_SSL + r = be_tls_open_server(port); + + ereport(DEBUG2, + (errmsg_internal("SSL connection from DN:\"%s\" CN:\"%s\"", + port->peer_dn ? port->peer_dn : "(anonymous)", + port->peer_cn ? port->peer_cn : "(anonymous)"))); +#endif + + return r; +} + +/* + * Close secure session. + */ +void +secure_close(Port *port) +{ +#ifdef USE_SSL + if (port->ssl_in_use) + be_tls_close(port); +#endif +} + +/* + * Read data from a secure connection. + */ +ssize_t +secure_read(Port *port, void *ptr, size_t len) +{ + ssize_t n; + int waitfor; + + /* Deal with any already-pending interrupt condition. */ + ProcessClientReadInterrupt(false); + +retry: +#ifdef USE_SSL + waitfor = 0; + if (port->ssl_in_use) + { + n = be_tls_read(port, ptr, len, &waitfor); + } + else +#endif +#ifdef ENABLE_GSS + if (port->gss && port->gss->enc) + { + n = be_gssapi_read(port, ptr, len); + waitfor = WL_SOCKET_READABLE; + } + else +#endif + { + n = secure_raw_read(port, ptr, len); + waitfor = WL_SOCKET_READABLE; + } + + /* In blocking mode, wait until the socket is ready */ + if (n < 0 && !port->noblock && (errno == EWOULDBLOCK || errno == EAGAIN)) + { + WaitEvent event; + + Assert(waitfor); + + ModifyWaitEvent(FeBeWaitSet, FeBeWaitSetSocketPos, waitfor, NULL); + + WaitEventSetWait(FeBeWaitSet, -1 /* no timeout */ , &event, 1, + WAIT_EVENT_CLIENT_READ); + + /* + * If the postmaster has died, it's not safe to continue running, + * because it is the postmaster's job to kill us if some other backend + * exits uncleanly. Moreover, we won't run very well in this state; + * helper processes like walwriter and the bgwriter will exit, so + * performance may be poor. Finally, if we don't exit, pg_ctl will be + * unable to restart the postmaster without manual intervention, so no + * new connections can be accepted. Exiting clears the deck for a + * postmaster restart. + * + * (Note that we only make this check when we would otherwise sleep on + * our latch. We might still continue running for a while if the + * postmaster is killed in mid-query, or even through multiple queries + * if we never have to wait for read. We don't want to burn too many + * cycles checking for this very rare condition, and this should cause + * us to exit quickly in most cases.) + */ + if (event.events & WL_POSTMASTER_DEATH) + ereport(FATAL, + (errcode(ERRCODE_ADMIN_SHUTDOWN), + errmsg("terminating connection due to unexpected postmaster exit"))); + + /* Handle interrupt. */ + if (event.events & WL_LATCH_SET) + { + ResetLatch(MyLatch); + ProcessClientReadInterrupt(true); + + /* + * We'll retry the read. Most likely it will return immediately + * because there's still no data available, and we'll wait for the + * socket to become ready again. + */ + } + goto retry; + } + + /* + * Process interrupts that happened during a successful (or non-blocking, + * or hard-failed) read. + */ + ProcessClientReadInterrupt(false); + + return n; +} + +ssize_t +secure_raw_read(Port *port, void *ptr, size_t len) +{ + ssize_t n; + + /* + * Try to read from the socket without blocking. If it succeeds we're + * done, otherwise we'll wait for the socket using the latch mechanism. + */ +#ifdef WIN32 + pgwin32_noblock = true; +#endif + n = recv(port->sock, ptr, len, 0); +#ifdef WIN32 + pgwin32_noblock = false; +#endif + + return n; +} + + +/* + * Write data to a secure connection. + */ +ssize_t +secure_write(Port *port, void *ptr, size_t len) +{ + ssize_t n; + int waitfor; + + /* Deal with any already-pending interrupt condition. */ + ProcessClientWriteInterrupt(false); + +retry: + waitfor = 0; +#ifdef USE_SSL + if (port->ssl_in_use) + { + n = be_tls_write(port, ptr, len, &waitfor); + } + else +#endif +#ifdef ENABLE_GSS + if (port->gss && port->gss->enc) + { + n = be_gssapi_write(port, ptr, len); + waitfor = WL_SOCKET_WRITEABLE; + } + else +#endif + { + n = secure_raw_write(port, ptr, len); + waitfor = WL_SOCKET_WRITEABLE; + } + + if (n < 0 && !port->noblock && (errno == EWOULDBLOCK || errno == EAGAIN)) + { + WaitEvent event; + + Assert(waitfor); + + ModifyWaitEvent(FeBeWaitSet, FeBeWaitSetSocketPos, waitfor, NULL); + + WaitEventSetWait(FeBeWaitSet, -1 /* no timeout */ , &event, 1, + WAIT_EVENT_CLIENT_WRITE); + + /* See comments in secure_read. */ + if (event.events & WL_POSTMASTER_DEATH) + ereport(FATAL, + (errcode(ERRCODE_ADMIN_SHUTDOWN), + errmsg("terminating connection due to unexpected postmaster exit"))); + + /* Handle interrupt. */ + if (event.events & WL_LATCH_SET) + { + ResetLatch(MyLatch); + ProcessClientWriteInterrupt(true); + + /* + * We'll retry the write. Most likely it will return immediately + * because there's still no buffer space available, and we'll wait + * for the socket to become ready again. + */ + } + goto retry; + } + + /* + * Process interrupts that happened during a successful (or non-blocking, + * or hard-failed) write. + */ + ProcessClientWriteInterrupt(false); + + return n; +} + +ssize_t +secure_raw_write(Port *port, const void *ptr, size_t len) +{ + ssize_t n; + +#ifdef WIN32 + pgwin32_noblock = true; +#endif + n = send(port->sock, ptr, len, 0); +#ifdef WIN32 + pgwin32_noblock = false; +#endif + + return n; +} diff --git a/src/backend/libpq/crypt.c b/src/backend/libpq/crypt.c new file mode 100644 index 0000000..3fcad99 --- /dev/null +++ b/src/backend/libpq/crypt.c @@ -0,0 +1,290 @@ +/*------------------------------------------------------------------------- + * + * crypt.c + * Functions for dealing with encrypted passwords stored in + * pg_authid.rolpassword. + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * src/backend/libpq/crypt.c + * + *------------------------------------------------------------------------- + */ +#include "postgres.h" + +#include <unistd.h> + +#include "catalog/pg_authid.h" +#include "common/md5.h" +#include "common/scram-common.h" +#include "libpq/crypt.h" +#include "libpq/scram.h" +#include "miscadmin.h" +#include "utils/builtins.h" +#include "utils/syscache.h" +#include "utils/timestamp.h" + + +/* + * Fetch stored password for a user, for authentication. + * + * On error, returns NULL, and stores a palloc'd string describing the reason, + * for the postmaster log, in *logdetail. The error reason should *not* be + * sent to the client, to avoid giving away user information! + */ +char * +get_role_password(const char *role, char **logdetail) +{ + TimestampTz vuntil = 0; + HeapTuple roleTup; + Datum datum; + bool isnull; + char *shadow_pass; + + /* Get role info from pg_authid */ + roleTup = SearchSysCache1(AUTHNAME, PointerGetDatum(role)); + if (!HeapTupleIsValid(roleTup)) + { + *logdetail = psprintf(_("Role \"%s\" does not exist."), + role); + return NULL; /* no such user */ + } + + datum = SysCacheGetAttr(AUTHNAME, roleTup, + Anum_pg_authid_rolpassword, &isnull); + if (isnull) + { + ReleaseSysCache(roleTup); + *logdetail = psprintf(_("User \"%s\" has no password assigned."), + role); + return NULL; /* user has no password */ + } + shadow_pass = TextDatumGetCString(datum); + + datum = SysCacheGetAttr(AUTHNAME, roleTup, + Anum_pg_authid_rolvaliduntil, &isnull); + if (!isnull) + vuntil = DatumGetTimestampTz(datum); + + ReleaseSysCache(roleTup); + + /* + * Password OK, but check to be sure we are not past rolvaliduntil + */ + if (!isnull && vuntil < GetCurrentTimestamp()) + { + *logdetail = psprintf(_("User \"%s\" has an expired password."), + role); + return NULL; + } + + return shadow_pass; +} + +/* + * What kind of a password type is 'shadow_pass'? + */ +PasswordType +get_password_type(const char *shadow_pass) +{ + char *encoded_salt; + int iterations; + uint8 stored_key[SCRAM_KEY_LEN]; + uint8 server_key[SCRAM_KEY_LEN]; + + if (strncmp(shadow_pass, "md5", 3) == 0 && + strlen(shadow_pass) == MD5_PASSWD_LEN && + strspn(shadow_pass + 3, MD5_PASSWD_CHARSET) == MD5_PASSWD_LEN - 3) + return PASSWORD_TYPE_MD5; + if (parse_scram_secret(shadow_pass, &iterations, &encoded_salt, + stored_key, server_key)) + return PASSWORD_TYPE_SCRAM_SHA_256; + return PASSWORD_TYPE_PLAINTEXT; +} + +/* + * Given a user-supplied password, convert it into a secret of + * 'target_type' kind. + * + * If the password is already in encrypted form, we cannot reverse the + * hash, so it is stored as it is regardless of the requested type. + */ +char * +encrypt_password(PasswordType target_type, const char *role, + const char *password) +{ + PasswordType guessed_type = get_password_type(password); + char *encrypted_password; + + if (guessed_type != PASSWORD_TYPE_PLAINTEXT) + { + /* + * Cannot convert an already-encrypted password from one format to + * another, so return it as it is. + */ + return pstrdup(password); + } + + switch (target_type) + { + case PASSWORD_TYPE_MD5: + encrypted_password = palloc(MD5_PASSWD_LEN + 1); + + if (!pg_md5_encrypt(password, role, strlen(role), + encrypted_password)) + elog(ERROR, "password encryption failed"); + return encrypted_password; + + case PASSWORD_TYPE_SCRAM_SHA_256: + return pg_be_scram_build_secret(password); + + case PASSWORD_TYPE_PLAINTEXT: + elog(ERROR, "cannot encrypt password with 'plaintext'"); + } + + /* + * This shouldn't happen, because the above switch statements should + * handle every combination of source and target password types. + */ + elog(ERROR, "cannot encrypt password to requested type"); + return NULL; /* keep compiler quiet */ +} + +/* + * Check MD5 authentication response, and return STATUS_OK or STATUS_ERROR. + * + * 'shadow_pass' is the user's correct password or password hash, as stored + * in pg_authid.rolpassword. + * 'client_pass' is the response given by the remote user to the MD5 challenge. + * 'md5_salt' is the salt used in the MD5 authentication challenge. + * + * In the error case, optionally store a palloc'd string at *logdetail + * that will be sent to the postmaster log (but not the client). + */ +int +md5_crypt_verify(const char *role, const char *shadow_pass, + const char *client_pass, + const char *md5_salt, int md5_salt_len, + char **logdetail) +{ + int retval; + char crypt_pwd[MD5_PASSWD_LEN + 1]; + + Assert(md5_salt_len > 0); + + if (get_password_type(shadow_pass) != PASSWORD_TYPE_MD5) + { + /* incompatible password hash format. */ + *logdetail = psprintf(_("User \"%s\" has a password that cannot be used with MD5 authentication."), + role); + return STATUS_ERROR; + } + + /* + * Compute the correct answer for the MD5 challenge. + * + * We do not bother setting logdetail for any pg_md5_encrypt failure + * below: the only possible error is out-of-memory, which is unlikely, and + * if it did happen adding a psprintf call would only make things worse. + */ + /* stored password already encrypted, only do salt */ + if (!pg_md5_encrypt(shadow_pass + strlen("md5"), + md5_salt, md5_salt_len, + crypt_pwd)) + { + return STATUS_ERROR; + } + + if (strcmp(client_pass, crypt_pwd) == 0) + retval = STATUS_OK; + else + { + *logdetail = psprintf(_("Password does not match for user \"%s\"."), + role); + retval = STATUS_ERROR; + } + + return retval; +} + +/* + * Check given password for given user, and return STATUS_OK or STATUS_ERROR. + * + * 'shadow_pass' is the user's correct password hash, as stored in + * pg_authid.rolpassword. + * 'client_pass' is the password given by the remote user. + * + * In the error case, optionally store a palloc'd string at *logdetail + * that will be sent to the postmaster log (but not the client). + */ +int +plain_crypt_verify(const char *role, const char *shadow_pass, + const char *client_pass, + char **logdetail) +{ + char crypt_client_pass[MD5_PASSWD_LEN + 1]; + + /* + * Client sent password in plaintext. If we have an MD5 hash stored, hash + * the password the client sent, and compare the hashes. Otherwise + * compare the plaintext passwords directly. + */ + switch (get_password_type(shadow_pass)) + { + case PASSWORD_TYPE_SCRAM_SHA_256: + if (scram_verify_plain_password(role, + client_pass, + shadow_pass)) + { + return STATUS_OK; + } + else + { + *logdetail = psprintf(_("Password does not match for user \"%s\"."), + role); + return STATUS_ERROR; + } + break; + + case PASSWORD_TYPE_MD5: + if (!pg_md5_encrypt(client_pass, + role, + strlen(role), + crypt_client_pass)) + { + /* + * We do not bother setting logdetail for pg_md5_encrypt + * failure: the only possible error is out-of-memory, which is + * unlikely, and if it did happen adding a psprintf call would + * only make things worse. + */ + return STATUS_ERROR; + } + if (strcmp(crypt_client_pass, shadow_pass) == 0) + return STATUS_OK; + else + { + *logdetail = psprintf(_("Password does not match for user \"%s\"."), + role); + return STATUS_ERROR; + } + break; + + case PASSWORD_TYPE_PLAINTEXT: + + /* + * We never store passwords in plaintext, so this shouldn't + * happen. + */ + break; + } + + /* + * This shouldn't happen. Plain "password" authentication is possible + * with any kind of stored password hash. + */ + *logdetail = psprintf(_("Password of user \"%s\" is in unrecognized format."), + role); + return STATUS_ERROR; +} diff --git a/src/backend/libpq/hba.c b/src/backend/libpq/hba.c new file mode 100644 index 0000000..64e59d4 --- /dev/null +++ b/src/backend/libpq/hba.c @@ -0,0 +1,3166 @@ +/*------------------------------------------------------------------------- + * + * hba.c + * Routines to handle host based authentication (that's the scheme + * wherein you authenticate a user by seeing what IP address the system + * says he comes from and choosing authentication method based on it). + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * + * IDENTIFICATION + * src/backend/libpq/hba.c + * + *------------------------------------------------------------------------- + */ +#include "postgres.h" + +#include <ctype.h> +#include <pwd.h> +#include <fcntl.h> +#include <sys/param.h> +#include <sys/socket.h> +#include <netinet/in.h> +#include <arpa/inet.h> +#include <unistd.h> + +#include "access/htup_details.h" +#include "catalog/pg_collation.h" +#include "catalog/pg_type.h" +#include "common/ip.h" +#include "common/string.h" +#include "funcapi.h" +#include "libpq/ifaddr.h" +#include "libpq/libpq.h" +#include "miscadmin.h" +#include "postmaster/postmaster.h" +#include "regex/regex.h" +#include "replication/walsender.h" +#include "storage/fd.h" +#include "utils/acl.h" +#include "utils/builtins.h" +#include "utils/guc.h" +#include "utils/lsyscache.h" +#include "utils/memutils.h" +#include "utils/varlena.h" + +#ifdef USE_LDAP +#ifdef WIN32 +#include <winldap.h> +#else +#include <ldap.h> +#endif +#endif + + +#define MAX_TOKEN 256 + +/* callback data for check_network_callback */ +typedef struct check_network_data +{ + IPCompareMethod method; /* test method */ + SockAddr *raddr; /* client's actual address */ + bool result; /* set to true if match */ +} check_network_data; + + +#define token_is_keyword(t, k) (!t->quoted && strcmp(t->string, k) == 0) +#define token_matches(t, k) (strcmp(t->string, k) == 0) + +/* + * A single string token lexed from a config file, together with whether + * the token had been quoted. + */ +typedef struct HbaToken +{ + char *string; + bool quoted; +} HbaToken; + +/* + * TokenizedLine represents one line lexed from a config file. + * Each item in the "fields" list is a sub-list of HbaTokens. + * We don't emit a TokenizedLine for empty or all-comment lines, + * so "fields" is never NIL (nor are any of its sub-lists). + * Exception: if an error occurs during tokenization, we might + * have fields == NIL, in which case err_msg != NULL. + */ +typedef struct TokenizedLine +{ + List *fields; /* List of lists of HbaTokens */ + int line_num; /* Line number */ + char *raw_line; /* Raw line text */ + char *err_msg; /* Error message if any */ +} TokenizedLine; + +/* + * pre-parsed content of HBA config file: list of HbaLine structs. + * parsed_hba_context is the memory context where it lives. + */ +static List *parsed_hba_lines = NIL; +static MemoryContext parsed_hba_context = NULL; + +/* + * pre-parsed content of ident mapping file: list of IdentLine structs. + * parsed_ident_context is the memory context where it lives. + * + * NOTE: the IdentLine structs can contain pre-compiled regular expressions + * that live outside the memory context. Before destroying or resetting the + * memory context, they need to be explicitly free'd. + */ +static List *parsed_ident_lines = NIL; +static MemoryContext parsed_ident_context = NULL; + +/* + * The following character array represents the names of the authentication + * methods that are supported by PostgreSQL. + * + * Note: keep this in sync with the UserAuth enum in hba.h. + */ +static const char *const UserAuthName[] = +{ + "reject", + "implicit reject", /* Not a user-visible option */ + "trust", + "ident", + "password", + "md5", + "scram-sha-256", + "gss", + "sspi", + "pam", + "bsd", + "ldap", + "cert", + "radius", + "peer" +}; + + +static MemoryContext tokenize_file(const char *filename, FILE *file, + List **tok_lines, int elevel); +static List *tokenize_inc_file(List *tokens, const char *outer_filename, + const char *inc_filename, int elevel, char **err_msg); +static bool parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline, + int elevel, char **err_msg); +static ArrayType *gethba_options(HbaLine *hba); +static void fill_hba_line(Tuplestorestate *tuple_store, TupleDesc tupdesc, + int lineno, HbaLine *hba, const char *err_msg); +static void fill_hba_view(Tuplestorestate *tuple_store, TupleDesc tupdesc); + + +/* + * isblank() exists in the ISO C99 spec, but it's not very portable yet, + * so provide our own version. + */ +bool +pg_isblank(const char c) +{ + return c == ' ' || c == '\t' || c == '\r'; +} + + +/* + * Grab one token out of the string pointed to by *lineptr. + * + * Tokens are strings of non-blank characters bounded by blank characters, + * commas, beginning of line, and end of line. Blank means space or tab. + * + * Tokens can be delimited by double quotes (this allows the inclusion of + * blanks or '#', but not newlines). As in SQL, write two double-quotes + * to represent a double quote. + * + * Comments (started by an unquoted '#') are skipped, i.e. the remainder + * of the line is ignored. + * + * (Note that line continuation processing happens before tokenization. + * Thus, if a continuation occurs within quoted text or a comment, the + * quoted text or comment is considered to continue to the next line.) + * + * The token, if any, is returned at *buf (a buffer of size bufsz), and + * *lineptr is advanced past the token. + * + * Also, we set *initial_quote to indicate whether there was quoting before + * the first character. (We use that to prevent "@x" from being treated + * as a file inclusion request. Note that @"x" should be so treated; + * we want to allow that to support embedded spaces in file paths.) + * + * We set *terminating_comma to indicate whether the token is terminated by a + * comma (which is not returned). + * + * In event of an error, log a message at ereport level elevel, and also + * set *err_msg to a string describing the error. Currently the only + * possible error is token too long for buf. + * + * If successful: store null-terminated token at *buf and return true. + * If no more tokens on line: set *buf = '\0' and return false. + * If error: fill buf with truncated or misformatted token and return false. + */ +static bool +next_token(char **lineptr, char *buf, int bufsz, + bool *initial_quote, bool *terminating_comma, + int elevel, char **err_msg) +{ + int c; + char *start_buf = buf; + char *end_buf = buf + (bufsz - 1); + bool in_quote = false; + bool was_quote = false; + bool saw_quote = false; + + Assert(end_buf > start_buf); + + *initial_quote = false; + *terminating_comma = false; + + /* Move over any whitespace and commas preceding the next token */ + while ((c = (*(*lineptr)++)) != '\0' && (pg_isblank(c) || c == ',')) + ; + + /* + * Build a token in buf of next characters up to EOL, unquoted comma, or + * unquoted whitespace. + */ + while (c != '\0' && + (!pg_isblank(c) || in_quote)) + { + /* skip comments to EOL */ + if (c == '#' && !in_quote) + { + while ((c = (*(*lineptr)++)) != '\0') + ; + break; + } + + if (buf >= end_buf) + { + *buf = '\0'; + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("authentication file token too long, skipping: \"%s\"", + start_buf))); + *err_msg = "authentication file token too long"; + /* Discard remainder of line */ + while ((c = (*(*lineptr)++)) != '\0') + ; + /* Un-eat the '\0', in case we're called again */ + (*lineptr)--; + return false; + } + + /* we do not pass back a terminating comma in the token */ + if (c == ',' && !in_quote) + { + *terminating_comma = true; + break; + } + + if (c != '"' || was_quote) + *buf++ = c; + + /* Literal double-quote is two double-quotes */ + if (in_quote && c == '"') + was_quote = !was_quote; + else + was_quote = false; + + if (c == '"') + { + in_quote = !in_quote; + saw_quote = true; + if (buf == start_buf) + *initial_quote = true; + } + + c = *(*lineptr)++; + } + + /* + * Un-eat the char right after the token (critical in case it is '\0', + * else next call will read past end of string). + */ + (*lineptr)--; + + *buf = '\0'; + + return (saw_quote || buf > start_buf); +} + +/* + * Construct a palloc'd HbaToken struct, copying the given string. + */ +static HbaToken * +make_hba_token(const char *token, bool quoted) +{ + HbaToken *hbatoken; + int toklen; + + toklen = strlen(token); + /* we copy string into same palloc block as the struct */ + hbatoken = (HbaToken *) palloc(sizeof(HbaToken) + toklen + 1); + hbatoken->string = (char *) hbatoken + sizeof(HbaToken); + hbatoken->quoted = quoted; + memcpy(hbatoken->string, token, toklen + 1); + + return hbatoken; +} + +/* + * Copy a HbaToken struct into freshly palloc'd memory. + */ +static HbaToken * +copy_hba_token(HbaToken *in) +{ + HbaToken *out = make_hba_token(in->string, in->quoted); + + return out; +} + + +/* + * Tokenize one HBA field from a line, handling file inclusion and comma lists. + * + * filename: current file's pathname (needed to resolve relative pathnames) + * *lineptr: current line pointer, which will be advanced past field + * + * In event of an error, log a message at ereport level elevel, and also + * set *err_msg to a string describing the error. Note that the result + * may be non-NIL anyway, so *err_msg must be tested to determine whether + * there was an error. + * + * The result is a List of HbaToken structs, one for each token in the field, + * or NIL if we reached EOL. + */ +static List * +next_field_expand(const char *filename, char **lineptr, + int elevel, char **err_msg) +{ + char buf[MAX_TOKEN]; + bool trailing_comma; + bool initial_quote; + List *tokens = NIL; + + do + { + if (!next_token(lineptr, buf, sizeof(buf), + &initial_quote, &trailing_comma, + elevel, err_msg)) + break; + + /* Is this referencing a file? */ + if (!initial_quote && buf[0] == '@' && buf[1] != '\0') + tokens = tokenize_inc_file(tokens, filename, buf + 1, + elevel, err_msg); + else + tokens = lappend(tokens, make_hba_token(buf, initial_quote)); + } while (trailing_comma && (*err_msg == NULL)); + + return tokens; +} + +/* + * tokenize_inc_file + * Expand a file included from another file into an hba "field" + * + * Opens and tokenises a file included from another HBA config file with @, + * and returns all values found therein as a flat list of HbaTokens. If a + * @-token is found, recursively expand it. The newly read tokens are + * appended to "tokens" (so that foo,bar,@baz does what you expect). + * All new tokens are allocated in caller's memory context. + * + * In event of an error, log a message at ereport level elevel, and also + * set *err_msg to a string describing the error. Note that the result + * may be non-NIL anyway, so *err_msg must be tested to determine whether + * there was an error. + */ +static List * +tokenize_inc_file(List *tokens, + const char *outer_filename, + const char *inc_filename, + int elevel, + char **err_msg) +{ + char *inc_fullname; + FILE *inc_file; + List *inc_lines; + ListCell *inc_line; + MemoryContext linecxt; + + if (is_absolute_path(inc_filename)) + { + /* absolute path is taken as-is */ + inc_fullname = pstrdup(inc_filename); + } + else + { + /* relative path is relative to dir of calling file */ + inc_fullname = (char *) palloc(strlen(outer_filename) + 1 + + strlen(inc_filename) + 1); + strcpy(inc_fullname, outer_filename); + get_parent_directory(inc_fullname); + join_path_components(inc_fullname, inc_fullname, inc_filename); + canonicalize_path(inc_fullname); + } + + inc_file = AllocateFile(inc_fullname, "r"); + if (inc_file == NULL) + { + int save_errno = errno; + + ereport(elevel, + (errcode_for_file_access(), + errmsg("could not open secondary authentication file \"@%s\" as \"%s\": %m", + inc_filename, inc_fullname))); + *err_msg = psprintf("could not open secondary authentication file \"@%s\" as \"%s\": %s", + inc_filename, inc_fullname, strerror(save_errno)); + pfree(inc_fullname); + return tokens; + } + + /* There is possible recursion here if the file contains @ */ + linecxt = tokenize_file(inc_fullname, inc_file, &inc_lines, elevel); + + FreeFile(inc_file); + pfree(inc_fullname); + + /* Copy all tokens found in the file and append to the tokens list */ + foreach(inc_line, inc_lines) + { + TokenizedLine *tok_line = (TokenizedLine *) lfirst(inc_line); + ListCell *inc_field; + + /* If any line has an error, propagate that up to caller */ + if (tok_line->err_msg) + { + *err_msg = pstrdup(tok_line->err_msg); + break; + } + + foreach(inc_field, tok_line->fields) + { + List *inc_tokens = lfirst(inc_field); + ListCell *inc_token; + + foreach(inc_token, inc_tokens) + { + HbaToken *token = lfirst(inc_token); + + tokens = lappend(tokens, copy_hba_token(token)); + } + } + } + + MemoryContextDelete(linecxt); + return tokens; +} + +/* + * Tokenize the given file. + * + * The output is a list of TokenizedLine structs; see struct definition above. + * + * filename: the absolute path to the target file + * file: the already-opened target file + * tok_lines: receives output list + * elevel: message logging level + * + * Errors are reported by logging messages at ereport level elevel and by + * adding TokenizedLine structs containing non-null err_msg fields to the + * output list. + * + * Return value is a memory context which contains all memory allocated by + * this function (it's a child of caller's context). + */ +static MemoryContext +tokenize_file(const char *filename, FILE *file, List **tok_lines, int elevel) +{ + int line_number = 1; + StringInfoData buf; + MemoryContext linecxt; + MemoryContext oldcxt; + + linecxt = AllocSetContextCreate(CurrentMemoryContext, + "tokenize_file", + ALLOCSET_SMALL_SIZES); + oldcxt = MemoryContextSwitchTo(linecxt); + + initStringInfo(&buf); + + *tok_lines = NIL; + + while (!feof(file) && !ferror(file)) + { + char *lineptr; + List *current_line = NIL; + char *err_msg = NULL; + int last_backslash_buflen = 0; + int continuations = 0; + + /* Collect the next input line, handling backslash continuations */ + resetStringInfo(&buf); + + while (pg_get_line_append(file, &buf)) + { + /* Strip trailing newline, including \r in case we're on Windows */ + buf.len = pg_strip_crlf(buf.data); + + /* + * Check for backslash continuation. The backslash must be after + * the last place we found a continuation, else two backslashes + * followed by two \n's would behave surprisingly. + */ + if (buf.len > last_backslash_buflen && + buf.data[buf.len - 1] == '\\') + { + /* Continuation, so strip it and keep reading */ + buf.data[--buf.len] = '\0'; + last_backslash_buflen = buf.len; + continuations++; + continue; + } + + /* Nope, so we have the whole line */ + break; + } + + if (ferror(file)) + { + /* I/O error! */ + int save_errno = errno; + + ereport(elevel, + (errcode_for_file_access(), + errmsg("could not read file \"%s\": %m", filename))); + err_msg = psprintf("could not read file \"%s\": %s", + filename, strerror(save_errno)); + break; + } + + /* Parse fields */ + lineptr = buf.data; + while (*lineptr && err_msg == NULL) + { + List *current_field; + + current_field = next_field_expand(filename, &lineptr, + elevel, &err_msg); + /* add field to line, unless we are at EOL or comment start */ + if (current_field != NIL) + current_line = lappend(current_line, current_field); + } + + /* Reached EOL; emit line to TokenizedLine list unless it's boring */ + if (current_line != NIL || err_msg != NULL) + { + TokenizedLine *tok_line; + + tok_line = (TokenizedLine *) palloc(sizeof(TokenizedLine)); + tok_line->fields = current_line; + tok_line->line_num = line_number; + tok_line->raw_line = pstrdup(buf.data); + tok_line->err_msg = err_msg; + *tok_lines = lappend(*tok_lines, tok_line); + } + + line_number += continuations + 1; + } + + MemoryContextSwitchTo(oldcxt); + + return linecxt; +} + + +/* + * Does user belong to role? + * + * userid is the OID of the role given as the attempted login identifier. + * We check to see if it is a member of the specified role name. + */ +static bool +is_member(Oid userid, const char *role) +{ + Oid roleid; + + if (!OidIsValid(userid)) + return false; /* if user not exist, say "no" */ + + roleid = get_role_oid(role, true); + + if (!OidIsValid(roleid)) + return false; /* if target role not exist, say "no" */ + + /* + * See if user is directly or indirectly a member of role. For this + * purpose, a superuser is not considered to be automatically a member of + * the role, so group auth only applies to explicit membership. + */ + return is_member_of_role_nosuper(userid, roleid); +} + +/* + * Check HbaToken list for a match to role, allowing group names. + */ +static bool +check_role(const char *role, Oid roleid, List *tokens) +{ + ListCell *cell; + HbaToken *tok; + + foreach(cell, tokens) + { + tok = lfirst(cell); + if (!tok->quoted && tok->string[0] == '+') + { + if (is_member(roleid, tok->string + 1)) + return true; + } + else if (token_matches(tok, role) || + token_is_keyword(tok, "all")) + return true; + } + return false; +} + +/* + * Check to see if db/role combination matches HbaToken list. + */ +static bool +check_db(const char *dbname, const char *role, Oid roleid, List *tokens) +{ + ListCell *cell; + HbaToken *tok; + + foreach(cell, tokens) + { + tok = lfirst(cell); + if (am_walsender && !am_db_walsender) + { + /* + * physical replication walsender connections can only match + * replication keyword + */ + if (token_is_keyword(tok, "replication")) + return true; + } + else if (token_is_keyword(tok, "all")) + return true; + else if (token_is_keyword(tok, "sameuser")) + { + if (strcmp(dbname, role) == 0) + return true; + } + else if (token_is_keyword(tok, "samegroup") || + token_is_keyword(tok, "samerole")) + { + if (is_member(roleid, dbname)) + return true; + } + else if (token_is_keyword(tok, "replication")) + continue; /* never match this if not walsender */ + else if (token_matches(tok, dbname)) + return true; + } + return false; +} + +static bool +ipv4eq(struct sockaddr_in *a, struct sockaddr_in *b) +{ + return (a->sin_addr.s_addr == b->sin_addr.s_addr); +} + +#ifdef HAVE_IPV6 + +static bool +ipv6eq(struct sockaddr_in6 *a, struct sockaddr_in6 *b) +{ + int i; + + for (i = 0; i < 16; i++) + if (a->sin6_addr.s6_addr[i] != b->sin6_addr.s6_addr[i]) + return false; + + return true; +} +#endif /* HAVE_IPV6 */ + +/* + * Check whether host name matches pattern. + */ +static bool +hostname_match(const char *pattern, const char *actual_hostname) +{ + if (pattern[0] == '.') /* suffix match */ + { + size_t plen = strlen(pattern); + size_t hlen = strlen(actual_hostname); + + if (hlen < plen) + return false; + + return (pg_strcasecmp(pattern, actual_hostname + (hlen - plen)) == 0); + } + else + return (pg_strcasecmp(pattern, actual_hostname) == 0); +} + +/* + * Check to see if a connecting IP matches a given host name. + */ +static bool +check_hostname(hbaPort *port, const char *hostname) +{ + struct addrinfo *gai_result, + *gai; + int ret; + bool found; + + /* Quick out if remote host name already known bad */ + if (port->remote_hostname_resolv < 0) + return false; + + /* Lookup remote host name if not already done */ + if (!port->remote_hostname) + { + char remote_hostname[NI_MAXHOST]; + + ret = pg_getnameinfo_all(&port->raddr.addr, port->raddr.salen, + remote_hostname, sizeof(remote_hostname), + NULL, 0, + NI_NAMEREQD); + if (ret != 0) + { + /* remember failure; don't complain in the postmaster log yet */ + port->remote_hostname_resolv = -2; + port->remote_hostname_errcode = ret; + return false; + } + + port->remote_hostname = pstrdup(remote_hostname); + } + + /* Now see if remote host name matches this pg_hba line */ + if (!hostname_match(hostname, port->remote_hostname)) + return false; + + /* If we already verified the forward lookup, we're done */ + if (port->remote_hostname_resolv == +1) + return true; + + /* Lookup IP from host name and check against original IP */ + ret = getaddrinfo(port->remote_hostname, NULL, NULL, &gai_result); + if (ret != 0) + { + /* remember failure; don't complain in the postmaster log yet */ + port->remote_hostname_resolv = -2; + port->remote_hostname_errcode = ret; + return false; + } + + found = false; + for (gai = gai_result; gai; gai = gai->ai_next) + { + if (gai->ai_addr->sa_family == port->raddr.addr.ss_family) + { + if (gai->ai_addr->sa_family == AF_INET) + { + if (ipv4eq((struct sockaddr_in *) gai->ai_addr, + (struct sockaddr_in *) &port->raddr.addr)) + { + found = true; + break; + } + } +#ifdef HAVE_IPV6 + else if (gai->ai_addr->sa_family == AF_INET6) + { + if (ipv6eq((struct sockaddr_in6 *) gai->ai_addr, + (struct sockaddr_in6 *) &port->raddr.addr)) + { + found = true; + break; + } + } +#endif + } + } + + if (gai_result) + freeaddrinfo(gai_result); + + if (!found) + elog(DEBUG2, "pg_hba.conf host name \"%s\" rejected because address resolution did not return a match with IP address of client", + hostname); + + port->remote_hostname_resolv = found ? +1 : -1; + + return found; +} + +/* + * Check to see if a connecting IP matches the given address and netmask. + */ +static bool +check_ip(SockAddr *raddr, struct sockaddr *addr, struct sockaddr *mask) +{ + if (raddr->addr.ss_family == addr->sa_family && + pg_range_sockaddr(&raddr->addr, + (struct sockaddr_storage *) addr, + (struct sockaddr_storage *) mask)) + return true; + return false; +} + +/* + * pg_foreach_ifaddr callback: does client addr match this machine interface? + */ +static void +check_network_callback(struct sockaddr *addr, struct sockaddr *netmask, + void *cb_data) +{ + check_network_data *cn = (check_network_data *) cb_data; + struct sockaddr_storage mask; + + /* Already found a match? */ + if (cn->result) + return; + + if (cn->method == ipCmpSameHost) + { + /* Make an all-ones netmask of appropriate length for family */ + pg_sockaddr_cidr_mask(&mask, NULL, addr->sa_family); + cn->result = check_ip(cn->raddr, addr, (struct sockaddr *) &mask); + } + else + { + /* Use the netmask of the interface itself */ + cn->result = check_ip(cn->raddr, addr, netmask); + } +} + +/* + * Use pg_foreach_ifaddr to check a samehost or samenet match + */ +static bool +check_same_host_or_net(SockAddr *raddr, IPCompareMethod method) +{ + check_network_data cn; + + cn.method = method; + cn.raddr = raddr; + cn.result = false; + + errno = 0; + if (pg_foreach_ifaddr(check_network_callback, &cn) < 0) + { + ereport(LOG, + (errmsg("error enumerating network interfaces: %m"))); + return false; + } + + return cn.result; +} + + +/* + * Macros used to check and report on invalid configuration options. + * On error: log a message at level elevel, set *err_msg, and exit the function. + * These macros are not as general-purpose as they look, because they know + * what the calling function's error-exit value is. + * + * INVALID_AUTH_OPTION = reports when an option is specified for a method where it's + * not supported. + * REQUIRE_AUTH_OPTION = same as INVALID_AUTH_OPTION, except it also checks if the + * method is actually the one specified. Used as a shortcut when + * the option is only valid for one authentication method. + * MANDATORY_AUTH_ARG = check if a required option is set for an authentication method, + * reporting error if it's not. + */ +#define INVALID_AUTH_OPTION(optname, validmethods) \ +do { \ + ereport(elevel, \ + (errcode(ERRCODE_CONFIG_FILE_ERROR), \ + /* translator: the second %s is a list of auth methods */ \ + errmsg("authentication option \"%s\" is only valid for authentication methods %s", \ + optname, _(validmethods)), \ + errcontext("line %d of configuration file \"%s\"", \ + line_num, HbaFileName))); \ + *err_msg = psprintf("authentication option \"%s\" is only valid for authentication methods %s", \ + optname, validmethods); \ + return false; \ +} while (0) + +#define REQUIRE_AUTH_OPTION(methodval, optname, validmethods) \ +do { \ + if (hbaline->auth_method != methodval) \ + INVALID_AUTH_OPTION(optname, validmethods); \ +} while (0) + +#define MANDATORY_AUTH_ARG(argvar, argname, authname) \ +do { \ + if (argvar == NULL) { \ + ereport(elevel, \ + (errcode(ERRCODE_CONFIG_FILE_ERROR), \ + errmsg("authentication method \"%s\" requires argument \"%s\" to be set", \ + authname, argname), \ + errcontext("line %d of configuration file \"%s\"", \ + line_num, HbaFileName))); \ + *err_msg = psprintf("authentication method \"%s\" requires argument \"%s\" to be set", \ + authname, argname); \ + return NULL; \ + } \ +} while (0) + +/* + * Macros for handling pg_ident problems. + * Much as above, but currently the message level is hardwired as LOG + * and there is no provision for an err_msg string. + * + * IDENT_FIELD_ABSENT: + * Log a message and exit the function if the given ident field ListCell is + * not populated. + * + * IDENT_MULTI_VALUE: + * Log a message and exit the function if the given ident token List has more + * than one element. + */ +#define IDENT_FIELD_ABSENT(field) \ +do { \ + if (!field) { \ + ereport(LOG, \ + (errcode(ERRCODE_CONFIG_FILE_ERROR), \ + errmsg("missing entry in file \"%s\" at end of line %d", \ + IdentFileName, line_num))); \ + return NULL; \ + } \ +} while (0) + +#define IDENT_MULTI_VALUE(tokens) \ +do { \ + if (tokens->length > 1) { \ + ereport(LOG, \ + (errcode(ERRCODE_CONFIG_FILE_ERROR), \ + errmsg("multiple values in ident field"), \ + errcontext("line %d of configuration file \"%s\"", \ + line_num, IdentFileName))); \ + return NULL; \ + } \ +} while (0) + + +/* + * Parse one tokenised line from the hba config file and store the result in a + * HbaLine structure. + * + * If parsing fails, log a message at ereport level elevel, store an error + * string in tok_line->err_msg, and return NULL. (Some non-error conditions + * can also result in such messages.) + * + * Note: this function leaks memory when an error occurs. Caller is expected + * to have set a memory context that will be reset if this function returns + * NULL. + */ +static HbaLine * +parse_hba_line(TokenizedLine *tok_line, int elevel) +{ + int line_num = tok_line->line_num; + char **err_msg = &tok_line->err_msg; + char *str; + struct addrinfo *gai_result; + struct addrinfo hints; + int ret; + char *cidr_slash; + char *unsupauth; + ListCell *field; + List *tokens; + ListCell *tokencell; + HbaToken *token; + HbaLine *parsedline; + + parsedline = palloc0(sizeof(HbaLine)); + parsedline->linenumber = line_num; + parsedline->rawline = pstrdup(tok_line->raw_line); + + /* Check the record type. */ + Assert(tok_line->fields != NIL); + field = list_head(tok_line->fields); + tokens = lfirst(field); + if (tokens->length > 1) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("multiple values specified for connection type"), + errhint("Specify exactly one connection type per line."), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "multiple values specified for connection type"; + return NULL; + } + token = linitial(tokens); + if (strcmp(token->string, "local") == 0) + { +#ifdef HAVE_UNIX_SOCKETS + parsedline->conntype = ctLocal; +#else + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("local connections are not supported by this build"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "local connections are not supported by this build"; + return NULL; +#endif + } + else if (strcmp(token->string, "host") == 0 || + strcmp(token->string, "hostssl") == 0 || + strcmp(token->string, "hostnossl") == 0 || + strcmp(token->string, "hostgssenc") == 0 || + strcmp(token->string, "hostnogssenc") == 0) + { + + if (token->string[4] == 's') /* "hostssl" */ + { + parsedline->conntype = ctHostSSL; + /* Log a warning if SSL support is not active */ +#ifdef USE_SSL + if (!EnableSSL) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("hostssl record cannot match because SSL is disabled"), + errhint("Set ssl = on in postgresql.conf."), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "hostssl record cannot match because SSL is disabled"; + } +#else + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("hostssl record cannot match because SSL is not supported by this build"), + errhint("Compile with --with-ssl to use SSL connections."), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "hostssl record cannot match because SSL is not supported by this build"; +#endif + } + else if (token->string[4] == 'g') /* "hostgssenc" */ + { + parsedline->conntype = ctHostGSS; +#ifndef ENABLE_GSS + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("hostgssenc record cannot match because GSSAPI is not supported by this build"), + errhint("Compile with --with-gssapi to use GSSAPI connections."), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "hostgssenc record cannot match because GSSAPI is not supported by this build"; +#endif + } + else if (token->string[4] == 'n' && token->string[6] == 's') + parsedline->conntype = ctHostNoSSL; + else if (token->string[4] == 'n' && token->string[6] == 'g') + parsedline->conntype = ctHostNoGSS; + else + { + /* "host" */ + parsedline->conntype = ctHost; + } + } /* record type */ + else + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("invalid connection type \"%s\"", + token->string), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = psprintf("invalid connection type \"%s\"", token->string); + return NULL; + } + + /* Get the databases. */ + field = lnext(tok_line->fields, field); + if (!field) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("end-of-line before database specification"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "end-of-line before database specification"; + return NULL; + } + parsedline->databases = NIL; + tokens = lfirst(field); + foreach(tokencell, tokens) + { + parsedline->databases = lappend(parsedline->databases, + copy_hba_token(lfirst(tokencell))); + } + + /* Get the roles. */ + field = lnext(tok_line->fields, field); + if (!field) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("end-of-line before role specification"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "end-of-line before role specification"; + return NULL; + } + parsedline->roles = NIL; + tokens = lfirst(field); + foreach(tokencell, tokens) + { + parsedline->roles = lappend(parsedline->roles, + copy_hba_token(lfirst(tokencell))); + } + + if (parsedline->conntype != ctLocal) + { + /* Read the IP address field. (with or without CIDR netmask) */ + field = lnext(tok_line->fields, field); + if (!field) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("end-of-line before IP address specification"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "end-of-line before IP address specification"; + return NULL; + } + tokens = lfirst(field); + if (tokens->length > 1) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("multiple values specified for host address"), + errhint("Specify one address range per line."), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "multiple values specified for host address"; + return NULL; + } + token = linitial(tokens); + + if (token_is_keyword(token, "all")) + { + parsedline->ip_cmp_method = ipCmpAll; + } + else if (token_is_keyword(token, "samehost")) + { + /* Any IP on this host is allowed to connect */ + parsedline->ip_cmp_method = ipCmpSameHost; + } + else if (token_is_keyword(token, "samenet")) + { + /* Any IP on the host's subnets is allowed to connect */ + parsedline->ip_cmp_method = ipCmpSameNet; + } + else + { + /* IP and netmask are specified */ + parsedline->ip_cmp_method = ipCmpMask; + + /* need a modifiable copy of token */ + str = pstrdup(token->string); + + /* Check if it has a CIDR suffix and if so isolate it */ + cidr_slash = strchr(str, '/'); + if (cidr_slash) + *cidr_slash = '\0'; + + /* Get the IP address either way */ + hints.ai_flags = AI_NUMERICHOST; + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = 0; + hints.ai_protocol = 0; + hints.ai_addrlen = 0; + hints.ai_canonname = NULL; + hints.ai_addr = NULL; + hints.ai_next = NULL; + + ret = pg_getaddrinfo_all(str, NULL, &hints, &gai_result); + if (ret == 0 && gai_result) + { + memcpy(&parsedline->addr, gai_result->ai_addr, + gai_result->ai_addrlen); + parsedline->addrlen = gai_result->ai_addrlen; + } + else if (ret == EAI_NONAME) + parsedline->hostname = str; + else + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("invalid IP address \"%s\": %s", + str, gai_strerror(ret)), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = psprintf("invalid IP address \"%s\": %s", + str, gai_strerror(ret)); + if (gai_result) + pg_freeaddrinfo_all(hints.ai_family, gai_result); + return NULL; + } + + pg_freeaddrinfo_all(hints.ai_family, gai_result); + + /* Get the netmask */ + if (cidr_slash) + { + if (parsedline->hostname) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("specifying both host name and CIDR mask is invalid: \"%s\"", + token->string), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = psprintf("specifying both host name and CIDR mask is invalid: \"%s\"", + token->string); + return NULL; + } + + if (pg_sockaddr_cidr_mask(&parsedline->mask, cidr_slash + 1, + parsedline->addr.ss_family) < 0) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("invalid CIDR mask in address \"%s\"", + token->string), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = psprintf("invalid CIDR mask in address \"%s\"", + token->string); + return NULL; + } + parsedline->masklen = parsedline->addrlen; + pfree(str); + } + else if (!parsedline->hostname) + { + /* Read the mask field. */ + pfree(str); + field = lnext(tok_line->fields, field); + if (!field) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("end-of-line before netmask specification"), + errhint("Specify an address range in CIDR notation, or provide a separate netmask."), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "end-of-line before netmask specification"; + return NULL; + } + tokens = lfirst(field); + if (tokens->length > 1) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("multiple values specified for netmask"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "multiple values specified for netmask"; + return NULL; + } + token = linitial(tokens); + + ret = pg_getaddrinfo_all(token->string, NULL, + &hints, &gai_result); + if (ret || !gai_result) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("invalid IP mask \"%s\": %s", + token->string, gai_strerror(ret)), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = psprintf("invalid IP mask \"%s\": %s", + token->string, gai_strerror(ret)); + if (gai_result) + pg_freeaddrinfo_all(hints.ai_family, gai_result); + return NULL; + } + + memcpy(&parsedline->mask, gai_result->ai_addr, + gai_result->ai_addrlen); + parsedline->masklen = gai_result->ai_addrlen; + pg_freeaddrinfo_all(hints.ai_family, gai_result); + + if (parsedline->addr.ss_family != parsedline->mask.ss_family) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("IP address and mask do not match"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "IP address and mask do not match"; + return NULL; + } + } + } + } /* != ctLocal */ + + /* Get the authentication method */ + field = lnext(tok_line->fields, field); + if (!field) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("end-of-line before authentication method"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "end-of-line before authentication method"; + return NULL; + } + tokens = lfirst(field); + if (tokens->length > 1) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("multiple values specified for authentication type"), + errhint("Specify exactly one authentication type per line."), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "multiple values specified for authentication type"; + return NULL; + } + token = linitial(tokens); + + unsupauth = NULL; + if (strcmp(token->string, "trust") == 0) + parsedline->auth_method = uaTrust; + else if (strcmp(token->string, "ident") == 0) + parsedline->auth_method = uaIdent; + else if (strcmp(token->string, "peer") == 0) + parsedline->auth_method = uaPeer; + else if (strcmp(token->string, "password") == 0) + parsedline->auth_method = uaPassword; + else if (strcmp(token->string, "gss") == 0) +#ifdef ENABLE_GSS + parsedline->auth_method = uaGSS; +#else + unsupauth = "gss"; +#endif + else if (strcmp(token->string, "sspi") == 0) +#ifdef ENABLE_SSPI + parsedline->auth_method = uaSSPI; +#else + unsupauth = "sspi"; +#endif + else if (strcmp(token->string, "reject") == 0) + parsedline->auth_method = uaReject; + else if (strcmp(token->string, "md5") == 0) + { + if (Db_user_namespace) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("MD5 authentication is not supported when \"db_user_namespace\" is enabled"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "MD5 authentication is not supported when \"db_user_namespace\" is enabled"; + return NULL; + } + parsedline->auth_method = uaMD5; + } + else if (strcmp(token->string, "scram-sha-256") == 0) + parsedline->auth_method = uaSCRAM; + else if (strcmp(token->string, "pam") == 0) +#ifdef USE_PAM + parsedline->auth_method = uaPAM; +#else + unsupauth = "pam"; +#endif + else if (strcmp(token->string, "bsd") == 0) +#ifdef USE_BSD_AUTH + parsedline->auth_method = uaBSD; +#else + unsupauth = "bsd"; +#endif + else if (strcmp(token->string, "ldap") == 0) +#ifdef USE_LDAP + parsedline->auth_method = uaLDAP; +#else + unsupauth = "ldap"; +#endif + else if (strcmp(token->string, "cert") == 0) +#ifdef USE_SSL + parsedline->auth_method = uaCert; +#else + unsupauth = "cert"; +#endif + else if (strcmp(token->string, "radius") == 0) + parsedline->auth_method = uaRADIUS; + else + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("invalid authentication method \"%s\"", + token->string), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = psprintf("invalid authentication method \"%s\"", + token->string); + return NULL; + } + + if (unsupauth) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("invalid authentication method \"%s\": not supported by this build", + token->string), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = psprintf("invalid authentication method \"%s\": not supported by this build", + token->string); + return NULL; + } + + /* + * XXX: When using ident on local connections, change it to peer, for + * backwards compatibility. + */ + if (parsedline->conntype == ctLocal && + parsedline->auth_method == uaIdent) + parsedline->auth_method = uaPeer; + + /* Invalid authentication combinations */ + if (parsedline->conntype == ctLocal && + parsedline->auth_method == uaGSS) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("gssapi authentication is not supported on local sockets"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "gssapi authentication is not supported on local sockets"; + return NULL; + } + + if (parsedline->conntype != ctLocal && + parsedline->auth_method == uaPeer) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("peer authentication is only supported on local sockets"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "peer authentication is only supported on local sockets"; + return NULL; + } + + /* + * SSPI authentication can never be enabled on ctLocal connections, + * because it's only supported on Windows, where ctLocal isn't supported. + */ + + + if (parsedline->conntype != ctHostSSL && + parsedline->auth_method == uaCert) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("cert authentication is only supported on hostssl connections"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "cert authentication is only supported on hostssl connections"; + return NULL; + } + + /* + * For GSS and SSPI, set the default value of include_realm to true. + * Having include_realm set to false is dangerous in multi-realm + * situations and is generally considered bad practice. We keep the + * capability around for backwards compatibility, but we might want to + * remove it at some point in the future. Users who still need to strip + * the realm off would be better served by using an appropriate regex in a + * pg_ident.conf mapping. + */ + if (parsedline->auth_method == uaGSS || + parsedline->auth_method == uaSSPI) + parsedline->include_realm = true; + + /* + * For SSPI, include_realm defaults to the SAM-compatible domain (aka + * NetBIOS name) and user names instead of the Kerberos principal name for + * compatibility. + */ + if (parsedline->auth_method == uaSSPI) + { + parsedline->compat_realm = true; + parsedline->upn_username = false; + } + + /* Parse remaining arguments */ + while ((field = lnext(tok_line->fields, field)) != NULL) + { + tokens = lfirst(field); + foreach(tokencell, tokens) + { + char *val; + + token = lfirst(tokencell); + + str = pstrdup(token->string); + val = strchr(str, '='); + if (val == NULL) + { + /* + * Got something that's not a name=value pair. + */ + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("authentication option not in name=value format: %s", token->string), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = psprintf("authentication option not in name=value format: %s", + token->string); + return NULL; + } + + *val++ = '\0'; /* str now holds "name", val holds "value" */ + if (!parse_hba_auth_opt(str, val, parsedline, elevel, err_msg)) + /* parse_hba_auth_opt already logged the error message */ + return NULL; + pfree(str); + } + } + + /* + * Check if the selected authentication method has any mandatory arguments + * that are not set. + */ + if (parsedline->auth_method == uaLDAP) + { +#ifndef HAVE_LDAP_INITIALIZE + /* Not mandatory for OpenLDAP, because it can use DNS SRV records */ + MANDATORY_AUTH_ARG(parsedline->ldapserver, "ldapserver", "ldap"); +#endif + + /* + * LDAP can operate in two modes: either with a direct bind, using + * ldapprefix and ldapsuffix, or using a search+bind, using + * ldapbasedn, ldapbinddn, ldapbindpasswd and one of + * ldapsearchattribute or ldapsearchfilter. Disallow mixing these + * parameters. + */ + if (parsedline->ldapprefix || parsedline->ldapsuffix) + { + if (parsedline->ldapbasedn || + parsedline->ldapbinddn || + parsedline->ldapbindpasswd || + parsedline->ldapsearchattribute || + parsedline->ldapsearchfilter) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("cannot use ldapbasedn, ldapbinddn, ldapbindpasswd, ldapsearchattribute, ldapsearchfilter, or ldapurl together with ldapprefix"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "cannot use ldapbasedn, ldapbinddn, ldapbindpasswd, ldapsearchattribute, ldapsearchfilter, or ldapurl together with ldapprefix"; + return NULL; + } + } + else if (!parsedline->ldapbasedn) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("authentication method \"ldap\" requires argument \"ldapbasedn\", \"ldapprefix\", or \"ldapsuffix\" to be set"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "authentication method \"ldap\" requires argument \"ldapbasedn\", \"ldapprefix\", or \"ldapsuffix\" to be set"; + return NULL; + } + + /* + * When using search+bind, you can either use a simple attribute + * (defaulting to "uid") or a fully custom search filter. You can't + * do both. + */ + if (parsedline->ldapsearchattribute && parsedline->ldapsearchfilter) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("cannot use ldapsearchattribute together with ldapsearchfilter"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "cannot use ldapsearchattribute together with ldapsearchfilter"; + return NULL; + } + } + + if (parsedline->auth_method == uaRADIUS) + { + MANDATORY_AUTH_ARG(parsedline->radiusservers, "radiusservers", "radius"); + MANDATORY_AUTH_ARG(parsedline->radiussecrets, "radiussecrets", "radius"); + + if (list_length(parsedline->radiusservers) < 1) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("list of RADIUS servers cannot be empty"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "list of RADIUS servers cannot be empty"; + return NULL; + } + + if (list_length(parsedline->radiussecrets) < 1) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("list of RADIUS secrets cannot be empty"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "list of RADIUS secrets cannot be empty"; + return NULL; + } + + /* + * Verify length of option lists - each can be 0 (except for secrets, + * but that's already checked above), 1 (use the same value + * everywhere) or the same as the number of servers. + */ + if (!(list_length(parsedline->radiussecrets) == 1 || + list_length(parsedline->radiussecrets) == list_length(parsedline->radiusservers))) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("the number of RADIUS secrets (%d) must be 1 or the same as the number of RADIUS servers (%d)", + list_length(parsedline->radiussecrets), + list_length(parsedline->radiusservers)), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = psprintf("the number of RADIUS secrets (%d) must be 1 or the same as the number of RADIUS servers (%d)", + list_length(parsedline->radiussecrets), + list_length(parsedline->radiusservers)); + return NULL; + } + if (!(list_length(parsedline->radiusports) == 0 || + list_length(parsedline->radiusports) == 1 || + list_length(parsedline->radiusports) == list_length(parsedline->radiusservers))) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("the number of RADIUS ports (%d) must be 1 or the same as the number of RADIUS servers (%d)", + list_length(parsedline->radiusports), + list_length(parsedline->radiusservers)), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = psprintf("the number of RADIUS ports (%d) must be 1 or the same as the number of RADIUS servers (%d)", + list_length(parsedline->radiusports), + list_length(parsedline->radiusservers)); + return NULL; + } + if (!(list_length(parsedline->radiusidentifiers) == 0 || + list_length(parsedline->radiusidentifiers) == 1 || + list_length(parsedline->radiusidentifiers) == list_length(parsedline->radiusservers))) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("the number of RADIUS identifiers (%d) must be 1 or the same as the number of RADIUS servers (%d)", + list_length(parsedline->radiusidentifiers), + list_length(parsedline->radiusservers)), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = psprintf("the number of RADIUS identifiers (%d) must be 1 or the same as the number of RADIUS servers (%d)", + list_length(parsedline->radiusidentifiers), + list_length(parsedline->radiusservers)); + return NULL; + } + } + + /* + * Enforce any parameters implied by other settings. + */ + if (parsedline->auth_method == uaCert) + { + /* + * For auth method cert, client certificate validation is mandatory, and it implies + * the level of verify-full. + */ + parsedline->clientcert = clientCertFull; + } + + return parsedline; +} + + +/* + * Parse one name-value pair as an authentication option into the given + * HbaLine. Return true if we successfully parse the option, false if we + * encounter an error. In the event of an error, also log a message at + * ereport level elevel, and store a message string into *err_msg. + */ +static bool +parse_hba_auth_opt(char *name, char *val, HbaLine *hbaline, + int elevel, char **err_msg) +{ + int line_num = hbaline->linenumber; + +#ifdef USE_LDAP + hbaline->ldapscope = LDAP_SCOPE_SUBTREE; +#endif + + if (strcmp(name, "map") == 0) + { + if (hbaline->auth_method != uaIdent && + hbaline->auth_method != uaPeer && + hbaline->auth_method != uaGSS && + hbaline->auth_method != uaSSPI && + hbaline->auth_method != uaCert) + INVALID_AUTH_OPTION("map", gettext_noop("ident, peer, gssapi, sspi, and cert")); + hbaline->usermap = pstrdup(val); + } + else if (strcmp(name, "clientcert") == 0) + { + if (hbaline->conntype != ctHostSSL) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("clientcert can only be configured for \"hostssl\" rows"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "clientcert can only be configured for \"hostssl\" rows"; + return false; + } + + if (strcmp(val, "verify-full") == 0) + { + hbaline->clientcert = clientCertFull; + } + else if (strcmp(val, "verify-ca") == 0) + { + if (hbaline->auth_method == uaCert) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("clientcert only accepts \"verify-full\" when using \"cert\" authentication"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "clientcert can only be set to \"verify-full\" when using \"cert\" authentication"; + return false; + } + + hbaline->clientcert = clientCertCA; + } + else + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("invalid value for clientcert: \"%s\"", val), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + return false; + } + } + else if (strcmp(name, "clientname") == 0) + { + if (hbaline->conntype != ctHostSSL) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("clientname can only be configured for \"hostssl\" rows"), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = "clientname can only be configured for \"hostssl\" rows"; + return false; + } + + if (strcmp(val, "CN") == 0) + { + hbaline->clientcertname = clientCertCN; + } + else if (strcmp(val, "DN") == 0) + { + hbaline->clientcertname = clientCertDN; + } + else + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("invalid value for clientname: \"%s\"", val), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + return false; + } + } + else if (strcmp(name, "pamservice") == 0) + { + REQUIRE_AUTH_OPTION(uaPAM, "pamservice", "pam"); + hbaline->pamservice = pstrdup(val); + } + else if (strcmp(name, "pam_use_hostname") == 0) + { + REQUIRE_AUTH_OPTION(uaPAM, "pam_use_hostname", "pam"); + if (strcmp(val, "1") == 0) + hbaline->pam_use_hostname = true; + else + hbaline->pam_use_hostname = false; + + } + else if (strcmp(name, "ldapurl") == 0) + { +#ifdef LDAP_API_FEATURE_X_OPENLDAP + LDAPURLDesc *urldata; + int rc; +#endif + + REQUIRE_AUTH_OPTION(uaLDAP, "ldapurl", "ldap"); +#ifdef LDAP_API_FEATURE_X_OPENLDAP + rc = ldap_url_parse(val, &urldata); + if (rc != LDAP_SUCCESS) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("could not parse LDAP URL \"%s\": %s", val, ldap_err2string(rc)))); + *err_msg = psprintf("could not parse LDAP URL \"%s\": %s", + val, ldap_err2string(rc)); + return false; + } + + if (strcmp(urldata->lud_scheme, "ldap") != 0 && + strcmp(urldata->lud_scheme, "ldaps") != 0) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("unsupported LDAP URL scheme: %s", urldata->lud_scheme))); + *err_msg = psprintf("unsupported LDAP URL scheme: %s", + urldata->lud_scheme); + ldap_free_urldesc(urldata); + return false; + } + + if (urldata->lud_scheme) + hbaline->ldapscheme = pstrdup(urldata->lud_scheme); + if (urldata->lud_host) + hbaline->ldapserver = pstrdup(urldata->lud_host); + hbaline->ldapport = urldata->lud_port; + if (urldata->lud_dn) + hbaline->ldapbasedn = pstrdup(urldata->lud_dn); + + if (urldata->lud_attrs) + hbaline->ldapsearchattribute = pstrdup(urldata->lud_attrs[0]); /* only use first one */ + hbaline->ldapscope = urldata->lud_scope; + if (urldata->lud_filter) + hbaline->ldapsearchfilter = pstrdup(urldata->lud_filter); + ldap_free_urldesc(urldata); +#else /* not OpenLDAP */ + ereport(elevel, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("LDAP URLs not supported on this platform"))); + *err_msg = "LDAP URLs not supported on this platform"; +#endif /* not OpenLDAP */ + } + else if (strcmp(name, "ldaptls") == 0) + { + REQUIRE_AUTH_OPTION(uaLDAP, "ldaptls", "ldap"); + if (strcmp(val, "1") == 0) + hbaline->ldaptls = true; + else + hbaline->ldaptls = false; + } + else if (strcmp(name, "ldapscheme") == 0) + { + REQUIRE_AUTH_OPTION(uaLDAP, "ldapscheme", "ldap"); + if (strcmp(val, "ldap") != 0 && strcmp(val, "ldaps") != 0) + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("invalid ldapscheme value: \"%s\"", val), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + hbaline->ldapscheme = pstrdup(val); + } + else if (strcmp(name, "ldapserver") == 0) + { + REQUIRE_AUTH_OPTION(uaLDAP, "ldapserver", "ldap"); + hbaline->ldapserver = pstrdup(val); + } + else if (strcmp(name, "ldapport") == 0) + { + REQUIRE_AUTH_OPTION(uaLDAP, "ldapport", "ldap"); + hbaline->ldapport = atoi(val); + if (hbaline->ldapport == 0) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("invalid LDAP port number: \"%s\"", val), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = psprintf("invalid LDAP port number: \"%s\"", val); + return false; + } + } + else if (strcmp(name, "ldapbinddn") == 0) + { + REQUIRE_AUTH_OPTION(uaLDAP, "ldapbinddn", "ldap"); + hbaline->ldapbinddn = pstrdup(val); + } + else if (strcmp(name, "ldapbindpasswd") == 0) + { + REQUIRE_AUTH_OPTION(uaLDAP, "ldapbindpasswd", "ldap"); + hbaline->ldapbindpasswd = pstrdup(val); + } + else if (strcmp(name, "ldapsearchattribute") == 0) + { + REQUIRE_AUTH_OPTION(uaLDAP, "ldapsearchattribute", "ldap"); + hbaline->ldapsearchattribute = pstrdup(val); + } + else if (strcmp(name, "ldapsearchfilter") == 0) + { + REQUIRE_AUTH_OPTION(uaLDAP, "ldapsearchfilter", "ldap"); + hbaline->ldapsearchfilter = pstrdup(val); + } + else if (strcmp(name, "ldapbasedn") == 0) + { + REQUIRE_AUTH_OPTION(uaLDAP, "ldapbasedn", "ldap"); + hbaline->ldapbasedn = pstrdup(val); + } + else if (strcmp(name, "ldapprefix") == 0) + { + REQUIRE_AUTH_OPTION(uaLDAP, "ldapprefix", "ldap"); + hbaline->ldapprefix = pstrdup(val); + } + else if (strcmp(name, "ldapsuffix") == 0) + { + REQUIRE_AUTH_OPTION(uaLDAP, "ldapsuffix", "ldap"); + hbaline->ldapsuffix = pstrdup(val); + } + else if (strcmp(name, "krb_realm") == 0) + { + if (hbaline->auth_method != uaGSS && + hbaline->auth_method != uaSSPI) + INVALID_AUTH_OPTION("krb_realm", gettext_noop("gssapi and sspi")); + hbaline->krb_realm = pstrdup(val); + } + else if (strcmp(name, "include_realm") == 0) + { + if (hbaline->auth_method != uaGSS && + hbaline->auth_method != uaSSPI) + INVALID_AUTH_OPTION("include_realm", gettext_noop("gssapi and sspi")); + if (strcmp(val, "1") == 0) + hbaline->include_realm = true; + else + hbaline->include_realm = false; + } + else if (strcmp(name, "compat_realm") == 0) + { + if (hbaline->auth_method != uaSSPI) + INVALID_AUTH_OPTION("compat_realm", gettext_noop("sspi")); + if (strcmp(val, "1") == 0) + hbaline->compat_realm = true; + else + hbaline->compat_realm = false; + } + else if (strcmp(name, "upn_username") == 0) + { + if (hbaline->auth_method != uaSSPI) + INVALID_AUTH_OPTION("upn_username", gettext_noop("sspi")); + if (strcmp(val, "1") == 0) + hbaline->upn_username = true; + else + hbaline->upn_username = false; + } + else if (strcmp(name, "radiusservers") == 0) + { + struct addrinfo *gai_result; + struct addrinfo hints; + int ret; + List *parsed_servers; + ListCell *l; + char *dupval = pstrdup(val); + + REQUIRE_AUTH_OPTION(uaRADIUS, "radiusservers", "radius"); + + if (!SplitGUCList(dupval, ',', &parsed_servers)) + { + /* syntax error in list */ + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("could not parse RADIUS server list \"%s\"", + val), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + return false; + } + + /* For each entry in the list, translate it */ + foreach(l, parsed_servers) + { + MemSet(&hints, 0, sizeof(hints)); + hints.ai_socktype = SOCK_DGRAM; + hints.ai_family = AF_UNSPEC; + + ret = pg_getaddrinfo_all((char *) lfirst(l), NULL, &hints, &gai_result); + if (ret || !gai_result) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("could not translate RADIUS server name \"%s\" to address: %s", + (char *) lfirst(l), gai_strerror(ret)), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + if (gai_result) + pg_freeaddrinfo_all(hints.ai_family, gai_result); + + list_free(parsed_servers); + return false; + } + pg_freeaddrinfo_all(hints.ai_family, gai_result); + } + + /* All entries are OK, so store them */ + hbaline->radiusservers = parsed_servers; + hbaline->radiusservers_s = pstrdup(val); + } + else if (strcmp(name, "radiusports") == 0) + { + List *parsed_ports; + ListCell *l; + char *dupval = pstrdup(val); + + REQUIRE_AUTH_OPTION(uaRADIUS, "radiusports", "radius"); + + if (!SplitGUCList(dupval, ',', &parsed_ports)) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("could not parse RADIUS port list \"%s\"", + val), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = psprintf("invalid RADIUS port number: \"%s\"", val); + return false; + } + + foreach(l, parsed_ports) + { + if (atoi(lfirst(l)) == 0) + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("invalid RADIUS port number: \"%s\"", val), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + + return false; + } + } + hbaline->radiusports = parsed_ports; + hbaline->radiusports_s = pstrdup(val); + } + else if (strcmp(name, "radiussecrets") == 0) + { + List *parsed_secrets; + char *dupval = pstrdup(val); + + REQUIRE_AUTH_OPTION(uaRADIUS, "radiussecrets", "radius"); + + if (!SplitGUCList(dupval, ',', &parsed_secrets)) + { + /* syntax error in list */ + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("could not parse RADIUS secret list \"%s\"", + val), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + return false; + } + + hbaline->radiussecrets = parsed_secrets; + hbaline->radiussecrets_s = pstrdup(val); + } + else if (strcmp(name, "radiusidentifiers") == 0) + { + List *parsed_identifiers; + char *dupval = pstrdup(val); + + REQUIRE_AUTH_OPTION(uaRADIUS, "radiusidentifiers", "radius"); + + if (!SplitGUCList(dupval, ',', &parsed_identifiers)) + { + /* syntax error in list */ + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("could not parse RADIUS identifiers list \"%s\"", + val), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + return false; + } + + hbaline->radiusidentifiers = parsed_identifiers; + hbaline->radiusidentifiers_s = pstrdup(val); + } + else + { + ereport(elevel, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("unrecognized authentication option name: \"%s\"", + name), + errcontext("line %d of configuration file \"%s\"", + line_num, HbaFileName))); + *err_msg = psprintf("unrecognized authentication option name: \"%s\"", + name); + return false; + } + return true; +} + +/* + * Scan the pre-parsed hba file, looking for a match to the port's connection + * request. + */ +static void +check_hba(hbaPort *port) +{ + Oid roleid; + ListCell *line; + HbaLine *hba; + + /* Get the target role's OID. Note we do not error out for bad role. */ + roleid = get_role_oid(port->user_name, true); + + foreach(line, parsed_hba_lines) + { + hba = (HbaLine *) lfirst(line); + + /* Check connection type */ + if (hba->conntype == ctLocal) + { + if (!IS_AF_UNIX(port->raddr.addr.ss_family)) + continue; + } + else + { + if (IS_AF_UNIX(port->raddr.addr.ss_family)) + continue; + + /* Check SSL state */ + if (port->ssl_in_use) + { + /* Connection is SSL, match both "host" and "hostssl" */ + if (hba->conntype == ctHostNoSSL) + continue; + } + else + { + /* Connection is not SSL, match both "host" and "hostnossl" */ + if (hba->conntype == ctHostSSL) + continue; + } + + /* Check GSSAPI state */ +#ifdef ENABLE_GSS + if (port->gss && port->gss->enc && + hba->conntype == ctHostNoGSS) + continue; + else if (!(port->gss && port->gss->enc) && + hba->conntype == ctHostGSS) + continue; +#else + if (hba->conntype == ctHostGSS) + continue; +#endif + + /* Check IP address */ + switch (hba->ip_cmp_method) + { + case ipCmpMask: + if (hba->hostname) + { + if (!check_hostname(port, + hba->hostname)) + continue; + } + else + { + if (!check_ip(&port->raddr, + (struct sockaddr *) &hba->addr, + (struct sockaddr *) &hba->mask)) + continue; + } + break; + case ipCmpAll: + break; + case ipCmpSameHost: + case ipCmpSameNet: + if (!check_same_host_or_net(&port->raddr, + hba->ip_cmp_method)) + continue; + break; + default: + /* shouldn't get here, but deem it no-match if so */ + continue; + } + } /* != ctLocal */ + + /* Check database and role */ + if (!check_db(port->database_name, port->user_name, roleid, + hba->databases)) + continue; + + if (!check_role(port->user_name, roleid, hba->roles)) + continue; + + /* Found a record that matched! */ + port->hba = hba; + return; + } + + /* If no matching entry was found, then implicitly reject. */ + hba = palloc0(sizeof(HbaLine)); + hba->auth_method = uaImplicitReject; + port->hba = hba; +} + +/* + * Read the config file and create a List of HbaLine records for the contents. + * + * The configuration is read into a temporary list, and if any parse error + * occurs the old list is kept in place and false is returned. Only if the + * whole file parses OK is the list replaced, and the function returns true. + * + * On a false result, caller will take care of reporting a FATAL error in case + * this is the initial startup. If it happens on reload, we just keep running + * with the old data. + */ +bool +load_hba(void) +{ + FILE *file; + List *hba_lines = NIL; + ListCell *line; + List *new_parsed_lines = NIL; + bool ok = true; + MemoryContext linecxt; + MemoryContext oldcxt; + MemoryContext hbacxt; + + file = AllocateFile(HbaFileName, "r"); + if (file == NULL) + { + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not open configuration file \"%s\": %m", + HbaFileName))); + return false; + } + + linecxt = tokenize_file(HbaFileName, file, &hba_lines, LOG); + FreeFile(file); + + /* Now parse all the lines */ + Assert(PostmasterContext); + hbacxt = AllocSetContextCreate(PostmasterContext, + "hba parser context", + ALLOCSET_SMALL_SIZES); + oldcxt = MemoryContextSwitchTo(hbacxt); + foreach(line, hba_lines) + { + TokenizedLine *tok_line = (TokenizedLine *) lfirst(line); + HbaLine *newline; + + /* don't parse lines that already have errors */ + if (tok_line->err_msg != NULL) + { + ok = false; + continue; + } + + if ((newline = parse_hba_line(tok_line, LOG)) == NULL) + { + /* Parse error; remember there's trouble */ + ok = false; + + /* + * Keep parsing the rest of the file so we can report errors on + * more than the first line. Error has already been logged, no + * need for more chatter here. + */ + continue; + } + + new_parsed_lines = lappend(new_parsed_lines, newline); + } + + /* + * A valid HBA file must have at least one entry; else there's no way to + * connect to the postmaster. But only complain about this if we didn't + * already have parsing errors. + */ + if (ok && new_parsed_lines == NIL) + { + ereport(LOG, + (errcode(ERRCODE_CONFIG_FILE_ERROR), + errmsg("configuration file \"%s\" contains no entries", + HbaFileName))); + ok = false; + } + + /* Free tokenizer memory */ + MemoryContextDelete(linecxt); + MemoryContextSwitchTo(oldcxt); + + if (!ok) + { + /* File contained one or more errors, so bail out */ + MemoryContextDelete(hbacxt); + return false; + } + + /* Loaded new file successfully, replace the one we use */ + if (parsed_hba_context != NULL) + MemoryContextDelete(parsed_hba_context); + parsed_hba_context = hbacxt; + parsed_hba_lines = new_parsed_lines; + + return true; +} + +/* + * This macro specifies the maximum number of authentication options + * that are possible with any given authentication method that is supported. + * Currently LDAP supports 11, and there are 3 that are not dependent on + * the auth method here. It may not actually be possible to set all of them + * at the same time, but we'll set the macro value high enough to be + * conservative and avoid warnings from static analysis tools. + */ +#define MAX_HBA_OPTIONS 14 + +/* + * Create a text array listing the options specified in the HBA line. + * Return NULL if no options are specified. + */ +static ArrayType * +gethba_options(HbaLine *hba) +{ + int noptions; + Datum options[MAX_HBA_OPTIONS]; + + noptions = 0; + + if (hba->auth_method == uaGSS || hba->auth_method == uaSSPI) + { + if (hba->include_realm) + options[noptions++] = + CStringGetTextDatum("include_realm=true"); + + if (hba->krb_realm) + options[noptions++] = + CStringGetTextDatum(psprintf("krb_realm=%s", hba->krb_realm)); + } + + if (hba->usermap) + options[noptions++] = + CStringGetTextDatum(psprintf("map=%s", hba->usermap)); + + if (hba->clientcert != clientCertOff) + options[noptions++] = + CStringGetTextDatum(psprintf("clientcert=%s", (hba->clientcert == clientCertCA) ? "verify-ca" : "verify-full")); + + if (hba->pamservice) + options[noptions++] = + CStringGetTextDatum(psprintf("pamservice=%s", hba->pamservice)); + + if (hba->auth_method == uaLDAP) + { + if (hba->ldapserver) + options[noptions++] = + CStringGetTextDatum(psprintf("ldapserver=%s", hba->ldapserver)); + + if (hba->ldapport) + options[noptions++] = + CStringGetTextDatum(psprintf("ldapport=%d", hba->ldapport)); + + if (hba->ldaptls) + options[noptions++] = + CStringGetTextDatum("ldaptls=true"); + + if (hba->ldapprefix) + options[noptions++] = + CStringGetTextDatum(psprintf("ldapprefix=%s", hba->ldapprefix)); + + if (hba->ldapsuffix) + options[noptions++] = + CStringGetTextDatum(psprintf("ldapsuffix=%s", hba->ldapsuffix)); + + if (hba->ldapbasedn) + options[noptions++] = + CStringGetTextDatum(psprintf("ldapbasedn=%s", hba->ldapbasedn)); + + if (hba->ldapbinddn) + options[noptions++] = + CStringGetTextDatum(psprintf("ldapbinddn=%s", hba->ldapbinddn)); + + if (hba->ldapbindpasswd) + options[noptions++] = + CStringGetTextDatum(psprintf("ldapbindpasswd=%s", + hba->ldapbindpasswd)); + + if (hba->ldapsearchattribute) + options[noptions++] = + CStringGetTextDatum(psprintf("ldapsearchattribute=%s", + hba->ldapsearchattribute)); + + if (hba->ldapsearchfilter) + options[noptions++] = + CStringGetTextDatum(psprintf("ldapsearchfilter=%s", + hba->ldapsearchfilter)); + + if (hba->ldapscope) + options[noptions++] = + CStringGetTextDatum(psprintf("ldapscope=%d", hba->ldapscope)); + } + + if (hba->auth_method == uaRADIUS) + { + if (hba->radiusservers_s) + options[noptions++] = + CStringGetTextDatum(psprintf("radiusservers=%s", hba->radiusservers_s)); + + if (hba->radiussecrets_s) + options[noptions++] = + CStringGetTextDatum(psprintf("radiussecrets=%s", hba->radiussecrets_s)); + + if (hba->radiusidentifiers_s) + options[noptions++] = + CStringGetTextDatum(psprintf("radiusidentifiers=%s", hba->radiusidentifiers_s)); + + if (hba->radiusports_s) + options[noptions++] = + CStringGetTextDatum(psprintf("radiusports=%s", hba->radiusports_s)); + } + + /* If you add more options, consider increasing MAX_HBA_OPTIONS. */ + Assert(noptions <= MAX_HBA_OPTIONS); + + if (noptions > 0) + return construct_array(options, noptions, TEXTOID, -1, false, TYPALIGN_INT); + else + return NULL; +} + +/* Number of columns in pg_hba_file_rules view */ +#define NUM_PG_HBA_FILE_RULES_ATTS 9 + +/* + * fill_hba_line: build one row of pg_hba_file_rules view, add it to tuplestore + * + * tuple_store: where to store data + * tupdesc: tuple descriptor for the view + * lineno: pg_hba.conf line number (must always be valid) + * hba: parsed line data (can be NULL, in which case err_msg should be set) + * err_msg: error message (NULL if none) + * + * Note: leaks memory, but we don't care since this is run in a short-lived + * memory context. + */ +static void +fill_hba_line(Tuplestorestate *tuple_store, TupleDesc tupdesc, + int lineno, HbaLine *hba, const char *err_msg) +{ + Datum values[NUM_PG_HBA_FILE_RULES_ATTS]; + bool nulls[NUM_PG_HBA_FILE_RULES_ATTS]; + char buffer[NI_MAXHOST]; + HeapTuple tuple; + int index; + ListCell *lc; + const char *typestr; + const char *addrstr; + const char *maskstr; + ArrayType *options; + + Assert(tupdesc->natts == NUM_PG_HBA_FILE_RULES_ATTS); + + memset(values, 0, sizeof(values)); + memset(nulls, 0, sizeof(nulls)); + index = 0; + + /* line_number */ + values[index++] = Int32GetDatum(lineno); + + if (hba != NULL) + { + /* type */ + /* Avoid a default: case so compiler will warn about missing cases */ + typestr = NULL; + switch (hba->conntype) + { + case ctLocal: + typestr = "local"; + break; + case ctHost: + typestr = "host"; + break; + case ctHostSSL: + typestr = "hostssl"; + break; + case ctHostNoSSL: + typestr = "hostnossl"; + break; + case ctHostGSS: + typestr = "hostgssenc"; + break; + case ctHostNoGSS: + typestr = "hostnogssenc"; + break; + } + if (typestr) + values[index++] = CStringGetTextDatum(typestr); + else + nulls[index++] = true; + + /* database */ + if (hba->databases) + { + /* + * Flatten HbaToken list to string list. It might seem that we + * should re-quote any quoted tokens, but that has been rejected + * on the grounds that it makes it harder to compare the array + * elements to other system catalogs. That makes entries like + * "all" or "samerole" formally ambiguous ... but users who name + * databases/roles that way are inflicting their own pain. + */ + List *names = NIL; + + foreach(lc, hba->databases) + { + HbaToken *tok = lfirst(lc); + + names = lappend(names, tok->string); + } + values[index++] = PointerGetDatum(strlist_to_textarray(names)); + } + else + nulls[index++] = true; + + /* user */ + if (hba->roles) + { + /* Flatten HbaToken list to string list; see comment above */ + List *roles = NIL; + + foreach(lc, hba->roles) + { + HbaToken *tok = lfirst(lc); + + roles = lappend(roles, tok->string); + } + values[index++] = PointerGetDatum(strlist_to_textarray(roles)); + } + else + nulls[index++] = true; + + /* address and netmask */ + /* Avoid a default: case so compiler will warn about missing cases */ + addrstr = maskstr = NULL; + switch (hba->ip_cmp_method) + { + case ipCmpMask: + if (hba->hostname) + { + addrstr = hba->hostname; + } + else + { + /* + * Note: if pg_getnameinfo_all fails, it'll set buffer to + * "???", which we want to return. + */ + if (hba->addrlen > 0) + { + if (pg_getnameinfo_all(&hba->addr, hba->addrlen, + buffer, sizeof(buffer), + NULL, 0, + NI_NUMERICHOST) == 0) + clean_ipv6_addr(hba->addr.ss_family, buffer); + addrstr = pstrdup(buffer); + } + if (hba->masklen > 0) + { + if (pg_getnameinfo_all(&hba->mask, hba->masklen, + buffer, sizeof(buffer), + NULL, 0, + NI_NUMERICHOST) == 0) + clean_ipv6_addr(hba->mask.ss_family, buffer); + maskstr = pstrdup(buffer); + } + } + break; + case ipCmpAll: + addrstr = "all"; + break; + case ipCmpSameHost: + addrstr = "samehost"; + break; + case ipCmpSameNet: + addrstr = "samenet"; + break; + } + if (addrstr) + values[index++] = CStringGetTextDatum(addrstr); + else + nulls[index++] = true; + if (maskstr) + values[index++] = CStringGetTextDatum(maskstr); + else + nulls[index++] = true; + + /* auth_method */ + values[index++] = CStringGetTextDatum(hba_authname(hba->auth_method)); + + /* options */ + options = gethba_options(hba); + if (options) + values[index++] = PointerGetDatum(options); + else + nulls[index++] = true; + } + else + { + /* no parsing result, so set relevant fields to nulls */ + memset(&nulls[1], true, (NUM_PG_HBA_FILE_RULES_ATTS - 2) * sizeof(bool)); + } + + /* error */ + if (err_msg) + values[NUM_PG_HBA_FILE_RULES_ATTS - 1] = CStringGetTextDatum(err_msg); + else + nulls[NUM_PG_HBA_FILE_RULES_ATTS - 1] = true; + + tuple = heap_form_tuple(tupdesc, values, nulls); + tuplestore_puttuple(tuple_store, tuple); +} + +/* + * Read the pg_hba.conf file and fill the tuplestore with view records. + */ +static void +fill_hba_view(Tuplestorestate *tuple_store, TupleDesc tupdesc) +{ + FILE *file; + List *hba_lines = NIL; + ListCell *line; + MemoryContext linecxt; + MemoryContext hbacxt; + MemoryContext oldcxt; + + /* + * In the unlikely event that we can't open pg_hba.conf, we throw an + * error, rather than trying to report it via some sort of view entry. + * (Most other error conditions should result in a message in a view + * entry.) + */ + file = AllocateFile(HbaFileName, "r"); + if (file == NULL) + ereport(ERROR, + (errcode_for_file_access(), + errmsg("could not open configuration file \"%s\": %m", + HbaFileName))); + + linecxt = tokenize_file(HbaFileName, file, &hba_lines, DEBUG3); + FreeFile(file); + + /* Now parse all the lines */ + hbacxt = AllocSetContextCreate(CurrentMemoryContext, + "hba parser context", + ALLOCSET_SMALL_SIZES); + oldcxt = MemoryContextSwitchTo(hbacxt); + foreach(line, hba_lines) + { + TokenizedLine *tok_line = (TokenizedLine *) lfirst(line); + HbaLine *hbaline = NULL; + + /* don't parse lines that already have errors */ + if (tok_line->err_msg == NULL) + hbaline = parse_hba_line(tok_line, DEBUG3); + + fill_hba_line(tuple_store, tupdesc, tok_line->line_num, + hbaline, tok_line->err_msg); + } + + /* Free tokenizer memory */ + MemoryContextDelete(linecxt); + /* Free parse_hba_line memory */ + MemoryContextSwitchTo(oldcxt); + MemoryContextDelete(hbacxt); +} + +/* + * SQL-accessible SRF to return all the entries in the pg_hba.conf file. + */ +Datum +pg_hba_file_rules(PG_FUNCTION_ARGS) +{ + Tuplestorestate *tuple_store; + TupleDesc tupdesc; + MemoryContext old_cxt; + ReturnSetInfo *rsi; + + /* + * We must use the Materialize mode to be safe against HBA file changes + * while the cursor is open. It's also more efficient than having to look + * up our current position in the parsed list every time. + */ + rsi = (ReturnSetInfo *) fcinfo->resultinfo; + + /* Check to see if caller supports us returning a tuplestore */ + if (rsi == NULL || !IsA(rsi, ReturnSetInfo)) + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("set-valued function called in context that cannot accept a set"))); + if (!(rsi->allowedModes & SFRM_Materialize)) + ereport(ERROR, + (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("materialize mode required, but it is not allowed in this context"))); + + rsi->returnMode = SFRM_Materialize; + + /* Build a tuple descriptor for our result type */ + if (get_call_result_type(fcinfo, NULL, &tupdesc) != TYPEFUNC_COMPOSITE) + elog(ERROR, "return type must be a row type"); + + /* Build tuplestore to hold the result rows */ + old_cxt = MemoryContextSwitchTo(rsi->econtext->ecxt_per_query_memory); + + tuple_store = + tuplestore_begin_heap(rsi->allowedModes & SFRM_Materialize_Random, + false, work_mem); + rsi->setDesc = tupdesc; + rsi->setResult = tuple_store; + + MemoryContextSwitchTo(old_cxt); + + /* Fill the tuplestore */ + fill_hba_view(tuple_store, tupdesc); + + PG_RETURN_NULL(); +} + + +/* + * Parse one tokenised line from the ident config file and store the result in + * an IdentLine structure. + * + * If parsing fails, log a message and return NULL. + * + * If ident_user is a regular expression (ie. begins with a slash), it is + * compiled and stored in IdentLine structure. + * + * Note: this function leaks memory when an error occurs. Caller is expected + * to have set a memory context that will be reset if this function returns + * NULL. + */ +static IdentLine * +parse_ident_line(TokenizedLine *tok_line) +{ + int line_num = tok_line->line_num; + ListCell *field; + List *tokens; + HbaToken *token; + IdentLine *parsedline; + + Assert(tok_line->fields != NIL); + field = list_head(tok_line->fields); + + parsedline = palloc0(sizeof(IdentLine)); + parsedline->linenumber = line_num; + + /* Get the map token (must exist) */ + tokens = lfirst(field); + IDENT_MULTI_VALUE(tokens); + token = linitial(tokens); + parsedline->usermap = pstrdup(token->string); + + /* Get the ident user token */ + field = lnext(tok_line->fields, field); + IDENT_FIELD_ABSENT(field); + tokens = lfirst(field); + IDENT_MULTI_VALUE(tokens); + token = linitial(tokens); + parsedline->ident_user = pstrdup(token->string); + + /* Get the PG rolename token */ + field = lnext(tok_line->fields, field); + IDENT_FIELD_ABSENT(field); + tokens = lfirst(field); + IDENT_MULTI_VALUE(tokens); + token = linitial(tokens); + parsedline->pg_role = pstrdup(token->string); + + if (parsedline->ident_user[0] == '/') + { + /* + * When system username starts with a slash, treat it as a regular + * expression. Pre-compile it. + */ + int r; + pg_wchar *wstr; + int wlen; + + wstr = palloc((strlen(parsedline->ident_user + 1) + 1) * sizeof(pg_wchar)); + wlen = pg_mb2wchar_with_len(parsedline->ident_user + 1, + wstr, strlen(parsedline->ident_user + 1)); + + r = pg_regcomp(&parsedline->re, wstr, wlen, REG_ADVANCED, C_COLLATION_OID); + if (r) + { + char errstr[100]; + + pg_regerror(r, &parsedline->re, errstr, sizeof(errstr)); + ereport(LOG, + (errcode(ERRCODE_INVALID_REGULAR_EXPRESSION), + errmsg("invalid regular expression \"%s\": %s", + parsedline->ident_user + 1, errstr))); + + pfree(wstr); + return NULL; + } + pfree(wstr); + } + + return parsedline; +} + +/* + * Process one line from the parsed ident config lines. + * + * Compare input parsed ident line to the needed map, pg_role and ident_user. + * *found_p and *error_p are set according to our results. + */ +static void +check_ident_usermap(IdentLine *identLine, const char *usermap_name, + const char *pg_role, const char *ident_user, + bool case_insensitive, bool *found_p, bool *error_p) +{ + *found_p = false; + *error_p = false; + + if (strcmp(identLine->usermap, usermap_name) != 0) + /* Line does not match the map name we're looking for, so just abort */ + return; + + /* Match? */ + if (identLine->ident_user[0] == '/') + { + /* + * When system username starts with a slash, treat it as a regular + * expression. In this case, we process the system username as a + * regular expression that returns exactly one match. This is replaced + * for \1 in the database username string, if present. + */ + int r; + regmatch_t matches[2]; + pg_wchar *wstr; + int wlen; + char *ofs; + char *regexp_pgrole; + + wstr = palloc((strlen(ident_user) + 1) * sizeof(pg_wchar)); + wlen = pg_mb2wchar_with_len(ident_user, wstr, strlen(ident_user)); + + r = pg_regexec(&identLine->re, wstr, wlen, 0, NULL, 2, matches, 0); + if (r) + { + char errstr[100]; + + if (r != REG_NOMATCH) + { + /* REG_NOMATCH is not an error, everything else is */ + pg_regerror(r, &identLine->re, errstr, sizeof(errstr)); + ereport(LOG, + (errcode(ERRCODE_INVALID_REGULAR_EXPRESSION), + errmsg("regular expression match for \"%s\" failed: %s", + identLine->ident_user + 1, errstr))); + *error_p = true; + } + + pfree(wstr); + return; + } + pfree(wstr); + + if ((ofs = strstr(identLine->pg_role, "\\1")) != NULL) + { + int offset; + + /* substitution of the first argument requested */ + if (matches[1].rm_so < 0) + { + ereport(LOG, + (errcode(ERRCODE_INVALID_REGULAR_EXPRESSION), + errmsg("regular expression \"%s\" has no subexpressions as requested by backreference in \"%s\"", + identLine->ident_user + 1, identLine->pg_role))); + *error_p = true; + return; + } + + /* + * length: original length minus length of \1 plus length of match + * plus null terminator + */ + regexp_pgrole = palloc0(strlen(identLine->pg_role) - 2 + (matches[1].rm_eo - matches[1].rm_so) + 1); + offset = ofs - identLine->pg_role; + memcpy(regexp_pgrole, identLine->pg_role, offset); + memcpy(regexp_pgrole + offset, + ident_user + matches[1].rm_so, + matches[1].rm_eo - matches[1].rm_so); + strcat(regexp_pgrole, ofs + 2); + } + else + { + /* no substitution, so copy the match */ + regexp_pgrole = pstrdup(identLine->pg_role); + } + + /* + * now check if the username actually matched what the user is trying + * to connect as + */ + if (case_insensitive) + { + if (pg_strcasecmp(regexp_pgrole, pg_role) == 0) + *found_p = true; + } + else + { + if (strcmp(regexp_pgrole, pg_role) == 0) + *found_p = true; + } + pfree(regexp_pgrole); + + return; + } + else + { + /* Not regular expression, so make complete match */ + if (case_insensitive) + { + if (pg_strcasecmp(identLine->pg_role, pg_role) == 0 && + pg_strcasecmp(identLine->ident_user, ident_user) == 0) + *found_p = true; + } + else + { + if (strcmp(identLine->pg_role, pg_role) == 0 && + strcmp(identLine->ident_user, ident_user) == 0) + *found_p = true; + } + } +} + + +/* + * Scan the (pre-parsed) ident usermap file line by line, looking for a match + * + * See if the user with ident username "auth_user" is allowed to act + * as Postgres user "pg_role" according to usermap "usermap_name". + * + * Special case: Usermap NULL, equivalent to what was previously called + * "sameuser" or "samerole", means don't look in the usermap file. + * That's an implied map wherein "pg_role" must be identical to + * "auth_user" in order to be authorized. + * + * Iff authorized, return STATUS_OK, otherwise return STATUS_ERROR. + */ +int +check_usermap(const char *usermap_name, + const char *pg_role, + const char *auth_user, + bool case_insensitive) +{ + bool found_entry = false, + error = false; + + if (usermap_name == NULL || usermap_name[0] == '\0') + { + if (case_insensitive) + { + if (pg_strcasecmp(pg_role, auth_user) == 0) + return STATUS_OK; + } + else + { + if (strcmp(pg_role, auth_user) == 0) + return STATUS_OK; + } + ereport(LOG, + (errmsg("provided user name (%s) and authenticated user name (%s) do not match", + pg_role, auth_user))); + return STATUS_ERROR; + } + else + { + ListCell *line_cell; + + foreach(line_cell, parsed_ident_lines) + { + check_ident_usermap(lfirst(line_cell), usermap_name, + pg_role, auth_user, case_insensitive, + &found_entry, &error); + if (found_entry || error) + break; + } + } + if (!found_entry && !error) + { + ereport(LOG, + (errmsg("no match in usermap \"%s\" for user \"%s\" authenticated as \"%s\"", + usermap_name, pg_role, auth_user))); + } + return found_entry ? STATUS_OK : STATUS_ERROR; +} + + +/* + * Read the ident config file and create a List of IdentLine records for + * the contents. + * + * This works the same as load_hba(), but for the user config file. + */ +bool +load_ident(void) +{ + FILE *file; + List *ident_lines = NIL; + ListCell *line_cell, + *parsed_line_cell; + List *new_parsed_lines = NIL; + bool ok = true; + MemoryContext linecxt; + MemoryContext oldcxt; + MemoryContext ident_context; + IdentLine *newline; + + file = AllocateFile(IdentFileName, "r"); + if (file == NULL) + { + /* not fatal ... we just won't do any special ident maps */ + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not open usermap file \"%s\": %m", + IdentFileName))); + return false; + } + + linecxt = tokenize_file(IdentFileName, file, &ident_lines, LOG); + FreeFile(file); + + /* Now parse all the lines */ + Assert(PostmasterContext); + ident_context = AllocSetContextCreate(PostmasterContext, + "ident parser context", + ALLOCSET_SMALL_SIZES); + oldcxt = MemoryContextSwitchTo(ident_context); + foreach(line_cell, ident_lines) + { + TokenizedLine *tok_line = (TokenizedLine *) lfirst(line_cell); + + /* don't parse lines that already have errors */ + if (tok_line->err_msg != NULL) + { + ok = false; + continue; + } + + if ((newline = parse_ident_line(tok_line)) == NULL) + { + /* Parse error; remember there's trouble */ + ok = false; + + /* + * Keep parsing the rest of the file so we can report errors on + * more than the first line. Error has already been logged, no + * need for more chatter here. + */ + continue; + } + + new_parsed_lines = lappend(new_parsed_lines, newline); + } + + /* Free tokenizer memory */ + MemoryContextDelete(linecxt); + MemoryContextSwitchTo(oldcxt); + + if (!ok) + { + /* + * File contained one or more errors, so bail out, first being careful + * to clean up whatever we allocated. Most stuff will go away via + * MemoryContextDelete, but we have to clean up regexes explicitly. + */ + foreach(parsed_line_cell, new_parsed_lines) + { + newline = (IdentLine *) lfirst(parsed_line_cell); + if (newline->ident_user[0] == '/') + pg_regfree(&newline->re); + } + MemoryContextDelete(ident_context); + return false; + } + + /* Loaded new file successfully, replace the one we use */ + if (parsed_ident_lines != NIL) + { + foreach(parsed_line_cell, parsed_ident_lines) + { + newline = (IdentLine *) lfirst(parsed_line_cell); + if (newline->ident_user[0] == '/') + pg_regfree(&newline->re); + } + } + if (parsed_ident_context != NULL) + MemoryContextDelete(parsed_ident_context); + + parsed_ident_context = ident_context; + parsed_ident_lines = new_parsed_lines; + + return true; +} + + + +/* + * Determine what authentication method should be used when accessing database + * "database" from frontend "raddr", user "user". Return the method and + * an optional argument (stored in fields of *port), and STATUS_OK. + * + * If the file does not contain any entry matching the request, we return + * method = uaImplicitReject. + */ +void +hba_getauthmethod(hbaPort *port) +{ + check_hba(port); +} + + +/* + * Return the name of the auth method in use ("gss", "md5", "trust", etc.). + * + * The return value is statically allocated (see the UserAuthName array) and + * should not be freed. + */ +const char * +hba_authname(UserAuth auth_method) +{ + /* + * Make sure UserAuthName[] tracks additions to the UserAuth enum + */ + StaticAssertStmt(lengthof(UserAuthName) == USER_AUTH_LAST + 1, + "UserAuthName[] must match the UserAuth enum"); + + return UserAuthName[auth_method]; +} diff --git a/src/backend/libpq/ifaddr.c b/src/backend/libpq/ifaddr.c new file mode 100644 index 0000000..75760f3 --- /dev/null +++ b/src/backend/libpq/ifaddr.c @@ -0,0 +1,594 @@ +/*------------------------------------------------------------------------- + * + * ifaddr.c + * IP netmask calculations, and enumerating network interfaces. + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * + * IDENTIFICATION + * src/backend/libpq/ifaddr.c + * + * This file and the IPV6 implementation were initially provided by + * Nigel Kukard <nkukard@lbsd.net>, Linux Based Systems Design + * http://www.lbsd.net. + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" + +#include <unistd.h> +#include <sys/stat.h> +#include <sys/socket.h> +#include <netdb.h> +#include <netinet/in.h> +#ifdef HAVE_NETINET_TCP_H +#include <netinet/tcp.h> +#endif +#include <sys/file.h> + +#include "libpq/ifaddr.h" +#include "port/pg_bswap.h" + +static int range_sockaddr_AF_INET(const struct sockaddr_in *addr, + const struct sockaddr_in *netaddr, + const struct sockaddr_in *netmask); + +#ifdef HAVE_IPV6 +static int range_sockaddr_AF_INET6(const struct sockaddr_in6 *addr, + const struct sockaddr_in6 *netaddr, + const struct sockaddr_in6 *netmask); +#endif + + +/* + * pg_range_sockaddr - is addr within the subnet specified by netaddr/netmask ? + * + * Note: caller must already have verified that all three addresses are + * in the same address family; and AF_UNIX addresses are not supported. + */ +int +pg_range_sockaddr(const struct sockaddr_storage *addr, + const struct sockaddr_storage *netaddr, + const struct sockaddr_storage *netmask) +{ + if (addr->ss_family == AF_INET) + return range_sockaddr_AF_INET((const struct sockaddr_in *) addr, + (const struct sockaddr_in *) netaddr, + (const struct sockaddr_in *) netmask); +#ifdef HAVE_IPV6 + else if (addr->ss_family == AF_INET6) + return range_sockaddr_AF_INET6((const struct sockaddr_in6 *) addr, + (const struct sockaddr_in6 *) netaddr, + (const struct sockaddr_in6 *) netmask); +#endif + else + return 0; +} + +static int +range_sockaddr_AF_INET(const struct sockaddr_in *addr, + const struct sockaddr_in *netaddr, + const struct sockaddr_in *netmask) +{ + if (((addr->sin_addr.s_addr ^ netaddr->sin_addr.s_addr) & + netmask->sin_addr.s_addr) == 0) + return 1; + else + return 0; +} + + +#ifdef HAVE_IPV6 + +static int +range_sockaddr_AF_INET6(const struct sockaddr_in6 *addr, + const struct sockaddr_in6 *netaddr, + const struct sockaddr_in6 *netmask) +{ + int i; + + for (i = 0; i < 16; i++) + { + if (((addr->sin6_addr.s6_addr[i] ^ netaddr->sin6_addr.s6_addr[i]) & + netmask->sin6_addr.s6_addr[i]) != 0) + return 0; + } + + return 1; +} +#endif /* HAVE_IPV6 */ + +/* + * pg_sockaddr_cidr_mask - make a network mask of the appropriate family + * and required number of significant bits + * + * numbits can be null, in which case the mask is fully set. + * + * The resulting mask is placed in *mask, which had better be big enough. + * + * Return value is 0 if okay, -1 if not. + */ +int +pg_sockaddr_cidr_mask(struct sockaddr_storage *mask, char *numbits, int family) +{ + long bits; + char *endptr; + + if (numbits == NULL) + { + bits = (family == AF_INET) ? 32 : 128; + } + else + { + bits = strtol(numbits, &endptr, 10); + if (*numbits == '\0' || *endptr != '\0') + return -1; + } + + switch (family) + { + case AF_INET: + { + struct sockaddr_in mask4; + long maskl; + + if (bits < 0 || bits > 32) + return -1; + memset(&mask4, 0, sizeof(mask4)); + /* avoid "x << 32", which is not portable */ + if (bits > 0) + maskl = (0xffffffffUL << (32 - (int) bits)) + & 0xffffffffUL; + else + maskl = 0; + mask4.sin_addr.s_addr = pg_hton32(maskl); + memcpy(mask, &mask4, sizeof(mask4)); + break; + } + +#ifdef HAVE_IPV6 + case AF_INET6: + { + struct sockaddr_in6 mask6; + int i; + + if (bits < 0 || bits > 128) + return -1; + memset(&mask6, 0, sizeof(mask6)); + for (i = 0; i < 16; i++) + { + if (bits <= 0) + mask6.sin6_addr.s6_addr[i] = 0; + else if (bits >= 8) + mask6.sin6_addr.s6_addr[i] = 0xff; + else + { + mask6.sin6_addr.s6_addr[i] = + (0xff << (8 - (int) bits)) & 0xff; + } + bits -= 8; + } + memcpy(mask, &mask6, sizeof(mask6)); + break; + } +#endif + default: + return -1; + } + + mask->ss_family = family; + return 0; +} + + +/* + * Run the callback function for the addr/mask, after making sure the + * mask is sane for the addr. + */ +static void +run_ifaddr_callback(PgIfAddrCallback callback, void *cb_data, + struct sockaddr *addr, struct sockaddr *mask) +{ + struct sockaddr_storage fullmask; + + if (!addr) + return; + + /* Check that the mask is valid */ + if (mask) + { + if (mask->sa_family != addr->sa_family) + { + mask = NULL; + } + else if (mask->sa_family == AF_INET) + { + if (((struct sockaddr_in *) mask)->sin_addr.s_addr == INADDR_ANY) + mask = NULL; + } +#ifdef HAVE_IPV6 + else if (mask->sa_family == AF_INET6) + { + if (IN6_IS_ADDR_UNSPECIFIED(&((struct sockaddr_in6 *) mask)->sin6_addr)) + mask = NULL; + } +#endif + } + + /* If mask is invalid, generate our own fully-set mask */ + if (!mask) + { + pg_sockaddr_cidr_mask(&fullmask, NULL, addr->sa_family); + mask = (struct sockaddr *) &fullmask; + } + + (*callback) (addr, mask, cb_data); +} + +#ifdef WIN32 + +#include <winsock2.h> +#include <ws2tcpip.h> + +/* + * Enumerate the system's network interface addresses and call the callback + * for each one. Returns 0 if successful, -1 if trouble. + * + * This version is for Win32. Uses the Winsock 2 functions (ie: ws2_32.dll) + */ +int +pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data) +{ + INTERFACE_INFO *ptr, + *ii = NULL; + unsigned long length, + i; + unsigned long n_ii = 0; + SOCKET sock; + int error; + + sock = WSASocket(AF_INET, SOCK_DGRAM, 0, 0, 0, 0); + if (sock == INVALID_SOCKET) + return -1; + + while (n_ii < 1024) + { + n_ii += 64; + ptr = realloc(ii, sizeof(INTERFACE_INFO) * n_ii); + if (!ptr) + { + free(ii); + closesocket(sock); + errno = ENOMEM; + return -1; + } + + ii = ptr; + if (WSAIoctl(sock, SIO_GET_INTERFACE_LIST, 0, 0, + ii, n_ii * sizeof(INTERFACE_INFO), + &length, 0, 0) == SOCKET_ERROR) + { + error = WSAGetLastError(); + if (error == WSAEFAULT || error == WSAENOBUFS) + continue; /* need to make the buffer bigger */ + closesocket(sock); + free(ii); + return -1; + } + + break; + } + + for (i = 0; i < length / sizeof(INTERFACE_INFO); ++i) + run_ifaddr_callback(callback, cb_data, + (struct sockaddr *) &ii[i].iiAddress, + (struct sockaddr *) &ii[i].iiNetmask); + + closesocket(sock); + free(ii); + return 0; +} +#elif HAVE_GETIFADDRS /* && !WIN32 */ + +#ifdef HAVE_IFADDRS_H +#include <ifaddrs.h> +#endif + +/* + * Enumerate the system's network interface addresses and call the callback + * for each one. Returns 0 if successful, -1 if trouble. + * + * This version uses the getifaddrs() interface, which is available on + * BSDs, AIX, and modern Linux. + */ +int +pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data) +{ + struct ifaddrs *ifa, + *l; + + if (getifaddrs(&ifa) < 0) + return -1; + + for (l = ifa; l; l = l->ifa_next) + run_ifaddr_callback(callback, cb_data, + l->ifa_addr, l->ifa_netmask); + + freeifaddrs(ifa); + return 0; +} +#else /* !HAVE_GETIFADDRS && !WIN32 */ + +#include <sys/ioctl.h> + +#ifdef HAVE_NET_IF_H +#include <net/if.h> +#endif + +#ifdef HAVE_SYS_SOCKIO_H +#include <sys/sockio.h> +#endif + +/* + * SIOCGIFCONF does not return IPv6 addresses on Solaris + * and HP/UX. So we prefer SIOCGLIFCONF if it's available. + * + * On HP/UX, however, it *only* returns IPv6 addresses, + * and the structs are named slightly differently too. + * We'd have to do another call with SIOCGIFCONF to get the + * IPv4 addresses as well. We don't currently bother, just + * fall back to SIOCGIFCONF on HP/UX. + */ + +#if defined(SIOCGLIFCONF) && !defined(__hpux) + +/* + * Enumerate the system's network interface addresses and call the callback + * for each one. Returns 0 if successful, -1 if trouble. + * + * This version uses ioctl(SIOCGLIFCONF). + */ +int +pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data) +{ + struct lifconf lifc; + struct lifreq *lifr, + lmask; + struct sockaddr *addr, + *mask; + char *ptr, + *buffer = NULL; + size_t n_buffer = 1024; + pgsocket sock, + fd; + +#ifdef HAVE_IPV6 + pgsocket sock6; +#endif + int i, + total; + + sock = socket(AF_INET, SOCK_DGRAM, 0); + if (sock == PGINVALID_SOCKET) + return -1; + + while (n_buffer < 1024 * 100) + { + n_buffer += 1024; + ptr = realloc(buffer, n_buffer); + if (!ptr) + { + free(buffer); + close(sock); + errno = ENOMEM; + return -1; + } + + memset(&lifc, 0, sizeof(lifc)); + lifc.lifc_family = AF_UNSPEC; + lifc.lifc_buf = buffer = ptr; + lifc.lifc_len = n_buffer; + + if (ioctl(sock, SIOCGLIFCONF, &lifc) < 0) + { + if (errno == EINVAL) + continue; + free(buffer); + close(sock); + return -1; + } + + /* + * Some Unixes try to return as much data as possible, with no + * indication of whether enough space allocated. Don't believe we have + * it all unless there's lots of slop. + */ + if (lifc.lifc_len < n_buffer - 1024) + break; + } + +#ifdef HAVE_IPV6 + /* We'll need an IPv6 socket too for the SIOCGLIFNETMASK ioctls */ + sock6 = socket(AF_INET6, SOCK_DGRAM, 0); + if (sock6 == PGINVALID_SOCKET) + { + free(buffer); + close(sock); + return -1; + } +#endif + + total = lifc.lifc_len / sizeof(struct lifreq); + lifr = lifc.lifc_req; + for (i = 0; i < total; ++i) + { + addr = (struct sockaddr *) &lifr[i].lifr_addr; + memcpy(&lmask, &lifr[i], sizeof(struct lifreq)); +#ifdef HAVE_IPV6 + fd = (addr->sa_family == AF_INET6) ? sock6 : sock; +#else + fd = sock; +#endif + if (ioctl(fd, SIOCGLIFNETMASK, &lmask) < 0) + mask = NULL; + else + mask = (struct sockaddr *) &lmask.lifr_addr; + run_ifaddr_callback(callback, cb_data, addr, mask); + } + + free(buffer); + close(sock); +#ifdef HAVE_IPV6 + close(sock6); +#endif + return 0; +} +#elif defined(SIOCGIFCONF) + +/* + * Remaining Unixes use SIOCGIFCONF. Some only return IPv4 information + * here, so this is the least preferred method. Note that there is no + * standard way to iterate the struct ifreq returned in the array. + * On some OSs the structures are padded large enough for any address, + * on others you have to calculate the size of the struct ifreq. + */ + +/* Some OSs have _SIZEOF_ADDR_IFREQ, so just use that */ +#ifndef _SIZEOF_ADDR_IFREQ + +/* Calculate based on sockaddr.sa_len */ +#ifdef HAVE_STRUCT_SOCKADDR_SA_LEN +#define _SIZEOF_ADDR_IFREQ(ifr) \ + ((ifr).ifr_addr.sa_len > sizeof(struct sockaddr) ? \ + (sizeof(struct ifreq) - sizeof(struct sockaddr) + \ + (ifr).ifr_addr.sa_len) : sizeof(struct ifreq)) + +/* Padded ifreq structure, simple */ +#else +#define _SIZEOF_ADDR_IFREQ(ifr) \ + sizeof (struct ifreq) +#endif +#endif /* !_SIZEOF_ADDR_IFREQ */ + +/* + * Enumerate the system's network interface addresses and call the callback + * for each one. Returns 0 if successful, -1 if trouble. + * + * This version uses ioctl(SIOCGIFCONF). + */ +int +pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data) +{ + struct ifconf ifc; + struct ifreq *ifr, + *end, + addr, + mask; + char *ptr, + *buffer = NULL; + size_t n_buffer = 1024; + pgsocket sock; + + sock = socket(AF_INET, SOCK_DGRAM, 0); + if (sock == PGINVALID_SOCKET) + return -1; + + while (n_buffer < 1024 * 100) + { + n_buffer += 1024; + ptr = realloc(buffer, n_buffer); + if (!ptr) + { + free(buffer); + close(sock); + errno = ENOMEM; + return -1; + } + + memset(&ifc, 0, sizeof(ifc)); + ifc.ifc_buf = buffer = ptr; + ifc.ifc_len = n_buffer; + + if (ioctl(sock, SIOCGIFCONF, &ifc) < 0) + { + if (errno == EINVAL) + continue; + free(buffer); + close(sock); + return -1; + } + + /* + * Some Unixes try to return as much data as possible, with no + * indication of whether enough space allocated. Don't believe we have + * it all unless there's lots of slop. + */ + if (ifc.ifc_len < n_buffer - 1024) + break; + } + + end = (struct ifreq *) (buffer + ifc.ifc_len); + for (ifr = ifc.ifc_req; ifr < end;) + { + memcpy(&addr, ifr, sizeof(addr)); + memcpy(&mask, ifr, sizeof(mask)); + if (ioctl(sock, SIOCGIFADDR, &addr, sizeof(addr)) == 0 && + ioctl(sock, SIOCGIFNETMASK, &mask, sizeof(mask)) == 0) + run_ifaddr_callback(callback, cb_data, + &addr.ifr_addr, &mask.ifr_addr); + ifr = (struct ifreq *) ((char *) ifr + _SIZEOF_ADDR_IFREQ(*ifr)); + } + + free(buffer); + close(sock); + return 0; +} +#else /* !defined(SIOCGIFCONF) */ + +/* + * Enumerate the system's network interface addresses and call the callback + * for each one. Returns 0 if successful, -1 if trouble. + * + * This version is our fallback if there's no known way to get the + * interface addresses. Just return the standard loopback addresses. + */ +int +pg_foreach_ifaddr(PgIfAddrCallback callback, void *cb_data) +{ + struct sockaddr_in addr; + struct sockaddr_storage mask; + +#ifdef HAVE_IPV6 + struct sockaddr_in6 addr6; +#endif + + /* addr 127.0.0.1/8 */ + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = pg_ntoh32(0x7f000001); + memset(&mask, 0, sizeof(mask)); + pg_sockaddr_cidr_mask(&mask, "8", AF_INET); + run_ifaddr_callback(callback, cb_data, + (struct sockaddr *) &addr, + (struct sockaddr *) &mask); + +#ifdef HAVE_IPV6 + /* addr ::1/128 */ + memset(&addr6, 0, sizeof(addr6)); + addr6.sin6_family = AF_INET6; + addr6.sin6_addr.s6_addr[15] = 1; + memset(&mask, 0, sizeof(mask)); + pg_sockaddr_cidr_mask(&mask, "128", AF_INET6); + run_ifaddr_callback(callback, cb_data, + (struct sockaddr *) &addr6, + (struct sockaddr *) &mask); +#endif + + return 0; +} +#endif /* !defined(SIOCGIFCONF) */ + +#endif /* !HAVE_GETIFADDRS */ diff --git a/src/backend/libpq/pg_hba.conf.sample b/src/backend/libpq/pg_hba.conf.sample new file mode 100644 index 0000000..5f3f63e --- /dev/null +++ b/src/backend/libpq/pg_hba.conf.sample @@ -0,0 +1,94 @@ +# PostgreSQL Client Authentication Configuration File +# =================================================== +# +# Refer to the "Client Authentication" section in the PostgreSQL +# documentation for a complete description of this file. A short +# synopsis follows. +# +# This file controls: which hosts are allowed to connect, how clients +# are authenticated, which PostgreSQL user names they can use, which +# databases they can access. Records take one of these forms: +# +# local DATABASE USER METHOD [OPTIONS] +# host DATABASE USER ADDRESS METHOD [OPTIONS] +# hostssl DATABASE USER ADDRESS METHOD [OPTIONS] +# hostnossl DATABASE USER ADDRESS METHOD [OPTIONS] +# hostgssenc DATABASE USER ADDRESS METHOD [OPTIONS] +# hostnogssenc DATABASE USER ADDRESS METHOD [OPTIONS] +# +# (The uppercase items must be replaced by actual values.) +# +# The first field is the connection type: +# - "local" is a Unix-domain socket +# - "host" is a TCP/IP socket (encrypted or not) +# - "hostssl" is a TCP/IP socket that is SSL-encrypted +# - "hostnossl" is a TCP/IP socket that is not SSL-encrypted +# - "hostgssenc" is a TCP/IP socket that is GSSAPI-encrypted +# - "hostnogssenc" is a TCP/IP socket that is not GSSAPI-encrypted +# +# DATABASE can be "all", "sameuser", "samerole", "replication", a +# database name, or a comma-separated list thereof. The "all" +# keyword does not match "replication". Access to replication +# must be enabled in a separate record (see example below). +# +# USER can be "all", a user name, a group name prefixed with "+", or a +# comma-separated list thereof. In both the DATABASE and USER fields +# you can also write a file name prefixed with "@" to include names +# from a separate file. +# +# ADDRESS specifies the set of hosts the record matches. It can be a +# host name, or it is made up of an IP address and a CIDR mask that is +# an integer (between 0 and 32 (IPv4) or 128 (IPv6) inclusive) that +# specifies the number of significant bits in the mask. A host name +# that starts with a dot (.) matches a suffix of the actual host name. +# Alternatively, you can write an IP address and netmask in separate +# columns to specify the set of hosts. Instead of a CIDR-address, you +# can write "samehost" to match any of the server's own IP addresses, +# or "samenet" to match any address in any subnet that the server is +# directly connected to. +# +# METHOD can be "trust", "reject", "md5", "password", "scram-sha-256", +# "gss", "sspi", "ident", "peer", "pam", "ldap", "radius" or "cert". +# Note that "password" sends passwords in clear text; "md5" or +# "scram-sha-256" are preferred since they send encrypted passwords. +# +# OPTIONS are a set of options for the authentication in the format +# NAME=VALUE. The available options depend on the different +# authentication methods -- refer to the "Client Authentication" +# section in the documentation for a list of which options are +# available for which authentication methods. +# +# Database and user names containing spaces, commas, quotes and other +# special characters must be quoted. Quoting one of the keywords +# "all", "sameuser", "samerole" or "replication" makes the name lose +# its special character, and just match a database or username with +# that name. +# +# This file is read on server startup and when the server receives a +# SIGHUP signal. If you edit the file on a running system, you have to +# SIGHUP the server for the changes to take effect, run "pg_ctl reload", +# or execute "SELECT pg_reload_conf()". +# +# Put your actual configuration here +# ---------------------------------- +# +# If you want to allow non-local connections, you need to add more +# "host" records. In that case you will also need to make PostgreSQL +# listen on a non-local interface via the listen_addresses +# configuration parameter, or via the -i or -h command line switches. + +@authcomment@ + +# TYPE DATABASE USER ADDRESS METHOD + +@remove-line-for-nolocal@# "local" is for Unix domain socket connections only +@remove-line-for-nolocal@local all all @authmethodlocal@ +# IPv4 local connections: +host all all 127.0.0.1/32 @authmethodhost@ +# IPv6 local connections: +host all all ::1/128 @authmethodhost@ +# Allow replication connections from localhost, by a user with the +# replication privilege. +@remove-line-for-nolocal@local replication all @authmethodlocal@ +host replication all 127.0.0.1/32 @authmethodhost@ +host replication all ::1/128 @authmethodhost@ diff --git a/src/backend/libpq/pg_ident.conf.sample b/src/backend/libpq/pg_ident.conf.sample new file mode 100644 index 0000000..a5870e6 --- /dev/null +++ b/src/backend/libpq/pg_ident.conf.sample @@ -0,0 +1,42 @@ +# PostgreSQL User Name Maps +# ========================= +# +# Refer to the PostgreSQL documentation, chapter "Client +# Authentication" for a complete description. A short synopsis +# follows. +# +# This file controls PostgreSQL user name mapping. It maps external +# user names to their corresponding PostgreSQL user names. Records +# are of the form: +# +# MAPNAME SYSTEM-USERNAME PG-USERNAME +# +# (The uppercase quantities must be replaced by actual values.) +# +# MAPNAME is the (otherwise freely chosen) map name that was used in +# pg_hba.conf. SYSTEM-USERNAME is the detected user name of the +# client. PG-USERNAME is the requested PostgreSQL user name. The +# existence of a record specifies that SYSTEM-USERNAME may connect as +# PG-USERNAME. +# +# If SYSTEM-USERNAME starts with a slash (/), it will be treated as a +# regular expression. Optionally this can contain a capture (a +# parenthesized subexpression). The substring matching the capture +# will be substituted for \1 (backslash-one) if present in +# PG-USERNAME. +# +# Multiple maps may be specified in this file and used by pg_hba.conf. +# +# No map names are defined in the default configuration. If all +# system user names and PostgreSQL user names are the same, you don't +# need anything in this file. +# +# This file is read on server startup and when the postmaster receives +# a SIGHUP signal. If you edit the file on a running system, you have +# to SIGHUP the postmaster for the changes to take effect. You can +# use "pg_ctl reload" to do that. + +# Put your actual configuration here +# ---------------------------------- + +# MAPNAME SYSTEM-USERNAME PG-USERNAME diff --git a/src/backend/libpq/pqcomm.c b/src/backend/libpq/pqcomm.c new file mode 100644 index 0000000..44782f2 --- /dev/null +++ b/src/backend/libpq/pqcomm.c @@ -0,0 +1,1976 @@ +/*------------------------------------------------------------------------- + * + * pqcomm.c + * Communication functions between the Frontend and the Backend + * + * These routines handle the low-level details of communication between + * frontend and backend. They just shove data across the communication + * channel, and are ignorant of the semantics of the data. + * + * To emit an outgoing message, use the routines in pqformat.c to construct + * the message in a buffer and then emit it in one call to pq_putmessage. + * There are no functions to send raw bytes or partial messages; this + * ensures that the channel will not be clogged by an incomplete message if + * execution is aborted by ereport(ERROR) partway through the message. + * + * At one time, libpq was shared between frontend and backend, but now + * the backend's "backend/libpq" is quite separate from "interfaces/libpq". + * All that remains is similarities of names to trap the unwary... + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * src/backend/libpq/pqcomm.c + * + *------------------------------------------------------------------------- + */ + +/*------------------------ + * INTERFACE ROUTINES + * + * setup/teardown: + * StreamServerPort - Open postmaster's server port + * StreamConnection - Create new connection with client + * StreamClose - Close a client/backend connection + * TouchSocketFiles - Protect socket files against /tmp cleaners + * pq_init - initialize libpq at backend startup + * socket_comm_reset - reset libpq during error recovery + * socket_close - shutdown libpq at backend exit + * + * low-level I/O: + * pq_getbytes - get a known number of bytes from connection + * pq_getmessage - get a message with length word from connection + * pq_getbyte - get next byte from connection + * pq_peekbyte - peek at next byte from connection + * pq_flush - flush pending output + * pq_flush_if_writable - flush pending output if writable without blocking + * pq_getbyte_if_available - get a byte if available without blocking + * + * message-level I/O + * pq_putmessage - send a normal message (suppressed in COPY OUT mode) + * pq_putmessage_noblock - buffer a normal message (suppressed in COPY OUT) + * + *------------------------ + */ +#include "postgres.h" + +#ifdef HAVE_POLL_H +#include <poll.h> +#endif +#include <signal.h> +#include <fcntl.h> +#include <grp.h> +#include <unistd.h> +#include <sys/file.h> +#include <sys/socket.h> +#include <sys/stat.h> +#include <sys/time.h> +#include <netdb.h> +#include <netinet/in.h> +#ifdef HAVE_NETINET_TCP_H +#include <netinet/tcp.h> +#endif +#include <utime.h> +#ifdef _MSC_VER /* mstcpip.h is missing on mingw */ +#include <mstcpip.h> +#endif + +#include "common/ip.h" +#include "libpq/libpq.h" +#include "miscadmin.h" +#include "port/pg_bswap.h" +#include "storage/ipc.h" +#include "utils/guc.h" +#include "utils/memutils.h" + +/* + * Cope with the various platform-specific ways to spell TCP keepalive socket + * options. This doesn't cover Windows, which as usual does its own thing. + */ +#if defined(TCP_KEEPIDLE) +/* TCP_KEEPIDLE is the name of this option on Linux and *BSD */ +#define PG_TCP_KEEPALIVE_IDLE TCP_KEEPIDLE +#define PG_TCP_KEEPALIVE_IDLE_STR "TCP_KEEPIDLE" +#elif defined(TCP_KEEPALIVE_THRESHOLD) +/* TCP_KEEPALIVE_THRESHOLD is the name of this option on Solaris >= 11 */ +#define PG_TCP_KEEPALIVE_IDLE TCP_KEEPALIVE_THRESHOLD +#define PG_TCP_KEEPALIVE_IDLE_STR "TCP_KEEPALIVE_THRESHOLD" +#elif defined(TCP_KEEPALIVE) && defined(__darwin__) +/* TCP_KEEPALIVE is the name of this option on macOS */ +/* Caution: Solaris has this symbol but it means something different */ +#define PG_TCP_KEEPALIVE_IDLE TCP_KEEPALIVE +#define PG_TCP_KEEPALIVE_IDLE_STR "TCP_KEEPALIVE" +#endif + +/* + * Configuration options + */ +int Unix_socket_permissions; +char *Unix_socket_group; + +/* Where the Unix socket files are (list of palloc'd strings) */ +static List *sock_paths = NIL; + +/* + * Buffers for low-level I/O. + * + * The receive buffer is fixed size. Send buffer is usually 8k, but can be + * enlarged by pq_putmessage_noblock() if the message doesn't fit otherwise. + */ + +#define PQ_SEND_BUFFER_SIZE 8192 +#define PQ_RECV_BUFFER_SIZE 8192 + +static char *PqSendBuffer; +static int PqSendBufferSize; /* Size send buffer */ +static int PqSendPointer; /* Next index to store a byte in PqSendBuffer */ +static int PqSendStart; /* Next index to send a byte in PqSendBuffer */ + +static char PqRecvBuffer[PQ_RECV_BUFFER_SIZE]; +static int PqRecvPointer; /* Next index to read a byte from PqRecvBuffer */ +static int PqRecvLength; /* End of data available in PqRecvBuffer */ + +/* + * Message status + */ +static bool PqCommBusy; /* busy sending data to the client */ +static bool PqCommReadingMsg; /* in the middle of reading a message */ + + +/* Internal functions */ +static void socket_comm_reset(void); +static void socket_close(int code, Datum arg); +static void socket_set_nonblocking(bool nonblocking); +static int socket_flush(void); +static int socket_flush_if_writable(void); +static bool socket_is_send_pending(void); +static int socket_putmessage(char msgtype, const char *s, size_t len); +static void socket_putmessage_noblock(char msgtype, const char *s, size_t len); +static int internal_putbytes(const char *s, size_t len); +static int internal_flush(void); + +#ifdef HAVE_UNIX_SOCKETS +static int Lock_AF_UNIX(const char *unixSocketDir, const char *unixSocketPath); +static int Setup_AF_UNIX(const char *sock_path); +#endif /* HAVE_UNIX_SOCKETS */ + +static const PQcommMethods PqCommSocketMethods = { + socket_comm_reset, + socket_flush, + socket_flush_if_writable, + socket_is_send_pending, + socket_putmessage, + socket_putmessage_noblock +}; + +const PQcommMethods *PqCommMethods = &PqCommSocketMethods; + +WaitEventSet *FeBeWaitSet; + + +/* -------------------------------- + * pq_init - initialize libpq at backend startup + * -------------------------------- + */ +void +pq_init(void) +{ + int socket_pos PG_USED_FOR_ASSERTS_ONLY; + int latch_pos PG_USED_FOR_ASSERTS_ONLY; + + /* initialize state variables */ + PqSendBufferSize = PQ_SEND_BUFFER_SIZE; + PqSendBuffer = MemoryContextAlloc(TopMemoryContext, PqSendBufferSize); + PqSendPointer = PqSendStart = PqRecvPointer = PqRecvLength = 0; + PqCommBusy = false; + PqCommReadingMsg = false; + + /* set up process-exit hook to close the socket */ + on_proc_exit(socket_close, 0); + + /* + * In backends (as soon as forked) we operate the underlying socket in + * nonblocking mode and use latches to implement blocking semantics if + * needed. That allows us to provide safely interruptible reads and + * writes. + * + * Use COMMERROR on failure, because ERROR would try to send the error to + * the client, which might require changing the mode again, leading to + * infinite recursion. + */ +#ifndef WIN32 + if (!pg_set_noblock(MyProcPort->sock)) + ereport(COMMERROR, + (errmsg("could not set socket to nonblocking mode: %m"))); +#endif + + FeBeWaitSet = CreateWaitEventSet(TopMemoryContext, 3); + socket_pos = AddWaitEventToSet(FeBeWaitSet, WL_SOCKET_WRITEABLE, + MyProcPort->sock, NULL, NULL); + latch_pos = AddWaitEventToSet(FeBeWaitSet, WL_LATCH_SET, PGINVALID_SOCKET, + MyLatch, NULL); + AddWaitEventToSet(FeBeWaitSet, WL_POSTMASTER_DEATH, PGINVALID_SOCKET, + NULL, NULL); + + /* + * The event positions match the order we added them, but let's sanity + * check them to be sure. + */ + Assert(socket_pos == FeBeWaitSetSocketPos); + Assert(latch_pos == FeBeWaitSetLatchPos); +} + +/* -------------------------------- + * socket_comm_reset - reset libpq during error recovery + * + * This is called from error recovery at the outer idle loop. It's + * just to get us out of trouble if we somehow manage to elog() from + * inside a pqcomm.c routine (which ideally will never happen, but...) + * -------------------------------- + */ +static void +socket_comm_reset(void) +{ + /* Do not throw away pending data, but do reset the busy flag */ + PqCommBusy = false; +} + +/* -------------------------------- + * socket_close - shutdown libpq at backend exit + * + * This is the one pg_on_exit_callback in place during BackendInitialize(). + * That function's unusual signal handling constrains that this callback be + * safe to run at any instant. + * -------------------------------- + */ +static void +socket_close(int code, Datum arg) +{ + /* Nothing to do in a standalone backend, where MyProcPort is NULL. */ + if (MyProcPort != NULL) + { +#ifdef ENABLE_GSS + /* + * Shutdown GSSAPI layer. This section does nothing when interrupting + * BackendInitialize(), because pg_GSS_recvauth() makes first use of + * "ctx" and "cred". + * + * Note that we don't bother to free MyProcPort->gss, since we're + * about to exit anyway. + */ + if (MyProcPort->gss) + { + OM_uint32 min_s; + + if (MyProcPort->gss->ctx != GSS_C_NO_CONTEXT) + gss_delete_sec_context(&min_s, &MyProcPort->gss->ctx, NULL); + + if (MyProcPort->gss->cred != GSS_C_NO_CREDENTIAL) + gss_release_cred(&min_s, &MyProcPort->gss->cred); + } +#endif /* ENABLE_GSS */ + + /* + * Cleanly shut down SSL layer. Nowhere else does a postmaster child + * call this, so this is safe when interrupting BackendInitialize(). + */ + secure_close(MyProcPort); + + /* + * Formerly we did an explicit close() here, but it seems better to + * leave the socket open until the process dies. This allows clients + * to perform a "synchronous close" if they care --- wait till the + * transport layer reports connection closure, and you can be sure the + * backend has exited. + * + * We do set sock to PGINVALID_SOCKET to prevent any further I/O, + * though. + */ + MyProcPort->sock = PGINVALID_SOCKET; + } +} + + + +/* + * Streams -- wrapper around Unix socket system calls + * + * + * Stream functions are used for vanilla TCP connection protocol. + */ + + +/* + * StreamServerPort -- open a "listening" port to accept connections. + * + * family should be AF_UNIX or AF_UNSPEC; portNumber is the port number. + * For AF_UNIX ports, hostName should be NULL and unixSocketDir must be + * specified. For TCP ports, hostName is either NULL for all interfaces or + * the interface to listen on, and unixSocketDir is ignored (can be NULL). + * + * Successfully opened sockets are added to the ListenSocket[] array (of + * length MaxListen), at the first position that isn't PGINVALID_SOCKET. + * + * RETURNS: STATUS_OK or STATUS_ERROR + */ + +int +StreamServerPort(int family, const char *hostName, unsigned short portNumber, + const char *unixSocketDir, + pgsocket ListenSocket[], int MaxListen) +{ + pgsocket fd; + int err; + int maxconn; + int ret; + char portNumberStr[32]; + const char *familyDesc; + char familyDescBuf[64]; + const char *addrDesc; + char addrBuf[NI_MAXHOST]; + char *service; + struct addrinfo *addrs = NULL, + *addr; + struct addrinfo hint; + int listen_index = 0; + int added = 0; + +#ifdef HAVE_UNIX_SOCKETS + char unixSocketPath[MAXPGPATH]; +#endif +#if !defined(WIN32) || defined(IPV6_V6ONLY) + int one = 1; +#endif + + /* Initialize hint structure */ + MemSet(&hint, 0, sizeof(hint)); + hint.ai_family = family; + hint.ai_flags = AI_PASSIVE; + hint.ai_socktype = SOCK_STREAM; + +#ifdef HAVE_UNIX_SOCKETS + if (family == AF_UNIX) + { + /* + * Create unixSocketPath from portNumber and unixSocketDir and lock + * that file path + */ + UNIXSOCK_PATH(unixSocketPath, portNumber, unixSocketDir); + if (strlen(unixSocketPath) >= UNIXSOCK_PATH_BUFLEN) + { + ereport(LOG, + (errmsg("Unix-domain socket path \"%s\" is too long (maximum %d bytes)", + unixSocketPath, + (int) (UNIXSOCK_PATH_BUFLEN - 1)))); + return STATUS_ERROR; + } + if (Lock_AF_UNIX(unixSocketDir, unixSocketPath) != STATUS_OK) + return STATUS_ERROR; + service = unixSocketPath; + } + else +#endif /* HAVE_UNIX_SOCKETS */ + { + snprintf(portNumberStr, sizeof(portNumberStr), "%d", portNumber); + service = portNumberStr; + } + + ret = pg_getaddrinfo_all(hostName, service, &hint, &addrs); + if (ret || !addrs) + { + if (hostName) + ereport(LOG, + (errmsg("could not translate host name \"%s\", service \"%s\" to address: %s", + hostName, service, gai_strerror(ret)))); + else + ereport(LOG, + (errmsg("could not translate service \"%s\" to address: %s", + service, gai_strerror(ret)))); + if (addrs) + pg_freeaddrinfo_all(hint.ai_family, addrs); + return STATUS_ERROR; + } + + for (addr = addrs; addr; addr = addr->ai_next) + { + if (!IS_AF_UNIX(family) && IS_AF_UNIX(addr->ai_family)) + { + /* + * Only set up a unix domain socket when they really asked for it. + * The service/port is different in that case. + */ + continue; + } + + /* See if there is still room to add 1 more socket. */ + for (; listen_index < MaxListen; listen_index++) + { + if (ListenSocket[listen_index] == PGINVALID_SOCKET) + break; + } + if (listen_index >= MaxListen) + { + ereport(LOG, + (errmsg("could not bind to all requested addresses: MAXLISTEN (%d) exceeded", + MaxListen))); + break; + } + + /* set up address family name for log messages */ + switch (addr->ai_family) + { + case AF_INET: + familyDesc = _("IPv4"); + break; +#ifdef HAVE_IPV6 + case AF_INET6: + familyDesc = _("IPv6"); + break; +#endif +#ifdef HAVE_UNIX_SOCKETS + case AF_UNIX: + familyDesc = _("Unix"); + break; +#endif + default: + snprintf(familyDescBuf, sizeof(familyDescBuf), + _("unrecognized address family %d"), + addr->ai_family); + familyDesc = familyDescBuf; + break; + } + + /* set up text form of address for log messages */ +#ifdef HAVE_UNIX_SOCKETS + if (addr->ai_family == AF_UNIX) + addrDesc = unixSocketPath; + else +#endif + { + pg_getnameinfo_all((const struct sockaddr_storage *) addr->ai_addr, + addr->ai_addrlen, + addrBuf, sizeof(addrBuf), + NULL, 0, + NI_NUMERICHOST); + addrDesc = addrBuf; + } + + if ((fd = socket(addr->ai_family, SOCK_STREAM, 0)) == PGINVALID_SOCKET) + { + ereport(LOG, + (errcode_for_socket_access(), + /* translator: first %s is IPv4, IPv6, or Unix */ + errmsg("could not create %s socket for address \"%s\": %m", + familyDesc, addrDesc))); + continue; + } + +#ifndef WIN32 + + /* + * Without the SO_REUSEADDR flag, a new postmaster can't be started + * right away after a stop or crash, giving "address already in use" + * error on TCP ports. + * + * On win32, however, this behavior only happens if the + * SO_EXCLUSIVEADDRUSE is set. With SO_REUSEADDR, win32 allows + * multiple servers to listen on the same address, resulting in + * unpredictable behavior. With no flags at all, win32 behaves as Unix + * with SO_REUSEADDR. + */ + if (!IS_AF_UNIX(addr->ai_family)) + { + if ((setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, + (char *) &one, sizeof(one))) == -1) + { + ereport(LOG, + (errcode_for_socket_access(), + /* translator: third %s is IPv4, IPv6, or Unix */ + errmsg("%s(%s) failed for %s address \"%s\": %m", + "setsockopt", "SO_REUSEADDR", + familyDesc, addrDesc))); + closesocket(fd); + continue; + } + } +#endif + +#ifdef IPV6_V6ONLY + if (addr->ai_family == AF_INET6) + { + if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, + (char *) &one, sizeof(one)) == -1) + { + ereport(LOG, + (errcode_for_socket_access(), + /* translator: third %s is IPv4, IPv6, or Unix */ + errmsg("%s(%s) failed for %s address \"%s\": %m", + "setsockopt", "IPV6_V6ONLY", + familyDesc, addrDesc))); + closesocket(fd); + continue; + } + } +#endif + + /* + * Note: This might fail on some OS's, like Linux older than + * 2.4.21-pre3, that don't have the IPV6_V6ONLY socket option, and map + * ipv4 addresses to ipv6. It will show ::ffff:ipv4 for all ipv4 + * connections. + */ + err = bind(fd, addr->ai_addr, addr->ai_addrlen); + if (err < 0) + { + int saved_errno = errno; + + ereport(LOG, + (errcode_for_socket_access(), + /* translator: first %s is IPv4, IPv6, or Unix */ + errmsg("could not bind %s address \"%s\": %m", + familyDesc, addrDesc), + saved_errno == EADDRINUSE ? + (IS_AF_UNIX(addr->ai_family) ? + errhint("Is another postmaster already running on port %d?", + (int) portNumber) : + errhint("Is another postmaster already running on port %d?" + " If not, wait a few seconds and retry.", + (int) portNumber)) : 0)); + closesocket(fd); + continue; + } + +#ifdef HAVE_UNIX_SOCKETS + if (addr->ai_family == AF_UNIX) + { + if (Setup_AF_UNIX(service) != STATUS_OK) + { + closesocket(fd); + break; + } + } +#endif + + /* + * Select appropriate accept-queue length limit. PG_SOMAXCONN is only + * intended to provide a clamp on the request on platforms where an + * overly large request provokes a kernel error (are there any?). + */ + maxconn = MaxBackends * 2; + if (maxconn > PG_SOMAXCONN) + maxconn = PG_SOMAXCONN; + + err = listen(fd, maxconn); + if (err < 0) + { + ereport(LOG, + (errcode_for_socket_access(), + /* translator: first %s is IPv4, IPv6, or Unix */ + errmsg("could not listen on %s address \"%s\": %m", + familyDesc, addrDesc))); + closesocket(fd); + continue; + } + +#ifdef HAVE_UNIX_SOCKETS + if (addr->ai_family == AF_UNIX) + ereport(LOG, + (errmsg("listening on Unix socket \"%s\"", + addrDesc))); + else +#endif + ereport(LOG, + /* translator: first %s is IPv4 or IPv6 */ + (errmsg("listening on %s address \"%s\", port %d", + familyDesc, addrDesc, (int) portNumber))); + + ListenSocket[listen_index] = fd; + added++; + } + + pg_freeaddrinfo_all(hint.ai_family, addrs); + + if (!added) + return STATUS_ERROR; + + return STATUS_OK; +} + + +#ifdef HAVE_UNIX_SOCKETS + +/* + * Lock_AF_UNIX -- configure unix socket file path + */ +static int +Lock_AF_UNIX(const char *unixSocketDir, const char *unixSocketPath) +{ + /* no lock file for abstract sockets */ + if (unixSocketPath[0] == '@') + return STATUS_OK; + + /* + * Grab an interlock file associated with the socket file. + * + * Note: there are two reasons for using a socket lock file, rather than + * trying to interlock directly on the socket itself. First, it's a lot + * more portable, and second, it lets us remove any pre-existing socket + * file without race conditions. + */ + CreateSocketLockFile(unixSocketPath, true, unixSocketDir); + + /* + * Once we have the interlock, we can safely delete any pre-existing + * socket file to avoid failure at bind() time. + */ + (void) unlink(unixSocketPath); + + /* + * Remember socket file pathnames for later maintenance. + */ + sock_paths = lappend(sock_paths, pstrdup(unixSocketPath)); + + return STATUS_OK; +} + + +/* + * Setup_AF_UNIX -- configure unix socket permissions + */ +static int +Setup_AF_UNIX(const char *sock_path) +{ + /* no file system permissions for abstract sockets */ + if (sock_path[0] == '@') + return STATUS_OK; + + /* + * Fix socket ownership/permission if requested. Note we must do this + * before we listen() to avoid a window where unwanted connections could + * get accepted. + */ + Assert(Unix_socket_group); + if (Unix_socket_group[0] != '\0') + { +#ifdef WIN32 + elog(WARNING, "configuration item unix_socket_group is not supported on this platform"); +#else + char *endptr; + unsigned long val; + gid_t gid; + + val = strtoul(Unix_socket_group, &endptr, 10); + if (*endptr == '\0') + { /* numeric group id */ + gid = val; + } + else + { /* convert group name to id */ + struct group *gr; + + gr = getgrnam(Unix_socket_group); + if (!gr) + { + ereport(LOG, + (errmsg("group \"%s\" does not exist", + Unix_socket_group))); + return STATUS_ERROR; + } + gid = gr->gr_gid; + } + if (chown(sock_path, -1, gid) == -1) + { + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not set group of file \"%s\": %m", + sock_path))); + return STATUS_ERROR; + } +#endif + } + + if (chmod(sock_path, Unix_socket_permissions) == -1) + { + ereport(LOG, + (errcode_for_file_access(), + errmsg("could not set permissions of file \"%s\": %m", + sock_path))); + return STATUS_ERROR; + } + return STATUS_OK; +} +#endif /* HAVE_UNIX_SOCKETS */ + + +/* + * StreamConnection -- create a new connection with client using + * server port. Set port->sock to the FD of the new connection. + * + * ASSUME: that this doesn't need to be non-blocking because + * the Postmaster uses select() to tell when the socket is ready for + * accept(). + * + * RETURNS: STATUS_OK or STATUS_ERROR + */ +int +StreamConnection(pgsocket server_fd, Port *port) +{ + /* accept connection and fill in the client (remote) address */ + port->raddr.salen = sizeof(port->raddr.addr); + if ((port->sock = accept(server_fd, + (struct sockaddr *) &port->raddr.addr, + &port->raddr.salen)) == PGINVALID_SOCKET) + { + ereport(LOG, + (errcode_for_socket_access(), + errmsg("could not accept new connection: %m"))); + + /* + * If accept() fails then postmaster.c will still see the server + * socket as read-ready, and will immediately try again. To avoid + * uselessly sucking lots of CPU, delay a bit before trying again. + * (The most likely reason for failure is being out of kernel file + * table slots; we can do little except hope some will get freed up.) + */ + pg_usleep(100000L); /* wait 0.1 sec */ + return STATUS_ERROR; + } + + /* fill in the server (local) address */ + port->laddr.salen = sizeof(port->laddr.addr); + if (getsockname(port->sock, + (struct sockaddr *) &port->laddr.addr, + &port->laddr.salen) < 0) + { + ereport(LOG, + (errmsg("%s() failed: %m", "getsockname"))); + return STATUS_ERROR; + } + + /* select NODELAY and KEEPALIVE options if it's a TCP connection */ + if (!IS_AF_UNIX(port->laddr.addr.ss_family)) + { + int on; +#ifdef WIN32 + int oldopt; + int optlen; + int newopt; +#endif + +#ifdef TCP_NODELAY + on = 1; + if (setsockopt(port->sock, IPPROTO_TCP, TCP_NODELAY, + (char *) &on, sizeof(on)) < 0) + { + ereport(LOG, + (errmsg("%s(%s) failed: %m", "setsockopt", "TCP_NODELAY"))); + return STATUS_ERROR; + } +#endif + on = 1; + if (setsockopt(port->sock, SOL_SOCKET, SO_KEEPALIVE, + (char *) &on, sizeof(on)) < 0) + { + ereport(LOG, + (errmsg("%s(%s) failed: %m", "setsockopt", "SO_KEEPALIVE"))); + return STATUS_ERROR; + } + +#ifdef WIN32 + + /* + * This is a Win32 socket optimization. The OS send buffer should be + * large enough to send the whole Postgres send buffer in one go, or + * performance suffers. The Postgres send buffer can be enlarged if a + * very large message needs to be sent, but we won't attempt to + * enlarge the OS buffer if that happens, so somewhat arbitrarily + * ensure that the OS buffer is at least PQ_SEND_BUFFER_SIZE * 4. + * (That's 32kB with the current default). + * + * The default OS buffer size used to be 8kB in earlier Windows + * versions, but was raised to 64kB in Windows 2012. So it shouldn't + * be necessary to change it in later versions anymore. Changing it + * unnecessarily can even reduce performance, because setting + * SO_SNDBUF in the application disables the "dynamic send buffering" + * feature that was introduced in Windows 7. So before fiddling with + * SO_SNDBUF, check if the current buffer size is already large enough + * and only increase it if necessary. + * + * See https://support.microsoft.com/kb/823764/EN-US/ and + * https://msdn.microsoft.com/en-us/library/bb736549%28v=vs.85%29.aspx + */ + optlen = sizeof(oldopt); + if (getsockopt(port->sock, SOL_SOCKET, SO_SNDBUF, (char *) &oldopt, + &optlen) < 0) + { + ereport(LOG, + (errmsg("%s(%s) failed: %m", "getsockopt", "SO_SNDBUF"))); + return STATUS_ERROR; + } + newopt = PQ_SEND_BUFFER_SIZE * 4; + if (oldopt < newopt) + { + if (setsockopt(port->sock, SOL_SOCKET, SO_SNDBUF, (char *) &newopt, + sizeof(newopt)) < 0) + { + ereport(LOG, + (errmsg("%s(%s) failed: %m", "setsockopt", "SO_SNDBUF"))); + return STATUS_ERROR; + } + } +#endif + + /* + * Also apply the current keepalive parameters. If we fail to set a + * parameter, don't error out, because these aren't universally + * supported. (Note: you might think we need to reset the GUC + * variables to 0 in such a case, but it's not necessary because the + * show hooks for these variables report the truth anyway.) + */ + (void) pq_setkeepalivesidle(tcp_keepalives_idle, port); + (void) pq_setkeepalivesinterval(tcp_keepalives_interval, port); + (void) pq_setkeepalivescount(tcp_keepalives_count, port); + (void) pq_settcpusertimeout(tcp_user_timeout, port); + } + + return STATUS_OK; +} + +/* + * StreamClose -- close a client/backend connection + * + * NOTE: this is NOT used to terminate a session; it is just used to release + * the file descriptor in a process that should no longer have the socket + * open. (For example, the postmaster calls this after passing ownership + * of the connection to a child process.) It is expected that someone else + * still has the socket open. So, we only want to close the descriptor, + * we do NOT want to send anything to the far end. + */ +void +StreamClose(pgsocket sock) +{ + closesocket(sock); +} + +/* + * TouchSocketFiles -- mark socket files as recently accessed + * + * This routine should be called every so often to ensure that the socket + * files have a recent mod date (ordinary operations on sockets usually won't + * change the mod date). That saves them from being removed by + * overenthusiastic /tmp-directory-cleaner daemons. (Another reason we should + * never have put the socket file in /tmp...) + */ +void +TouchSocketFiles(void) +{ + ListCell *l; + + /* Loop through all created sockets... */ + foreach(l, sock_paths) + { + char *sock_path = (char *) lfirst(l); + + /* Ignore errors; there's no point in complaining */ + (void) utime(sock_path, NULL); + } +} + +/* + * RemoveSocketFiles -- unlink socket files at postmaster shutdown + */ +void +RemoveSocketFiles(void) +{ + ListCell *l; + + /* Loop through all created sockets... */ + foreach(l, sock_paths) + { + char *sock_path = (char *) lfirst(l); + + /* Ignore any error. */ + (void) unlink(sock_path); + } + /* Since we're about to exit, no need to reclaim storage */ + sock_paths = NIL; +} + + +/* -------------------------------- + * Low-level I/O routines begin here. + * + * These routines communicate with a frontend client across a connection + * already established by the preceding routines. + * -------------------------------- + */ + +/* -------------------------------- + * socket_set_nonblocking - set socket blocking/non-blocking + * + * Sets the socket non-blocking if nonblocking is true, or sets it + * blocking otherwise. + * -------------------------------- + */ +static void +socket_set_nonblocking(bool nonblocking) +{ + if (MyProcPort == NULL) + ereport(ERROR, + (errcode(ERRCODE_CONNECTION_DOES_NOT_EXIST), + errmsg("there is no client connection"))); + + MyProcPort->noblock = nonblocking; +} + +/* -------------------------------- + * pq_recvbuf - load some bytes into the input buffer + * + * returns 0 if OK, EOF if trouble + * -------------------------------- + */ +static int +pq_recvbuf(void) +{ + if (PqRecvPointer > 0) + { + if (PqRecvLength > PqRecvPointer) + { + /* still some unread data, left-justify it in the buffer */ + memmove(PqRecvBuffer, PqRecvBuffer + PqRecvPointer, + PqRecvLength - PqRecvPointer); + PqRecvLength -= PqRecvPointer; + PqRecvPointer = 0; + } + else + PqRecvLength = PqRecvPointer = 0; + } + + /* Ensure that we're in blocking mode */ + socket_set_nonblocking(false); + + /* Can fill buffer from PqRecvLength and upwards */ + for (;;) + { + int r; + + r = secure_read(MyProcPort, PqRecvBuffer + PqRecvLength, + PQ_RECV_BUFFER_SIZE - PqRecvLength); + + if (r < 0) + { + if (errno == EINTR) + continue; /* Ok if interrupted */ + + /* + * Careful: an ereport() that tries to write to the client would + * cause recursion to here, leading to stack overflow and core + * dump! This message must go *only* to the postmaster log. + */ + ereport(COMMERROR, + (errcode_for_socket_access(), + errmsg("could not receive data from client: %m"))); + return EOF; + } + if (r == 0) + { + /* + * EOF detected. We used to write a log message here, but it's + * better to expect the ultimate caller to do that. + */ + return EOF; + } + /* r contains number of bytes read, so just incr length */ + PqRecvLength += r; + return 0; + } +} + +/* -------------------------------- + * pq_getbyte - get a single byte from connection, or return EOF + * -------------------------------- + */ +int +pq_getbyte(void) +{ + Assert(PqCommReadingMsg); + + while (PqRecvPointer >= PqRecvLength) + { + if (pq_recvbuf()) /* If nothing in buffer, then recv some */ + return EOF; /* Failed to recv data */ + } + return (unsigned char) PqRecvBuffer[PqRecvPointer++]; +} + +/* -------------------------------- + * pq_peekbyte - peek at next byte from connection + * + * Same as pq_getbyte() except we don't advance the pointer. + * -------------------------------- + */ +int +pq_peekbyte(void) +{ + Assert(PqCommReadingMsg); + + while (PqRecvPointer >= PqRecvLength) + { + if (pq_recvbuf()) /* If nothing in buffer, then recv some */ + return EOF; /* Failed to recv data */ + } + return (unsigned char) PqRecvBuffer[PqRecvPointer]; +} + +/* -------------------------------- + * pq_getbyte_if_available - get a single byte from connection, + * if available + * + * The received byte is stored in *c. Returns 1 if a byte was read, + * 0 if no data was available, or EOF if trouble. + * -------------------------------- + */ +int +pq_getbyte_if_available(unsigned char *c) +{ + int r; + + Assert(PqCommReadingMsg); + + if (PqRecvPointer < PqRecvLength) + { + *c = PqRecvBuffer[PqRecvPointer++]; + return 1; + } + + /* Put the socket into non-blocking mode */ + socket_set_nonblocking(true); + + r = secure_read(MyProcPort, c, 1); + if (r < 0) + { + /* + * Ok if no data available without blocking or interrupted (though + * EINTR really shouldn't happen with a non-blocking socket). Report + * other errors. + */ + if (errno == EAGAIN || errno == EWOULDBLOCK || errno == EINTR) + r = 0; + else + { + /* + * Careful: an ereport() that tries to write to the client would + * cause recursion to here, leading to stack overflow and core + * dump! This message must go *only* to the postmaster log. + */ + ereport(COMMERROR, + (errcode_for_socket_access(), + errmsg("could not receive data from client: %m"))); + r = EOF; + } + } + else if (r == 0) + { + /* EOF detected */ + r = EOF; + } + + return r; +} + +/* -------------------------------- + * pq_getbytes - get a known number of bytes from connection + * + * returns 0 if OK, EOF if trouble + * -------------------------------- + */ +int +pq_getbytes(char *s, size_t len) +{ + size_t amount; + + Assert(PqCommReadingMsg); + + while (len > 0) + { + while (PqRecvPointer >= PqRecvLength) + { + if (pq_recvbuf()) /* If nothing in buffer, then recv some */ + return EOF; /* Failed to recv data */ + } + amount = PqRecvLength - PqRecvPointer; + if (amount > len) + amount = len; + memcpy(s, PqRecvBuffer + PqRecvPointer, amount); + PqRecvPointer += amount; + s += amount; + len -= amount; + } + return 0; +} + +/* -------------------------------- + * pq_discardbytes - throw away a known number of bytes + * + * same as pq_getbytes except we do not copy the data to anyplace. + * this is used for resynchronizing after read errors. + * + * returns 0 if OK, EOF if trouble + * -------------------------------- + */ +static int +pq_discardbytes(size_t len) +{ + size_t amount; + + Assert(PqCommReadingMsg); + + while (len > 0) + { + while (PqRecvPointer >= PqRecvLength) + { + if (pq_recvbuf()) /* If nothing in buffer, then recv some */ + return EOF; /* Failed to recv data */ + } + amount = PqRecvLength - PqRecvPointer; + if (amount > len) + amount = len; + PqRecvPointer += amount; + len -= amount; + } + return 0; +} + +/* -------------------------------- + * pq_buffer_has_data - is any buffered data available to read? + * + * This will *not* attempt to read more data. + * -------------------------------- + */ +bool +pq_buffer_has_data(void) +{ + return (PqRecvPointer < PqRecvLength); +} + + +/* -------------------------------- + * pq_startmsgread - begin reading a message from the client. + * + * This must be called before any of the pq_get* functions. + * -------------------------------- + */ +void +pq_startmsgread(void) +{ + /* + * There shouldn't be a read active already, but let's check just to be + * sure. + */ + if (PqCommReadingMsg) + ereport(FATAL, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("terminating connection because protocol synchronization was lost"))); + + PqCommReadingMsg = true; +} + + +/* -------------------------------- + * pq_endmsgread - finish reading message. + * + * This must be called after reading a message with pq_getbytes() + * and friends, to indicate that we have read the whole message. + * pq_getmessage() does this implicitly. + * -------------------------------- + */ +void +pq_endmsgread(void) +{ + Assert(PqCommReadingMsg); + + PqCommReadingMsg = false; +} + +/* -------------------------------- + * pq_is_reading_msg - are we currently reading a message? + * + * This is used in error recovery at the outer idle loop to detect if we have + * lost protocol sync, and need to terminate the connection. pq_startmsgread() + * will check for that too, but it's nicer to detect it earlier. + * -------------------------------- + */ +bool +pq_is_reading_msg(void) +{ + return PqCommReadingMsg; +} + +/* -------------------------------- + * pq_getmessage - get a message with length word from connection + * + * The return value is placed in an expansible StringInfo, which has + * already been initialized by the caller. + * Only the message body is placed in the StringInfo; the length word + * is removed. Also, s->cursor is initialized to zero for convenience + * in scanning the message contents. + * + * maxlen is the upper limit on the length of the + * message we are willing to accept. We abort the connection (by + * returning EOF) if client tries to send more than that. + * + * returns 0 if OK, EOF if trouble + * -------------------------------- + */ +int +pq_getmessage(StringInfo s, int maxlen) +{ + int32 len; + + Assert(PqCommReadingMsg); + + resetStringInfo(s); + + /* Read message length word */ + if (pq_getbytes((char *) &len, 4) == EOF) + { + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("unexpected EOF within message length word"))); + return EOF; + } + + len = pg_ntoh32(len); + + if (len < 4 || len > maxlen) + { + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("invalid message length"))); + return EOF; + } + + len -= 4; /* discount length itself */ + + if (len > 0) + { + /* + * Allocate space for message. If we run out of room (ridiculously + * large message), we will elog(ERROR), but we want to discard the + * message body so as not to lose communication sync. + */ + PG_TRY(); + { + enlargeStringInfo(s, len); + } + PG_CATCH(); + { + if (pq_discardbytes(len) == EOF) + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("incomplete message from client"))); + + /* we discarded the rest of the message so we're back in sync. */ + PqCommReadingMsg = false; + PG_RE_THROW(); + } + PG_END_TRY(); + + /* And grab the message */ + if (pq_getbytes(s->data, len) == EOF) + { + ereport(COMMERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("incomplete message from client"))); + return EOF; + } + s->len = len; + /* Place a trailing null per StringInfo convention */ + s->data[len] = '\0'; + } + + /* finished reading the message. */ + PqCommReadingMsg = false; + + return 0; +} + + +static int +internal_putbytes(const char *s, size_t len) +{ + size_t amount; + + while (len > 0) + { + /* If buffer is full, then flush it out */ + if (PqSendPointer >= PqSendBufferSize) + { + socket_set_nonblocking(false); + if (internal_flush()) + return EOF; + } + amount = PqSendBufferSize - PqSendPointer; + if (amount > len) + amount = len; + memcpy(PqSendBuffer + PqSendPointer, s, amount); + PqSendPointer += amount; + s += amount; + len -= amount; + } + return 0; +} + +/* -------------------------------- + * socket_flush - flush pending output + * + * returns 0 if OK, EOF if trouble + * -------------------------------- + */ +static int +socket_flush(void) +{ + int res; + + /* No-op if reentrant call */ + if (PqCommBusy) + return 0; + PqCommBusy = true; + socket_set_nonblocking(false); + res = internal_flush(); + PqCommBusy = false; + return res; +} + +/* -------------------------------- + * internal_flush - flush pending output + * + * Returns 0 if OK (meaning everything was sent, or operation would block + * and the socket is in non-blocking mode), or EOF if trouble. + * -------------------------------- + */ +static int +internal_flush(void) +{ + static int last_reported_send_errno = 0; + + char *bufptr = PqSendBuffer + PqSendStart; + char *bufend = PqSendBuffer + PqSendPointer; + + while (bufptr < bufend) + { + int r; + + r = secure_write(MyProcPort, bufptr, bufend - bufptr); + + if (r <= 0) + { + if (errno == EINTR) + continue; /* Ok if we were interrupted */ + + /* + * Ok if no data writable without blocking, and the socket is in + * non-blocking mode. + */ + if (errno == EAGAIN || + errno == EWOULDBLOCK) + { + return 0; + } + + /* + * Careful: an ereport() that tries to write to the client would + * cause recursion to here, leading to stack overflow and core + * dump! This message must go *only* to the postmaster log. + * + * If a client disconnects while we're in the midst of output, we + * might write quite a bit of data before we get to a safe query + * abort point. So, suppress duplicate log messages. + */ + if (errno != last_reported_send_errno) + { + last_reported_send_errno = errno; + ereport(COMMERROR, + (errcode_for_socket_access(), + errmsg("could not send data to client: %m"))); + } + + /* + * We drop the buffered data anyway so that processing can + * continue, even though we'll probably quit soon. We also set a + * flag that'll cause the next CHECK_FOR_INTERRUPTS to terminate + * the connection. + */ + PqSendStart = PqSendPointer = 0; + ClientConnectionLost = 1; + InterruptPending = 1; + return EOF; + } + + last_reported_send_errno = 0; /* reset after any successful send */ + bufptr += r; + PqSendStart += r; + } + + PqSendStart = PqSendPointer = 0; + return 0; +} + +/* -------------------------------- + * pq_flush_if_writable - flush pending output if writable without blocking + * + * Returns 0 if OK, or EOF if trouble. + * -------------------------------- + */ +static int +socket_flush_if_writable(void) +{ + int res; + + /* Quick exit if nothing to do */ + if (PqSendPointer == PqSendStart) + return 0; + + /* No-op if reentrant call */ + if (PqCommBusy) + return 0; + + /* Temporarily put the socket into non-blocking mode */ + socket_set_nonblocking(true); + + PqCommBusy = true; + res = internal_flush(); + PqCommBusy = false; + return res; +} + +/* -------------------------------- + * socket_is_send_pending - is there any pending data in the output buffer? + * -------------------------------- + */ +static bool +socket_is_send_pending(void) +{ + return (PqSendStart < PqSendPointer); +} + +/* -------------------------------- + * Message-level I/O routines begin here. + * -------------------------------- + */ + + +/* -------------------------------- + * socket_putmessage - send a normal message (suppressed in COPY OUT mode) + * + * msgtype is a message type code to place before the message body. + * + * len is the length of the message body data at *s. A message length + * word (equal to len+4 because it counts itself too) is inserted by this + * routine. + * + * We suppress messages generated while pqcomm.c is busy. This + * avoids any possibility of messages being inserted within other + * messages. The only known trouble case arises if SIGQUIT occurs + * during a pqcomm.c routine --- quickdie() will try to send a warning + * message, and the most reasonable approach seems to be to drop it. + * + * returns 0 if OK, EOF if trouble + * -------------------------------- + */ +static int +socket_putmessage(char msgtype, const char *s, size_t len) +{ + uint32 n32; + + Assert(msgtype != 0); + + if (PqCommBusy) + return 0; + PqCommBusy = true; + if (internal_putbytes(&msgtype, 1)) + goto fail; + + n32 = pg_hton32((uint32) (len + 4)); + if (internal_putbytes((char *) &n32, 4)) + goto fail; + + if (internal_putbytes(s, len)) + goto fail; + PqCommBusy = false; + return 0; + +fail: + PqCommBusy = false; + return EOF; +} + +/* -------------------------------- + * pq_putmessage_noblock - like pq_putmessage, but never blocks + * + * If the output buffer is too small to hold the message, the buffer + * is enlarged. + */ +static void +socket_putmessage_noblock(char msgtype, const char *s, size_t len) +{ + int res PG_USED_FOR_ASSERTS_ONLY; + int required; + + /* + * Ensure we have enough space in the output buffer for the message header + * as well as the message itself. + */ + required = PqSendPointer + 1 + 4 + len; + if (required > PqSendBufferSize) + { + PqSendBuffer = repalloc(PqSendBuffer, required); + PqSendBufferSize = required; + } + res = pq_putmessage(msgtype, s, len); + Assert(res == 0); /* should not fail when the message fits in + * buffer */ +} + +/* -------------------------------- + * pq_putmessage_v2 - send a message in protocol version 2 + * + * msgtype is a message type code to place before the message body. + * + * We no longer support protocol version 2, but we have kept this + * function so that if a client tries to connect with protocol version 2, + * as a courtesy we can still send the "unsupported protocol version" + * error to the client in the old format. + * + * Like in pq_putmessage(), we suppress messages generated while + * pqcomm.c is busy. + * + * returns 0 if OK, EOF if trouble + * -------------------------------- + */ +int +pq_putmessage_v2(char msgtype, const char *s, size_t len) +{ + Assert(msgtype != 0); + + if (PqCommBusy) + return 0; + PqCommBusy = true; + if (internal_putbytes(&msgtype, 1)) + goto fail; + + if (internal_putbytes(s, len)) + goto fail; + PqCommBusy = false; + return 0; + +fail: + PqCommBusy = false; + return EOF; +} + +/* + * Support for TCP Keepalive parameters + */ + +/* + * On Windows, we need to set both idle and interval at the same time. + * We also cannot reset them to the default (setting to zero will + * actually set them to zero, not default), therefore we fallback to + * the out-of-the-box default instead. + */ +#if defined(WIN32) && defined(SIO_KEEPALIVE_VALS) +static int +pq_setkeepaliveswin32(Port *port, int idle, int interval) +{ + struct tcp_keepalive ka; + DWORD retsize; + + if (idle <= 0) + idle = 2 * 60 * 60; /* default = 2 hours */ + if (interval <= 0) + interval = 1; /* default = 1 second */ + + ka.onoff = 1; + ka.keepalivetime = idle * 1000; + ka.keepaliveinterval = interval * 1000; + + if (WSAIoctl(port->sock, + SIO_KEEPALIVE_VALS, + (LPVOID) &ka, + sizeof(ka), + NULL, + 0, + &retsize, + NULL, + NULL) + != 0) + { + ereport(LOG, + (errmsg("%s(%s) failed: error code %d", + "WSAIoctl", "SIO_KEEPALIVE_VALS", WSAGetLastError()))); + return STATUS_ERROR; + } + if (port->keepalives_idle != idle) + port->keepalives_idle = idle; + if (port->keepalives_interval != interval) + port->keepalives_interval = interval; + return STATUS_OK; +} +#endif + +int +pq_getkeepalivesidle(Port *port) +{ +#if defined(PG_TCP_KEEPALIVE_IDLE) || defined(SIO_KEEPALIVE_VALS) + if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family)) + return 0; + + if (port->keepalives_idle != 0) + return port->keepalives_idle; + + if (port->default_keepalives_idle == 0) + { +#ifndef WIN32 + ACCEPT_TYPE_ARG3 size = sizeof(port->default_keepalives_idle); + + if (getsockopt(port->sock, IPPROTO_TCP, PG_TCP_KEEPALIVE_IDLE, + (char *) &port->default_keepalives_idle, + &size) < 0) + { + ereport(LOG, + (errmsg("%s(%s) failed: %m", "getsockopt", PG_TCP_KEEPALIVE_IDLE_STR))); + port->default_keepalives_idle = -1; /* don't know */ + } +#else /* WIN32 */ + /* We can't get the defaults on Windows, so return "don't know" */ + port->default_keepalives_idle = -1; +#endif /* WIN32 */ + } + + return port->default_keepalives_idle; +#else + return 0; +#endif +} + +int +pq_setkeepalivesidle(int idle, Port *port) +{ + if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family)) + return STATUS_OK; + +/* check SIO_KEEPALIVE_VALS here, not just WIN32, as some toolchains lack it */ +#if defined(PG_TCP_KEEPALIVE_IDLE) || defined(SIO_KEEPALIVE_VALS) + if (idle == port->keepalives_idle) + return STATUS_OK; + +#ifndef WIN32 + if (port->default_keepalives_idle <= 0) + { + if (pq_getkeepalivesidle(port) < 0) + { + if (idle == 0) + return STATUS_OK; /* default is set but unknown */ + else + return STATUS_ERROR; + } + } + + if (idle == 0) + idle = port->default_keepalives_idle; + + if (setsockopt(port->sock, IPPROTO_TCP, PG_TCP_KEEPALIVE_IDLE, + (char *) &idle, sizeof(idle)) < 0) + { + ereport(LOG, + (errmsg("%s(%s) failed: %m", "setsockopt", PG_TCP_KEEPALIVE_IDLE_STR))); + return STATUS_ERROR; + } + + port->keepalives_idle = idle; +#else /* WIN32 */ + return pq_setkeepaliveswin32(port, idle, port->keepalives_interval); +#endif +#else + if (idle != 0) + { + ereport(LOG, + (errmsg("setting the keepalive idle time is not supported"))); + return STATUS_ERROR; + } +#endif + + return STATUS_OK; +} + +int +pq_getkeepalivesinterval(Port *port) +{ +#if defined(TCP_KEEPINTVL) || defined(SIO_KEEPALIVE_VALS) + if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family)) + return 0; + + if (port->keepalives_interval != 0) + return port->keepalives_interval; + + if (port->default_keepalives_interval == 0) + { +#ifndef WIN32 + ACCEPT_TYPE_ARG3 size = sizeof(port->default_keepalives_interval); + + if (getsockopt(port->sock, IPPROTO_TCP, TCP_KEEPINTVL, + (char *) &port->default_keepalives_interval, + &size) < 0) + { + ereport(LOG, + (errmsg("%s(%s) failed: %m", "getsockopt", "TCP_KEEPINTVL"))); + port->default_keepalives_interval = -1; /* don't know */ + } +#else + /* We can't get the defaults on Windows, so return "don't know" */ + port->default_keepalives_interval = -1; +#endif /* WIN32 */ + } + + return port->default_keepalives_interval; +#else + return 0; +#endif +} + +int +pq_setkeepalivesinterval(int interval, Port *port) +{ + if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family)) + return STATUS_OK; + +#if defined(TCP_KEEPINTVL) || defined(SIO_KEEPALIVE_VALS) + if (interval == port->keepalives_interval) + return STATUS_OK; + +#ifndef WIN32 + if (port->default_keepalives_interval <= 0) + { + if (pq_getkeepalivesinterval(port) < 0) + { + if (interval == 0) + return STATUS_OK; /* default is set but unknown */ + else + return STATUS_ERROR; + } + } + + if (interval == 0) + interval = port->default_keepalives_interval; + + if (setsockopt(port->sock, IPPROTO_TCP, TCP_KEEPINTVL, + (char *) &interval, sizeof(interval)) < 0) + { + ereport(LOG, + (errmsg("%s(%s) failed: %m", "setsockopt", "TCP_KEEPINTVL"))); + return STATUS_ERROR; + } + + port->keepalives_interval = interval; +#else /* WIN32 */ + return pq_setkeepaliveswin32(port, port->keepalives_idle, interval); +#endif +#else + if (interval != 0) + { + ereport(LOG, + (errmsg("%s(%s) not supported", "setsockopt", "TCP_KEEPINTVL"))); + return STATUS_ERROR; + } +#endif + + return STATUS_OK; +} + +int +pq_getkeepalivescount(Port *port) +{ +#ifdef TCP_KEEPCNT + if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family)) + return 0; + + if (port->keepalives_count != 0) + return port->keepalives_count; + + if (port->default_keepalives_count == 0) + { + ACCEPT_TYPE_ARG3 size = sizeof(port->default_keepalives_count); + + if (getsockopt(port->sock, IPPROTO_TCP, TCP_KEEPCNT, + (char *) &port->default_keepalives_count, + &size) < 0) + { + ereport(LOG, + (errmsg("%s(%s) failed: %m", "getsockopt", "TCP_KEEPCNT"))); + port->default_keepalives_count = -1; /* don't know */ + } + } + + return port->default_keepalives_count; +#else + return 0; +#endif +} + +int +pq_setkeepalivescount(int count, Port *port) +{ + if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family)) + return STATUS_OK; + +#ifdef TCP_KEEPCNT + if (count == port->keepalives_count) + return STATUS_OK; + + if (port->default_keepalives_count <= 0) + { + if (pq_getkeepalivescount(port) < 0) + { + if (count == 0) + return STATUS_OK; /* default is set but unknown */ + else + return STATUS_ERROR; + } + } + + if (count == 0) + count = port->default_keepalives_count; + + if (setsockopt(port->sock, IPPROTO_TCP, TCP_KEEPCNT, + (char *) &count, sizeof(count)) < 0) + { + ereport(LOG, + (errmsg("%s(%s) failed: %m", "setsockopt", "TCP_KEEPCNT"))); + return STATUS_ERROR; + } + + port->keepalives_count = count; +#else + if (count != 0) + { + ereport(LOG, + (errmsg("%s(%s) not supported", "setsockopt", "TCP_KEEPCNT"))); + return STATUS_ERROR; + } +#endif + + return STATUS_OK; +} + +int +pq_gettcpusertimeout(Port *port) +{ +#ifdef TCP_USER_TIMEOUT + if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family)) + return 0; + + if (port->tcp_user_timeout != 0) + return port->tcp_user_timeout; + + if (port->default_tcp_user_timeout == 0) + { + ACCEPT_TYPE_ARG3 size = sizeof(port->default_tcp_user_timeout); + + if (getsockopt(port->sock, IPPROTO_TCP, TCP_USER_TIMEOUT, + (char *) &port->default_tcp_user_timeout, + &size) < 0) + { + ereport(LOG, + (errmsg("%s(%s) failed: %m", "getsockopt", "TCP_USER_TIMEOUT"))); + port->default_tcp_user_timeout = -1; /* don't know */ + } + } + + return port->default_tcp_user_timeout; +#else + return 0; +#endif +} + +int +pq_settcpusertimeout(int timeout, Port *port) +{ + if (port == NULL || IS_AF_UNIX(port->laddr.addr.ss_family)) + return STATUS_OK; + +#ifdef TCP_USER_TIMEOUT + if (timeout == port->tcp_user_timeout) + return STATUS_OK; + + if (port->default_tcp_user_timeout <= 0) + { + if (pq_gettcpusertimeout(port) < 0) + { + if (timeout == 0) + return STATUS_OK; /* default is set but unknown */ + else + return STATUS_ERROR; + } + } + + if (timeout == 0) + timeout = port->default_tcp_user_timeout; + + if (setsockopt(port->sock, IPPROTO_TCP, TCP_USER_TIMEOUT, + (char *) &timeout, sizeof(timeout)) < 0) + { + ereport(LOG, + (errmsg("%s(%s) failed: %m", "setsockopt", "TCP_USER_TIMEOUT"))); + return STATUS_ERROR; + } + + port->tcp_user_timeout = timeout; +#else + if (timeout != 0) + { + ereport(LOG, + (errmsg("%s(%s) not supported", "setsockopt", "TCP_USER_TIMEOUT"))); + return STATUS_ERROR; + } +#endif + + return STATUS_OK; +} + +/* + * Check if the client is still connected. + */ +bool +pq_check_connection(void) +{ +#if defined(POLLRDHUP) + /* + * POLLRDHUP is a Linux extension to poll(2) to detect sockets closed by + * the other end. We don't have a portable way to do that without + * actually trying to read or write data on other systems. We don't want + * to read because that would be confused by pipelined queries and COPY + * data. Perhaps in future we'll try to write a heartbeat message instead. + */ + struct pollfd pollfd; + int rc; + + pollfd.fd = MyProcPort->sock; + pollfd.events = POLLOUT | POLLIN | POLLRDHUP; + pollfd.revents = 0; + + rc = poll(&pollfd, 1, 0); + + if (rc < 0) + { + ereport(COMMERROR, + (errcode_for_socket_access(), + errmsg("could not poll socket: %m"))); + return false; + } + else if (rc == 1 && (pollfd.revents & (POLLHUP | POLLRDHUP))) + return false; +#endif + + return true; +} diff --git a/src/backend/libpq/pqformat.c b/src/backend/libpq/pqformat.c new file mode 100644 index 0000000..1999898 --- /dev/null +++ b/src/backend/libpq/pqformat.c @@ -0,0 +1,643 @@ +/*------------------------------------------------------------------------- + * + * pqformat.c + * Routines for formatting and parsing frontend/backend messages + * + * Outgoing messages are built up in a StringInfo buffer (which is expansible) + * and then sent in a single call to pq_putmessage. This module provides data + * formatting/conversion routines that are needed to produce valid messages. + * Note in particular the distinction between "raw data" and "text"; raw data + * is message protocol characters and binary values that are not subject to + * character set conversion, while text is converted by character encoding + * rules. + * + * Incoming messages are similarly read into a StringInfo buffer, via + * pq_getmessage, and then parsed and converted from that using the routines + * in this module. + * + * These same routines support reading and writing of external binary formats + * (typsend/typreceive routines). The conversion routines for individual + * data types are exactly the same, only initialization and completion + * are different. + * + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * src/backend/libpq/pqformat.c + * + *------------------------------------------------------------------------- + */ +/* + * INTERFACE ROUTINES + * Message assembly and output: + * pq_beginmessage - initialize StringInfo buffer + * pq_sendbyte - append a raw byte to a StringInfo buffer + * pq_sendint - append a binary integer to a StringInfo buffer + * pq_sendint64 - append a binary 8-byte int to a StringInfo buffer + * pq_sendfloat4 - append a float4 to a StringInfo buffer + * pq_sendfloat8 - append a float8 to a StringInfo buffer + * pq_sendbytes - append raw data to a StringInfo buffer + * pq_sendcountedtext - append a counted text string (with character set conversion) + * pq_sendtext - append a text string (with conversion) + * pq_sendstring - append a null-terminated text string (with conversion) + * pq_send_ascii_string - append a null-terminated text string (without conversion) + * pq_endmessage - send the completed message to the frontend + * Note: it is also possible to append data to the StringInfo buffer using + * the regular StringInfo routines, but this is discouraged since required + * character set conversion may not occur. + * + * typsend support (construct a bytea value containing external binary data): + * pq_begintypsend - initialize StringInfo buffer + * pq_endtypsend - return the completed string as a "bytea*" + * + * Special-case message output: + * pq_puttextmessage - generate a character set-converted message in one step + * pq_putemptymessage - convenience routine for message with empty body + * + * Message parsing after input: + * pq_getmsgbyte - get a raw byte from a message buffer + * pq_getmsgint - get a binary integer from a message buffer + * pq_getmsgint64 - get a binary 8-byte int from a message buffer + * pq_getmsgfloat4 - get a float4 from a message buffer + * pq_getmsgfloat8 - get a float8 from a message buffer + * pq_getmsgbytes - get raw data from a message buffer + * pq_copymsgbytes - copy raw data from a message buffer + * pq_getmsgtext - get a counted text string (with conversion) + * pq_getmsgstring - get a null-terminated text string (with conversion) + * pq_getmsgrawstring - get a null-terminated text string - NO conversion + * pq_getmsgend - verify message fully consumed + */ + +#include "postgres.h" + +#include <sys/param.h> + +#include "libpq/libpq.h" +#include "libpq/pqformat.h" +#include "mb/pg_wchar.h" +#include "port/pg_bswap.h" + + +/* -------------------------------- + * pq_beginmessage - initialize for sending a message + * -------------------------------- + */ +void +pq_beginmessage(StringInfo buf, char msgtype) +{ + initStringInfo(buf); + + /* + * We stash the message type into the buffer's cursor field, expecting + * that the pq_sendXXX routines won't touch it. We could alternatively + * make it the first byte of the buffer contents, but this seems easier. + */ + buf->cursor = msgtype; +} + +/* -------------------------------- + + * pq_beginmessage_reuse - initialize for sending a message, reuse buffer + * + * This requires the buffer to be allocated in a sufficiently long-lived + * memory context. + * -------------------------------- + */ +void +pq_beginmessage_reuse(StringInfo buf, char msgtype) +{ + resetStringInfo(buf); + + /* + * We stash the message type into the buffer's cursor field, expecting + * that the pq_sendXXX routines won't touch it. We could alternatively + * make it the first byte of the buffer contents, but this seems easier. + */ + buf->cursor = msgtype; +} + +/* -------------------------------- + * pq_sendbytes - append raw data to a StringInfo buffer + * -------------------------------- + */ +void +pq_sendbytes(StringInfo buf, const char *data, int datalen) +{ + /* use variant that maintains a trailing null-byte, out of caution */ + appendBinaryStringInfo(buf, data, datalen); +} + +/* -------------------------------- + * pq_sendcountedtext - append a counted text string (with character set conversion) + * + * The data sent to the frontend by this routine is a 4-byte count field + * followed by the string. The count includes itself or not, as per the + * countincludesself flag (pre-3.0 protocol requires it to include itself). + * The passed text string need not be null-terminated, and the data sent + * to the frontend isn't either. + * -------------------------------- + */ +void +pq_sendcountedtext(StringInfo buf, const char *str, int slen, + bool countincludesself) +{ + int extra = countincludesself ? 4 : 0; + char *p; + + p = pg_server_to_client(str, slen); + if (p != str) /* actual conversion has been done? */ + { + slen = strlen(p); + pq_sendint32(buf, slen + extra); + appendBinaryStringInfoNT(buf, p, slen); + pfree(p); + } + else + { + pq_sendint32(buf, slen + extra); + appendBinaryStringInfoNT(buf, str, slen); + } +} + +/* -------------------------------- + * pq_sendtext - append a text string (with conversion) + * + * The passed text string need not be null-terminated, and the data sent + * to the frontend isn't either. Note that this is not actually useful + * for direct frontend transmissions, since there'd be no way for the + * frontend to determine the string length. But it is useful for binary + * format conversions. + * -------------------------------- + */ +void +pq_sendtext(StringInfo buf, const char *str, int slen) +{ + char *p; + + p = pg_server_to_client(str, slen); + if (p != str) /* actual conversion has been done? */ + { + slen = strlen(p); + appendBinaryStringInfo(buf, p, slen); + pfree(p); + } + else + appendBinaryStringInfo(buf, str, slen); +} + +/* -------------------------------- + * pq_sendstring - append a null-terminated text string (with conversion) + * + * NB: passed text string must be null-terminated, and so is the data + * sent to the frontend. + * -------------------------------- + */ +void +pq_sendstring(StringInfo buf, const char *str) +{ + int slen = strlen(str); + char *p; + + p = pg_server_to_client(str, slen); + if (p != str) /* actual conversion has been done? */ + { + slen = strlen(p); + appendBinaryStringInfoNT(buf, p, slen + 1); + pfree(p); + } + else + appendBinaryStringInfoNT(buf, str, slen + 1); +} + +/* -------------------------------- + * pq_send_ascii_string - append a null-terminated text string (without conversion) + * + * This function intentionally bypasses encoding conversion, instead just + * silently replacing any non-7-bit-ASCII characters with question marks. + * It is used only when we are having trouble sending an error message to + * the client with normal localization and encoding conversion. The caller + * should already have taken measures to ensure the string is just ASCII; + * the extra work here is just to make certain we don't send a badly encoded + * string to the client (which might or might not be robust about that). + * + * NB: passed text string must be null-terminated, and so is the data + * sent to the frontend. + * -------------------------------- + */ +void +pq_send_ascii_string(StringInfo buf, const char *str) +{ + while (*str) + { + char ch = *str++; + + if (IS_HIGHBIT_SET(ch)) + ch = '?'; + appendStringInfoCharMacro(buf, ch); + } + appendStringInfoChar(buf, '\0'); +} + +/* -------------------------------- + * pq_sendfloat4 - append a float4 to a StringInfo buffer + * + * The point of this routine is to localize knowledge of the external binary + * representation of float4, which is a component of several datatypes. + * + * We currently assume that float4 should be byte-swapped in the same way + * as int4. This rule is not perfect but it gives us portability across + * most IEEE-float-using architectures. + * -------------------------------- + */ +void +pq_sendfloat4(StringInfo buf, float4 f) +{ + union + { + float4 f; + uint32 i; + } swap; + + swap.f = f; + pq_sendint32(buf, swap.i); +} + +/* -------------------------------- + * pq_sendfloat8 - append a float8 to a StringInfo buffer + * + * The point of this routine is to localize knowledge of the external binary + * representation of float8, which is a component of several datatypes. + * + * We currently assume that float8 should be byte-swapped in the same way + * as int8. This rule is not perfect but it gives us portability across + * most IEEE-float-using architectures. + * -------------------------------- + */ +void +pq_sendfloat8(StringInfo buf, float8 f) +{ + union + { + float8 f; + int64 i; + } swap; + + swap.f = f; + pq_sendint64(buf, swap.i); +} + +/* -------------------------------- + * pq_endmessage - send the completed message to the frontend + * + * The data buffer is pfree()d, but if the StringInfo was allocated with + * makeStringInfo then the caller must still pfree it. + * -------------------------------- + */ +void +pq_endmessage(StringInfo buf) +{ + /* msgtype was saved in cursor field */ + (void) pq_putmessage(buf->cursor, buf->data, buf->len); + /* no need to complain about any failure, since pqcomm.c already did */ + pfree(buf->data); + buf->data = NULL; +} + +/* -------------------------------- + * pq_endmessage_reuse - send the completed message to the frontend + * + * The data buffer is *not* freed, allowing to reuse the buffer with + * pq_beginmessage_reuse. + -------------------------------- + */ + +void +pq_endmessage_reuse(StringInfo buf) +{ + /* msgtype was saved in cursor field */ + (void) pq_putmessage(buf->cursor, buf->data, buf->len); +} + + +/* -------------------------------- + * pq_begintypsend - initialize for constructing a bytea result + * -------------------------------- + */ +void +pq_begintypsend(StringInfo buf) +{ + initStringInfo(buf); + /* Reserve four bytes for the bytea length word */ + appendStringInfoCharMacro(buf, '\0'); + appendStringInfoCharMacro(buf, '\0'); + appendStringInfoCharMacro(buf, '\0'); + appendStringInfoCharMacro(buf, '\0'); +} + +/* -------------------------------- + * pq_endtypsend - finish constructing a bytea result + * + * The data buffer is returned as the palloc'd bytea value. (We expect + * that it will be suitably aligned for this because it has been palloc'd.) + * We assume the StringInfoData is just a local variable in the caller and + * need not be pfree'd. + * -------------------------------- + */ +bytea * +pq_endtypsend(StringInfo buf) +{ + bytea *result = (bytea *) buf->data; + + /* Insert correct length into bytea length word */ + Assert(buf->len >= VARHDRSZ); + SET_VARSIZE(result, buf->len); + + return result; +} + + +/* -------------------------------- + * pq_puttextmessage - generate a character set-converted message in one step + * + * This is the same as the pqcomm.c routine pq_putmessage, except that + * the message body is a null-terminated string to which encoding + * conversion applies. + * -------------------------------- + */ +void +pq_puttextmessage(char msgtype, const char *str) +{ + int slen = strlen(str); + char *p; + + p = pg_server_to_client(str, slen); + if (p != str) /* actual conversion has been done? */ + { + (void) pq_putmessage(msgtype, p, strlen(p) + 1); + pfree(p); + return; + } + (void) pq_putmessage(msgtype, str, slen + 1); +} + + +/* -------------------------------- + * pq_putemptymessage - convenience routine for message with empty body + * -------------------------------- + */ +void +pq_putemptymessage(char msgtype) +{ + (void) pq_putmessage(msgtype, NULL, 0); +} + + +/* -------------------------------- + * pq_getmsgbyte - get a raw byte from a message buffer + * -------------------------------- + */ +int +pq_getmsgbyte(StringInfo msg) +{ + if (msg->cursor >= msg->len) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("no data left in message"))); + return (unsigned char) msg->data[msg->cursor++]; +} + +/* -------------------------------- + * pq_getmsgint - get a binary integer from a message buffer + * + * Values are treated as unsigned. + * -------------------------------- + */ +unsigned int +pq_getmsgint(StringInfo msg, int b) +{ + unsigned int result; + unsigned char n8; + uint16 n16; + uint32 n32; + + switch (b) + { + case 1: + pq_copymsgbytes(msg, (char *) &n8, 1); + result = n8; + break; + case 2: + pq_copymsgbytes(msg, (char *) &n16, 2); + result = pg_ntoh16(n16); + break; + case 4: + pq_copymsgbytes(msg, (char *) &n32, 4); + result = pg_ntoh32(n32); + break; + default: + elog(ERROR, "unsupported integer size %d", b); + result = 0; /* keep compiler quiet */ + break; + } + return result; +} + +/* -------------------------------- + * pq_getmsgint64 - get a binary 8-byte int from a message buffer + * + * It is tempting to merge this with pq_getmsgint, but we'd have to make the + * result int64 for all data widths --- that could be a big performance + * hit on machines where int64 isn't efficient. + * -------------------------------- + */ +int64 +pq_getmsgint64(StringInfo msg) +{ + uint64 n64; + + pq_copymsgbytes(msg, (char *) &n64, sizeof(n64)); + + return pg_ntoh64(n64); +} + +/* -------------------------------- + * pq_getmsgfloat4 - get a float4 from a message buffer + * + * See notes for pq_sendfloat4. + * -------------------------------- + */ +float4 +pq_getmsgfloat4(StringInfo msg) +{ + union + { + float4 f; + uint32 i; + } swap; + + swap.i = pq_getmsgint(msg, 4); + return swap.f; +} + +/* -------------------------------- + * pq_getmsgfloat8 - get a float8 from a message buffer + * + * See notes for pq_sendfloat8. + * -------------------------------- + */ +float8 +pq_getmsgfloat8(StringInfo msg) +{ + union + { + float8 f; + int64 i; + } swap; + + swap.i = pq_getmsgint64(msg); + return swap.f; +} + +/* -------------------------------- + * pq_getmsgbytes - get raw data from a message buffer + * + * Returns a pointer directly into the message buffer; note this + * may not have any particular alignment. + * -------------------------------- + */ +const char * +pq_getmsgbytes(StringInfo msg, int datalen) +{ + const char *result; + + if (datalen < 0 || datalen > (msg->len - msg->cursor)) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("insufficient data left in message"))); + result = &msg->data[msg->cursor]; + msg->cursor += datalen; + return result; +} + +/* -------------------------------- + * pq_copymsgbytes - copy raw data from a message buffer + * + * Same as above, except data is copied to caller's buffer. + * -------------------------------- + */ +void +pq_copymsgbytes(StringInfo msg, char *buf, int datalen) +{ + if (datalen < 0 || datalen > (msg->len - msg->cursor)) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("insufficient data left in message"))); + memcpy(buf, &msg->data[msg->cursor], datalen); + msg->cursor += datalen; +} + +/* -------------------------------- + * pq_getmsgtext - get a counted text string (with conversion) + * + * Always returns a pointer to a freshly palloc'd result. + * The result has a trailing null, *and* we return its strlen in *nbytes. + * -------------------------------- + */ +char * +pq_getmsgtext(StringInfo msg, int rawbytes, int *nbytes) +{ + char *str; + char *p; + + if (rawbytes < 0 || rawbytes > (msg->len - msg->cursor)) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("insufficient data left in message"))); + str = &msg->data[msg->cursor]; + msg->cursor += rawbytes; + + p = pg_client_to_server(str, rawbytes); + if (p != str) /* actual conversion has been done? */ + *nbytes = strlen(p); + else + { + p = (char *) palloc(rawbytes + 1); + memcpy(p, str, rawbytes); + p[rawbytes] = '\0'; + *nbytes = rawbytes; + } + return p; +} + +/* -------------------------------- + * pq_getmsgstring - get a null-terminated text string (with conversion) + * + * May return a pointer directly into the message buffer, or a pointer + * to a palloc'd conversion result. + * -------------------------------- + */ +const char * +pq_getmsgstring(StringInfo msg) +{ + char *str; + int slen; + + str = &msg->data[msg->cursor]; + + /* + * It's safe to use strlen() here because a StringInfo is guaranteed to + * have a trailing null byte. But check we found a null inside the + * message. + */ + slen = strlen(str); + if (msg->cursor + slen >= msg->len) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("invalid string in message"))); + msg->cursor += slen + 1; + + return pg_client_to_server(str, slen); +} + +/* -------------------------------- + * pq_getmsgrawstring - get a null-terminated text string - NO conversion + * + * Returns a pointer directly into the message buffer. + * -------------------------------- + */ +const char * +pq_getmsgrawstring(StringInfo msg) +{ + char *str; + int slen; + + str = &msg->data[msg->cursor]; + + /* + * It's safe to use strlen() here because a StringInfo is guaranteed to + * have a trailing null byte. But check we found a null inside the + * message. + */ + slen = strlen(str); + if (msg->cursor + slen >= msg->len) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("invalid string in message"))); + msg->cursor += slen + 1; + + return str; +} + +/* -------------------------------- + * pq_getmsgend - verify message fully consumed + * -------------------------------- + */ +void +pq_getmsgend(StringInfo msg) +{ + if (msg->cursor != msg->len) + ereport(ERROR, + (errcode(ERRCODE_PROTOCOL_VIOLATION), + errmsg("invalid message format"))); +} diff --git a/src/backend/libpq/pqmq.c b/src/backend/libpq/pqmq.c new file mode 100644 index 0000000..d1a1f47 --- /dev/null +++ b/src/backend/libpq/pqmq.c @@ -0,0 +1,313 @@ +/*------------------------------------------------------------------------- + * + * pqmq.c + * Use the frontend/backend protocol for communication over a shm_mq + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * src/backend/libpq/pqmq.c + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" + +#include "libpq/libpq.h" +#include "libpq/pqformat.h" +#include "libpq/pqmq.h" +#include "miscadmin.h" +#include "pgstat.h" +#include "tcop/tcopprot.h" +#include "utils/builtins.h" + +static shm_mq_handle *pq_mq_handle; +static bool pq_mq_busy = false; +static pid_t pq_mq_parallel_leader_pid = 0; +static pid_t pq_mq_parallel_leader_backend_id = InvalidBackendId; + +static void pq_cleanup_redirect_to_shm_mq(dsm_segment *seg, Datum arg); +static void mq_comm_reset(void); +static int mq_flush(void); +static int mq_flush_if_writable(void); +static bool mq_is_send_pending(void); +static int mq_putmessage(char msgtype, const char *s, size_t len); +static void mq_putmessage_noblock(char msgtype, const char *s, size_t len); + +static const PQcommMethods PqCommMqMethods = { + mq_comm_reset, + mq_flush, + mq_flush_if_writable, + mq_is_send_pending, + mq_putmessage, + mq_putmessage_noblock +}; + +/* + * Arrange to redirect frontend/backend protocol messages to a shared-memory + * message queue. + */ +void +pq_redirect_to_shm_mq(dsm_segment *seg, shm_mq_handle *mqh) +{ + PqCommMethods = &PqCommMqMethods; + pq_mq_handle = mqh; + whereToSendOutput = DestRemote; + FrontendProtocol = PG_PROTOCOL_LATEST; + on_dsm_detach(seg, pq_cleanup_redirect_to_shm_mq, (Datum) 0); +} + +/* + * When the DSM that contains our shm_mq goes away, we need to stop sending + * messages to it. + */ +static void +pq_cleanup_redirect_to_shm_mq(dsm_segment *seg, Datum arg) +{ + pq_mq_handle = NULL; + whereToSendOutput = DestNone; +} + +/* + * Arrange to SendProcSignal() to the parallel leader each time we transmit + * message data via the shm_mq. + */ +void +pq_set_parallel_leader(pid_t pid, BackendId backend_id) +{ + Assert(PqCommMethods == &PqCommMqMethods); + pq_mq_parallel_leader_pid = pid; + pq_mq_parallel_leader_backend_id = backend_id; +} + +static void +mq_comm_reset(void) +{ + /* Nothing to do. */ +} + +static int +mq_flush(void) +{ + /* Nothing to do. */ + return 0; +} + +static int +mq_flush_if_writable(void) +{ + /* Nothing to do. */ + return 0; +} + +static bool +mq_is_send_pending(void) +{ + /* There's never anything pending. */ + return 0; +} + +/* + * Transmit a libpq protocol message to the shared memory message queue + * selected via pq_mq_handle. We don't include a length word, because the + * receiver will know the length of the message from shm_mq_receive(). + */ +static int +mq_putmessage(char msgtype, const char *s, size_t len) +{ + shm_mq_iovec iov[2]; + shm_mq_result result; + + /* + * If we're sending a message, and we have to wait because the queue is + * full, and then we get interrupted, and that interrupt results in trying + * to send another message, we respond by detaching the queue. There's no + * way to return to the original context, but even if there were, just + * queueing the message would amount to indefinitely postponing the + * response to the interrupt. So we do this instead. + */ + if (pq_mq_busy) + { + if (pq_mq_handle != NULL) + shm_mq_detach(pq_mq_handle); + pq_mq_handle = NULL; + return EOF; + } + + /* + * If the message queue is already gone, just ignore the message. This + * doesn't necessarily indicate a problem; for example, DEBUG messages can + * be generated late in the shutdown sequence, after all DSMs have already + * been detached. + */ + if (pq_mq_handle == NULL) + return 0; + + pq_mq_busy = true; + + iov[0].data = &msgtype; + iov[0].len = 1; + iov[1].data = s; + iov[1].len = len; + + Assert(pq_mq_handle != NULL); + + for (;;) + { + result = shm_mq_sendv(pq_mq_handle, iov, 2, true); + + if (pq_mq_parallel_leader_pid != 0) + SendProcSignal(pq_mq_parallel_leader_pid, + PROCSIG_PARALLEL_MESSAGE, + pq_mq_parallel_leader_backend_id); + + if (result != SHM_MQ_WOULD_BLOCK) + break; + + (void) WaitLatch(MyLatch, WL_LATCH_SET | WL_EXIT_ON_PM_DEATH, 0, + WAIT_EVENT_MQ_PUT_MESSAGE); + ResetLatch(MyLatch); + CHECK_FOR_INTERRUPTS(); + } + + pq_mq_busy = false; + + Assert(result == SHM_MQ_SUCCESS || result == SHM_MQ_DETACHED); + if (result != SHM_MQ_SUCCESS) + return EOF; + return 0; +} + +static void +mq_putmessage_noblock(char msgtype, const char *s, size_t len) +{ + /* + * While the shm_mq machinery does support sending a message in + * non-blocking mode, there's currently no way to try sending beginning to + * send the message that doesn't also commit us to completing the + * transmission. This could be improved in the future, but for now we + * don't need it. + */ + elog(ERROR, "not currently supported"); +} + +/* + * Parse an ErrorResponse or NoticeResponse payload and populate an ErrorData + * structure with the results. + */ +void +pq_parse_errornotice(StringInfo msg, ErrorData *edata) +{ + /* Initialize edata with reasonable defaults. */ + MemSet(edata, 0, sizeof(ErrorData)); + edata->elevel = ERROR; + edata->assoc_context = CurrentMemoryContext; + + /* Loop over fields and extract each one. */ + for (;;) + { + char code = pq_getmsgbyte(msg); + const char *value; + + if (code == '\0') + { + pq_getmsgend(msg); + break; + } + value = pq_getmsgrawstring(msg); + + switch (code) + { + case PG_DIAG_SEVERITY: + /* ignore, trusting we'll get a nonlocalized version */ + break; + case PG_DIAG_SEVERITY_NONLOCALIZED: + if (strcmp(value, "DEBUG") == 0) + { + /* + * We can't reconstruct the exact DEBUG level, but + * presumably it was >= client_min_messages, so select + * DEBUG1 to ensure we'll pass it on to the client. + */ + edata->elevel = DEBUG1; + } + else if (strcmp(value, "LOG") == 0) + { + /* + * It can't be LOG_SERVER_ONLY, or the worker wouldn't + * have sent it to us; so LOG is the correct value. + */ + edata->elevel = LOG; + } + else if (strcmp(value, "INFO") == 0) + edata->elevel = INFO; + else if (strcmp(value, "NOTICE") == 0) + edata->elevel = NOTICE; + else if (strcmp(value, "WARNING") == 0) + edata->elevel = WARNING; + else if (strcmp(value, "ERROR") == 0) + edata->elevel = ERROR; + else if (strcmp(value, "FATAL") == 0) + edata->elevel = FATAL; + else if (strcmp(value, "PANIC") == 0) + edata->elevel = PANIC; + else + elog(ERROR, "unrecognized error severity: \"%s\"", value); + break; + case PG_DIAG_SQLSTATE: + if (strlen(value) != 5) + elog(ERROR, "invalid SQLSTATE: \"%s\"", value); + edata->sqlerrcode = MAKE_SQLSTATE(value[0], value[1], value[2], + value[3], value[4]); + break; + case PG_DIAG_MESSAGE_PRIMARY: + edata->message = pstrdup(value); + break; + case PG_DIAG_MESSAGE_DETAIL: + edata->detail = pstrdup(value); + break; + case PG_DIAG_MESSAGE_HINT: + edata->hint = pstrdup(value); + break; + case PG_DIAG_STATEMENT_POSITION: + edata->cursorpos = pg_strtoint32(value); + break; + case PG_DIAG_INTERNAL_POSITION: + edata->internalpos = pg_strtoint32(value); + break; + case PG_DIAG_INTERNAL_QUERY: + edata->internalquery = pstrdup(value); + break; + case PG_DIAG_CONTEXT: + edata->context = pstrdup(value); + break; + case PG_DIAG_SCHEMA_NAME: + edata->schema_name = pstrdup(value); + break; + case PG_DIAG_TABLE_NAME: + edata->table_name = pstrdup(value); + break; + case PG_DIAG_COLUMN_NAME: + edata->column_name = pstrdup(value); + break; + case PG_DIAG_DATATYPE_NAME: + edata->datatype_name = pstrdup(value); + break; + case PG_DIAG_CONSTRAINT_NAME: + edata->constraint_name = pstrdup(value); + break; + case PG_DIAG_SOURCE_FILE: + edata->filename = pstrdup(value); + break; + case PG_DIAG_SOURCE_LINE: + edata->lineno = pg_strtoint32(value); + break; + case PG_DIAG_SOURCE_FUNCTION: + edata->funcname = pstrdup(value); + break; + default: + elog(ERROR, "unrecognized error field code: %d", (int) code); + break; + } + } +} diff --git a/src/backend/libpq/pqsignal.c b/src/backend/libpq/pqsignal.c new file mode 100644 index 0000000..dedf3a4 --- /dev/null +++ b/src/backend/libpq/pqsignal.c @@ -0,0 +1,148 @@ +/*------------------------------------------------------------------------- + * + * pqsignal.c + * Backend signal(2) support (see also src/port/pqsignal.c) + * + * Portions Copyright (c) 1996-2021, PostgreSQL Global Development Group + * Portions Copyright (c) 1994, Regents of the University of California + * + * + * IDENTIFICATION + * src/backend/libpq/pqsignal.c + * + * ------------------------------------------------------------------------ + */ + +#include "postgres.h" + +#include "libpq/pqsignal.h" + + +/* Global variables */ +sigset_t UnBlockSig, + BlockSig, + StartupBlockSig; + + +/* + * Initialize BlockSig, UnBlockSig, and StartupBlockSig. + * + * BlockSig is the set of signals to block when we are trying to block + * signals. This includes all signals we normally expect to get, but NOT + * signals that should never be turned off. + * + * StartupBlockSig is the set of signals to block during startup packet + * collection; it's essentially BlockSig minus SIGTERM, SIGQUIT, SIGALRM. + * + * UnBlockSig is the set of signals to block when we don't want to block + * signals. + */ +void +pqinitmask(void) +{ + sigemptyset(&UnBlockSig); + + /* Note: InitializeLatchSupport() modifies UnBlockSig. */ + + /* First set all signals, then clear some. */ + sigfillset(&BlockSig); + sigfillset(&StartupBlockSig); + + /* + * Unmark those signals that should never be blocked. Some of these signal + * names don't exist on all platforms. Most do, but might as well ifdef + * them all for consistency... + */ +#ifdef SIGTRAP + sigdelset(&BlockSig, SIGTRAP); + sigdelset(&StartupBlockSig, SIGTRAP); +#endif +#ifdef SIGABRT + sigdelset(&BlockSig, SIGABRT); + sigdelset(&StartupBlockSig, SIGABRT); +#endif +#ifdef SIGILL + sigdelset(&BlockSig, SIGILL); + sigdelset(&StartupBlockSig, SIGILL); +#endif +#ifdef SIGFPE + sigdelset(&BlockSig, SIGFPE); + sigdelset(&StartupBlockSig, SIGFPE); +#endif +#ifdef SIGSEGV + sigdelset(&BlockSig, SIGSEGV); + sigdelset(&StartupBlockSig, SIGSEGV); +#endif +#ifdef SIGBUS + sigdelset(&BlockSig, SIGBUS); + sigdelset(&StartupBlockSig, SIGBUS); +#endif +#ifdef SIGSYS + sigdelset(&BlockSig, SIGSYS); + sigdelset(&StartupBlockSig, SIGSYS); +#endif +#ifdef SIGCONT + sigdelset(&BlockSig, SIGCONT); + sigdelset(&StartupBlockSig, SIGCONT); +#endif + +/* Signals unique to startup */ +#ifdef SIGQUIT + sigdelset(&StartupBlockSig, SIGQUIT); +#endif +#ifdef SIGTERM + sigdelset(&StartupBlockSig, SIGTERM); +#endif +#ifdef SIGALRM + sigdelset(&StartupBlockSig, SIGALRM); +#endif +} + +/* + * Set up a postmaster signal handler for signal "signo" + * + * Returns the previous handler. + * + * This is used only in the postmaster, which has its own odd approach to + * signal handling. For signals with handlers, we block all signals for the + * duration of signal handler execution. We also do not set the SA_RESTART + * flag; this should be safe given the tiny range of code in which the + * postmaster ever unblocks signals. + * + * pqinitmask() must have been invoked previously. + * + * On Windows, this function is just an alias for pqsignal() + * (and note that it's calling the code in src/backend/port/win32/signal.c, + * not src/port/pqsignal.c). On that platform, the postmaster's signal + * handlers still have to block signals for themselves. + */ +pqsigfunc +pqsignal_pm(int signo, pqsigfunc func) +{ +#ifndef WIN32 + struct sigaction act, + oact; + + act.sa_handler = func; + if (func == SIG_IGN || func == SIG_DFL) + { + /* in these cases, act the same as pqsignal() */ + sigemptyset(&act.sa_mask); + act.sa_flags = SA_RESTART; + } + else + { + act.sa_mask = BlockSig; + act.sa_flags = 0; + } +#ifdef SA_NOCLDSTOP + if (signo == SIGCHLD) + act.sa_flags |= SA_NOCLDSTOP; +#endif + if (sigaction(signo, &act, &oact) < 0) + return SIG_ERR; + return oact.sa_handler; +#else /* WIN32 */ + return pqsignal(signo, func); +#endif +} |