summaryrefslogtreecommitdiffstats
path: root/tests/omhttp_server.py
blob: 61e0e63352a9563121936f6eb9e4c70adad9d94a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# call this via "python[3] script name"
import argparse
import json
import os
import zlib
import base64

try:
    from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer # Python 2
except ImportError:
    from http.server import BaseHTTPRequestHandler, HTTPServer # Python 3

# Keep track of data received at each path
data = {}

metadata = {'posts': 0, 'fail_after': 0, 'fail_every': -1, 'decompress': False, 'userpwd': ''}


class MyHandler(BaseHTTPRequestHandler):
    """
    POST'd data is kept in the data global dict.
    Keys are the path, values are the raw received data.
    Two post requests to <host>:<port>/post/endpoint means data looks like...
        {"/post/endpoint": ["{\"msgnum\":\"00001\"}", "{\"msgnum\":\"00001\"}"]}

    GET requests return all data posted to that endpoint as a json list.
    Note that rsyslog usually sends escaped json data, so some parsing may be needed.
    A get request for <host>:<post>/post/endpoint responds with...
        ["{\"msgnum\":\"00001\"}", "{\"msgnum\":\"00001\"}"]
    """

    def validate_auth(self):
        # header format for basic authentication
        # 'Authorization: Basic <base 64 encoded uid:pwd>'
        if 'Authorization' not in self.headers:
            self.send_response(401)
            self.end_headers()
            self.wfile.write(b'missing "Authorization" header')
            return False

        auth_header = self.headers['Authorization']
        _, b64userpwd = auth_header.split()
        userpwd = base64.b64decode(b64userpwd).decode('utf-8')
        if userpwd != metadata['userpwd']:
            self.send_response(401)
            self.end_headers()
            self.wfile.write(b'invalid auth: {0}'.format(userpwd))
            return False

        return True

    def do_POST(self):
        metadata['posts'] += 1

        if metadata['userpwd']:
            if not self.validate_auth():
                return

        if metadata['fail_with_400_after'] != -1 and metadata['posts'] > metadata['fail_with_400_after']:
            self.send_response(400)
            self.end_headers()
            self.wfile.write(b'BAD REQUEST')
            return

        if metadata['posts'] > 1 and metadata['fail_every'] != -1 and metadata['posts'] % metadata['fail_every'] == 0:
            self.send_response(500)
            self.end_headers()
            self.wfile.write(b'INTERNAL ERROR')
            return

        content_length = int(self.headers['Content-Length'] or 0)
        raw_data = self.rfile.read(content_length)

        if metadata['decompress']:
            post_data = zlib.decompress(raw_data, 31)
        else:
            post_data = raw_data

        if self.path not in data:
            data[self.path] = []
        data[self.path].append(post_data.decode('utf-8'))

        res = json.dumps({'msg': 'ok'}).encode('utf8')

        self.send_response(200)
        self.send_header('Content-Type', 'application/json; charset=utf-8')
        self.send_header('Content-Length', len(res))
        self.end_headers()

        self.wfile.write(res)
        return

    def do_GET(self):
        if self.path in data:
            result = data[self.path]
        else:
            result = []

        res = json.dumps(result).encode('utf8')

        self.send_response(200)
        self.send_header('Content-Type', 'application/json; charset=utf-8')
        self.send_header('Content-Length', len(res))
        self.end_headers()

        self.wfile.write(res)
        return


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Archive and delete core app log files')
    parser.add_argument('-p', '--port', action='store', type=int, default=8080, help='port')
    parser.add_argument('--port-file', action='store', type=str, default='', help='file to store listen port number')
    parser.add_argument('-i', '--interface', action='store', type=str, default='localhost', help='port')
    parser.add_argument('--fail-after', action='store', type=int, default=0, help='start failing after n posts')
    parser.add_argument('--fail-every', action='store', type=int, default=-1, help='fail every n posts')
    parser.add_argument('--fail-with-400-after', action='store', type=int, default=-1, help='fail with 400 after n posts')
    parser.add_argument('--decompress', action='store_true', default=False, help='decompress posted data')
    parser.add_argument('--userpwd', action='store', default='', help='only accept this user:password combination')
    args = parser.parse_args()
    metadata['fail_after'] = args.fail_after
    metadata['fail_every'] = args.fail_every
    metadata['fail_with_400_after'] = args.fail_with_400_after
    metadata['decompress'] = args.decompress
    metadata['userpwd'] = args.userpwd
    server = HTTPServer((args.interface, args.port), MyHandler)
    lstn_port = server.server_address[1]
    pid = os.getpid()
    print('starting omhttp test server at {interface}:{port} with pid {pid}'
          .format(interface=args.interface, port=lstn_port, pid=pid))
    if args.port_file != '':
        f = open(args.port_file, "w")
        f.write(str(lstn_port))
        f.close()
    server.serve_forever()