diff options
Diffstat (limited to 'lualib/lua_redis.lua')
-rw-r--r-- | lualib/lua_redis.lua | 1817 |
1 files changed, 1817 insertions, 0 deletions
diff --git a/lualib/lua_redis.lua b/lualib/lua_redis.lua new file mode 100644 index 0000000..818d955 --- /dev/null +++ b/lualib/lua_redis.lua @@ -0,0 +1,1817 @@ +--[[ +Copyright (c) 2022, Vsevolod Stakhov <vsevolod@rspamd.com> + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +]]-- + +local logger = require "rspamd_logger" +local lutil = require "lua_util" +local rspamd_util = require "rspamd_util" +local ts = require("tableshape").types + +local exports = {} + +local E = {} +local N = "lua_redis" + +local common_schema = { + timeout = (ts.number + ts.string / lutil.parse_time_interval):is_optional():describe("Connection timeout"), + db = ts.string:is_optional():describe("Database number"), + database = ts.string:is_optional():describe("Database number"), + dbname = ts.string:is_optional():describe("Database number"), + prefix = ts.string:is_optional():describe("Key prefix"), + username = ts.string:is_optional():describe("Username"), + password = ts.string:is_optional():describe("Password"), + expand_keys = ts.boolean:is_optional():describe("Expand keys"), + sentinels = (ts.string + ts.array_of(ts.string)):is_optional():describe("Sentinel servers"), + sentinel_watch_time = (ts.number + ts.string / lutil.parse_time_interval):is_optional():describe("Sentinel watch time"), + sentinel_masters_pattern = ts.string:is_optional():describe("Sentinel masters pattern"), + sentinel_master_maxerrors = (ts.number + ts.string / tonumber):is_optional():describe("Sentinel master max errors"), + sentinel_username = ts.string:is_optional():describe("Sentinel username"), + sentinel_password = ts.string:is_optional():describe("Sentinel password"), +} + +local read_schema = lutil.table_merge({ + read_servers = ts.string + ts.array_of(ts.string), +}, common_schema) + +local write_schema = lutil.table_merge({ + write_servers = ts.string + ts.array_of(ts.string), +}, common_schema) + +local rw_schema = lutil.table_merge({ + read_servers = ts.string + ts.array_of(ts.string), + write_servers = ts.string + ts.array_of(ts.string), +}, common_schema) + +local servers_schema = lutil.table_merge({ + servers = ts.string + ts.array_of(ts.string), +}, common_schema) + +local server_schema = lutil.table_merge({ + server = ts.string + ts.array_of(ts.string), +}, common_schema) + +local enrich_schema = function(external) + return ts.one_of { + ts.shape(external), -- no specific redis parameters + ts.shape(lutil.table_merge(read_schema, external)), -- read_servers specified + ts.shape(lutil.table_merge(write_schema, external)), -- write_servers specified + ts.shape(lutil.table_merge(rw_schema, external)), -- both read and write servers defined + ts.shape(lutil.table_merge(servers_schema, external)), -- just servers for both ops + ts.shape(lutil.table_merge(server_schema, external)), -- legacy `server` attribute + } +end + +exports.enrich_schema = enrich_schema + +local function redis_query_sentinel(ev_base, params, initialised) + local function flatten_redis_table(tbl) + local res = {} + for i = 1, #tbl, 2 do + res[tbl[i]] = tbl[i + 1] + end + + return res + end + -- Coroutines syntax + local rspamd_redis = require "rspamd_redis" + local sentinels = params.sentinels + local addr = sentinels:get_upstream_round_robin() + + local host = addr:get_addr() + local masters = {} + local process_masters -- Function that is called to process masters data + + local function masters_cb(err, result) + if not err and result and type(result) == 'table' then + + local pending_subrequests = 0 + + for _, m in ipairs(result) do + local master = flatten_redis_table(m) + + -- Wrap IPv6-addresses in brackets + if (master.ip:match(":")) then + master.ip = "[" .. master.ip .. "]" + end + + if params.sentinel_masters_pattern then + if master.name:match(params.sentinel_masters_pattern) then + lutil.debugm(N, 'found master %s with ip %s and port %s', + master.name, master.ip, master.port) + masters[master.name] = master + else + lutil.debugm(N, 'skip master %s with ip %s and port %s, pattern %s', + master.name, master.ip, master.port, params.sentinel_masters_pattern) + end + else + lutil.debugm(N, 'found master %s with ip %s and port %s', + master.name, master.ip, master.port) + masters[master.name] = master + end + end + + -- For each master we need to get a list of slaves + for k, v in pairs(masters) do + v.slaves = {} + local function slaves_cb(slave_err, slave_result) + if not slave_err and type(slave_result) == 'table' then + for _, s in ipairs(slave_result) do + local slave = flatten_redis_table(s) + lutil.debugm(N, rspamd_config, + 'found slave for master %s with ip %s and port %s', + v.name, slave.ip, slave.port) + -- Wrap IPv6-addresses in brackets + if (slave.ip:match(":")) then + slave.ip = "[" .. slave.ip .. "]" + end + v.slaves[#v.slaves + 1] = slave + end + else + logger.errx('cannot get slaves data from Redis Sentinel %s: %s', + host:to_string(true), slave_err) + addr:fail() + end + + pending_subrequests = pending_subrequests - 1 + + if pending_subrequests == 0 then + -- Finalize masters and slaves + process_masters() + end + end + + local ret = rspamd_redis.make_request { + host = addr:get_addr(), + timeout = params.timeout, + username = params.sentinel_username, + password = params.sentinel_password, + config = rspamd_config, + ev_base = ev_base, + cmd = 'SENTINEL', + args = { 'slaves', k }, + no_pool = true, + callback = slaves_cb + } + + if not ret then + logger.errx(rspamd_config, 'cannot connect sentinel when query slaves at address: %s', + host:to_string(true)) + addr:fail() + else + pending_subrequests = pending_subrequests + 1 + end + end + + addr:ok() + else + logger.errx('cannot get masters data from Redis Sentinel %s: %s', + host:to_string(true), err) + addr:fail() + end + end + + local ret = rspamd_redis.make_request { + host = addr:get_addr(), + timeout = params.timeout, + config = rspamd_config, + ev_base = ev_base, + username = params.sentinel_username, + password = params.sentinel_password, + cmd = 'SENTINEL', + args = { 'masters' }, + no_pool = true, + callback = masters_cb, + } + + if not ret then + logger.errx(rspamd_config, 'cannot connect sentinel at address: %s', + host:to_string(true)) + addr:fail() + end + + process_masters = function() + -- We now form new strings for masters and slaves + local read_servers_tbl, write_servers_tbl = {}, {} + + for _, master in pairs(masters) do + write_servers_tbl[#write_servers_tbl + 1] = string.format( + '%s:%s', master.ip, master.port + ) + read_servers_tbl[#read_servers_tbl + 1] = string.format( + '%s:%s', master.ip, master.port + ) + + for _, slave in ipairs(master.slaves) do + if slave['master-link-status'] == 'ok' then + read_servers_tbl[#read_servers_tbl + 1] = string.format( + '%s:%s', slave.ip, slave.port + ) + end + end + end + + table.sort(read_servers_tbl) + table.sort(write_servers_tbl) + + local read_servers_str = table.concat(read_servers_tbl, ',') + local write_servers_str = table.concat(write_servers_tbl, ',') + + lutil.debugm(N, rspamd_config, + 'new servers list: %s read; %s write', + read_servers_str, + write_servers_str) + + if read_servers_str ~= params.read_servers_str then + local upstream_list = require "rspamd_upstream_list" + + local read_upstreams = upstream_list.create(rspamd_config, + read_servers_str, 6379) + + if read_upstreams then + logger.infox(rspamd_config, 'sentinel %s: replace read servers with new list: %s', + host:to_string(true), read_servers_str) + params.read_servers = read_upstreams + params.read_servers_str = read_servers_str + end + end + + if write_servers_str ~= params.write_servers_str then + local upstream_list = require "rspamd_upstream_list" + + local write_upstreams = upstream_list.create(rspamd_config, + write_servers_str, 6379) + + if write_upstreams then + logger.infox(rspamd_config, 'sentinel %s: replace write servers with new list: %s', + host:to_string(true), write_servers_str) + params.write_servers = write_upstreams + params.write_servers_str = write_servers_str + + local queried = false + + local function monitor_failures(up, _, count) + if count > params.sentinel_master_maxerrors and not queried then + logger.infox(rspamd_config, 'sentinel: master with address %s, caused %s failures, try to query sentinel', + host:to_string(true), count) + queried = true -- Avoid multiple checks caused by this monitor + redis_query_sentinel(ev_base, params, true) + end + end + + write_upstreams:add_watcher('failure', monitor_failures) + end + end + end + +end + +local function add_redis_sentinels(params) + local upstream_list = require "rspamd_upstream_list" + + local upstreams_sentinels = upstream_list.create(rspamd_config, + params.sentinels, 5000) + + if not upstreams_sentinels then + logger.errx(rspamd_config, 'cannot load redis sentinels string: %s', + params.sentinels) + + return + end + + params.sentinels = upstreams_sentinels + + if not params.sentinel_watch_time then + params.sentinel_watch_time = 60 -- Each minute + end + + if not params.sentinel_master_maxerrors then + params.sentinel_master_maxerrors = 2 -- Maximum number of errors before rechecking + end + + rspamd_config:add_on_load(function(_, ev_base, worker) + local initialised = false + if worker:is_scanner() or worker:get_type() == 'fuzzy' then + rspamd_config:add_periodic(ev_base, 0.0, function() + redis_query_sentinel(ev_base, params, initialised) + initialised = true + + return params.sentinel_watch_time + end, false) + end + end) +end + +local cached_results = {} + +local function calculate_redis_hash(params) + local cr = require "rspamd_cryptobox_hash" + + local h = cr.create() + + local function rec_hash(k, v) + if type(v) == 'string' then + h:update(k) + h:update(v) + elseif type(v) == 'number' then + h:update(k) + h:update(tostring(v)) + elseif type(v) == 'table' then + for kk, vv in pairs(v) do + rec_hash(kk, vv) + end + end + end + + rec_hash('top', params) + + return h:base32() +end + +local function process_redis_opts(options, redis_params) + local default_timeout = 1.0 + local default_expand_keys = false + + if not redis_params['timeout'] or redis_params['timeout'] == default_timeout then + if options['timeout'] then + redis_params['timeout'] = tonumber(options['timeout']) + else + redis_params['timeout'] = default_timeout + end + end + + if options['prefix'] and not redis_params['prefix'] then + redis_params['prefix'] = options['prefix'] + end + + if type(options['expand_keys']) == 'boolean' then + redis_params['expand_keys'] = options['expand_keys'] + else + redis_params['expand_keys'] = default_expand_keys + end + + if not redis_params['db'] then + if options['db'] then + redis_params['db'] = tostring(options['db']) + elseif options['dbname'] then + redis_params['db'] = tostring(options['dbname']) + elseif options['database'] then + redis_params['db'] = tostring(options['database']) + end + end + if options['username'] and not redis_params['username'] then + redis_params['username'] = options['username'] + end + if options['password'] and not redis_params['password'] then + redis_params['password'] = options['password'] + end + + if not redis_params.sentinels and options.sentinels then + redis_params.sentinels = options.sentinels + end + + if options['sentinel_masters_pattern'] and not redis_params['sentinel_masters_pattern'] then + redis_params['sentinel_masters_pattern'] = options['sentinel_masters_pattern'] + end + +end + +local function enrich_defaults(rspamd_config, module, redis_params) + if rspamd_config then + local opts = rspamd_config:get_all_opt('redis') + + if opts then + if module then + if opts[module] then + process_redis_opts(opts[module], redis_params) + end + end + + process_redis_opts(opts, redis_params) + end + end +end + +local function maybe_return_cached(redis_params) + local h = calculate_redis_hash(redis_params) + + if cached_results[h] then + lutil.debugm(N, 'reused redis server: %s', redis_params) + return cached_results[h] + end + + redis_params.hash = h + cached_results[h] = redis_params + + if not redis_params.read_only and redis_params.sentinels then + add_redis_sentinels(redis_params) + end + + lutil.debugm(N, 'loaded new redis server: %s', redis_params) + return redis_params +end + +--[[[ +-- @module lua_redis +-- This module contains helper functions for working with Redis +--]] +local function process_redis_options(options, rspamd_config, result) + local default_port = 6379 + local upstream_list = require "rspamd_upstream_list" + local read_only = true + + -- Try to get read servers: + local upstreams_read, upstreams_write + + if options['read_servers'] then + if rspamd_config then + upstreams_read = upstream_list.create(rspamd_config, + options['read_servers'], default_port) + else + upstreams_read = upstream_list.create(options['read_servers'], + default_port) + end + + result.read_servers_str = options['read_servers'] + elseif options['servers'] then + if rspamd_config then + upstreams_read = upstream_list.create(rspamd_config, + options['servers'], default_port) + else + upstreams_read = upstream_list.create(options['servers'], default_port) + end + + result.read_servers_str = options['servers'] + read_only = false + elseif options['server'] then + if rspamd_config then + upstreams_read = upstream_list.create(rspamd_config, + options['server'], default_port) + else + upstreams_read = upstream_list.create(options['server'], default_port) + end + + result.read_servers_str = options['server'] + read_only = false + end + + if upstreams_read then + if options['write_servers'] then + if rspamd_config then + upstreams_write = upstream_list.create(rspamd_config, + options['write_servers'], default_port) + else + upstreams_write = upstream_list.create(options['write_servers'], + default_port) + end + result.write_servers_str = options['write_servers'] + read_only = false + elseif not read_only then + upstreams_write = upstreams_read + result.write_servers_str = result.read_servers_str + end + end + + -- Store options + process_redis_opts(options, result) + + if read_only and not upstreams_write then + result.read_only = true + elseif upstreams_write then + result.read_only = false + end + + if upstreams_read then + result.read_servers = upstreams_read + + if upstreams_write then + result.write_servers = upstreams_write + end + + return true + end + + lutil.debugm(N, rspamd_config, + 'cannot load redis server from obj: %s, processed to %s', + options, result) + + return false +end + +--[[[ +@function try_load_redis_servers(options, rspamd_config, no_fallback) +Tries to load redis servers from the specified `options` object. +Returns `redis_params` table or nil in case of failure + +--]] +exports.try_load_redis_servers = function(options, rspamd_config, no_fallback, module_name) + local result = {} + + if process_redis_options(options, rspamd_config, result) then + if not no_fallback then + enrich_defaults(rspamd_config, module_name, result) + end + return maybe_return_cached(result) + end +end + +-- This function parses redis server definition using either +-- specific server string for this module or global +-- redis section +local function rspamd_parse_redis_server(module_name, module_opts, no_fallback) + local result = {} + + -- Try local options + local opts + lutil.debugm(N, rspamd_config, 'try load redis config for: %s', module_name) + if not module_opts then + opts = rspamd_config:get_all_opt(module_name) + else + opts = module_opts + end + + if opts then + local ret + + if opts.redis then + ret = process_redis_options(opts.redis, rspamd_config, result) + + if ret then + if not no_fallback then + enrich_defaults(rspamd_config, module_name, result) + end + return maybe_return_cached(result) + end + end + + ret = process_redis_options(opts, rspamd_config, result) + + if ret then + if not no_fallback then + enrich_defaults(rspamd_config, module_name, result) + end + return maybe_return_cached(result) + end + end + + if no_fallback then + logger.infox(rspamd_config, "cannot find Redis definitions for %s and fallback is disabled", + module_name) + + return nil + end + + -- Try global options + opts = rspamd_config:get_all_opt('redis') + + if opts then + local ret + + if opts[module_name] then + ret = process_redis_options(opts[module_name], rspamd_config, result) + + if ret then + return maybe_return_cached(result) + end + else + ret = process_redis_options(opts, rspamd_config, result) + + -- Exclude disabled + if opts['disabled_modules'] then + for _, v in ipairs(opts['disabled_modules']) do + if v == module_name then + logger.infox(rspamd_config, "NOT using default redis server for module %s: it is disabled", + module_name) + + return nil + end + end + end + + if ret then + logger.infox(rspamd_config, "use default Redis settings for %s", + module_name) + return maybe_return_cached(result) + end + end + end + + if result.read_servers then + return maybe_return_cached(result) + end + + return nil +end + +--[[[ +-- @function lua_redis.parse_redis_server(module_name, module_opts, no_fallback) +-- Extracts Redis server settings from configuration +-- @param {string} module_name name of module to get settings for +-- @param {table} module_opts settings for module or `nil` to fetch them from configuration +-- @param {boolean} no_fallback should be `true` if global settings must not be used +-- @return {table} redis server settings +-- @example +-- local rconfig = lua_redis.parse_redis_server('my_module') +-- -- rconfig contains upstream_list objects in ['write_servers'] and ['read_servers'] +-- -- ['timeout'] contains timeout in seconds +-- -- ['expand_keys'] if true tells that redis key expansion is enabled +--]] + +exports.rspamd_parse_redis_server = rspamd_parse_redis_server +exports.parse_redis_server = rspamd_parse_redis_server + +local process_cmd = { + bitop = function(args) + local idx_l = {} + for i = 2, #args do + table.insert(idx_l, i) + end + return idx_l + end, + blpop = function(args) + local idx_l = {} + for i = 1, #args - 1 do + table.insert(idx_l, i) + end + return idx_l + end, + eval = function(args) + local idx_l = {} + local numkeys = args[2] + if numkeys and tonumber(numkeys) >= 1 then + for i = 3, numkeys + 2 do + table.insert(idx_l, i) + end + end + return idx_l + end, + set = function(args) + return { 1 } + end, + mget = function(args) + local idx_l = {} + for i = 1, #args do + table.insert(idx_l, i) + end + return idx_l + end, + mset = function(args) + local idx_l = {} + for i = 1, #args, 2 do + table.insert(idx_l, i) + end + return idx_l + end, + sdiffstore = function(args) + local idx_l = {} + for i = 2, #args do + table.insert(idx_l, i) + end + return idx_l + end, + smove = function(args) + return { 1, 2 } + end, + script = function() + end +} +process_cmd.append = process_cmd.set +process_cmd.auth = process_cmd.script +process_cmd.bgrewriteaof = process_cmd.script +process_cmd.bgsave = process_cmd.script +process_cmd.bitcount = process_cmd.set +process_cmd.bitfield = process_cmd.set +process_cmd.bitpos = process_cmd.set +process_cmd.brpop = process_cmd.blpop +process_cmd.brpoplpush = process_cmd.blpop +process_cmd.client = process_cmd.script +process_cmd.cluster = process_cmd.script +process_cmd.command = process_cmd.script +process_cmd.config = process_cmd.script +process_cmd.dbsize = process_cmd.script +process_cmd.debug = process_cmd.script +process_cmd.decr = process_cmd.set +process_cmd.decrby = process_cmd.set +process_cmd.del = process_cmd.mget +process_cmd.discard = process_cmd.script +process_cmd.dump = process_cmd.set +process_cmd.echo = process_cmd.script +process_cmd.evalsha = process_cmd.eval +process_cmd.exec = process_cmd.script +process_cmd.exists = process_cmd.mget +process_cmd.expire = process_cmd.set +process_cmd.expireat = process_cmd.set +process_cmd.flushall = process_cmd.script +process_cmd.flushdb = process_cmd.script +process_cmd.geoadd = process_cmd.set +process_cmd.geohash = process_cmd.set +process_cmd.geopos = process_cmd.set +process_cmd.geodist = process_cmd.set +process_cmd.georadius = process_cmd.set +process_cmd.georadiusbymember = process_cmd.set +process_cmd.get = process_cmd.set +process_cmd.getbit = process_cmd.set +process_cmd.getrange = process_cmd.set +process_cmd.getset = process_cmd.set +process_cmd.hdel = process_cmd.set +process_cmd.hexists = process_cmd.set +process_cmd.hget = process_cmd.set +process_cmd.hgetall = process_cmd.set +process_cmd.hincrby = process_cmd.set +process_cmd.hincrbyfloat = process_cmd.set +process_cmd.hkeys = process_cmd.set +process_cmd.hlen = process_cmd.set +process_cmd.hmget = process_cmd.set +process_cmd.hmset = process_cmd.set +process_cmd.hscan = process_cmd.set +process_cmd.hset = process_cmd.set +process_cmd.hsetnx = process_cmd.set +process_cmd.hstrlen = process_cmd.set +process_cmd.hvals = process_cmd.set +process_cmd.incr = process_cmd.set +process_cmd.incrby = process_cmd.set +process_cmd.incrbyfloat = process_cmd.set +process_cmd.info = process_cmd.script +process_cmd.keys = process_cmd.script +process_cmd.lastsave = process_cmd.script +process_cmd.lindex = process_cmd.set +process_cmd.linsert = process_cmd.set +process_cmd.llen = process_cmd.set +process_cmd.lpop = process_cmd.set +process_cmd.lpush = process_cmd.set +process_cmd.lpushx = process_cmd.set +process_cmd.lrange = process_cmd.set +process_cmd.lrem = process_cmd.set +process_cmd.lset = process_cmd.set +process_cmd.ltrim = process_cmd.set +process_cmd.migrate = process_cmd.script +process_cmd.monitor = process_cmd.script +process_cmd.move = process_cmd.set +process_cmd.msetnx = process_cmd.mset +process_cmd.multi = process_cmd.script +process_cmd.object = process_cmd.script +process_cmd.persist = process_cmd.set +process_cmd.pexpire = process_cmd.set +process_cmd.pexpireat = process_cmd.set +process_cmd.pfadd = process_cmd.set +process_cmd.pfcount = process_cmd.set +process_cmd.pfmerge = process_cmd.mget +process_cmd.ping = process_cmd.script +process_cmd.psetex = process_cmd.set +process_cmd.psubscribe = process_cmd.script +process_cmd.pubsub = process_cmd.script +process_cmd.pttl = process_cmd.set +process_cmd.publish = process_cmd.script +process_cmd.punsubscribe = process_cmd.script +process_cmd.quit = process_cmd.script +process_cmd.randomkey = process_cmd.script +process_cmd.readonly = process_cmd.script +process_cmd.readwrite = process_cmd.script +process_cmd.rename = process_cmd.mget +process_cmd.renamenx = process_cmd.mget +process_cmd.restore = process_cmd.set +process_cmd.role = process_cmd.script +process_cmd.rpop = process_cmd.set +process_cmd.rpoplpush = process_cmd.mget +process_cmd.rpush = process_cmd.set +process_cmd.rpushx = process_cmd.set +process_cmd.sadd = process_cmd.set +process_cmd.save = process_cmd.script +process_cmd.scard = process_cmd.set +process_cmd.sdiff = process_cmd.mget +process_cmd.select = process_cmd.script +process_cmd.setbit = process_cmd.set +process_cmd.setex = process_cmd.set +process_cmd.setnx = process_cmd.set +process_cmd.sinterstore = process_cmd.sdiff +process_cmd.sismember = process_cmd.set +process_cmd.slaveof = process_cmd.script +process_cmd.slowlog = process_cmd.script +process_cmd.smembers = process_cmd.script +process_cmd.sort = process_cmd.set +process_cmd.spop = process_cmd.set +process_cmd.srandmember = process_cmd.set +process_cmd.srem = process_cmd.set +process_cmd.strlen = process_cmd.set +process_cmd.subscribe = process_cmd.script +process_cmd.sunion = process_cmd.mget +process_cmd.sunionstore = process_cmd.mget +process_cmd.swapdb = process_cmd.script +process_cmd.sync = process_cmd.script +process_cmd.time = process_cmd.script +process_cmd.touch = process_cmd.mget +process_cmd.ttl = process_cmd.set +process_cmd.type = process_cmd.set +process_cmd.unsubscribe = process_cmd.script +process_cmd.unlink = process_cmd.mget +process_cmd.unwatch = process_cmd.script +process_cmd.wait = process_cmd.script +process_cmd.watch = process_cmd.mget +process_cmd.zadd = process_cmd.set +process_cmd.zcard = process_cmd.set +process_cmd.zcount = process_cmd.set +process_cmd.zincrby = process_cmd.set +process_cmd.zinterstore = process_cmd.eval +process_cmd.zlexcount = process_cmd.set +process_cmd.zrange = process_cmd.set +process_cmd.zrangebylex = process_cmd.set +process_cmd.zrank = process_cmd.set +process_cmd.zrem = process_cmd.set +process_cmd.zrembylex = process_cmd.set +process_cmd.zrembyrank = process_cmd.set +process_cmd.zrembyscore = process_cmd.set +process_cmd.zrevrange = process_cmd.set +process_cmd.zrevrangebyscore = process_cmd.set +process_cmd.zrevrank = process_cmd.set +process_cmd.zscore = process_cmd.set +process_cmd.zunionstore = process_cmd.eval +process_cmd.scan = process_cmd.script +process_cmd.sscan = process_cmd.set +process_cmd.hscan = process_cmd.set +process_cmd.zscan = process_cmd.set + +local function get_key_indexes(cmd, args) + local idx_l = {} + cmd = string.lower(cmd) + if process_cmd[cmd] then + idx_l = process_cmd[cmd](args) + else + logger.warnx(rspamd_config, "Don't know how to extract keys for %s Redis command", cmd) + end + return idx_l +end + +local gen_meta = { + principal_recipient = function(task) + return task:get_principal_recipient() + end, + principal_recipient_domain = function(task) + local p = task:get_principal_recipient() + if not p then + return + end + return string.match(p, '.*@(.*)') + end, + ip = function(task) + local i = task:get_ip() + if i and i:is_valid() then + return i:to_string() + end + end, + from = function(task) + return ((task:get_from('smtp') or E)[1] or E)['addr'] + end, + from_domain = function(task) + return ((task:get_from('smtp') or E)[1] or E)['domain'] + end, + from_domain_or_helo_domain = function(task) + local d = ((task:get_from('smtp') or E)[1] or E)['domain'] + if d and #d > 0 then + return d + end + return task:get_helo() + end, + mime_from = function(task) + return ((task:get_from('mime') or E)[1] or E)['addr'] + end, + mime_from_domain = function(task) + return ((task:get_from('mime') or E)[1] or E)['domain'] + end, +} + +local function gen_get_esld(f) + return function(task) + local d = f(task) + if not d then + return + end + return rspamd_util.get_tld(d) + end +end + +gen_meta.smtp_from = gen_meta.from +gen_meta.smtp_from_domain = gen_meta.from_domain +gen_meta.smtp_from_domain_or_helo_domain = gen_meta.from_domain_or_helo_domain +gen_meta.esld_principal_recipient_domain = gen_get_esld(gen_meta.principal_recipient_domain) +gen_meta.esld_from_domain = gen_get_esld(gen_meta.from_domain) +gen_meta.esld_smtp_from_domain = gen_meta.esld_from_domain +gen_meta.esld_mime_from_domain = gen_get_esld(gen_meta.mime_from_domain) +gen_meta.esld_from_domain_or_helo_domain = gen_get_esld(gen_meta.from_domain_or_helo_domain) +gen_meta.esld_smtp_from_domain_or_helo_domain = gen_meta.esld_from_domain_or_helo_domain + +local function get_key_expansion_metadata(task) + + local md_mt = { + __index = function(self, k) + k = string.lower(k) + local v = rawget(self, k) + if v then + return v + end + if gen_meta[k] then + v = gen_meta[k](task) + rawset(self, k, v) + end + return v + end, + } + + local lazy_meta = {} + setmetatable(lazy_meta, md_mt) + return lazy_meta + +end + +-- Performs async call to redis hiding all complexity inside function +-- task - rspamd_task +-- redis_params - valid params returned by rspamd_parse_redis_server +-- key - key to select upstream or nil to select round-robin/master-slave +-- is_write - true if need to write to redis server +-- callback - function to be called upon request is completed +-- command - redis command +-- args - table of arguments +-- extra_opts - table of optional request arguments +local function rspamd_redis_make_request(task, redis_params, key, is_write, + callback, command, args, extra_opts) + local addr + local function rspamd_redis_make_request_cb(err, data) + if err then + addr:fail() + else + addr:ok() + end + if callback then + callback(err, data, addr) + end + end + if not task or not redis_params or not command then + return false, nil, nil + end + + local rspamd_redis = require "rspamd_redis" + + if key then + if is_write then + addr = redis_params['write_servers']:get_upstream_by_hash(key) + else + addr = redis_params['read_servers']:get_upstream_by_hash(key) + end + else + if is_write then + addr = redis_params['write_servers']:get_upstream_master_slave(key) + else + addr = redis_params['read_servers']:get_upstream_round_robin(key) + end + end + + if not addr then + logger.errx(task, 'cannot select server to make redis request') + end + + if redis_params['expand_keys'] then + local m = get_key_expansion_metadata(task) + local indexes = get_key_indexes(command, args) + for _, i in ipairs(indexes) do + args[i] = lutil.template(args[i], m) + end + end + + local ip_addr = addr:get_addr() + local options = { + task = task, + callback = rspamd_redis_make_request_cb, + host = ip_addr, + timeout = redis_params['timeout'], + cmd = command, + args = args + } + + if extra_opts then + for k, v in pairs(extra_opts) do + options[k] = v + end + end + + if redis_params['username'] then + options['username'] = redis_params['username'] + end + + if redis_params['password'] then + options['password'] = redis_params['password'] + end + + if redis_params['db'] then + options['dbname'] = redis_params['db'] + end + + lutil.debugm(N, task, 'perform request to redis server' .. + ' (host=%s, timeout=%s): cmd: %s', ip_addr, + options.timeout, options.cmd) + + local ret, conn = rspamd_redis.make_request(options) + + if not ret then + addr:fail() + logger.warnx(task, "cannot make redis request to: %s", tostring(ip_addr)) + end + + return ret, conn, addr +end + +--[[[ +-- @function lua_redis.redis_make_request(task, redis_params, key, is_write, callback, command, args) +-- Sends a request to Redis +-- @param {rspamd_task} task task object +-- @param {table} redis_params redis configuration in format returned by lua_redis.parse_redis_server() +-- @param {string} key key to use for sharding +-- @param {boolean} is_write should be `true` if we are performing a write operating +-- @param {function} callback callback function (first parameter is error if applicable, second is a 2D array (table)) +-- @param {string} command Redis command to run +-- @param {table} args Numerically indexed table containing arguments for command +--]] + +exports.rspamd_redis_make_request = rspamd_redis_make_request +exports.redis_make_request = rspamd_redis_make_request + +local function redis_make_request_taskless(ev_base, cfg, redis_params, key, + is_write, callback, command, args, extra_opts) + if not ev_base or not redis_params or not command then + return false, nil, nil + end + + local addr + local function rspamd_redis_make_request_cb(err, data) + if err then + addr:fail() + else + addr:ok() + end + if callback then + callback(err, data, addr) + end + end + + local rspamd_redis = require "rspamd_redis" + + if key then + if is_write then + addr = redis_params['write_servers']:get_upstream_by_hash(key) + else + addr = redis_params['read_servers']:get_upstream_by_hash(key) + end + else + if is_write then + addr = redis_params['write_servers']:get_upstream_master_slave(key) + else + addr = redis_params['read_servers']:get_upstream_round_robin(key) + end + end + + if not addr then + logger.errx(cfg, 'cannot select server to make redis request') + end + + local options = { + ev_base = ev_base, + config = cfg, + callback = rspamd_redis_make_request_cb, + host = addr:get_addr(), + timeout = redis_params['timeout'], + cmd = command, + args = args + } + if extra_opts then + for k, v in pairs(extra_opts) do + options[k] = v + end + end + + if redis_params['username'] then + options['username'] = redis_params['username'] + end + + if redis_params['password'] then + options['password'] = redis_params['password'] + end + + if redis_params['db'] then + options['dbname'] = redis_params['db'] + end + + lutil.debugm(N, cfg, 'perform taskless request to redis server' .. + ' (host=%s, timeout=%s): cmd: %s', options.host:tostring(true), + options.timeout, options.cmd) + local ret, conn = rspamd_redis.make_request(options) + if not ret then + logger.errx('cannot execute redis request') + addr:fail() + end + + return ret, conn, addr +end + +--[[[ +-- @function lua_redis.redis_make_request_taskless(ev_base, redis_params, key, is_write, callback, command, args) +-- Sends a request to Redis in context where `task` is not available for some specific use-cases +-- Identical to redis_make_request() except in that first parameter is an `event base` object +--]] + +exports.rspamd_redis_make_request_taskless = redis_make_request_taskless +exports.redis_make_request_taskless = redis_make_request_taskless + +local redis_scripts = { +} + +local function script_set_loaded(script) + if script.sha then + script.loaded = true + end + + local wait_table = {} + for _, s in ipairs(script.waitq) do + table.insert(wait_table, s) + end + + script.waitq = {} + + for _, s in ipairs(wait_table) do + s(script.loaded) + end +end + +local function prepare_redis_call(script) + local servers = {} + local options = {} + + if script.redis_params.read_servers then + servers = lutil.table_merge(servers, script.redis_params.read_servers:all_upstreams()) + end + if script.redis_params.write_servers then + servers = lutil.table_merge(servers, script.redis_params.write_servers:all_upstreams()) + end + + -- Call load script on each server, set loaded flag + script.in_flight = #servers + for _, s in ipairs(servers) do + local cur_opts = { + host = s:get_addr(), + timeout = script.redis_params['timeout'], + cmd = 'SCRIPT', + args = { 'LOAD', script.script }, + upstream = s + } + + if script.redis_params['username'] then + cur_opts['username'] = script.redis_params['username'] + end + + if script.redis_params['password'] then + cur_opts['password'] = script.redis_params['password'] + end + + if script.redis_params['db'] then + cur_opts['dbname'] = script.redis_params['db'] + end + + table.insert(options, cur_opts) + end + + return options +end + +local function load_script_task(script, task, is_write) + local rspamd_redis = require "rspamd_redis" + local opts = prepare_redis_call(script) + + for _, opt in ipairs(opts) do + opt.task = task + opt.is_write = is_write + opt.callback = function(err, data) + if err then + logger.errx(task, 'cannot upload script to %s: %s; registered from: %s:%s', + opt.upstream:get_addr():to_string(true), + err, script.caller.short_src, script.caller.currentline) + opt.upstream:fail() + script.fatal_error = err + else + opt.upstream:ok() + logger.infox(task, + "uploaded redis script to %s %s %s, sha: %s", + opt.upstream:get_addr():to_string(true), + script.filename and "from file" or "with id", script.filename or script.id, data) + script.sha = data -- We assume that sha is the same on all servers + end + script.in_flight = script.in_flight - 1 + + if script.in_flight == 0 then + script_set_loaded(script) + end + end + + local ret = rspamd_redis.make_request(opt) + + if not ret then + logger.errx('cannot execute redis request to load script on %s', + opt.upstream:get_addr()) + script.in_flight = script.in_flight - 1 + opt.upstream:fail() + end + + if script.in_flight == 0 then + script_set_loaded(script) + end + end +end + +local function load_script_taskless(script, cfg, ev_base, is_write) + local rspamd_redis = require "rspamd_redis" + local opts = prepare_redis_call(script) + + for _, opt in ipairs(opts) do + opt.config = cfg + opt.ev_base = ev_base + opt.is_write = is_write + opt.callback = function(err, data) + if err then + logger.errx(cfg, 'cannot upload script to %s: %s; registered from: %s:%s, filename: %s', + opt.upstream:get_addr():to_string(true), + err, script.caller.short_src, script.caller.currentline, script.filename) + opt.upstream:fail() + script.fatal_error = err + else + opt.upstream:ok() + logger.infox(cfg, + "uploaded redis script to %s %s %s, sha: %s", + opt.upstream:get_addr():to_string(true), + script.filename and "from file" or "with id", script.filename or script.id, + data) + script.sha = data -- We assume that sha is the same on all servers + script.fatal_error = nil + end + script.in_flight = script.in_flight - 1 + + if script.in_flight == 0 then + script_set_loaded(script) + end + end + local ret = rspamd_redis.make_request(opt) + + if not ret then + logger.errx('cannot execute redis request to load script on %s', + opt.upstream:get_addr()) + script.in_flight = script.in_flight - 1 + opt.upstream:fail() + end + + if script.in_flight == 0 then + script_set_loaded(script) + end + end +end + +local function load_redis_script(script, cfg, ev_base, _) + if script.redis_params then + load_script_taskless(script, cfg, ev_base) + end +end + +local function add_redis_script(script, redis_params, caller_level, maybe_filename) + if not caller_level then + caller_level = 2 + end + local caller = debug.getinfo(caller_level) or debug.getinfo(caller_level - 1) or E + + local new_script = { + caller = caller, + loaded = false, + redis_params = redis_params, + script = script, + waitq = {}, -- callbacks pending for script being loaded + id = #redis_scripts + 1, + filename = maybe_filename, + } + + -- Register on load function + rspamd_config:add_on_load(function(cfg, ev_base, worker) + local mult = 0.0 + rspamd_config:add_periodic(ev_base, 0.0, function() + if not new_script.sha then + load_redis_script(new_script, cfg, ev_base, worker) + mult = mult + 1 + return 1.0 * mult -- Check one more time in one second + end + + return false + end, false) + end) + + table.insert(redis_scripts, new_script) + + return #redis_scripts +end +exports.add_redis_script = add_redis_script + +-- Loads a Redis script from a file, strips comments, and passes the content to +-- `add_redis_script` function. +-- +-- @param filename The name of the file containing the Redis script. +-- @param redis_params The Redis parameters to use for this script. +-- @return The ID of the newly added Redis script. +-- +local function load_redis_script_from_file(filename, redis_params, dir) + local lua_util = require "lua_util" + local rspamd_logger = require "rspamd_logger" + + if not dir then + dir = rspamd_paths.LUALIBDIR + end + local path = filename + if filename:sub(1, 1) ~= package.config:sub(1, 1) then + -- Relative path + path = lua_util.join_path(dir, "redis_scripts", filename) + end + -- Read file contents + local file = io.open(path, "r") + if not file then + rspamd_logger.errx("failed to open Redis script file: %s", path) + return nil + end + local script = file:read("*all") + if not script then + rspamd_logger.errx("failed to load Redis script file: %s", path) + return nil + end + file:close() + script = lua_util.strip_lua_comments(script) + + return add_redis_script(script, redis_params, 3, filename) +end + +exports.load_redis_script_from_file = load_redis_script_from_file + +local function exec_redis_script(id, params, callback, keys, args) + local redis_args = {} + + if not redis_scripts[id] then + logger.errx("cannot find registered script with id %s", id) + return false + end + + local script = redis_scripts[id] + + if script.fatal_error then + callback(script.fatal_error, nil) + return true + end + + if not script.redis_params then + callback('no redis servers defined', nil) + return true + end + + local function do_call(can_reload) + local function redis_cb(err, data) + if not err then + callback(err, data) + elseif string.match(err, 'NOSCRIPT') then + -- Schedule restart + script.sha = nil + if can_reload then + table.insert(script.waitq, do_call) + if script.in_flight == 0 then + -- Reload scripts if this has not been initiated yet + if params.task then + load_script_task(script, params.task) + else + load_script_taskless(script, rspamd_config, params.ev_base) + end + end + else + callback(err, data) + end + else + callback(err, data) + end + end + + if #redis_args == 0 then + table.insert(redis_args, script.sha) + table.insert(redis_args, tostring(#keys)) + for _, k in ipairs(keys) do + table.insert(redis_args, k) + end + + if type(args) == 'table' then + for _, a in ipairs(args) do + table.insert(redis_args, a) + end + end + end + + if params.task then + if not rspamd_redis_make_request(params.task, script.redis_params, + params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then + callback('Cannot make redis request', nil) + end + else + if not redis_make_request_taskless(params.ev_base, rspamd_config, + script.redis_params, + params.key, params.is_write, redis_cb, 'EVALSHA', redis_args) then + callback('Cannot make redis request', nil) + end + end + end + + if script.loaded then + do_call(true) + else + -- Delayed until scripts are loaded + if not params.task then + table.insert(script.waitq, do_call) + else + -- TODO: fix taskfull requests + table.insert(script.waitq, function() + if script.loaded then + do_call(false) + else + callback('NOSCRIPT', nil) + end + end) + load_script_task(script, params.task, params.is_write) + end + end + + return true +end + +exports.exec_redis_script = exec_redis_script + +local function redis_connect_sync(redis_params, is_write, key, cfg, ev_base) + if not redis_params then + return false, nil + end + + local rspamd_redis = require "rspamd_redis" + local addr + + if key then + if is_write then + addr = redis_params['write_servers']:get_upstream_by_hash(key) + else + addr = redis_params['read_servers']:get_upstream_by_hash(key) + end + else + if is_write then + addr = redis_params['write_servers']:get_upstream_master_slave(key) + else + addr = redis_params['read_servers']:get_upstream_round_robin(key) + end + end + + if not addr then + logger.errx(cfg, 'cannot select server to make redis request') + end + + local options = { + host = addr:get_addr(), + timeout = redis_params['timeout'], + config = cfg or rspamd_config, + ev_base = ev_base or rspamadm_ev_base, + session = redis_params.session or rspamadm_session + } + + for k, v in pairs(redis_params) do + options[k] = v + end + + if not options.config then + logger.errx('config is not set') + return false, nil, addr + end + + if not options.ev_base then + logger.errx('ev_base is not set') + return false, nil, addr + end + + if not options.session then + logger.errx('session is not set') + return false, nil, addr + end + + local ret, conn = rspamd_redis.connect_sync(options) + if not ret then + logger.errx('cannot create redis connection: %s', conn) + addr:fail() + + return false, nil, addr + end + + if conn then + local need_exec = false + if redis_params['username'] then + if redis_params['password'] then + conn:add_cmd('AUTH', { redis_params['username'], redis_params['password'] }) + need_exec = true + else + logger.warnx('Redis requires a password when username is supplied') + return false, nil, addr + end + elseif redis_params['password'] then + conn:add_cmd('AUTH', { redis_params['password'] }) + need_exec = true + end + + if redis_params['db'] then + conn:add_cmd('SELECT', { tostring(redis_params['db']) }) + need_exec = true + elseif redis_params['dbname'] then + conn:add_cmd('SELECT', { tostring(redis_params['dbname']) }) + need_exec = true + end + + if need_exec then + local exec_ret, res = conn:exec() + + if not exec_ret then + logger.errx('cannot prepare redis connection (authentication or db selection failure): %s', + res) + addr:fail() + return false, nil, addr + end + end + end + + return ret, conn, addr +end + +exports.redis_connect_sync = redis_connect_sync + +--[[[ +-- @function lua_redis.request(redis_params, attrs, req) +-- Sends a request to Redis synchronously with coroutines or asynchronously using +-- a callback (modern API) +-- @param redis_params a table of redis server parameters +-- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session) +-- @param req a table of request: a command + command options +-- @return {result,data/connection,address} boolean result, connection object in case of async request and results if using coroutines, redis server address +--]] + +exports.request = function(redis_params, attrs, req) + local lua_util = require "lua_util" + + if not attrs or not redis_params or not req then + logger.errx('invalid arguments for redis request') + return false, nil, nil + end + + if not (attrs.task or (attrs.config and attrs.ev_base)) then + logger.errx('invalid attributes for redis request') + return false, nil, nil + end + + local opts = lua_util.shallowcopy(attrs) + + local log_obj = opts.task or opts.config + + local addr + + if opts.callback then + -- Wrap callback + local callback = opts.callback + local function rspamd_redis_make_request_cb(err, data) + if err then + addr:fail() + else + addr:ok() + end + callback(err, data, addr) + end + opts.callback = rspamd_redis_make_request_cb + end + + local rspamd_redis = require "rspamd_redis" + local is_write = opts.is_write + + if opts.key then + if is_write then + addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key) + else + addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key) + end + else + if is_write then + addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key) + else + addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key) + end + end + + if not addr then + logger.errx(log_obj, 'cannot select server to make redis request') + end + + opts.host = addr:get_addr() + opts.timeout = redis_params.timeout + + if type(req) == 'string' then + opts.cmd = req + else + -- XXX: modifies the input table + opts.cmd = table.remove(req, 1); + opts.args = req + end + + if redis_params.username then + opts.username = redis_params.username + end + + if redis_params.password then + opts.password = redis_params.password + end + + if redis_params.db then + opts.dbname = redis_params.db + end + + lutil.debugm(N, 'perform generic request to redis server' .. + ' (host=%s, timeout=%s): cmd: %s, arguments: %s', addr, + opts.timeout, opts.cmd, opts.args) + + if opts.callback then + local ret, conn = rspamd_redis.make_request(opts) + if not ret then + logger.errx(log_obj, 'cannot execute redis request') + addr:fail() + end + + return ret, conn, addr + else + -- Coroutines version + local ret, conn = rspamd_redis.connect_sync(opts) + if not ret then + logger.errx(log_obj, 'cannot execute redis request') + addr:fail() + else + conn:add_cmd(opts.cmd, opts.args) + return conn:exec() + end + return false, nil, addr + end +end + +--[[[ +-- @function lua_redis.connect(redis_params, attrs) +-- Connects to Redis synchronously with coroutines or asynchronously using a callback (modern API) +-- @param redis_params a table of redis server parameters +-- @param attrs a table of redis request attributes (e.g. task, or ev_base + cfg + session) +-- @return {result,connection,address} boolean result, connection object, redis server address +--]] + +exports.connect = function(redis_params, attrs) + local lua_util = require "lua_util" + + if not attrs or not redis_params then + logger.errx('invalid arguments for redis connect') + return false, nil, nil + end + + if not (attrs.task or (attrs.config and attrs.ev_base)) then + logger.errx('invalid attributes for redis connect') + return false, nil, nil + end + + local opts = lua_util.shallowcopy(attrs) + + local log_obj = opts.task or opts.config + + local addr + + if opts.callback then + -- Wrap callback + local callback = opts.callback + local function rspamd_redis_make_request_cb(err, data) + if err then + addr:fail() + else + addr:ok() + end + callback(err, data, addr) + end + opts.callback = rspamd_redis_make_request_cb + end + + local rspamd_redis = require "rspamd_redis" + local is_write = opts.is_write + + if opts.key then + if is_write then + addr = redis_params['write_servers']:get_upstream_by_hash(attrs.key) + else + addr = redis_params['read_servers']:get_upstream_by_hash(attrs.key) + end + else + if is_write then + addr = redis_params['write_servers']:get_upstream_master_slave(attrs.key) + else + addr = redis_params['read_servers']:get_upstream_round_robin(attrs.key) + end + end + + if not addr then + logger.errx(log_obj, 'cannot select server to make redis connect') + end + + opts.host = addr:get_addr() + opts.timeout = redis_params.timeout + + if redis_params.username then + opts.username = redis_params.username + end + + if redis_params.password then + opts.password = redis_params.password + end + + if redis_params.db then + opts.dbname = redis_params.db + end + + if opts.callback then + local ret, conn = rspamd_redis.connect(opts) + if not ret then + logger.errx(log_obj, 'cannot execute redis connect') + addr:fail() + end + + return ret, conn, addr + else + -- Coroutines version + local ret, conn = rspamd_redis.connect_sync(opts) + if not ret then + logger.errx(log_obj, 'cannot execute redis connect') + addr:fail() + else + return true, conn, addr + end + + return false, nil, addr + end +end + +local redis_prefixes = {} + +--[[[ +-- @function lua_redis.register_prefix(prefix, module, description[, optional]) +-- Register new redis prefix for documentation purposes +-- @param {string} prefix string prefix +-- @param {string} module module name +-- @param {string} description prefix description +-- @param {table} optional optional kv pairs (e.g. pattern) +--]] +local function register_prefix(prefix, module, description, optional) + local pr = { + module = module, + description = description + } + + if optional and type(optional) == 'table' then + for k, v in pairs(optional) do + pr[k] = v + end + end + + redis_prefixes[prefix] = pr +end + +exports.register_prefix = register_prefix + +--[[[ +-- @function lua_redis.prefixes([mname]) +-- Returns prefixes for specific module (or all prefixes). Returns a table prefix -> table +--]] +exports.prefixes = function(mname) + if not mname then + return redis_prefixes + else + local fun = require "fun" + + return fun.totable(fun.filter(function(_, data) + return data.module == mname + end, + redis_prefixes)) + end +end + +return exports |