summaryrefslogtreecommitdiffstats
path: root/src/jaegertracing/thrift/test/py/TSimpleJSONProtocolTest.py
blob: 72987602bcf76ed4aae4929ca15394a2bdbbf81b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python

#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

from ThriftTest.ttypes import Bonk, VersioningTestV1, VersioningTestV2
from thrift.protocol import TJSONProtocol
from thrift.transport import TTransport

import json
import unittest


class SimpleJSONProtocolTest(unittest.TestCase):
    protocol_factory = TJSONProtocol.TSimpleJSONProtocolFactory()

    def _assertDictEqual(self, a, b, msg=None):
        if hasattr(self, 'assertDictEqual'):
            # assertDictEqual only in Python 2.7. Depends on your machine.
            self.assertDictEqual(a, b, msg)
            return

        # Substitute implementation not as good as unittest library's
        self.assertEquals(len(a), len(b), msg)
        for k, v in a.iteritems():
            self.assertTrue(k in b, msg)
            self.assertEquals(b.get(k), v, msg)

    def _serialize(self, obj):
        trans = TTransport.TMemoryBuffer()
        prot = self.protocol_factory.getProtocol(trans)
        obj.write(prot)
        return trans.getvalue()

    def _deserialize(self, objtype, data):
        prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
        ret = objtype()
        ret.read(prot)
        return ret

    def testWriteOnly(self):
        self.assertRaises(NotImplementedError,
                          self._deserialize, VersioningTestV1, b'{}')

    def testSimpleMessage(self):
        v1obj = VersioningTestV1(
            begin_in_both=12345,
            old_string='aaa',
            end_in_both=54321)
        expected = dict(begin_in_both=v1obj.begin_in_both,
                        old_string=v1obj.old_string,
                        end_in_both=v1obj.end_in_both)
        actual = json.loads(self._serialize(v1obj).decode('ascii'))

        self._assertDictEqual(expected, actual)

    def testComplicated(self):
        v2obj = VersioningTestV2(
            begin_in_both=12345,
            newint=1,
            newbyte=2,
            newshort=3,
            newlong=4,
            newdouble=5.0,
            newstruct=Bonk(message="Hello!", type=123),
            newlist=[7, 8, 9],
            newset=set([42, 1, 8]),
            newmap={1: 2, 2: 3},
            newstring="Hola!",
            end_in_both=54321)
        expected = dict(begin_in_both=v2obj.begin_in_both,
                        newint=v2obj.newint,
                        newbyte=v2obj.newbyte,
                        newshort=v2obj.newshort,
                        newlong=v2obj.newlong,
                        newdouble=v2obj.newdouble,
                        newstruct=dict(message=v2obj.newstruct.message,
                                       type=v2obj.newstruct.type),
                        newlist=v2obj.newlist,
                        newset=list(v2obj.newset),
                        newmap=v2obj.newmap,
                        newstring=v2obj.newstring,
                        end_in_both=v2obj.end_in_both)

        # Need to load/dump because map keys get escaped.
        expected = json.loads(json.dumps(expected))
        actual = json.loads(self._serialize(v2obj).decode('ascii'))
        self._assertDictEqual(expected, actual)


if __name__ == '__main__':
    unittest.main()