summaryrefslogtreecommitdiffstats
path: root/examples/respdiff.lua
blob: 801384429ced3c519e3807dd124b85cd00b8acc4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env dnsjit
local ffi = require("ffi")
local clock = require("dnsjit.lib.clock")
local log = require("dnsjit.core.log")
log.display_file_line(true)
local getopt = require("dnsjit.lib.getopt").new({
    { "v", "verbose", 0, "Enable and increase verbosity for each time given", "?+" },
})
local pcap, host, port, path, origname, recvname = unpack(getopt:parse())
if getopt:val("help") then
    getopt:usage()
    return
end
local v = getopt:val("v")
if v > 0 then
    log.enable("warning")
end
if v > 1 then
    log.enable("notice")
end
if v > 2 then
    log.enable("info")
end
if v > 3 then
    log.enable("debug")
end

if pcap == nil or host == nil or port == nil or path == nil or origname == nil or recvname == nil then
    print("usage: "..arg[1].." <pcap> <host> <port> <LMDB path> <origname> <recvname>")
    return
end

local object = require("dnsjit.core.objects")
local dns = require("dnsjit.core.object.dns").new()
local input = require("dnsjit.input.mmpcap").new()
input:open(pcap)
local layer = require("dnsjit.filter.layer").new()
layer:producer(input)

local udpcli, tcpcli
local udprecv, udpctx, tcprecv, tcpctx
local udpprod, tcpprod

local prod, pctx = layer:produce()
local queries = {}
local clipayload = ffi.new("core_object_payload_t")
clipayload.obj_type = object.PAYLOAD
local cliobject = ffi.cast("core_object_t*", clipayload)

local respdiff = require("dnsjit.output.respdiff").new(path, origname, recvname)
local resprecv, respctx = respdiff:receive()
local query_payload, original_payload, response_payload = ffi.new("core_object_payload_t"), ffi.new("core_object_payload_t"), ffi.new("core_object_payload_t")
query_payload.obj_type = object.PAYLOAD
original_payload.obj_type = object.PAYLOAD
response_payload.obj_type = object.PAYLOAD
local query_payload_obj = ffi.cast("core_object_t*", query_payload)
query_payload.obj_prev = ffi.cast("core_object_t*", original_payload)
original_payload.obj_prev = ffi.cast("core_object_t*", response_payload)

local start_sec, start_nsec = clock:realtime()
while true do
    local obj = prod(pctx)
    if obj == nil then break end
    local payload = obj:cast()
    if obj:type() == "payload" and payload.len > 0 then
        local transport = obj.obj_prev
        while transport ~= nil do
            if transport.obj_type == object.IP or transport.obj_type == object.IP6 then
                break
            end
            transport = transport.obj_prev
        end
        local protocol = obj.obj_prev
        while protocol ~= nil do
            if protocol.obj_type == object.UDP or protocol.obj_type == object.TCP then
                break
            end
            protocol = protocol.obj_prev
        end

        dns:reset()
        if protocol ~= nil and protocol.obj_type == object.TCP then
            dns.includes_dnslen = 1
        end
        dns.obj_prev = obj
        if transport ~= nil and protocol ~= nil and dns:parse_header() == 0 then
            transport = transport:cast()
            protocol = protocol:cast()

            if dns.qr == 0 then
                local k = string.format("%s %d %s %d", transport:source(), protocol.sport, transport:destination(), protocol.dport)
                local q = {
                    id = dns.id,
                    proto = protocol:type(),
                    payload = ffi.new("uint8_t[?]", payload.len),
                    len = tonumber(payload.len)
                }
                ffi.copy(q.payload, payload.payload, payload.len)
                queries[k] = q
            else
                local k = string.format("%s %d %s %d", transport:destination(), protocol.dport, transport:source(), protocol.sport)
                local q = queries[k]
                if q then
                    queries[k] = nil
                    clipayload.payload = q.payload
                    clipayload.len = q.len

                    local prod, pctx

                    if q.proto == "udp" then
                        if not udpcli then
                            udpcli = require("dnsjit.output.udpcli").new()
                            udpcli:connect(host, port)
                            udprecv, udpctx = udpcli:receive()
                            udpprod, _ = udpcli:produce()
                        end
                        udprecv(udpctx, cliobject)
                        prod = udpprod
                        pctx = udpctx
                    elseif q.proto == "tcp" then
                        if not tcpcli then
                            tcpcli = require("dnsjit.output.tcpcli").new()
                            tcpcli:connect(host, port)
                            tcprecv, tcpctx = tcpcli:receive()
                            tcpprod, _ = tcpcli:produce()
                        end
                        tcprecv(tcpctx, cliobject)
                        prod = tcpprod
                        pctx = tcpctx
                    end

                    while true do
                        local response = prod(pctx)
                        if response == nil then
                            log.fatal("producer error")
                        end
                        local rpl = response:cast()
                        if rpl.len == 0 then
                            log.info("timed out")
                        else
                            dns.obj_prev = response
                            if dns:parse_header() == 0 and dns.id == q.id then
                                query_payload.payload = q.payload
                                query_payload.len = q.len
                                original_payload.payload = payload.payload
                                original_payload.len = payload.len
                                response_payload.payload = rpl.payload
                                response_payload.len = rpl.len

                                resprecv(respctx, query_payload_obj)
                                break
                            end
                        end
                    end
                end
            end
        end
    end
end
local end_sec, end_nsec = clock:realtime()

respdiff:commit(start_sec, end_sec)