diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-10 21:30:40 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-04-10 21:30:40 +0000 |
commit | 133a45c109da5310add55824db21af5239951f93 (patch) | |
tree | ba6ac4c0a950a0dda56451944315d66409923918 /contrib/lua-fun | |
parent | Initial commit. (diff) | |
download | rspamd-133a45c109da5310add55824db21af5239951f93.tar.xz rspamd-133a45c109da5310add55824db21af5239951f93.zip |
Adding upstream version 3.8.1.upstream/3.8.1upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'contrib/lua-fun')
-rw-r--r-- | contrib/lua-fun/COPYING.md | 27 | ||||
-rw-r--r-- | contrib/lua-fun/fun.lua | 1076 |
2 files changed, 1103 insertions, 0 deletions
diff --git a/contrib/lua-fun/COPYING.md b/contrib/lua-fun/COPYING.md new file mode 100644 index 0000000..42ee231 --- /dev/null +++ b/contrib/lua-fun/COPYING.md @@ -0,0 +1,27 @@ +Copying +======= + +**Lua Fun** source codes, logo and documentation are distributed under the +**[MIT/X11 License]** - same as Lua and LuaJIT. + +Copyright (c) 2013 Roman Tsisyk <roman@tsisyk.com> + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +[MIT/X11 License]: http://www.opensource.org/licenses/mit-license.php diff --git a/contrib/lua-fun/fun.lua b/contrib/lua-fun/fun.lua new file mode 100644 index 0000000..2c712d3 --- /dev/null +++ b/contrib/lua-fun/fun.lua @@ -0,0 +1,1076 @@ +--- +--- Lua Fun - a high-performance functional programming library for LuaJIT +--- +--- Copyright (c) 2013-2017 Roman Tsisyk <roman@tsisyk.com> +--- +--- Distributed under the MIT/X11 License. See COPYING.md for more details. +--- + +local exports = {} +local methods = {} + +-- Rspamd specific +local rspamd_text = require "rspamd_text" +local text_cookie = rspamd_text.cookie + +-- compatibility with Lua 5.1/5.2 +local unpack = rawget(table, "unpack") or unpack + +-------------------------------------------------------------------------------- +-- Tools +-------------------------------------------------------------------------------- + +local return_if_not_empty = function(state_x, ...) + if state_x == nil then + return nil + end + return ... +end + +local call_if_not_empty = function(fun, state_x, ...) + if state_x == nil then + return nil + end + return state_x, fun(...) +end + +local function deepcopy(orig) -- used by cycle() + local orig_type = type(orig) + local copy + if orig_type == 'table' then + copy = {} + for orig_key, orig_value in next, orig, nil do + copy[deepcopy(orig_key)] = deepcopy(orig_value) + end + else + copy = orig + end + return copy +end + +local iterator_mt = { + -- usually called by for-in loop + __call = function(self, param, state) + return self.gen(param, state) + end; + __tostring = function(self) + return '<generator>' + end; + -- add all exported methods + __index = methods; +} + +local wrap = function(gen, param, state) + return setmetatable({ + gen = gen, + param = param, + state = state + }, iterator_mt), param, state +end +exports.wrap = wrap + +local unwrap = function(self) + return self.gen, self.param, self.state +end +methods.unwrap = unwrap + +-------------------------------------------------------------------------------- +-- Basic Functions +-------------------------------------------------------------------------------- + +local nil_gen = function(_param, _state) + return nil +end + +local string_gen = function(param, state) + local state = state + 1 + if state > #param then + return nil + end + local r = string.sub(param, state, state) + return state, r +end + +local text_gen = function(param, state) + local state = state + 1 + if state > #param then + return nil + end + local r = string.char(param:byte(state)) + return state, r +end + +local ipairs_gen = ipairs({}) -- get the generating function from ipairs + +local pairs_gen = pairs({ a = 0 }) -- get the generating function from pairs +local map_gen = function(tab, key) + local value + local key, value = pairs_gen(tab, key) + return key, key, value +end + +local rawiter = function(obj, param, state) + assert(obj ~= nil, "invalid iterator") + if type(obj) == "table" then + local mt = getmetatable(obj); + if mt ~= nil then + if mt == iterator_mt then + return obj.gen, obj.param, obj.state + elseif mt.__ipairs ~= nil then + return mt.__ipairs(obj) + elseif mt.__pairs ~= nil then + return mt.__pairs(obj) + end + end + if #obj > 0 then + -- array + return ipairs(obj) + else + -- hash + return map_gen, obj, nil + end + elseif (type(obj) == "function") then + return obj, param, state + elseif (type(obj) == "string") then + if #obj == 0 then + return nil_gen, nil, nil + end + return string_gen, obj, 0 + elseif (type(obj) == "userdata" and obj.cookie == text_cookie) then + if #obj == 0 then + return nil_gen, nil, nil + end + return text_gen, obj, 0 + end + error(string.format('object %s of type "%s" is not iterable', + obj, type(obj))) +end + +local iter = function(obj, param, state) + return wrap(rawiter(obj, param, state)) +end +exports.iter = iter + +local method0 = function(fun) + return function(self) + return fun(self.gen, self.param, self.state) + end +end + +local method1 = function(fun) + return function(self, arg1) + return fun(arg1, self.gen, self.param, self.state) + end +end + +local method2 = function(fun) + return function(self, arg1, arg2) + return fun(arg1, arg2, self.gen, self.param, self.state) + end +end + +local export0 = function(fun) + return function(gen, param, state) + return fun(rawiter(gen, param, state)) + end +end + +local export1 = function(fun) + return function(arg1, gen, param, state) + return fun(arg1, rawiter(gen, param, state)) + end +end + +local export2 = function(fun) + return function(arg1, arg2, gen, param, state) + return fun(arg1, arg2, rawiter(gen, param, state)) + end +end + +local each = function(fun, gen, param, state) + repeat + state = call_if_not_empty(fun, gen(param, state)) + until state == nil +end +methods.each = method1(each) +exports.each = export1(each) +methods.for_each = methods.each +exports.for_each = exports.each +methods.foreach = methods.each +exports.foreach = exports.each + +-------------------------------------------------------------------------------- +-- Generators +-------------------------------------------------------------------------------- + +local range_gen = function(param, state) + local stop, step = param[1], param[2] + local state = state + step + if state > stop then + return nil + end + return state, state +end + +local range_rev_gen = function(param, state) + local stop, step = param[1], param[2] + local state = state + step + if state < stop then + return nil + end + return state, state +end + +local range = function(start, stop, step) + if step == nil then + if stop == nil then + if start == 0 then + return nil_gen, nil, nil + end + stop = start + start = stop > 0 and 1 or -1 + end + step = start <= stop and 1 or -1 + end + + assert(type(start) == "number", "start must be a number") + assert(type(stop) == "number", "stop must be a number") + assert(type(step) == "number", "step must be a number") + assert(step ~= 0, "step must not be zero") + + if (step > 0) then + return wrap(range_gen, {stop, step}, start - step) + elseif (step < 0) then + return wrap(range_rev_gen, {stop, step}, start - step) + end +end +exports.range = range + +local duplicate_table_gen = function(param_x, state_x) + return state_x + 1, unpack(param_x) +end + +local duplicate_fun_gen = function(param_x, state_x) + return state_x + 1, param_x(state_x) +end + +local duplicate_gen = function(param_x, state_x) + return state_x + 1, param_x +end + +local duplicate = function(...) + if select('#', ...) <= 1 then + return wrap(duplicate_gen, select(1, ...), 0) + else + return wrap(duplicate_table_gen, {...}, 0) + end +end +exports.duplicate = duplicate +exports.replicate = duplicate +exports.xrepeat = duplicate + +local tabulate = function(fun) + assert(type(fun) == "function") + return wrap(duplicate_fun_gen, fun, 0) +end +exports.tabulate = tabulate + +local zeros = function() + return wrap(duplicate_gen, 0, 0) +end +exports.zeros = zeros + +local ones = function() + return wrap(duplicate_gen, 1, 0) +end +exports.ones = ones + +local rands_gen = function(param_x, _state_x) + return 0, math.random(param_x[1], param_x[2]) +end + +local rands_nil_gen = function(_param_x, _state_x) + return 0, math.random() +end + +local rands = function(n, m) + if n == nil and m == nil then + return wrap(rands_nil_gen, 0, 0) + end + assert(type(n) == "number", "invalid first arg to rands") + if m == nil then + m = n + n = 0 + else + assert(type(m) == "number", "invalid second arg to rands") + end + assert(n < m, "empty interval") + return wrap(rands_gen, {n, m - 1}, 0) +end +exports.rands = rands + +-------------------------------------------------------------------------------- +-- Slicing +-------------------------------------------------------------------------------- + +local nth = function(n, gen_x, param_x, state_x) + assert(n > 0, "invalid first argument to nth") + -- An optimization for arrays and strings + if gen_x == ipairs_gen then + return param_x[n] + elseif gen_x == string_gen then + if n <= #param_x then + return string.sub(param_x, n, n) + else + return nil + end + end + for i=1,n-1,1 do + state_x = gen_x(param_x, state_x) + if state_x == nil then + return nil + end + end + return return_if_not_empty(gen_x(param_x, state_x)) +end +methods.nth = method1(nth) +exports.nth = export1(nth) + +local head_call = function(state, ...) + if state == nil then + error("head: iterator is empty") + end + return ... +end + +local head = function(gen, param, state) + return head_call(gen(param, state)) +end +methods.head = method0(head) +exports.head = export0(head) +exports.car = exports.head +methods.car = methods.head + +local tail = function(gen, param, state) + state = gen(param, state) + if state == nil then + return wrap(nil_gen, nil, nil) + end + return wrap(gen, param, state) +end +methods.tail = method0(tail) +exports.tail = export0(tail) +exports.cdr = exports.tail +methods.cdr = methods.tail + +local take_n_gen_x = function(i, state_x, ...) + if state_x == nil then + return nil + end + return {i, state_x}, ... +end + +local take_n_gen = function(param, state) + local n, gen_x, param_x = param[1], param[2], param[3] + local i, state_x = state[1], state[2] + if i >= n then + return nil + end + return take_n_gen_x(i + 1, gen_x(param_x, state_x)) +end + +local take_n = function(n, gen, param, state) + assert(n >= 0, "invalid first argument to take_n") + return wrap(take_n_gen, {n, gen, param}, {0, state}) +end +methods.take_n = method1(take_n) +exports.take_n = export1(take_n) + +local take_while_gen_x = function(fun, state_x, ...) + if state_x == nil or not fun(...) then + return nil + end + return state_x, ... +end + +local take_while_gen = function(param, state_x) + local fun, gen_x, param_x = param[1], param[2], param[3] + return take_while_gen_x(fun, gen_x(param_x, state_x)) +end + +local take_while = function(fun, gen, param, state) + assert(type(fun) == "function", "invalid first argument to take_while") + return wrap(take_while_gen, {fun, gen, param}, state) +end +methods.take_while = method1(take_while) +exports.take_while = export1(take_while) + +local take = function(n_or_fun, gen, param, state) + if type(n_or_fun) == "number" then + return take_n(n_or_fun, gen, param, state) + else + return take_while(n_or_fun, gen, param, state) + end +end +methods.take = method1(take) +exports.take = export1(take) + +local drop_n = function(n, gen, param, state) + assert(n >= 0, "invalid first argument to drop_n") + local i + for i=1,n,1 do + state = gen(param, state) + if state == nil then + return wrap(nil_gen, nil, nil) + end + end + return wrap(gen, param, state) +end +methods.drop_n = method1(drop_n) +exports.drop_n = export1(drop_n) + +local drop_while_x = function(fun, state_x, ...) + if state_x == nil or not fun(...) then + return state_x, false + end + return state_x, true, ... +end + +local drop_while = function(fun, gen_x, param_x, state_x) + assert(type(fun) == "function", "invalid first argument to drop_while") + local cont, state_x_prev + repeat + state_x_prev = deepcopy(state_x) + state_x, cont = drop_while_x(fun, gen_x(param_x, state_x)) + until not cont + if state_x == nil then + return wrap(nil_gen, nil, nil) + end + return wrap(gen_x, param_x, state_x_prev) +end +methods.drop_while = method1(drop_while) +exports.drop_while = export1(drop_while) + +local drop = function(n_or_fun, gen_x, param_x, state_x) + if type(n_or_fun) == "number" then + return drop_n(n_or_fun, gen_x, param_x, state_x) + else + return drop_while(n_or_fun, gen_x, param_x, state_x) + end +end +methods.drop = method1(drop) +exports.drop = export1(drop) + +local split = function(n_or_fun, gen_x, param_x, state_x) + return take(n_or_fun, gen_x, param_x, state_x), + drop(n_or_fun, gen_x, param_x, state_x) +end +methods.split = method1(split) +exports.split = export1(split) +methods.split_at = methods.split +exports.split_at = exports.split +methods.span = methods.split +exports.span = exports.split + +-------------------------------------------------------------------------------- +-- Indexing +-------------------------------------------------------------------------------- + +local index = function(x, gen, param, state) + local i = 1 + for _k, r in gen, param, state do + if r == x then + return i + end + i = i + 1 + end + return nil +end +methods.index = method1(index) +exports.index = export1(index) +methods.index_of = methods.index +exports.index_of = exports.index +methods.elem_index = methods.index +exports.elem_index = exports.index + +local indexes_gen = function(param, state) + local x, gen_x, param_x = param[1], param[2], param[3] + local i, state_x = state[1], state[2] + local r + while true do + state_x, r = gen_x(param_x, state_x) + if state_x == nil then + return nil + end + i = i + 1 + if r == x then + return {i, state_x}, i + end + end +end + +local indexes = function(x, gen, param, state) + return wrap(indexes_gen, {x, gen, param}, {0, state}) +end +methods.indexes = method1(indexes) +exports.indexes = export1(indexes) +methods.elem_indexes = methods.indexes +exports.elem_indexes = exports.indexes +methods.indices = methods.indexes +exports.indices = exports.indexes +methods.elem_indices = methods.indexes +exports.elem_indices = exports.indexes + +-------------------------------------------------------------------------------- +-- Filtering +-------------------------------------------------------------------------------- + +local filter1_gen = function(fun, gen_x, param_x, state_x, a) + while true do + if state_x == nil or fun(a) then break; end + state_x, a = gen_x(param_x, state_x) + end + return state_x, a +end + +-- call each other +local filterm_gen +local filterm_gen_shrink = function(fun, gen_x, param_x, state_x) + return filterm_gen(fun, gen_x, param_x, gen_x(param_x, state_x)) +end + +filterm_gen = function(fun, gen_x, param_x, state_x, ...) + if state_x == nil then + return nil + end + if fun(...) then + return state_x, ... + end + return filterm_gen_shrink(fun, gen_x, param_x, state_x) +end + +local filter_detect = function(fun, gen_x, param_x, state_x, ...) + if select('#', ...) < 2 then + return filter1_gen(fun, gen_x, param_x, state_x, ...) + else + return filterm_gen(fun, gen_x, param_x, state_x, ...) + end +end + +local filter_gen = function(param, state_x) + local fun, gen_x, param_x = param[1], param[2], param[3] + return filter_detect(fun, gen_x, param_x, gen_x(param_x, state_x)) +end + +local filter = function(fun, gen, param, state) + return wrap(filter_gen, {fun, gen, param}, state) +end +methods.filter = method1(filter) +exports.filter = export1(filter) +methods.remove_if = methods.filter +exports.remove_if = exports.filter + +local grep = function(fun_or_regexp, gen, param, state) + local fun = fun_or_regexp + if type(fun_or_regexp) == "string" then + fun = function(x) return string.find(x, fun_or_regexp) ~= nil end + end + return filter(fun, gen, param, state) +end +methods.grep = method1(grep) +exports.grep = export1(grep) + +local partition = function(fun, gen, param, state) + local neg_fun = function(...) + return not fun(...) + end + return filter(fun, gen, param, state), + filter(neg_fun, gen, param, state) +end +methods.partition = method1(partition) +exports.partition = export1(partition) + +-------------------------------------------------------------------------------- +-- Reducing +-------------------------------------------------------------------------------- + +local foldl_call = function(fun, start, state, ...) + if state == nil then + return nil, start + end + return state, fun(start, ...) +end + +local foldl = function(fun, start, gen_x, param_x, state_x) + while true do + state_x, start = foldl_call(fun, start, gen_x(param_x, state_x)) + if state_x == nil then + break; + end + end + return start +end +methods.foldl = method2(foldl) +exports.foldl = export2(foldl) +methods.reduce = methods.foldl +exports.reduce = exports.foldl + +local length = function(gen, param, state) + if gen == ipairs_gen or gen == string_gen then + return #param + end + local len = 0 + repeat + state = gen(param, state) + len = len + 1 + until state == nil + return len - 1 +end +methods.length = method0(length) +exports.length = export0(length) + +local is_null = function(gen, param, state) + return gen(param, deepcopy(state)) == nil +end +methods.is_null = method0(is_null) +exports.is_null = export0(is_null) + +local is_prefix_of = function(iter_x, iter_y) + local gen_x, param_x, state_x = iter(iter_x) + local gen_y, param_y, state_y = iter(iter_y) + + local r_x, r_y + for i=1,10,1 do + state_x, r_x = gen_x(param_x, state_x) + state_y, r_y = gen_y(param_y, state_y) + if state_x == nil then + return true + end + if state_y == nil or r_x ~= r_y then + return false + end + end +end +methods.is_prefix_of = is_prefix_of +exports.is_prefix_of = is_prefix_of + +local all = function(fun, gen_x, param_x, state_x) + local r + repeat + state_x, r = call_if_not_empty(fun, gen_x(param_x, state_x)) + until state_x == nil or not r + return state_x == nil +end +methods.all = method1(all) +exports.all = export1(all) +methods.every = methods.all +exports.every = exports.all + +local any = function(fun, gen_x, param_x, state_x) + local r + repeat + state_x, r = call_if_not_empty(fun, gen_x(param_x, state_x)) + until state_x == nil or r + return not not r +end +methods.any = method1(any) +exports.any = export1(any) +methods.some = methods.any +exports.some = exports.any + +local sum = function(gen, param, state) + local s = 0 + local r = 0 + repeat + s = s + r + state, r = gen(param, state) + until state == nil + return s +end +methods.sum = method0(sum) +exports.sum = export0(sum) + +local product = function(gen, param, state) + local p = 1 + local r = 1 + repeat + p = p * r + state, r = gen(param, state) + until state == nil + return p +end +methods.product = method0(product) +exports.product = export0(product) + +local min_cmp = function(m, n) + if n < m then return n else return m end +end + +local max_cmp = function(m, n) + if n > m then return n else return m end +end + +local min = function(gen, param, state) + local state, m = gen(param, state) + if state == nil then + error("min: iterator is empty") + end + + local cmp + if type(m) == "number" then + -- An optimization: use math.min for numbers + cmp = math.min + else + cmp = min_cmp + end + + for _, r in gen, param, state do + m = cmp(m, r) + end + return m +end +methods.min = method0(min) +exports.min = export0(min) +methods.minimum = methods.min +exports.minimum = exports.min + +local min_by = function(cmp, gen_x, param_x, state_x) + local state_x, m = gen_x(param_x, state_x) + if state_x == nil then + error("min: iterator is empty") + end + + for _, r in gen_x, param_x, state_x do + m = cmp(m, r) + end + return m +end +methods.min_by = method1(min_by) +exports.min_by = export1(min_by) +methods.minimum_by = methods.min_by +exports.minimum_by = exports.min_by + +local max = function(gen_x, param_x, state_x) + local state_x, m = gen_x(param_x, state_x) + if state_x == nil then + error("max: iterator is empty") + end + + local cmp + if type(m) == "number" then + -- An optimization: use math.max for numbers + cmp = math.max + else + cmp = max_cmp + end + + for _, r in gen_x, param_x, state_x do + m = cmp(m, r) + end + return m +end +methods.max = method0(max) +exports.max = export0(max) +methods.maximum = methods.max +exports.maximum = exports.max + +local max_by = function(cmp, gen_x, param_x, state_x) + local state_x, m = gen_x(param_x, state_x) + if state_x == nil then + error("max: iterator is empty") + end + + for _, r in gen_x, param_x, state_x do + m = cmp(m, r) + end + return m +end +methods.max_by = method1(max_by) +exports.max_by = export1(max_by) +methods.maximum_by = methods.maximum_by +exports.maximum_by = exports.maximum_by + +local totable = function(gen_x, param_x, state_x) + local tab, key, val = {} + while true do + state_x, val = gen_x(param_x, state_x) + if state_x == nil then + break + end + table.insert(tab, val) + end + return tab +end +methods.totable = method0(totable) +exports.totable = export0(totable) + +local tomap = function(gen_x, param_x, state_x) + local tab, key, val = {} + while true do + state_x, key, val = gen_x(param_x, state_x) + if state_x == nil then + break + end + tab[key] = val + end + return tab +end +methods.tomap = method0(tomap) +exports.tomap = export0(tomap) + +-------------------------------------------------------------------------------- +-- Transformations +-------------------------------------------------------------------------------- + +local map_gen = function(param, state) + local gen_x, param_x, fun = param[1], param[2], param[3] + return call_if_not_empty(fun, gen_x(param_x, state)) +end + +local map = function(fun, gen, param, state) + return wrap(map_gen, {gen, param, fun}, state) +end +methods.map = method1(map) +exports.map = export1(map) + +local enumerate_gen_call = function(state, i, state_x, ...) + if state_x == nil then + return nil + end + return {i + 1, state_x}, i, ... +end + +local enumerate_gen = function(param, state) + local gen_x, param_x = param[1], param[2] + local i, state_x = state[1], state[2] + return enumerate_gen_call(state, i, gen_x(param_x, state_x)) +end + +local enumerate = function(gen, param, state) + return wrap(enumerate_gen, {gen, param}, {1, state}) +end +methods.enumerate = method0(enumerate) +exports.enumerate = export0(enumerate) + +local intersperse_call = function(i, state_x, ...) + if state_x == nil then + return nil + end + return {i + 1, state_x}, ... +end + +local intersperse_gen = function(param, state) + local x, gen_x, param_x = param[1], param[2], param[3] + local i, state_x = state[1], state[2] + if i % 2 == 1 then + return {i + 1, state_x}, x + else + return intersperse_call(i, gen_x(param_x, state_x)) + end +end + +-- TODO: interperse must not add x to the tail +local intersperse = function(x, gen, param, state) + return wrap(intersperse_gen, {x, gen, param}, {0, state}) +end +methods.intersperse = method1(intersperse) +exports.intersperse = export1(intersperse) + +-------------------------------------------------------------------------------- +-- Compositions +-------------------------------------------------------------------------------- + +local function zip_gen_r(param, state, state_new, ...) + if #state_new == #param / 2 then + return state_new, ... + end + + local i = #state_new + 1 + local gen_x, param_x = param[2 * i - 1], param[2 * i] + local state_x, r = gen_x(param_x, state[i]) + if state_x == nil then + return nil + end + table.insert(state_new, state_x) + return zip_gen_r(param, state, state_new, r, ...) +end + +local zip_gen = function(param, state) + return zip_gen_r(param, state, {}) +end + +-- A special hack for zip/chain to skip last two state, if a wrapped iterator +-- has been passed +local numargs = function(...) + local n = select('#', ...) + if n >= 3 then + -- Fix last argument + local it = select(n - 2, ...) + if type(it) == 'table' and getmetatable(it) == iterator_mt and + it.param == select(n - 1, ...) and it.state == select(n, ...) then + return n - 2 + end + end + return n +end + +local zip = function(...) + local n = numargs(...) + if n == 0 then + return wrap(nil_gen, nil, nil) + end + local param = { [2 * n] = 0 } + local state = { [n] = 0 } + + local i, gen_x, param_x, state_x + for i=1,n,1 do + local it = select(n - i + 1, ...) + gen_x, param_x, state_x = rawiter(it) + param[2 * i - 1] = gen_x + param[2 * i] = param_x + state[i] = state_x + end + + return wrap(zip_gen, param, state) +end +methods.zip = zip +exports.zip = zip + +local cycle_gen_call = function(param, state_x, ...) + if state_x == nil then + local gen_x, param_x, state_x0 = param[1], param[2], param[3] + return gen_x(param_x, deepcopy(state_x0)) + end + return state_x, ... +end + +local cycle_gen = function(param, state_x) + local gen_x, param_x, state_x0 = param[1], param[2], param[3] + return cycle_gen_call(param, gen_x(param_x, state_x)) +end + +local cycle = function(gen, param, state) + return wrap(cycle_gen, {gen, param, state}, deepcopy(state)) +end +methods.cycle = method0(cycle) +exports.cycle = export0(cycle) + +-- call each other +local chain_gen_r1 +local chain_gen_r2 = function(param, state, state_x, ...) + if state_x == nil then + local i = state[1] + i = i + 1 + if param[3 * i - 1] == nil then + return nil + end + local state_x = param[3 * i] + return chain_gen_r1(param, {i, state_x}) + end + return {state[1], state_x}, ... +end + +chain_gen_r1 = function(param, state) + local i, state_x = state[1], state[2] + local gen_x, param_x = param[3 * i - 2], param[3 * i - 1] + return chain_gen_r2(param, state, gen_x(param_x, state[2])) +end + +local chain = function(...) + local n = numargs(...) + if n == 0 then + return wrap(nil_gen, nil, nil) + end + + local param = { [3 * n] = 0 } + local i, gen_x, param_x, state_x + for i=1,n,1 do + local elem = select(i, ...) + gen_x, param_x, state_x = iter(elem) + param[3 * i - 2] = gen_x + param[3 * i - 1] = param_x + param[3 * i] = state_x + end + + return wrap(chain_gen_r1, param, {1, param[3]}) +end +methods.chain = chain +exports.chain = chain + +-------------------------------------------------------------------------------- +-- Operators +-------------------------------------------------------------------------------- + +local operator = { + ---------------------------------------------------------------------------- + -- Comparison operators + ---------------------------------------------------------------------------- + lt = function(a, b) return a < b end, + le = function(a, b) return a <= b end, + eq = function(a, b) return a == b end, + ne = function(a, b) return a ~= b end, + ge = function(a, b) return a >= b end, + gt = function(a, b) return a > b end, + + ---------------------------------------------------------------------------- + -- Arithmetic operators + ---------------------------------------------------------------------------- + add = function(a, b) return a + b end, + div = function(a, b) return a / b end, + floordiv = function(a, b) return math.floor(a/b) end, + intdiv = function(a, b) + local q = a / b + if a >= 0 then return math.floor(q) else return math.ceil(q) end + end, + mod = function(a, b) return a % b end, + mul = function(a, b) return a * b end, + neq = function(a) return -a end, + unm = function(a) return -a end, -- an alias + pow = function(a, b) return a ^ b end, + sub = function(a, b) return a - b end, + truediv = function(a, b) return a / b end, + + ---------------------------------------------------------------------------- + -- String operators + ---------------------------------------------------------------------------- + concat = function(a, b) return a..b end, + len = function(a) return #a end, + length = function(a) return #a end, -- an alias + + ---------------------------------------------------------------------------- + -- Logical operators + ---------------------------------------------------------------------------- + land = function(a, b) return a and b end, + lor = function(a, b) return a or b end, + lnot = function(a) return not a end, + truth = function(a) return not not a end, +} +exports.operator = operator +methods.operator = operator +exports.op = operator +methods.op = operator + +-------------------------------------------------------------------------------- +-- module definitions +-------------------------------------------------------------------------------- + +-- a special syntax sugar to export all functions to the global table +setmetatable(exports, { + __call = function(t, override) + for k, v in pairs(t) do + if _G[k] ~= nil then + local msg = 'function ' .. k .. ' already exists in global scope.' + if override then + _G[k] = v + print('WARNING: ' .. msg .. ' Overwritten.') + else + print('NOTICE: ' .. msg .. ' Skipped.') + end + else + _G[k] = v + end + end + end, +}) + +return exports |