summaryrefslogtreecommitdiffstats
path: root/src/lib-oauth2/oauth2-jwt.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib-oauth2/oauth2-jwt.c')
-rw-r--r--src/lib-oauth2/oauth2-jwt.c489
1 files changed, 489 insertions, 0 deletions
diff --git a/src/lib-oauth2/oauth2-jwt.c b/src/lib-oauth2/oauth2-jwt.c
new file mode 100644
index 0000000..ec7ad46
--- /dev/null
+++ b/src/lib-oauth2/oauth2-jwt.c
@@ -0,0 +1,489 @@
+/* Copyright (c) 2020 Dovecot authors, see the included COPYING file */
+
+#include "lib.h"
+#include "buffer.h"
+#include "str.h"
+#include "hmac.h"
+#include "array.h"
+#include "hash-method.h"
+#include "istream.h"
+#include "iso8601-date.h"
+#include "json-tree.h"
+#include "array.h"
+#include "base64.h"
+#include "str-sanitize.h"
+#include "dcrypt.h"
+#include "var-expand.h"
+#include "oauth2.h"
+#include "oauth2-private.h"
+#include "dict.h"
+
+#include <time.h>
+
+static const char *get_field(const struct json_tree *tree, const char *key)
+{
+ const struct json_tree_node *root = json_tree_root(tree);
+ const struct json_tree_node *value_node = json_tree_find_key(root, key);
+ if (value_node == NULL || value_node->value_type == JSON_TYPE_OBJECT ||
+ value_node->value_type == JSON_TYPE_ARRAY)
+ return NULL;
+ return json_tree_get_value_str(value_node);
+}
+
+static int get_time_field(const struct json_tree *tree, const char *key,
+ int64_t *value_r)
+{
+ time_t tvalue;
+ const char *value = get_field(tree, key);
+ int tz_offset ATTR_UNUSED;
+ if (value == NULL)
+ return 0;
+ if (str_to_int64(value, value_r) == 0) {
+ if (*value_r < 0)
+ return -1;
+ return 1;
+ } else if (iso8601_date_parse((const unsigned char*)value, strlen(value),
+ &tvalue, &tz_offset)) {
+ if (tvalue < 0)
+ return -1;
+ *value_r = tvalue;
+ return 1;
+ }
+ return -1;
+}
+
+/* Escapes '/' and '%' in identifier to %hex */
+static const char *escape_identifier(const char *identifier)
+{
+ size_t pos = strcspn(identifier, "/%");
+ /* nothing to escape */
+ if (identifier[pos] == '\0')
+ return identifier;
+
+ size_t len = strlen(identifier);
+ string_t *new_id = t_str_new(len);
+ str_append_data(new_id, identifier, pos);
+
+ for (size_t i = pos; i < len; i++) {
+ switch (identifier[i]) {
+ case '/':
+ str_append(new_id, "%2f");
+ break;
+ case '%':
+ str_append(new_id, "%25");
+ break;
+ default:
+ str_append_c(new_id, identifier[i]);
+ break;
+ }
+ }
+ return str_c(new_id);
+}
+
+static int
+oauth2_lookup_hmac_key(const struct oauth2_settings *set, const char *azp,
+ const char *alg, const char *key_id,
+ const buffer_t **hmac_key_r, const char **error_r)
+{
+ const char *base64_key;
+ const char *cache_key_id, *lookup_key;
+ int ret;
+
+ cache_key_id = t_strconcat(azp, ".", alg, ".", key_id, NULL);
+ if (oauth2_validation_key_cache_lookup_hmac_key(
+ set->key_cache, cache_key_id, hmac_key_r) == 0)
+ return 0;
+
+
+ /* do a synchronous dict lookup */
+ lookup_key = t_strconcat(DICT_PATH_SHARED, azp, "/", alg, "/", key_id,
+ NULL);
+ struct dict_op_settings dict_set = {
+ .username = NULL,
+ };
+ if ((ret = dict_lookup(set->key_dict, &dict_set, pool_datastack_create(),
+ lookup_key, &base64_key, error_r)) < 0) {
+ return -1;
+ } else if (ret == 0) {
+ *error_r = t_strdup_printf("%s key '%s' not found",
+ alg, key_id);
+ return -1;
+ }
+
+ /* decode key */
+ buffer_t *key = t_base64_decode_str(base64_key);
+ if (key->used == 0) {
+ *error_r = "Invalid base64 encoded key";
+ return -1;
+ }
+ oauth2_validation_key_cache_insert_hmac_key(set->key_cache,
+ cache_key_id, key);
+ *hmac_key_r = key;
+ return 0;
+}
+
+static int
+oauth2_validate_hmac(const struct oauth2_settings *set, const char *azp,
+ const char *alg, const char *key_id,
+ const char *const *blobs, const char **error_r)
+{
+ const struct hash_method *method;
+
+ if (strcmp(alg, "HS256") == 0)
+ method = hash_method_lookup("sha256");
+ else if (strcmp(alg, "HS384") == 0)
+ method = hash_method_lookup("sha384");
+ else if (strcmp(alg, "HS512") == 0)
+ method = hash_method_lookup("sha512");
+ else {
+ *error_r = t_strdup_printf("unsupported algorithm '%s'", alg);
+ return -1;
+ }
+
+ const buffer_t *key;
+ if (oauth2_lookup_hmac_key(set, azp, alg, key_id, &key, error_r) < 0)
+ return -1;
+
+ struct hmac_context ctx;
+ hmac_init(&ctx, key->data, key->used, method);
+ hmac_update(&ctx, blobs[0], strlen(blobs[0]));
+ hmac_update(&ctx, ".", 1);
+ hmac_update(&ctx, blobs[1], strlen(blobs[1]));
+ unsigned char digest[method->digest_size];
+
+ hmac_final(&ctx, digest);
+
+ buffer_t *their_digest =
+ t_base64url_decode_str(BASE64_DECODE_FLAG_NO_PADDING, blobs[2]);
+ if (method->digest_size != their_digest->used ||
+ !mem_equals_timing_safe(digest, their_digest->data,
+ method->digest_size)) {
+ *error_r = "Incorrect JWT signature";
+ return -1;
+ }
+ return 0;
+}
+
+static int
+oauth2_lookup_pubkey(const struct oauth2_settings *set, const char *azp,
+ const char *alg, const char *key_id,
+ struct dcrypt_public_key **key_r, const char **error_r)
+{
+ const char *key_str;
+ const char *cache_key_id, *lookup_key;
+ int ret;
+
+ cache_key_id = t_strconcat(azp, ".", alg, ".", key_id, NULL);
+ if (oauth2_validation_key_cache_lookup_pubkey(
+ set->key_cache, cache_key_id, key_r) == 0)
+ return 0;
+
+ /* do a synchronous dict lookup */
+ lookup_key = t_strconcat(DICT_PATH_SHARED, azp, "/", alg, "/", key_id,
+ NULL);
+ struct dict_op_settings dict_set = {
+ .username = NULL,
+ };
+ if ((ret = dict_lookup(set->key_dict, &dict_set, pool_datastack_create(),
+ lookup_key, &key_str, error_r)) < 0) {
+ return -1;
+ } else if (ret == 0) {
+ *error_r = t_strdup_printf("%s key '%s' not found",
+ alg, key_id);
+ return -1;
+ }
+
+ /* try to load key */
+ struct dcrypt_public_key *pubkey;
+ const char *error;
+ if (!dcrypt_key_load_public(&pubkey, key_str, &error)) {
+ *error_r = t_strdup_printf("Cannot load key: %s", error);
+ return -1;
+ }
+
+ /* cache key */
+ oauth2_validation_key_cache_insert_pubkey(set->key_cache, cache_key_id,
+ pubkey);
+ *key_r = pubkey;
+ return 0;
+}
+
+static int
+oauth2_validate_rsa_ecdsa(const struct oauth2_settings *set,
+ const char *azp, const char *alg, const char *key_id,
+ const char *const *blobs, const char **error_r)
+{
+ const char *method;
+ enum dcrypt_padding padding;
+ enum dcrypt_signature_format sig_format;
+
+ if (!dcrypt_is_initialized()) {
+ *error_r = "No crypto library loaded";
+ return -1;
+ }
+
+ if (str_begins(alg, "RS")) {
+ padding = DCRYPT_PADDING_RSA_PKCS1;
+ sig_format = DCRYPT_SIGNATURE_FORMAT_DSS;
+ } else if (str_begins(alg, "PS")) {
+ padding = DCRYPT_PADDING_RSA_PKCS1_PSS;
+ sig_format = DCRYPT_SIGNATURE_FORMAT_DSS;
+ } else if (str_begins(alg, "ES")) {
+ padding = DCRYPT_PADDING_DEFAULT;
+ sig_format = DCRYPT_SIGNATURE_FORMAT_X962;
+ } else {
+ /* this should be checked by caller */
+ i_unreached();
+ }
+
+ if (strcmp(alg+2, "256") == 0) {
+ method = "sha256";
+ } else if (strcmp(alg+2, "384") == 0) {
+ method = "sha384";
+ } else if (strcmp(alg+2, "512") == 0) {
+ method = "sha512";
+ } else {
+ *error_r = t_strdup_printf("Unsupported algorithm '%s'", alg);
+ return -1;
+ }
+
+ buffer_t *signature =
+ t_base64url_decode_str(BASE64_DECODE_FLAG_NO_PADDING, blobs[2]);
+
+ struct dcrypt_public_key *pubkey;
+ if (oauth2_lookup_pubkey(set, azp, alg, key_id, &pubkey, error_r) < 0)
+ return -1;
+
+ /* data to verify */
+ const char *data = t_strconcat(blobs[0], ".", blobs[1], NULL);
+
+ /* verify signature */
+ bool valid;
+ if (!dcrypt_verify(pubkey, method, sig_format, data, strlen(data),
+ signature->data, signature->used, &valid, padding,
+ error_r)) {
+ valid = FALSE;
+ } else if (!valid) {
+ *error_r = "Bad signature";
+ }
+
+ return valid ? 0 : -1;
+}
+
+static int
+oauth2_validate_signature(const struct oauth2_settings *set, const char *azp,
+ const char *alg, const char *key_id,
+ const char *const *blobs, const char **error_r)
+{
+ if (str_begins(alg, "HS")) {
+ return oauth2_validate_hmac(set, azp, alg, key_id, blobs,
+ error_r);
+ } else if (str_begins(alg, "RS") || str_begins(alg, "PS") ||
+ str_begins(alg, "ES")) {
+ return oauth2_validate_rsa_ecdsa(set, azp, alg, key_id, blobs,
+ error_r);
+ }
+
+ *error_r = t_strdup_printf("Unsupported algorithm '%s'", alg);
+ return -1;
+}
+
+static void
+oauth2_jwt_copy_fields(ARRAY_TYPE(oauth2_field) *fields, struct json_tree *tree)
+{
+ pool_t pool = array_get_pool(fields);
+ ARRAY(const struct json_tree_node*) nodes;
+ const struct json_tree_node *root = json_tree_root(tree);
+
+ t_array_init(&nodes, 1);
+ array_push_back(&nodes, &root);
+
+ while (array_count(&nodes) > 0) {
+ const struct json_tree_node *const *pnode = array_front(&nodes);
+ const struct json_tree_node *node = *pnode;
+ array_pop_front(&nodes);
+ while (node != NULL) {
+ if (node->value_type == JSON_TYPE_OBJECT) {
+ root = node->value.child;
+ array_push_back(&nodes, &root);
+ } else if (node->key != NULL) {
+ struct oauth2_field *field =
+ array_append_space(fields);
+ field->name = p_strdup(pool, node->key);
+ field->value = p_strdup(
+ pool, json_tree_get_value_str(node));
+ }
+ node = node->next;
+ }
+ }
+}
+
+static int
+oauth2_jwt_header_process(struct json_tree *tree, const char **alg_r,
+ const char **kid_r, const char **error_r)
+{
+ const char *typ = get_field(tree, "typ");
+ const char *alg = get_field(tree, "alg");
+ const char *kid = get_field(tree, "kid");
+
+ if (null_strcmp(typ, "JWT") != 0) {
+ *error_r = "Cannot find 'typ' field";
+ return -1;
+ }
+
+ if (alg == NULL) {
+ *error_r = "Cannot find 'alg' field";
+ return -1;
+ }
+
+ /* These are lost when tree is deinitialized.
+ Make sure algorithm is uppercased. */
+ *alg_r = t_str_ucase(alg);
+ *kid_r = t_strdup(kid);
+ return 0;
+}
+
+static int
+oauth2_jwt_body_process(const struct oauth2_settings *set, const char *alg,
+ const char *kid, ARRAY_TYPE(oauth2_field) *fields,
+ struct json_tree *tree, const char *const *blobs,
+ const char **error_r)
+{
+ const char *sub = get_field(tree, "sub");
+
+ int ret;
+ int64_t t0 = time(NULL);
+ /* default IAT and NBF to now */
+ int64_t iat, nbf, exp;
+ int tz_offset ATTR_UNUSED;
+
+ if (sub == NULL) {
+ *error_r = "Missing 'sub' field";
+ return -1;
+ }
+
+ if ((ret = get_time_field(tree, "exp", &exp)) < 1) {
+ *error_r = t_strdup_printf("%s 'exp' field",
+ ret == 0 ? "Missing" : "Malformed");
+ return -1;
+ }
+
+ if ((ret = get_time_field(tree, "nbf", &nbf)) < 0) {
+ *error_r = "Malformed 'nbf' field";
+ return -1;
+ } else if (ret == 0 || nbf == 0)
+ nbf = t0;
+
+ if ((ret = get_time_field(tree, "iat", &iat)) < 0) {
+ *error_r = "Malformed 'iat' field";
+ return -1;
+ } else if (ret == 0 || iat == 0)
+ iat = t0;
+
+ if (nbf > t0) {
+ *error_r = "Token is not valid yet";
+ return -1;
+ }
+ if (iat > t0) {
+ *error_r = "Token is issued in future";
+ return -1;
+ }
+ if (exp < t0) {
+ *error_r = "Token has expired";
+ return -1;
+ }
+
+ /* ensure token dates are not conflicting */
+ if (exp < iat ||
+ exp < nbf) {
+ *error_r = "Token time values are conflicting";
+ return -1;
+ }
+
+ const char *iss = get_field(tree, "iss");
+ if (set->issuers != NULL && *set->issuers != NULL) {
+ if (iss == NULL) {
+ *error_r = "Token is missing 'iss' field";
+ return -1;
+ }
+ if (!str_array_find(set->issuers, iss)) {
+ *error_r = t_strdup_printf("Issuer '%s' is not allowed",
+ str_sanitize_utf8(iss, 128));
+ return -1;
+ }
+ }
+
+ /* see if there is azp */
+ const char *azp = get_field(tree, "azp");
+ if (azp == NULL)
+ azp = "default";
+ else
+ azp = escape_identifier(azp);
+
+ if (oauth2_validate_signature(set, azp, alg, kid, blobs, error_r) < 0)
+ return -1;
+
+ oauth2_jwt_copy_fields(fields, tree);
+ return 0;
+}
+
+int oauth2_try_parse_jwt(const struct oauth2_settings *set,
+ const char *token, ARRAY_TYPE(oauth2_field) *fields,
+ bool *is_jwt_r, const char **error_r)
+{
+ const char *const *blobs = t_strsplit(token, ".");
+ int ret;
+
+ i_assert(set->key_dict != NULL);
+
+ /* we don't know if it's JWT token yet */
+ *is_jwt_r = FALSE;
+
+ if (str_array_length(blobs) != 3) {
+ *error_r = "Not a JWT token";
+ return -1;
+ }
+
+ /* attempt to decode header */
+ buffer_t *header =
+ t_base64url_decode_str(BASE64_DECODE_FLAG_NO_PADDING, blobs[0]);
+
+ if (header->used == 0) {
+ *error_r = "Not a JWT token";
+ return -1;
+ }
+
+ struct json_tree *header_tree;
+ if (oauth2_json_tree_build(header, &header_tree, error_r) < 0)
+ return -1;
+
+ const char *alg, *kid;
+ ret = oauth2_jwt_header_process(header_tree, &alg, &kid, error_r);
+ json_tree_deinit(&header_tree);
+ if (ret < 0)
+ return -1;
+
+ /* it is now assumed to be a JWT token */
+ *is_jwt_r = TRUE;
+
+ if (kid == NULL)
+ kid = "default";
+ else if (*kid == '\0') {
+ *error_r = "'kid' field is empty";
+ return -1;
+ } else {
+ kid = escape_identifier(kid);
+ }
+
+ /* parse body */
+ struct json_tree *body_tree;
+ buffer_t *body =
+ t_base64url_decode_str(BASE64_DECODE_FLAG_NO_PADDING, blobs[1]);
+ if (oauth2_json_tree_build(body, &body_tree, error_r) == -1)
+ return -1;
+ ret = oauth2_jwt_body_process(set, alg, kid, fields, body_tree, blobs,
+ error_r);
+ json_tree_deinit(&body_tree);
+
+ return ret;
+}