diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2021-07-18 17:45:34 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2021-07-18 17:45:34 +0000 |
commit | 77f4d592283d96b76512caabc0baaa0f33097143 (patch) | |
tree | 06a628805fdcfd2b3e343e95d7a32c16f0b822e9 /src/lib | |
parent | Initial commit. (diff) | |
download | drool-77f4d592283d96b76512caabc0baaa0f33097143.tar.xz drool-77f4d592283d96b76512caabc0baaa0f33097143.zip |
Adding upstream version 2.0.0.upstream/2.0.0upstream
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'src/lib')
-rw-r--r-- | src/lib/drool/replay.lua | 596 | ||||
-rw-r--r-- | src/lib/drool/respdiff.lua | 571 |
2 files changed, 1167 insertions, 0 deletions
diff --git a/src/lib/drool/replay.lua b/src/lib/drool/replay.lua new file mode 100644 index 0000000..c3b86bb --- /dev/null +++ b/src/lib/drool/replay.lua @@ -0,0 +1,596 @@ +-- DNS Reply Tool (drool) +-- +-- Copyright (c) 2017-2021, OARC, Inc. +-- Copyright (c) 2017, Comcast Corporation +-- All rights reserved. +-- +-- Redistribution and use in source and binary forms, with or without +-- modification, are permitted provided that the following conditions +-- are met: +-- +-- 1. Redistributions of source code must retain the above copyright +-- notice, this list of conditions and the following disclaimer. +-- +-- 2. Redistributions in binary form must reproduce the above copyright +-- notice, this list of conditions and the following disclaimer in +-- the documentation and/or other materials provided with the +-- distribution. +-- +-- 3. Neither the name of the copyright holder nor the names of its +-- contributors may be used to endorse or promote products derived +-- from this software without specific prior written permission. +-- +-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +-- "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +-- LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +-- FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +-- COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +-- INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +-- BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +-- LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +-- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +-- LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +-- ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +-- POSSIBILITY OF SUCH DAMAGE. + +module(...,package.seeall) + +local ffi = require("ffi") +local C = ffi.C +ffi.cdef[[ +struct replay_stats { + int64_t sent, received, responses, timeouts, errors; +}; +void* malloc(size_t); +void free(void*); +]] + +local object = require("dnsjit.core.objects") +require("dnsjit.core.timespec_h") + +Replay = {} + +function Replay.new(getopt) + local self = setmetatable({ + file = nil, + host = nil, + port = nil, + no_responses = false, + use_threads = false, + print_dns = false, + timeout = 10.0, + timing = nil, + timing_mode = "ignore", + timing_opt = nil, + layer = nil, + input = nil, + no_udp = false, + no_tcp = false, + udp_threads = 4, + tcp_threads = 2, + + packets = 0, + queries = 0, + sent = 0, + received = 0, + responses = 0, + errors = 0, + timeouts = 0, + + log = require("dnsjit.core.log").new("replay"), + + _timespec = ffi.new("core_timespec_t"), + _udp_channels = {}, + _tcp_channels = {}, + _threads = {}, + _udpcli = nil, + _tcpcli = nil, + _stat_channels = {}, + }, { __index = Replay }) + + if getopt then + getopt.usage_desc = arg[1] .. " replay [options...] file host port" + getopt:add("t", "timing", "ignore", "Set the timing mode [mode=option], default ignore", "?") + getopt:add("n", "no-responses", false, "Do not wait for responses before sending next request", "?") + getopt:add(nil, "no-udp", false, "Do not use UDP", "?") + getopt:add(nil, "no-tcp", false, "Do not use TCP", "?") + getopt:add("T", "threads", false, "Use threads", "?") + getopt:add(nil, "udp-threads", 4, "Set the number of UDP threads to use, default 4", "?") + getopt:add(nil, "tcp-threads", 2, "Set the number of TCP threads to use, default 2", "?") + getopt:add(nil, "timeout", "10.0", "Set timeout for waiting on responses [seconds.nanoseconds], default 10.0", "?") + getopt:add("D", nil, false, "Show DNS queries and responses as processing goes", "?") + end + + return self +end + +function Replay:getopt(getopt) + local _, file, host, port = unpack(getopt.left) + + if getopt:val("no-udp") and getopt:val("no-tcp") then + self.log:fatal("can not disable all transports") + end + if getopt:val("udp-threads") < 1 then + self.log:fatal("--udp-threads must be 1 or greater") + end + if getopt:val("tcp-threads") < 1 then + self.log:fatal("--tcp-threads must be 1 or greater") + end + + if file == nil then + self.log:fatal("no file given") + end + self.file = file + if host == nil then + self.log:fatal("no target host given") + end + self.host = host + if port == nil then + self.log:fatal("no target port given") + end + self.port = port + self.use_threads = getopt:val("T") + self.timeout = tonumber(getopt:val("timeout")) + self.no_responses = getopt:val("n") + self.print_dns = getopt:val("D") + self.no_udp = getopt:val("no-udp") + self.udp_threads = getopt:val("udp-threads") + self.no_tcp = getopt:val("no-tcp") + self.tcp_threads = getopt:val("tcp-threads") + + if getopt:val("t") ~= "ignore" then + self.timing_mode, self.timing_opt = getopt:val("t"):match("(%w+)=([%w%.]+)") + if self.timing_mode == nil then + self.timing_mode = getopt:val("t") + end + end +end + +local _thr_func = function(thr) + local mode, thrid, host, port, chan, stats, resp, print_dns, to_sec, to_nsec = thr:pop(10) + local log = require("dnsjit.core.log").new(mode .. "#" .. thrid) + require("dnsjit.core.objects") + local ffi = require("ffi") + local C = ffi.C + ffi.cdef[[ +struct replay_stats { + int64_t sent, received, responses, timeouts, errors; +}; +void* malloc(size_t); +void free(void*); +]] + local dns = require("dnsjit.core.object.dns").new() + + if print_dns == 1 then + print_dns = function(payload) + dns.obj_prev = payload + dns:print() + end + else + print_dns = nil + end + + local cli, recv, ctx, prod + if mode == "udp" then + cli = require("dnsjit.output.udpcli").new() + else + cli = require("dnsjit.output.tcpcli").new() + end + cli:timeout(to_sec, to_nsec) + if cli:connect(host, port) ~= 0 then + log:fatal("unable to connect to host " .. host .. " port " .. port) + end + recv, ctx = cli:receive() + prod = cli:produce() + + local stat = ffi.cast("struct replay_stats*", C.malloc(ffi.sizeof("struct replay_stats"))) + ffi.fill(stat, ffi.sizeof("struct replay_stats")) + ffi.gc(stat, C.free) + + local send + if resp == 0 then + send = function(obj) + log:info("sending query") + recv(ctx, obj) + if print_dns then + print_dns(obj) + end + end + else + send = function(obj) + log:info("sending query") + recv(ctx, obj) + if print_dns then + print_dns(obj) + end + + local response = prod(ctx) + if response == nil then + log:warning("producer error") + return + end + local payload = response:cast() + if payload.len == 0 then + stat.timeouts = stat.timeouts + 1 + log:info("timeout") + return + end + + stat.responses = stat.responses + 1 + log:info("got response") + if print_dns then + print_dns(response) + end + end + end + + while true do + local obj = chan:get() + if obj == nil then break end + obj = ffi.cast("core_object_t*", obj) + dns.obj_prev = obj + if dns:parse_header() == 0 and dns.qr == 0 then + send(obj) + stat.sent = stat.sent + 1 + end + obj:free() + end + + stat.errors = cli:errors() + ffi.gc(stat, nil) + stats:put(stat) +end + +function Replay:setup() + self.input = require("dnsjit.input.mmpcap").new() + if self.input:open(self.file) ~= 0 then + self.log:fatal("unable to open file " .. self.file) + end + + if self.timing_mode ~= "ignore" then + self.timing = require("dnsjit.filter.timing").new() + self.timing:producer(self.input) + + if self.timing_mode == "keep" then + else + if self.timing_mode == "inc" or self.timing_mode == "increase" then + self.timing:increase(tonumber(self.timing_opt)) + elseif self.timing_mode == "red" or self.timing_mode == "reduce" then + self.timing:reduce(tonumber(self.timing_opt)) + elseif self.timing_mode == "mul" or self.timing_mode == "multiply" then + self.timing:multiply(tonumber(self.timing_opt)) + elseif self.timing_mode == "fix" or self.timing_mode == "fixed" then + self.timing:fixed(tonumber(self.timing_opt)) + else + self.log:fatal("Invalid timing mode " .. self.timing_mode) + end + end + end + + self.layer = require("dnsjit.filter.layer").new() + if self.timing then + self.layer:producer(self.timing) + else + self.layer:producer(self.input) + end + + self._timespec.sec = math.floor(self.timeout) + self._timespec.nsec = (self.timeout - math.floor(self.timeout)) * 1000000000 + + if self.use_threads then + if not self.no_udp then + self.log:info("starting " .. self.udp_threads .. " UDP threads") + for n = 1, self.udp_threads do + local chan = require("dnsjit.core.channel").new() + local stats = require("dnsjit.core.channel").new() + local thr = require("dnsjit.core.thread").new() + + thr:start(_thr_func) + thr:push("udp", n, self.host, self.port, chan, stats) + if self.no_responses then + thr:push(0) + else + thr:push(1) + end + if self.print_dns then + thr:push(1) + else + thr:push(0) + end + thr:push(tonumber(self._timespec.sec), tonumber(self._timespec.nsec)) + + table.insert(self._udp_channels, chan) + table.insert(self._stat_channels, stats) + table.insert(self._threads, thr) + self.log:info("UDP thread " .. n .. " started") + end + end + if not self.no_tcp then + self.log:info("starting " .. self.tcp_threads .. " TCP threads") + for n = 1, self.tcp_threads do + local chan = require("dnsjit.core.channel").new() + local stats = require("dnsjit.core.channel").new() + local thr = require("dnsjit.core.thread").new() + + thr:start(_thr_func) + thr:push("tcp", n, self.host, self.port, chan, stats) + if self.no_responses then + thr:push(0) + else + thr:push(1) + end + if self.print_dns then + thr:push(1) + else + thr:push(0) + end + thr:push(tonumber(self._timespec.sec), tonumber(self._timespec.nsec)) + + table.insert(self._tcp_channels, chan) + table.insert(self._stat_channels, stats) + table.insert(self._threads, thr) + self.log:info("TCP thread " .. n .. " started") + end + end + else + if not self.no_udp then + self._udpcli = require("dnsjit.output.udpcli").new() + self._udpcli:timeout(self._timespec.sec, self._timespec.nsec) + if self._udpcli:connect(self.host, self.port) ~= 0 then + self.log:fatal("unable to connect to host " .. self.host .. " port " .. self.port .. " with UDP") + end + end + if not self.no_tcp then + self._tcpcli = require("dnsjit.output.tcpcli").new() + self._tcpcli:timeout(self._timespec.sec, self._timespec.nsec) + if self._tcpcli:connect(self.host, self.port) ~= 0 then + self.log:fatal("unable to connect to host " .. self.host .. " port " .. self.port .. " with TCP") + end + end + end +end + +function Replay:run() + local lprod, lctx = self.layer:produce() + local udpcli = self._udpcli + local tcpcli = self._tcpcli + local log, packets, queries, responses, errors, timeouts = self.log, 0, 0, 0, 0, 0 + local send + + if self.use_threads then + -- TODO: generate code for all udp/tcp channels, see split gen code in test + local udpidx, tcpidx = 1, 1 + + local send_udp, send_tcp + send_udp = function(obj) + local chan = self._udp_channels[udpidx] + if not chan then + udpidx = 1 + chan = self._udp_channels[1] + end + chan:put(obj:copy()) + udpidx = udpidx + 1 + end + send_tcp = function(obj) + local chan = self._tcp_channels[tcpidx] + if not chan then + tcpidx = 1 + chan = self._tcp_channels[1] + end + chan:put(obj:copy()) + tcpidx = tcpidx + 1 + end + if self._udp_channels[1] and self._tcp_channels[1] then + send = function(obj) + local protocol = obj.obj_prev + while protocol ~= nil do + if protocol.obj_type == object.UDP then + send_udp(obj) + break + elseif protocol.obj_type == object.TCP then + send_tcp(obj) + break + end + protocol = protocol.obj_prev + end + end + elseif self._udp_channels[1] then + send = send_udp + elseif self._tcp_channels[1] then + send = send_tcp + end + else + local urecv, uctx, uprod + if udpcli then + urecv, uctx = udpcli:receive() + uprod = udpcli:produce() + end + local trecv, tctx, tprod + if tcpcli then + trecv, tctx = tcpcli:receive() + tprod = tcpcli:produce() + end + + local dns = require("dnsjit.core.object.dns").new() + local print_dns + if self.print_dns then + print_dns = function(payload) + dns.obj_prev = payload + dns:print() + end + end + local send_udp, send_tcp + if self.no_responses then + send_udp = function(obj) + log:info("sending udp query") + urecv(uctx, obj) + if print_dns then + print_dns(obj) + end + end + send_tcp = function(obj) + log:info("sending tcp query") + trecv(tctx, obj) + if print_dns then + print_dns(obj) + end + end + else + send_udp = function(obj) + log:info("sending udp query") + urecv(uctx, obj) + if print_dns then + print_dns(obj) + end + + local response = uprod(uctx) + if response == nil then + log:warning("producer error") + return + end + local payload = response:cast() + if payload.len == 0 then + timeouts = timeouts + 1 + log:info("timeout") + return + end + + responses = responses + 1 + log:info("got response") + if print_dns then + print_dns(response) + end + end + send_tcp = function(obj) + log:info("sending tcp query") + trecv(tctx, obj) + if print_dns then + print_dns(obj) + end + + local response = tprod(tctx) + if response == nil then + log:warning("producer error") + return + end + local payload = response:cast() + if payload.len == 0 then + timeouts = timeouts + 1 + log:info("timeout") + return + end + + responses = responses + 1 + log:info("got response") + if print_dns then + print_dns(response) + end + end + end + if udpcli and tcpcli then + send = function(obj) + dns.obj_prev = obj + if dns:parse_header() == 0 and dns.qr == 0 then + queries = queries + 1 + + local protocol = obj.obj_prev + while protocol ~= nil do + if protocol.obj_type == object.UDP then + send_udp(obj) + break + elseif protocol.obj_type == object.TCP then + send_tcp(obj) + break + end + protocol = protocol.obj_prev + end + end + end + elseif udpcli then + send = function(obj) + dns.obj_prev = obj + if dns:parse_header() == 0 and dns.qr == 0 then + queries = queries + 1 + send_udp(obj) + end + end + elseif tcpcli then + send = function(obj) + dns.obj_prev = obj + if dns:parse_header() == 0 and dns.qr == 0 then + queries = queries + 1 + send_tcp(obj) + end + end + end + end + + while true do + local obj = lprod(lctx) + if obj == nil then break end + packets = packets + 1 + local payload = obj:cast() + if obj:type() == "payload" and payload.len > 0 then + send(obj) + end + end + + if self.use_threads then + for _, chan in pairs(self._udp_channels) do + chan:put(nil) + end + for _, chan in pairs(self._tcp_channels) do + chan:put(nil) + end + end + + self.packets = packets + self.queries = queries + self.sent = 0 + if udpcli then + self.sent = self.sent + udpcli:packets() + end + if tcpcli then + self.sent = self.sent + tcpcli:packets() + end + -- TODO: received == responses ? + self.received = responses + self.responses = responses + self.timeouts = timeouts + self.errors = errors + if udpcli then + self.errors = self.errors + udpcli:errors() + end + if tcpcli then + self.errors = self.errors + tcpcli:errors() + end +end + +function Replay:finish() + if self.use_threads then + for _, thr in pairs(self._threads) do + thr:stop() + end + for _, stats in pairs(self._stat_channels) do + local stat = ffi.cast("struct replay_stats*", stats:get()) + self.queries = self.queries + stat.sent + self.sent = self.sent + stat.sent + self.received = self.received + stat.received + self.responses = self.responses + stat.responses + self.timeouts = self.timeouts + stat.timeouts + self.errors = self.errors + stat.errors + C.free(stat) + end + self.queries = tonumber(self.queries) + self.sent = tonumber(self.sent) + self.received = tonumber(self.received) + self.responses = tonumber(self.responses) + self.timeouts = tonumber(self.timeouts) + self.errors = tonumber(self.errors) + + -- TODO: received == responses ? + self.received = self.responses + end +end + +return Replay diff --git a/src/lib/drool/respdiff.lua b/src/lib/drool/respdiff.lua new file mode 100644 index 0000000..f41bc63 --- /dev/null +++ b/src/lib/drool/respdiff.lua @@ -0,0 +1,571 @@ +-- DNS Reply Tool (drool) +-- +-- Copyright (c) 2017-2021, OARC, Inc. +-- Copyright (c) 2017, Comcast Corporation +-- All rights reserved. +-- +-- Redistribution and use in source and binary forms, with or without +-- modification, are permitted provided that the following conditions +-- are met: +-- +-- 1. Redistributions of source code must retain the above copyright +-- notice, this list of conditions and the following disclaimer. +-- +-- 2. Redistributions in binary form must reproduce the above copyright +-- notice, this list of conditions and the following disclaimer in +-- the documentation and/or other materials provided with the +-- distribution. +-- +-- 3. Neither the name of the copyright holder nor the names of its +-- contributors may be used to endorse or promote products derived +-- from this software without specific prior written permission. +-- +-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +-- "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +-- LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS +-- FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE +-- COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +-- INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, +-- BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +-- LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +-- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT +-- LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN +-- ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +-- POSSIBILITY OF SUCH DAMAGE. + +module(...,package.seeall) + +local ffi = require("ffi") +local C = ffi.C +ffi.cdef[[ +struct respdiff_stats { + int64_t sent, received, responses, timeouts, errors; +}; +void* malloc(size_t); +void free(void*); +]] + +local clock = require("dnsjit.lib.clock") +local object = require("dnsjit.core.objects") +require("dnsjit.core.timespec_h") + +Respdiff = {} + +function Respdiff.new(getopt) + local self = setmetatable({ + path = nil, + fname = nil, + file = nil, + hname = nil, + host = nil, + port = nil, + + use_threads = false, + timeout = 10.0, + layer = nil, + input = nil, + respdiff = nil, + no_udp = false, + no_tcp = false, + udp_threads = 4, + tcp_threads = 2, + size = 10485760, + + packets = 0, + queries = 0, + sent = 0, + received = 0, + responses = 0, + errors = 0, + timeouts = 0, + + log = require("dnsjit.core.log").new("respdiff"), + + _timespec = ffi.new("core_timespec_t"), + _udp_channels = {}, + _tcp_channels = {}, + _threads = {}, + _udpcli = nil, + _tcpcli = nil, + _stat_channels = {}, + _result_channels = {}, + }, { __index = Respdiff }) + + if getopt then + getopt.usage_desc = arg[1] .. " respdiff [options...] path name file name host port" + getopt:add(nil, "no-udp", false, "Do not use UDP", "?") + getopt:add(nil, "no-tcp", false, "Do not use TCP", "?") + getopt:add("T", "threads", false, "Use threads", "?") + getopt:add(nil, "udp-threads", 4, "Set the number of UDP threads to use, default 4", "?") + getopt:add(nil, "tcp-threads", 2, "Set the number of TCP threads to use, default 2", "?") + getopt:add(nil, "timeout", "10.0", "Set timeout for waiting on responses [seconds.nanoseconds], default 10.0", "?") + getopt:add(nil, "size", 10485760, "Set the size (in bytes, multiple of OS page size) of the LMDB database, default 10485760.", "?") + end + + return self +end + +function Respdiff:getopt(getopt) + local _, path, fname, file, hname, host, port = unpack(getopt.left) + + if getopt:val("no-udp") and getopt:val("no-tcp") then + self.log:fatal("can not disable all transports") + end + if getopt:val("udp-threads") < 1 then + self.log:fatal("--udp-threads must be 1 or greater") + end + if getopt:val("tcp-threads") < 1 then + self.log:fatal("--tcp-threads must be 1 or greater") + end + + if path == nil then + self.log:fatal("no path given") + end + self.path = path + if fname == nil then + self.log:fatal("no name for file given") + end + self.fname = fname + if file == nil then + self.log:fatal("no file given") + end + self.file = file + if hname == nil then + self.log:fatal("no name for host given") + end + self.hname = hname + if host == nil then + self.log:fatal("no target host given") + end + self.host = host + if port == nil then + self.log:fatal("no target port given") + end + self.port = port + + self.use_threads = getopt:val("T") + self.timeout = tonumber(getopt:val("timeout")) + self.no_udp = getopt:val("no-udp") + self.udp_threads = getopt:val("udp-threads") + self.no_tcp = getopt:val("no-tcp") + self.tcp_threads = getopt:val("tcp-threads") + self.size = getopt:val("size") +end + +local _thr_func = function(thr) + local mode, thrid, host, port, chan, stats, to_sec, to_nsec, result = thr:pop(9) + local log = require("dnsjit.core.log").new(mode .. "#" .. thrid) + require("dnsjit.core.objects") + local ffi = require("ffi") + local C = ffi.C + ffi.cdef[[ +struct respdiff_stats { + int64_t sent, received, responses, timeouts, errors; +}; +void* malloc(size_t); +void free(void*); +]] + + local cli, recv, ctx, prod + if mode == "udp" then + cli = require("dnsjit.output.udpcli").new() + else + cli = require("dnsjit.output.tcpcli").new() + end + cli:timeout(to_sec, to_nsec) + if cli:connect(host, port) ~= 0 then + log:fatal("unable to connect to host " .. host .. " port " .. port) + end + recv, ctx = cli:receive() + prod = cli:produce() + + local stat = ffi.cast("struct respdiff_stats*", C.malloc(ffi.sizeof("struct respdiff_stats"))) + ffi.fill(stat, ffi.sizeof("struct respdiff_stats")) + ffi.gc(stat, C.free) + + while true do + local obj = chan:get() + if obj == nil then break end + obj = ffi.cast("core_object_t*", obj) + local resp = ffi.cast("core_object_t*", obj.obj_prev) + + log:info("sending query") + recv(ctx, obj) + stat.sent = stat.sent + 1 + + local response = prod(ctx) + if response == nil then + log:warning("producer error") + stat.errors = stat.errors + 1 + break + end + local payload = response:cast() + if payload.len == 0 then + stat.timeouts = stat.timeouts + 1 + log:info("timeout") + else + stat.responses = stat.responses + 1 + log:info("got response") + resp.obj_prev = response:copy() + end + result:put(obj) + end + + stat.errors = stat.errors + cli:errors() + ffi.gc(stat, nil) + stats:put(stat) +end + +function Respdiff:setup() + self.input = require("dnsjit.input.mmpcap").new() + if self.input:open(self.file) ~= 0 then + self.log:fatal("unable to open file " .. self.file) + end + + self.layer = require("dnsjit.filter.layer").new() + self.layer:producer(self.input) + + self.respdiff = require("dnsjit.output.respdiff").new(self.path, self.fname, self.hname, self.size) + + self._timespec.sec = math.floor(self.timeout) + self._timespec.nsec = (self.timeout - math.floor(self.timeout)) * 1000000000 + + if self.use_threads then + if not self.no_udp then + self.log:info("starting " .. self.udp_threads .. " UDP threads") + for n = 1, self.udp_threads do + local chan = require("dnsjit.core.channel").new() + local stats = require("dnsjit.core.channel").new() + local result = require("dnsjit.core.channel").new() + local thr = require("dnsjit.core.thread").new() + + thr:start(_thr_func) + thr:push("udp", n, self.host, self.port, chan, stats, tonumber(self._timespec.sec), tonumber(self._timespec.nsec), result) + + table.insert(self._udp_channels, chan) + table.insert(self._stat_channels, stats) + table.insert(self._result_channels, result) + table.insert(self._threads, thr) + self.log:info("UDP thread " .. n .. " started") + end + end + if not self.no_tcp then + self.log:info("starting " .. self.tcp_threads .. " TCP threads") + for n = 1, self.tcp_threads do + local chan = require("dnsjit.core.channel").new() + local stats = require("dnsjit.core.channel").new() + local result = require("dnsjit.core.channel").new() + local thr = require("dnsjit.core.thread").new() + + thr:start(_thr_func) + thr:push("tcp", n, self.host, self.port, chan, stats, tonumber(self._timespec.sec), tonumber(self._timespec.nsec), result) + + table.insert(self._tcp_channels, chan) + table.insert(self._stat_channels, stats) + table.insert(self._result_channels, result) + table.insert(self._threads, thr) + self.log:info("TCP thread " .. n .. " started") + end + end + else + if not self.no_udp then + self._udpcli = require("dnsjit.output.udpcli").new() + self._udpcli:timeout(self._timespec.sec, self._timespec.nsec) + if self._udpcli:connect(self.host, self.port) ~= 0 then + self.log:fatal("unable to connect to host " .. self.host .. " port " .. self.port .. " with UDP") + end + end + if not self.no_tcp then + self._tcpcli = require("dnsjit.output.tcpcli").new() + self._tcpcli:timeout(self._timespec.sec, self._timespec.nsec) + if self._tcpcli:connect(self.host, self.port) ~= 0 then + self.log:fatal("unable to connect to host " .. self.host .. " port " .. self.port .. " with TCP") + end + end + end +end + +function Respdiff:run() + local lprod, lctx = self.layer:produce() + local udpcli = self._udpcli + local tcpcli = self._tcpcli + local log, packets, queries, responses, errors, timeouts = self.log, 0, 0, 0, 0, 0 + local send + local resprecv, respctx = self.respdiff:receive() + + if self.use_threads then + -- TODO: generate code for all udp/tcp channels, see split gen code in test + local udpidx, tcpidx = 1, 1 + + local send_udp = function(obj, resp) + local chan = self._udp_channels[udpidx] + if not chan then + udpidx = 1 + chan = self._udp_channels[1] + end + local obj_copy, resp_copy = obj:copy(), resp:copy() + obj_copy = ffi.cast("core_object_t*", obj_copy) + obj_copy.obj_prev = resp_copy + chan:put(obj_copy) + udpidx = udpidx + 1 + end + local send_tcp = function(obj, resp) + local chan = self._tcp_channels[tcpidx] + if not chan then + tcpidx = 1 + chan = self._tcp_channels[1] + end + local obj_copy, resp_copy = obj:copy(), resp:copy() + obj_copy = ffi.cast("core_object_t*", obj_copy) + obj_copy.obj_prev = resp_copy + chan:put(obj_copy) + tcpidx = tcpidx + 1 + end + if self._udp_channels[1] and self._tcp_channels[1] then + send = function(obj, resp, protocol) + if protocol.obj_type == object.UDP then + send_udp(obj, resp) + elseif protocol.obj_type == object.TCP then + send_tcp(obj, resp) + end + end + elseif self._udp_channels[1] then + send = send_udp + elseif self._tcp_channels[1] then + send = send_tcp + end + else + local urecv, uctx, uprod + if udpcli then + urecv, uctx = udpcli:receive() + uprod = udpcli:produce() + end + local trecv, tctx, tprod + if tcpcli then + trecv, tctx = tcpcli:receive() + tprod = tcpcli:produce() + end + + local send_udp = function(obj, resp) + log:info("sending udp query") + urecv(uctx, obj) + + local response = uprod(uctx) + if response == nil then + log:warning("producer error") + return + end + + obj = ffi.cast("core_object_t*", obj) + obj.obj_prev = resp + resp = ffi.cast("core_object_t*", resp) + + local payload = response:cast() + if payload.len == 0 then + timeouts = timeouts + 1 + log:info("timeout") + resp.obj_prev = nil + else + responses = responses + 1 + log:info("got response") + resp.obj_prev = response + end + resprecv(respctx, obj) + end + local send_tcp = function(obj, resp) + log:info("sending tcp query") + trecv(tctx, obj) + + local response = tprod(tctx) + if response == nil then + log:warning("producer error") + return + end + + obj = ffi.cast("core_object_t*", obj) + obj.obj_prev = resp + resp = ffi.cast("core_object_t*", resp) + + local payload = response:cast() + if payload.len == 0 then + timeouts = timeouts + 1 + log:info("timeout") + resp.obj_prev = nil + else + responses = responses + 1 + log:info("got response") + resp.obj_prev = response + end + resprecv(respctx, obj) + end + if udpcli and tcpcli then + send = function(obj, resp, protocol) + if protocol.obj_type == object.UDP then + send_udp(obj, resp) + elseif protocol.obj_type == object.TCP then + send_tcp(obj, resp) + end + end + elseif udpcli then + send = send_udp + elseif tcpcli then + send = send_tcp + end + end + + self.start_sec = clock:realtime() + + local qtbl = {} + local dns = require("dnsjit.core.object.dns").new() + while true do + local obj = lprod(lctx) + if obj == nil then break end + packets = packets + 1 + local payload = obj:cast() + if obj:type() == "payload" and payload.len > 0 then + + local transport = obj.obj_prev + while transport ~= nil do + if transport.obj_type == object.IP or transport.obj_type == object.IP6 then + break + end + transport = transport.obj_prev + end + local protocol = obj.obj_prev + while protocol ~= nil do + if protocol.obj_type == object.UDP or protocol.obj_type == object.TCP then + break + end + protocol = protocol.obj_prev + end + + if transport ~= nil and protocol ~= nil then + transport = transport:cast() + protocol = protocol:cast() + + dns.obj_prev = obj + if dns:parse_header() == 0 then + if dns.qr == 0 then + local k = string.format("%s %d %s %d", transport:source(), protocol.sport, transport:destination(), protocol.dport) + log:info("query " .. k .. " id " .. dns.id) + qtbl[k] = { + id = dns.id, + payload = payload:copy(), + } + else + local k = string.format("%s %d %s %d", transport:destination(), protocol.dport, transport:source(), protocol.sport) + local q = qtbl[k] + if q and q.id == dns.id then + log:info("response " .. k .. " id " .. dns.id) + queries = queries + 1 + send(q.payload:uncast(), obj, protocol) + qtbl[k] = nil + end + end + end + end + end + + if self.use_threads and responses < queries then + for _, result in pairs(self._result_channels) do + local res = result:try_get() + if res ~= nil then + res = ffi.cast("core_object_t*", res) + resprecv(respctx, res) + responses = responses + 1 + + if res.obj_prev.obj_prev ~= nil then + ffi.cast("core_object_t*", res.obj_prev.obj_prev):free() + end + ffi.cast("core_object_t*", res.obj_prev):free() + res:free() + end + end + end + end + + if self.use_threads then + for _, chan in pairs(self._udp_channels) do + chan:put(nil) + end + for _, chan in pairs(self._tcp_channels) do + chan:put(nil) + end + end + + self.packets = packets + self.queries = queries + self.sent = 0 + if udpcli then + self.sent = self.sent + udpcli:packets() + end + if tcpcli then + self.sent = self.sent + tcpcli:packets() + end + -- TODO: received == responses ? + self.received = responses + self.responses = responses + self.timeouts = timeouts + self.errors = errors + if udpcli then + self.errors = self.errors + udpcli:errors() + end + if tcpcli then + self.errors = self.errors + tcpcli:errors() + end +end + +function Respdiff:finish() + if self.use_threads then + local left = 0 - self.responses + self.responses = 0 + + for _, thr in pairs(self._threads) do + thr:stop() + end + for _, stats in pairs(self._stat_channels) do + local stat = ffi.cast("struct respdiff_stats*", stats:get()) + self.sent = self.sent + stat.sent + self.received = self.received + stat.received + self.responses = self.responses + stat.responses + self.timeouts = self.timeouts + stat.timeouts + self.errors = self.errors + stat.errors + C.free(stat) + end + self.sent = tonumber(self.sent) + self.received = tonumber(self.received) + self.responses = tonumber(self.responses) + self.timeouts = tonumber(self.timeouts) + self.errors = tonumber(self.errors) + + -- TODO: received == responses ? + self.received = self.responses + + left = left + self.responses + self.timeouts + local resprecv, respctx = self.respdiff:receive() + local tries = 0 + while left > 0 and tries < 10000 do + for _, result in pairs(self._result_channels) do + local res = result:try_get() + if res ~= nil then + res = ffi.cast("core_object_t*", res) + resprecv(respctx, res) + left = left - 1 + tries = 0 + + if res.obj_prev.obj_prev ~= nil then + ffi.cast("core_object_t*", res.obj_prev.obj_prev):free() + end + ffi.cast("core_object_t*", res.obj_prev):free() + res:free() + end + end + tries = tries + 1 + end + end + + local end_sec = clock:realtime() + self.respdiff:commit(self.start_sec, end_sec) +end + +return Respdiff |