summaryrefslogtreecommitdiffstats
path: root/src/lib/drool/replay.lua
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib/drool/replay.lua')
-rw-r--r--src/lib/drool/replay.lua596
1 files changed, 596 insertions, 0 deletions
diff --git a/src/lib/drool/replay.lua b/src/lib/drool/replay.lua
new file mode 100644
index 0000000..c3b86bb
--- /dev/null
+++ b/src/lib/drool/replay.lua
@@ -0,0 +1,596 @@
+-- DNS Reply Tool (drool)
+--
+-- Copyright (c) 2017-2021, OARC, Inc.
+-- Copyright (c) 2017, Comcast Corporation
+-- All rights reserved.
+--
+-- Redistribution and use in source and binary forms, with or without
+-- modification, are permitted provided that the following conditions
+-- are met:
+--
+-- 1. Redistributions of source code must retain the above copyright
+-- notice, this list of conditions and the following disclaimer.
+--
+-- 2. Redistributions in binary form must reproduce the above copyright
+-- notice, this list of conditions and the following disclaimer in
+-- the documentation and/or other materials provided with the
+-- distribution.
+--
+-- 3. Neither the name of the copyright holder nor the names of its
+-- contributors may be used to endorse or promote products derived
+-- from this software without specific prior written permission.
+--
+-- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+-- "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+-- LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
+-- FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
+-- COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
+-- INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
+-- BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+-- LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+-- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
+-- LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
+-- ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+-- POSSIBILITY OF SUCH DAMAGE.
+
+module(...,package.seeall)
+
+local ffi = require("ffi")
+local C = ffi.C
+ffi.cdef[[
+struct replay_stats {
+ int64_t sent, received, responses, timeouts, errors;
+};
+void* malloc(size_t);
+void free(void*);
+]]
+
+local object = require("dnsjit.core.objects")
+require("dnsjit.core.timespec_h")
+
+Replay = {}
+
+function Replay.new(getopt)
+ local self = setmetatable({
+ file = nil,
+ host = nil,
+ port = nil,
+ no_responses = false,
+ use_threads = false,
+ print_dns = false,
+ timeout = 10.0,
+ timing = nil,
+ timing_mode = "ignore",
+ timing_opt = nil,
+ layer = nil,
+ input = nil,
+ no_udp = false,
+ no_tcp = false,
+ udp_threads = 4,
+ tcp_threads = 2,
+
+ packets = 0,
+ queries = 0,
+ sent = 0,
+ received = 0,
+ responses = 0,
+ errors = 0,
+ timeouts = 0,
+
+ log = require("dnsjit.core.log").new("replay"),
+
+ _timespec = ffi.new("core_timespec_t"),
+ _udp_channels = {},
+ _tcp_channels = {},
+ _threads = {},
+ _udpcli = nil,
+ _tcpcli = nil,
+ _stat_channels = {},
+ }, { __index = Replay })
+
+ if getopt then
+ getopt.usage_desc = arg[1] .. " replay [options...] file host port"
+ getopt:add("t", "timing", "ignore", "Set the timing mode [mode=option], default ignore", "?")
+ getopt:add("n", "no-responses", false, "Do not wait for responses before sending next request", "?")
+ getopt:add(nil, "no-udp", false, "Do not use UDP", "?")
+ getopt:add(nil, "no-tcp", false, "Do not use TCP", "?")
+ getopt:add("T", "threads", false, "Use threads", "?")
+ getopt:add(nil, "udp-threads", 4, "Set the number of UDP threads to use, default 4", "?")
+ getopt:add(nil, "tcp-threads", 2, "Set the number of TCP threads to use, default 2", "?")
+ getopt:add(nil, "timeout", "10.0", "Set timeout for waiting on responses [seconds.nanoseconds], default 10.0", "?")
+ getopt:add("D", nil, false, "Show DNS queries and responses as processing goes", "?")
+ end
+
+ return self
+end
+
+function Replay:getopt(getopt)
+ local _, file, host, port = unpack(getopt.left)
+
+ if getopt:val("no-udp") and getopt:val("no-tcp") then
+ self.log:fatal("can not disable all transports")
+ end
+ if getopt:val("udp-threads") < 1 then
+ self.log:fatal("--udp-threads must be 1 or greater")
+ end
+ if getopt:val("tcp-threads") < 1 then
+ self.log:fatal("--tcp-threads must be 1 or greater")
+ end
+
+ if file == nil then
+ self.log:fatal("no file given")
+ end
+ self.file = file
+ if host == nil then
+ self.log:fatal("no target host given")
+ end
+ self.host = host
+ if port == nil then
+ self.log:fatal("no target port given")
+ end
+ self.port = port
+ self.use_threads = getopt:val("T")
+ self.timeout = tonumber(getopt:val("timeout"))
+ self.no_responses = getopt:val("n")
+ self.print_dns = getopt:val("D")
+ self.no_udp = getopt:val("no-udp")
+ self.udp_threads = getopt:val("udp-threads")
+ self.no_tcp = getopt:val("no-tcp")
+ self.tcp_threads = getopt:val("tcp-threads")
+
+ if getopt:val("t") ~= "ignore" then
+ self.timing_mode, self.timing_opt = getopt:val("t"):match("(%w+)=([%w%.]+)")
+ if self.timing_mode == nil then
+ self.timing_mode = getopt:val("t")
+ end
+ end
+end
+
+local _thr_func = function(thr)
+ local mode, thrid, host, port, chan, stats, resp, print_dns, to_sec, to_nsec = thr:pop(10)
+ local log = require("dnsjit.core.log").new(mode .. "#" .. thrid)
+ require("dnsjit.core.objects")
+ local ffi = require("ffi")
+ local C = ffi.C
+ ffi.cdef[[
+struct replay_stats {
+ int64_t sent, received, responses, timeouts, errors;
+};
+void* malloc(size_t);
+void free(void*);
+]]
+ local dns = require("dnsjit.core.object.dns").new()
+
+ if print_dns == 1 then
+ print_dns = function(payload)
+ dns.obj_prev = payload
+ dns:print()
+ end
+ else
+ print_dns = nil
+ end
+
+ local cli, recv, ctx, prod
+ if mode == "udp" then
+ cli = require("dnsjit.output.udpcli").new()
+ else
+ cli = require("dnsjit.output.tcpcli").new()
+ end
+ cli:timeout(to_sec, to_nsec)
+ if cli:connect(host, port) ~= 0 then
+ log:fatal("unable to connect to host " .. host .. " port " .. port)
+ end
+ recv, ctx = cli:receive()
+ prod = cli:produce()
+
+ local stat = ffi.cast("struct replay_stats*", C.malloc(ffi.sizeof("struct replay_stats")))
+ ffi.fill(stat, ffi.sizeof("struct replay_stats"))
+ ffi.gc(stat, C.free)
+
+ local send
+ if resp == 0 then
+ send = function(obj)
+ log:info("sending query")
+ recv(ctx, obj)
+ if print_dns then
+ print_dns(obj)
+ end
+ end
+ else
+ send = function(obj)
+ log:info("sending query")
+ recv(ctx, obj)
+ if print_dns then
+ print_dns(obj)
+ end
+
+ local response = prod(ctx)
+ if response == nil then
+ log:warning("producer error")
+ return
+ end
+ local payload = response:cast()
+ if payload.len == 0 then
+ stat.timeouts = stat.timeouts + 1
+ log:info("timeout")
+ return
+ end
+
+ stat.responses = stat.responses + 1
+ log:info("got response")
+ if print_dns then
+ print_dns(response)
+ end
+ end
+ end
+
+ while true do
+ local obj = chan:get()
+ if obj == nil then break end
+ obj = ffi.cast("core_object_t*", obj)
+ dns.obj_prev = obj
+ if dns:parse_header() == 0 and dns.qr == 0 then
+ send(obj)
+ stat.sent = stat.sent + 1
+ end
+ obj:free()
+ end
+
+ stat.errors = cli:errors()
+ ffi.gc(stat, nil)
+ stats:put(stat)
+end
+
+function Replay:setup()
+ self.input = require("dnsjit.input.mmpcap").new()
+ if self.input:open(self.file) ~= 0 then
+ self.log:fatal("unable to open file " .. self.file)
+ end
+
+ if self.timing_mode ~= "ignore" then
+ self.timing = require("dnsjit.filter.timing").new()
+ self.timing:producer(self.input)
+
+ if self.timing_mode == "keep" then
+ else
+ if self.timing_mode == "inc" or self.timing_mode == "increase" then
+ self.timing:increase(tonumber(self.timing_opt))
+ elseif self.timing_mode == "red" or self.timing_mode == "reduce" then
+ self.timing:reduce(tonumber(self.timing_opt))
+ elseif self.timing_mode == "mul" or self.timing_mode == "multiply" then
+ self.timing:multiply(tonumber(self.timing_opt))
+ elseif self.timing_mode == "fix" or self.timing_mode == "fixed" then
+ self.timing:fixed(tonumber(self.timing_opt))
+ else
+ self.log:fatal("Invalid timing mode " .. self.timing_mode)
+ end
+ end
+ end
+
+ self.layer = require("dnsjit.filter.layer").new()
+ if self.timing then
+ self.layer:producer(self.timing)
+ else
+ self.layer:producer(self.input)
+ end
+
+ self._timespec.sec = math.floor(self.timeout)
+ self._timespec.nsec = (self.timeout - math.floor(self.timeout)) * 1000000000
+
+ if self.use_threads then
+ if not self.no_udp then
+ self.log:info("starting " .. self.udp_threads .. " UDP threads")
+ for n = 1, self.udp_threads do
+ local chan = require("dnsjit.core.channel").new()
+ local stats = require("dnsjit.core.channel").new()
+ local thr = require("dnsjit.core.thread").new()
+
+ thr:start(_thr_func)
+ thr:push("udp", n, self.host, self.port, chan, stats)
+ if self.no_responses then
+ thr:push(0)
+ else
+ thr:push(1)
+ end
+ if self.print_dns then
+ thr:push(1)
+ else
+ thr:push(0)
+ end
+ thr:push(tonumber(self._timespec.sec), tonumber(self._timespec.nsec))
+
+ table.insert(self._udp_channels, chan)
+ table.insert(self._stat_channels, stats)
+ table.insert(self._threads, thr)
+ self.log:info("UDP thread " .. n .. " started")
+ end
+ end
+ if not self.no_tcp then
+ self.log:info("starting " .. self.tcp_threads .. " TCP threads")
+ for n = 1, self.tcp_threads do
+ local chan = require("dnsjit.core.channel").new()
+ local stats = require("dnsjit.core.channel").new()
+ local thr = require("dnsjit.core.thread").new()
+
+ thr:start(_thr_func)
+ thr:push("tcp", n, self.host, self.port, chan, stats)
+ if self.no_responses then
+ thr:push(0)
+ else
+ thr:push(1)
+ end
+ if self.print_dns then
+ thr:push(1)
+ else
+ thr:push(0)
+ end
+ thr:push(tonumber(self._timespec.sec), tonumber(self._timespec.nsec))
+
+ table.insert(self._tcp_channels, chan)
+ table.insert(self._stat_channels, stats)
+ table.insert(self._threads, thr)
+ self.log:info("TCP thread " .. n .. " started")
+ end
+ end
+ else
+ if not self.no_udp then
+ self._udpcli = require("dnsjit.output.udpcli").new()
+ self._udpcli:timeout(self._timespec.sec, self._timespec.nsec)
+ if self._udpcli:connect(self.host, self.port) ~= 0 then
+ self.log:fatal("unable to connect to host " .. self.host .. " port " .. self.port .. " with UDP")
+ end
+ end
+ if not self.no_tcp then
+ self._tcpcli = require("dnsjit.output.tcpcli").new()
+ self._tcpcli:timeout(self._timespec.sec, self._timespec.nsec)
+ if self._tcpcli:connect(self.host, self.port) ~= 0 then
+ self.log:fatal("unable to connect to host " .. self.host .. " port " .. self.port .. " with TCP")
+ end
+ end
+ end
+end
+
+function Replay:run()
+ local lprod, lctx = self.layer:produce()
+ local udpcli = self._udpcli
+ local tcpcli = self._tcpcli
+ local log, packets, queries, responses, errors, timeouts = self.log, 0, 0, 0, 0, 0
+ local send
+
+ if self.use_threads then
+ -- TODO: generate code for all udp/tcp channels, see split gen code in test
+ local udpidx, tcpidx = 1, 1
+
+ local send_udp, send_tcp
+ send_udp = function(obj)
+ local chan = self._udp_channels[udpidx]
+ if not chan then
+ udpidx = 1
+ chan = self._udp_channels[1]
+ end
+ chan:put(obj:copy())
+ udpidx = udpidx + 1
+ end
+ send_tcp = function(obj)
+ local chan = self._tcp_channels[tcpidx]
+ if not chan then
+ tcpidx = 1
+ chan = self._tcp_channels[1]
+ end
+ chan:put(obj:copy())
+ tcpidx = tcpidx + 1
+ end
+ if self._udp_channels[1] and self._tcp_channels[1] then
+ send = function(obj)
+ local protocol = obj.obj_prev
+ while protocol ~= nil do
+ if protocol.obj_type == object.UDP then
+ send_udp(obj)
+ break
+ elseif protocol.obj_type == object.TCP then
+ send_tcp(obj)
+ break
+ end
+ protocol = protocol.obj_prev
+ end
+ end
+ elseif self._udp_channels[1] then
+ send = send_udp
+ elseif self._tcp_channels[1] then
+ send = send_tcp
+ end
+ else
+ local urecv, uctx, uprod
+ if udpcli then
+ urecv, uctx = udpcli:receive()
+ uprod = udpcli:produce()
+ end
+ local trecv, tctx, tprod
+ if tcpcli then
+ trecv, tctx = tcpcli:receive()
+ tprod = tcpcli:produce()
+ end
+
+ local dns = require("dnsjit.core.object.dns").new()
+ local print_dns
+ if self.print_dns then
+ print_dns = function(payload)
+ dns.obj_prev = payload
+ dns:print()
+ end
+ end
+ local send_udp, send_tcp
+ if self.no_responses then
+ send_udp = function(obj)
+ log:info("sending udp query")
+ urecv(uctx, obj)
+ if print_dns then
+ print_dns(obj)
+ end
+ end
+ send_tcp = function(obj)
+ log:info("sending tcp query")
+ trecv(tctx, obj)
+ if print_dns then
+ print_dns(obj)
+ end
+ end
+ else
+ send_udp = function(obj)
+ log:info("sending udp query")
+ urecv(uctx, obj)
+ if print_dns then
+ print_dns(obj)
+ end
+
+ local response = uprod(uctx)
+ if response == nil then
+ log:warning("producer error")
+ return
+ end
+ local payload = response:cast()
+ if payload.len == 0 then
+ timeouts = timeouts + 1
+ log:info("timeout")
+ return
+ end
+
+ responses = responses + 1
+ log:info("got response")
+ if print_dns then
+ print_dns(response)
+ end
+ end
+ send_tcp = function(obj)
+ log:info("sending tcp query")
+ trecv(tctx, obj)
+ if print_dns then
+ print_dns(obj)
+ end
+
+ local response = tprod(tctx)
+ if response == nil then
+ log:warning("producer error")
+ return
+ end
+ local payload = response:cast()
+ if payload.len == 0 then
+ timeouts = timeouts + 1
+ log:info("timeout")
+ return
+ end
+
+ responses = responses + 1
+ log:info("got response")
+ if print_dns then
+ print_dns(response)
+ end
+ end
+ end
+ if udpcli and tcpcli then
+ send = function(obj)
+ dns.obj_prev = obj
+ if dns:parse_header() == 0 and dns.qr == 0 then
+ queries = queries + 1
+
+ local protocol = obj.obj_prev
+ while protocol ~= nil do
+ if protocol.obj_type == object.UDP then
+ send_udp(obj)
+ break
+ elseif protocol.obj_type == object.TCP then
+ send_tcp(obj)
+ break
+ end
+ protocol = protocol.obj_prev
+ end
+ end
+ end
+ elseif udpcli then
+ send = function(obj)
+ dns.obj_prev = obj
+ if dns:parse_header() == 0 and dns.qr == 0 then
+ queries = queries + 1
+ send_udp(obj)
+ end
+ end
+ elseif tcpcli then
+ send = function(obj)
+ dns.obj_prev = obj
+ if dns:parse_header() == 0 and dns.qr == 0 then
+ queries = queries + 1
+ send_tcp(obj)
+ end
+ end
+ end
+ end
+
+ while true do
+ local obj = lprod(lctx)
+ if obj == nil then break end
+ packets = packets + 1
+ local payload = obj:cast()
+ if obj:type() == "payload" and payload.len > 0 then
+ send(obj)
+ end
+ end
+
+ if self.use_threads then
+ for _, chan in pairs(self._udp_channels) do
+ chan:put(nil)
+ end
+ for _, chan in pairs(self._tcp_channels) do
+ chan:put(nil)
+ end
+ end
+
+ self.packets = packets
+ self.queries = queries
+ self.sent = 0
+ if udpcli then
+ self.sent = self.sent + udpcli:packets()
+ end
+ if tcpcli then
+ self.sent = self.sent + tcpcli:packets()
+ end
+ -- TODO: received == responses ?
+ self.received = responses
+ self.responses = responses
+ self.timeouts = timeouts
+ self.errors = errors
+ if udpcli then
+ self.errors = self.errors + udpcli:errors()
+ end
+ if tcpcli then
+ self.errors = self.errors + tcpcli:errors()
+ end
+end
+
+function Replay:finish()
+ if self.use_threads then
+ for _, thr in pairs(self._threads) do
+ thr:stop()
+ end
+ for _, stats in pairs(self._stat_channels) do
+ local stat = ffi.cast("struct replay_stats*", stats:get())
+ self.queries = self.queries + stat.sent
+ self.sent = self.sent + stat.sent
+ self.received = self.received + stat.received
+ self.responses = self.responses + stat.responses
+ self.timeouts = self.timeouts + stat.timeouts
+ self.errors = self.errors + stat.errors
+ C.free(stat)
+ end
+ self.queries = tonumber(self.queries)
+ self.sent = tonumber(self.sent)
+ self.received = tonumber(self.received)
+ self.responses = tonumber(self.responses)
+ self.timeouts = tonumber(self.timeouts)
+ self.errors = tonumber(self.errors)
+
+ -- TODO: received == responses ?
+ self.received = self.responses
+ end
+end
+
+return Replay