summaryrefslogtreecommitdiffstats
path: root/examples/respdiff.lua
blob: 831f349f36decc4d9da3dfc69a53a2560f621cc9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#!/usr/bin/env dnsjit
local ffi = require("ffi")
local clock = require("dnsjit.lib.clock")
local log = require("dnsjit.core.log")
log.display_file_line(true)
local getopt = require("dnsjit.lib.getopt").new({
    { "v", "verbose", 0, "Enable and increase verbosity for each time given", "?+" },
})
local pcap, host, port, path, origname, recvname = unpack(getopt:parse())
if getopt:val("help") then
    getopt:usage()
    return
end
local v = getopt:val("v")
if v > 0 then
    log.enable("warning")
end
if v > 1 then
    log.enable("notice")
end
if v > 2 then
    log.enable("info")
end
if v > 3 then
    log.enable("debug")
end

if pcap == nil or host == nil or port == nil or path == nil or origname == nil or recvname == nil then
    print("usage: "..arg[1].." <pcap> <host> <port> <LMDB path> <origname> <recvname>")
    return
end

local object = require("dnsjit.core.objects")
local dns = require("dnsjit.core.object.dns").new()
local input = require("dnsjit.input.mmpcap").new()
input:open(pcap)
local layer = require("dnsjit.filter.layer").new()
layer:producer(input)

local udpcli, tcpcli
local udprecv, udpctx, tcprecv, tcpctx
local udpprod, tcpprod

local prod, pctx = layer:produce()
local queries = {}
local clipayload = ffi.new("core_object_payload_t")
clipayload.obj_type = object.PAYLOAD
local cliobject = ffi.cast("core_object_t*", clipayload)

local respdiff = require("dnsjit.output.respdiff").new(path, origname, recvname)
local resprecv, respctx = respdiff:receive()
local query_payload, original_payload, response_payload = ffi.new("core_object_payload_t"), ffi.new("core_object_payload_t"), ffi.new("core_object_payload_t")
query_payload.obj_type = object.PAYLOAD
original_payload.obj_type = object.PAYLOAD
response_payload.obj_type = object.PAYLOAD
local query_payload_obj = ffi.cast("core_object_t*", query_payload)
query_payload.obj_prev = ffi.cast("core_object_t*", original_payload)
original_payload.obj_prev = ffi.cast("core_object_t*", response_payload)

local start_sec, start_nsec = clock:realtime()
while true do
    local obj = prod(pctx)
    if obj == nil then break end
    local payload = obj:cast()
    if obj:type() == "payload" and payload.len > 0 then
        dns.obj_prev = obj
        if dns:parse_header() == 0 then
            local transport = obj.obj_prev
            while transport ~= nil do
                if transport.obj_type == object.IP or transport.obj_type == object.IP6 then
                    break
                end
                transport = transport.obj_prev
            end
            local protocol = obj.obj_prev
            while protocol ~= nil do
                if protocol.obj_type == object.UDP or protocol.obj_type == object.TCP then
                    break
                end
                protocol = protocol.obj_prev
            end

            if transport ~= nil and protocol ~= nil then
                transport = transport:cast()
                protocol = protocol:cast()

                if dns.qr == 0 then
                    local k = string.format("%s %d %s %d", transport:source(), protocol.sport, transport:destination(), protocol.dport)
                    local q = {
                        id = dns.id,
                        proto = protocol:type(),
                        payload = ffi.new("uint8_t[?]", payload.len),
                        len = tonumber(payload.len)
                    }
                    ffi.copy(q.payload, payload.payload, payload.len)
                    queries[k] = q
                else
                    local k = string.format("%s %d %s %d", transport:destination(), protocol.dport, transport:source(), protocol.sport)
                    local q = queries[k]
                    if q then
                        queries[k] = nil
                        clipayload.payload = q.payload
                        clipayload.len = q.len

                        local prod, pctx

                        if q.proto == "udp" then
                            if not udpcli then
                                udpcli = require("dnsjit.output.udpcli").new()
                                udpcli:connect(host, port)
                                udprecv, udpctx = udpcli:receive()
                                udpprod, _ = udpcli:produce()
                            end
                            udprecv(udpctx, cliobject)
                            prod = udpprod
                            pctx = udpctx
                        elseif q.proto == "tcp" then
                            if not tcpcli then
                                tcpcli = require("dnsjit.output.tcpcli").new()
                                tcpcli:connect(host, port)
                                tcprecv, tcpctx = tcpcli:receive()
                                tcpprod, _ = tcpcli:produce()
                            end
                            tcprecv(tcpctx, cliobject)
                            prod = tcpprod
                            pctx = tcpctx
                        end

                        while true do
                            local response = prod(pctx)
                            if response == nil then
                                log.fatal("producer error")
                            end
                            local rpl = response:cast()
                            if rpl.len == 0 then
                                log.info("timed out")
                            else
                                dns.obj_prev = response
                                if dns:parse_header() == 0 and dns.id == q.id then
                                    query_payload.payload = q.payload
                                    query_payload.len = q.len
                                    original_payload.payload = payload.payload
                                    original_payload.len = payload.len
                                    response_payload.payload = rpl.payload
                                    response_payload.len = rpl.len

                                    resprecv(respctx, query_payload_obj)
                                    break
                                end
                            end
                        end
                    end
                end
            end
        end
    end
end
local end_sec, end_nsec = clock:realtime()

respdiff:commit(start_sec, end_sec)