summaryrefslogtreecommitdiffstats
path: root/tests/omhttp_server.py
blob: 22c718434d47aaece005a1ec15d7240a4b06349d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# call this via "python[3] script name"
import argparse
import json
import os
import zlib
import base64
import random
import time

try:
    from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer # Python 2
except ImportError:
    from http.server import BaseHTTPRequestHandler, HTTPServer # Python 3

# Keep track of data received at each path
data = {}

metadata = {'posts': 0, 'fail_after': 0, 'fail_every': -1, 'decompress': False, 'userpwd': ''}


class MyHandler(BaseHTTPRequestHandler):
    """
    POST'd data is kept in the data global dict.
    Keys are the path, values are the raw received data.
    Two post requests to <host>:<port>/post/endpoint means data looks like...
        {"/post/endpoint": ["{\"msgnum\":\"00001\"}", "{\"msgnum\":\"00001\"}"]}

    GET requests return all data posted to that endpoint as a json list.
    Note that rsyslog usually sends escaped json data, so some parsing may be needed.
    A get request for <host>:<post>/post/endpoint responds with...
        ["{\"msgnum\":\"00001\"}", "{\"msgnum\":\"00001\"}"]
    """

    def validate_auth(self):
        # header format for basic authentication
        # 'Authorization: Basic <base 64 encoded uid:pwd>'
        if 'Authorization' not in self.headers:
            self.send_response(401)
            self.end_headers()
            self.wfile.write(b'missing "Authorization" header')
            return False

        auth_header = self.headers['Authorization']
        _, b64userpwd = auth_header.split()
        userpwd = base64.b64decode(b64userpwd).decode('utf-8')
        if userpwd != metadata['userpwd']:
            self.send_response(401)
            self.end_headers()
            self.wfile.write(b'invalid auth: {0}'.format(userpwd))
            return False

        return True

    def do_POST(self):
        metadata['posts'] += 1

        if metadata['userpwd']:
            if not self.validate_auth():
                return

        if metadata['fail_with_400_after'] != -1 and metadata['posts'] > metadata['fail_with_400_after']:
            if metadata['fail_with_delay_secs']:
                print("sleeping for: {0}".format(metadata['fail_with_delay_secs']))
                time.sleep(metadata['fail_with_delay_secs'])
            self.send_response(400)
            self.end_headers()
            self.wfile.write(b'BAD REQUEST')
            return

        if metadata['fail_with_401_or_403_after'] != -1 and metadata['posts'] > metadata['fail_with_401_or_403_after']:
            status = random.choice([401, 403])
            self.send_response(status)
            self.end_headers()
            self.wfile.write(b'BAD REQUEST')
            return

        if metadata['posts'] > 1 and metadata['fail_every'] != -1 and metadata['posts'] % metadata['fail_every'] == 0:
            if metadata['fail_with_delay_secs']:
                print("sleeping for: {0}".format(metadata['fail_with_delay_secs']))
                time.sleep(metadata['fail_with_delay_secs'])
            code = metadata['fail_with'] if metadata['fail_with'] else 500
            self.send_response(code)
            self.end_headers()
            self.wfile.write(b'INTERNAL ERROR')
            return

        content_length = int(self.headers['Content-Length'] or 0)
        raw_data = self.rfile.read(content_length)

        if metadata['decompress']:
            post_data = zlib.decompress(raw_data, 31)
        else:
            post_data = raw_data

        if self.path not in data:
            data[self.path] = []
        data[self.path].append(post_data.decode('utf-8'))

        res = json.dumps({'msg': 'ok'}).encode('utf8')

        self.send_response(200)
        self.send_header('Content-Type', 'application/json; charset=utf-8')
        self.send_header('Content-Length', len(res))
        self.end_headers()

        self.wfile.write(res)
        return

    def do_GET(self):
        if self.path in data:
            result = data[self.path]
        else:
            result = []

        res = json.dumps(result).encode('utf8')

        self.send_response(200)
        self.send_header('Content-Type', 'application/json; charset=utf-8')
        self.send_header('Content-Length', len(res))
        self.end_headers()

        self.wfile.write(res)
        return


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Archive and delete core app log files')
    parser.add_argument('-p', '--port', action='store', type=int, default=8080, help='port')
    parser.add_argument('--port-file', action='store', type=str, default='', help='file to store listen port number')
    parser.add_argument('-i', '--interface', action='store', type=str, default='localhost', help='port')
    parser.add_argument('--fail-after', action='store', type=int, default=0, help='start failing after n posts')
    parser.add_argument('--fail-every', action='store', type=int, default=-1, help='fail every n posts')
    parser.add_argument('--fail-with', action='store', type=int, default=500, help='on failure, fail with this code')
    parser.add_argument('--fail-with-400-after', action='store', type=int, default=-1, help='fail with 400 after n posts')
    parser.add_argument('--fail-with-401-or-403-after', action='store', type=int, default=-1, help='fail with 401 or 403 after n posts')
    parser.add_argument('--fail-with-delay-secs', action='store', type=int, default=0, help='fail with n secs of delay')
    parser.add_argument('--decompress', action='store_true', default=False, help='decompress posted data')
    parser.add_argument('--userpwd', action='store', default='', help='only accept this user:password combination')
    args = parser.parse_args()
    metadata['fail_after'] = args.fail_after
    metadata['fail_every'] = args.fail_every
    metadata['fail_with'] = args.fail_with
    metadata['fail_with_400_after'] = args.fail_with_400_after
    metadata['fail_with_401_or_403_after'] = args.fail_with_401_or_403_after
    metadata['fail_with_delay_secs'] = args.fail_with_delay_secs
    metadata['decompress'] = args.decompress
    metadata['userpwd'] = args.userpwd
    server = HTTPServer((args.interface, args.port), MyHandler)
    lstn_port = server.server_address[1]
    pid = os.getpid()
    print('starting omhttp test server at {interface}:{port} with pid {pid}'
          .format(interface=args.interface, port=lstn_port, pid=pid))
    if args.port_file != '':
        f = open(args.port_file, "w")
        f.write(str(lstn_port))
        f.close()
    server.serve_forever()