summaryrefslogtreecommitdiffstats
path: root/lualib/rspamadm/vault.lua
diff options
context:
space:
mode:
Diffstat (limited to 'lualib/rspamadm/vault.lua')
-rw-r--r--lualib/rspamadm/vault.lua579
1 files changed, 579 insertions, 0 deletions
diff --git a/lualib/rspamadm/vault.lua b/lualib/rspamadm/vault.lua
new file mode 100644
index 0000000..840e504
--- /dev/null
+++ b/lualib/rspamadm/vault.lua
@@ -0,0 +1,579 @@
+--[[
+Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com>
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+]]--
+
+
+local rspamd_logger = require "rspamd_logger"
+local ansicolors = require "ansicolors"
+local ucl = require "ucl"
+local argparse = require "argparse"
+local fun = require "fun"
+local rspamd_http = require "rspamd_http"
+local cr = require "rspamd_cryptobox"
+
+local parser = argparse()
+ :name "rspamadm vault"
+ :description "Perform Hashicorp Vault management"
+ :help_description_margin(32)
+ :command_target("command")
+ :require_command(true)
+
+parser:flag "-s --silent"
+ :description "Do not output extra information"
+parser:option "-a --addr"
+ :description "Vault address (if not defined in VAULT_ADDR env)"
+parser:option "-t --token"
+ :description "Vault token (not recommended, better define VAULT_TOKEN env)"
+parser:option "-p --path"
+ :description "Path to work with in the vault"
+ :default "dkim"
+parser:option "-o --output"
+ :description "Output format ('ucl', 'json', 'json-compact', 'yaml')"
+ :argname("<type>")
+ :convert {
+ ucl = "ucl",
+ json = "json",
+ ['json-compact'] = "json-compact",
+ yaml = "yaml",
+}
+ :default "ucl"
+
+parser:command "list ls l"
+ :description "List elements in the vault"
+
+local show = parser:command "show get"
+ :description "Extract element from the vault"
+show:argument "domain"
+ :description "Domain to create key for"
+ :args "+"
+
+local delete = parser:command "delete del rm remove"
+ :description "Delete element from the vault"
+delete:argument "domain"
+ :description "Domain to create delete key(s) for"
+ :args "+"
+
+local newkey = parser:command "newkey new create"
+ :description "Add new key to the vault"
+newkey:argument "domain"
+ :description "Domain to create key for"
+ :args "+"
+newkey:option "-s --selector"
+ :description "Selector to use"
+ :count "?"
+newkey:option "-A --algorithm"
+ :argname("<type>")
+ :convert {
+ rsa = "rsa",
+ ed25519 = "ed25519",
+ eddsa = "ed25519",
+}
+ :default "rsa"
+newkey:option "-b --bits"
+ :argname("<nbits>")
+ :convert(tonumber)
+ :default "1024"
+newkey:option "-x --expire"
+ :argname("<days>")
+ :convert(tonumber)
+newkey:flag "-r --rewrite"
+
+local roll = parser:command "roll rollover"
+ :description "Perform keys rollover"
+roll:argument "domain"
+ :description "Domain to roll key(s) for"
+ :args "+"
+roll:option "-T --ttl"
+ :description "Validity period for old keys (days)"
+ :convert(tonumber)
+ :default "1"
+roll:flag "-r --remove-expired"
+ :description "Remove expired keys"
+roll:option "-x --expire"
+ :argname("<days>")
+ :convert(tonumber)
+
+local function printf(fmt, ...)
+ if fmt then
+ io.write(rspamd_logger.slog(fmt, ...))
+ end
+ io.write('\n')
+end
+
+local function maybe_printf(opts, fmt, ...)
+ if not opts.silent then
+ printf(fmt, ...)
+ end
+end
+
+local function highlight(str, color)
+ return ansicolors[color or 'white'] .. str .. ansicolors.reset
+end
+
+local function vault_url(opts, path)
+ if path then
+ return string.format('%s/v1/%s/%s', opts.addr, opts.path, path)
+ end
+
+ return string.format('%s/v1/%s', opts.addr, opts.path)
+end
+
+local function is_http_error(err, data)
+ return err or (math.floor(data.code / 100) ~= 2)
+end
+
+local function parse_vault_reply(data)
+ local p = ucl.parser()
+ local res, parser_err = p:parse_string(data)
+
+ if not res then
+ return nil, parser_err
+ else
+ return p:get_object(), nil
+ end
+end
+
+local function maybe_print_vault_data(opts, data, func)
+ if data then
+ local res, parser_err = parse_vault_reply(data)
+
+ if not res then
+ printf('vault reply for cannot be parsed: %s', parser_err)
+ else
+ if func then
+ printf(ucl.to_format(func(res), opts.output))
+ else
+ printf(ucl.to_format(res, opts.output))
+ end
+ end
+ else
+ printf('no data received')
+ end
+end
+
+local function print_dkim_txt_record(b64, selector, alg)
+ local labels = {}
+ local prefix = string.format("v=DKIM1; k=%s; p=", alg)
+ b64 = prefix .. b64
+ if #b64 < 255 then
+ labels = { '"' .. b64 .. '"' }
+ else
+ for sl = 1, #b64, 256 do
+ table.insert(labels, '"' .. b64:sub(sl, sl + 255) .. '"')
+ end
+ end
+
+ printf("%s._domainkey IN TXT ( %s )", selector,
+ table.concat(labels, "\n\t"))
+end
+
+local function show_handler(opts, domain)
+ local uri = vault_url(opts, domain)
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, err)
+ os.exit(1)
+ else
+ maybe_print_vault_data(opts, data.content, function(obj)
+ return obj.data.selectors
+ end)
+ end
+end
+
+local function delete_handler(opts, domain)
+ local uri = vault_url(opts, domain)
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'delete',
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, err)
+ os.exit(1)
+ else
+ printf('deleted key(s) for %s', domain)
+ end
+end
+
+local function list_handler(opts)
+ local uri = vault_url(opts)
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri .. '?list=true',
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, err)
+ os.exit(1)
+ else
+ maybe_print_vault_data(opts, data.content, function(obj)
+ return obj.data.keys
+ end)
+ end
+end
+
+-- Returns pair privkey+pubkey
+local function genkey(opts)
+ return cr.gen_dkim_keypair(opts.algorithm, opts.bits)
+end
+
+local function create_and_push_key(opts, domain, existing)
+ local uri = vault_url(opts, domain)
+ local sk, pk = genkey(opts)
+
+ local res = {
+ selectors = {
+ [1] = {
+ selector = opts.selector,
+ domain = domain,
+ key = tostring(sk),
+ pubkey = tostring(pk),
+ alg = opts.algorithm,
+ bits = opts.bits or 0,
+ valid_start = os.time(),
+ }
+ }
+ }
+
+ for _, sel in ipairs(existing) do
+ res.selectors[#res.selectors + 1] = sel
+ end
+
+ if opts.expire then
+ res.selectors[1].valid_end = os.time() + opts.expire * 3600 * 24
+ end
+
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'put',
+ headers = {
+ ['Content-Type'] = 'application/json',
+ ['X-Vault-Token'] = opts.token
+ },
+ body = {
+ ucl.to_format(res, 'json-compact')
+ },
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot get request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, data.content)
+ os.exit(1)
+ else
+ maybe_printf(opts, 'stored key for: %s, selector: %s', domain, opts.selector)
+ maybe_printf(opts, 'please place the corresponding public key as following:')
+
+ if opts.silent then
+ printf('%s', pk)
+ else
+ print_dkim_txt_record(tostring(pk), opts.selector, opts.algorithm)
+ end
+ end
+end
+
+local function newkey_handler(opts, domain)
+ local uri = vault_url(opts, domain)
+
+ if not opts.selector then
+ opts.selector = string.format('%s-%s', opts.algorithm,
+ os.date("!%Y%m%d"))
+ end
+
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'get',
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) or not data.content then
+ create_and_push_key(opts, domain, {})
+ else
+ -- Key exists
+ local rep = parse_vault_reply(data.content)
+
+ if not rep or not rep.data then
+ printf('cannot parse reply for %s: %s', uri, data.content)
+ os.exit(1)
+ end
+
+ local elts = rep.data.selectors
+
+ if not elts then
+ create_and_push_key(opts, domain, {})
+ os.exit(0)
+ end
+
+ for _, sel in ipairs(elts) do
+ if sel.alg == opts.algorithm then
+ printf('key with the specific algorithm %s is already presented at %s selector for %s domain',
+ opts.algorithm, sel.selector, domain)
+ os.exit(1)
+ else
+ create_and_push_key(opts, domain, elts)
+ end
+ end
+ end
+end
+
+local function roll_handler(opts, domain)
+ local uri = vault_url(opts, domain)
+ local res = {
+ selectors = {}
+ }
+
+ local err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'get',
+ headers = {
+ ['X-Vault-Token'] = opts.token
+ }
+ }
+
+ if is_http_error(err, data) or not data.content then
+ printf("No keys to roll for domain %s", domain)
+ os.exit(1)
+ else
+ local rep = parse_vault_reply(data.content)
+
+ if not rep or not rep.data then
+ printf('cannot parse reply for %s: %s', uri, data.content)
+ os.exit(1)
+ end
+
+ local elts = rep.data.selectors
+
+ if not elts then
+ printf("No keys to roll for domain %s", domain)
+ os.exit(1)
+ end
+
+ local nkeys = {} -- indexed by algorithm
+
+ local function insert_key(sel, add_expire)
+ if not nkeys[sel.alg] then
+ nkeys[sel.alg] = {}
+ end
+
+ if add_expire then
+ sel.valid_end = os.time() + opts.ttl * 3600 * 24
+ end
+
+ table.insert(nkeys[sel.alg], sel)
+ end
+
+ for _, sel in ipairs(elts) do
+ if sel.valid_end and sel.valid_end < os.time() then
+ if not opts.remove_expired then
+ insert_key(sel, false)
+ else
+ maybe_printf(opts, 'removed expired key for %s (selector %s, expire "%s"',
+ domain, sel.selector, os.date('%c', sel.valid_end))
+ end
+ else
+ insert_key(sel, true)
+ end
+ end
+
+ -- Now we need to ensure that all but one selectors have either expired or just a single key
+ for alg, keys in pairs(nkeys) do
+ table.sort(keys, function(k1, k2)
+ if k1.valid_end and k2.valid_end then
+ return k1.valid_end > k2.valid_end
+ elseif k1.valid_end then
+ return true
+ elseif k2.valid_end then
+ return false
+ end
+ return false
+ end)
+ -- Exclude the key with the highest expiration date and examine the rest
+ if not (#keys == 1 or fun.all(function(k)
+ return k.valid_end and k.valid_end < os.time()
+ end, fun.tail(keys))) then
+ printf('bad keys list for %s and %s algorithm', domain, alg)
+ fun.each(function(k)
+ if not k.valid_end then
+ printf('selector %s, algorithm %s has a key with no expire',
+ k.selector, k.alg)
+ elseif k.valid_end >= os.time() then
+ printf('selector %s, algorithm %s has a key that not yet expired: %s',
+ k.selector, k.alg, os.date('%c', k.valid_end))
+ end
+ end, fun.tail(keys))
+ os.exit(1)
+ end
+ -- Do not create new keys, if we only want to remove expired keys
+ if not opts.remove_expired then
+ -- OK to process
+ -- Insert keys for each algorithm in pairs <old_key(s)>, <new_key>
+ local sk, pk = genkey({ algorithm = alg, bits = keys[1].bits })
+ local selector = string.format('%s-%s', alg,
+ os.date("!%Y%m%d"))
+
+ if selector == keys[1].selector then
+ selector = selector .. '-1'
+ end
+ local nelt = {
+ selector = selector,
+ domain = domain,
+ key = tostring(sk),
+ pubkey = tostring(pk),
+ alg = alg,
+ bits = keys[1].bits,
+ valid_start = os.time(),
+ }
+
+ if opts.expire then
+ nelt.valid_end = os.time() + opts.expire * 3600 * 24
+ end
+
+ table.insert(res.selectors, nelt)
+ end
+ for _, k in ipairs(keys) do
+ table.insert(res.selectors, k)
+ end
+ end
+ end
+
+ -- We can now store res in the vault
+ err, data = rspamd_http.request {
+ config = rspamd_config,
+ ev_base = rspamadm_ev_base,
+ session = rspamadm_session,
+ resolver = rspamadm_dns_resolver,
+ url = uri,
+ method = 'put',
+ headers = {
+ ['Content-Type'] = 'application/json',
+ ['X-Vault-Token'] = opts.token
+ },
+ body = {
+ ucl.to_format(res, 'json-compact')
+ },
+ }
+
+ if is_http_error(err, data) then
+ printf('cannot put request to the vault (%s), HTTP error code %s', uri, data.code)
+ maybe_print_vault_data(opts, data.content)
+ os.exit(1)
+ else
+ for _, key in ipairs(res.selectors) do
+ if not key.valid_end or key.valid_end > os.time() + opts.ttl * 3600 * 24 then
+ maybe_printf(opts, 'rolled key for: %s, new selector: %s', domain, key.selector)
+ maybe_printf(opts, 'please place the corresponding public key as following:')
+
+ if opts.silent then
+ printf('%s', key.pubkey)
+ else
+ print_dkim_txt_record(key.pubkey, key.selector, key.alg)
+ end
+
+ end
+ end
+
+ maybe_printf(opts, 'your old keys will be valid until %s',
+ os.date('%c', os.time() + opts.ttl * 3600 * 24))
+ end
+end
+
+local function handler(args)
+ local opts = parser:parse(args)
+
+ if not opts.addr then
+ opts.addr = os.getenv('VAULT_ADDR')
+ end
+
+ if not opts.token then
+ opts.token = os.getenv('VAULT_TOKEN')
+ else
+ maybe_printf(opts, 'defining token via command line is insecure, define it via environment variable %s',
+ highlight('VAULT_TOKEN', 'red'))
+ end
+
+ if not opts.token or not opts.addr then
+ printf('no token or/and vault addr has been specified, exiting')
+ os.exit(1)
+ end
+
+ local command = opts.command
+
+ if command == 'list' then
+ list_handler(opts)
+ elseif command == 'show' then
+ fun.each(function(d)
+ show_handler(opts, d)
+ end, opts.domain)
+ elseif command == 'newkey' then
+ fun.each(function(d)
+ newkey_handler(opts, d)
+ end, opts.domain)
+ elseif command == 'roll' then
+ fun.each(function(d)
+ roll_handler(opts, d)
+ end, opts.domain)
+ elseif command == 'delete' then
+ fun.each(function(d)
+ delete_handler(opts, d)
+ end, opts.domain)
+ else
+ parser:error(string.format('command %s is not implemented', command))
+ end
+end
+
+return {
+ handler = handler,
+ description = parser._description,
+ name = 'vault'
+}