summaryrefslogtreecommitdiffstats
path: root/ml/dlib/tools/python/test/test_vector.py
blob: ff79ab339815dd3e4e7d425b1a4605c72add3e26 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
from dlib import vector, vectors, vectorss, dot
try:
    import cPickle as pickle  # Use cPickle on Python 2.7
except ImportError:
    import pickle
from pytest import raises


def test_vector_empty_init():
    v = vector()
    assert len(v) == 0
    assert v.shape == (0, 1)
    assert str(v) == ""
    assert repr(v) == "dlib.vector([])"


def test_vector_init_with_number():
    v = vector(3)
    assert len(v) == 3
    assert v.shape == (3, 1)
    assert str(v) == "0\n0\n0"
    assert repr(v) == "dlib.vector([0, 0, 0])"


def test_vector_set_size():
    v = vector(3)

    v.set_size(0)
    assert len(v) == 0
    assert v.shape == (0, 1)

    v.resize(10)
    assert len(v) == 10
    assert v.shape == (10, 1)
    for i in range(10):
        assert v[i] == 0


def test_vector_init_with_list():
    v = vector([1, 2, 3])
    assert len(v) == 3
    assert v.shape == (3, 1)
    assert str(v) == "1\n2\n3"
    assert repr(v) == "dlib.vector([1, 2, 3])"


def test_vector_getitem():
    v = vector([1, 2, 3])
    assert v[0] == 1
    assert v[-1] == 3
    assert v[1] == v[-2]


def test_vector_slice():
    v = vector([1, 2, 3, 4, 5])
    v_slice = v[1:4]
    assert len(v_slice) == 3
    for idx, val in enumerate([2, 3, 4]):
        assert v_slice[idx] == val

    v_slice = v[-3:-1]
    assert len(v_slice) == 2
    for idx, val in enumerate([3, 4]):
        assert v_slice[idx] == val

    v_slice = v[1:-2]
    assert len(v_slice) == 2
    for idx, val in enumerate([2, 3]):
        assert v_slice[idx] == val


def test_vector_invalid_getitem():
    v = vector([1, 2, 3])
    with raises(IndexError):
        v[-4]
    with raises(IndexError):
        v[3]


def test_vector_init_with_negative_number():
    with raises(Exception):
        vector(-3)


def test_dot():
    v1 = vector([1, 0])
    v2 = vector([0, 1])
    v3 = vector([-1, 0])
    assert dot(v1, v1) == 1
    assert dot(v1, v2) == 0
    assert dot(v1, v3) == -1


def test_vector_serialization():
    v = vector([1, 2, 3])
    ser = pickle.dumps(v, 2)
    deser = pickle.loads(ser)
    assert str(v) == str(deser)


def generate_test_vectors():
    vs = vectors()
    vs.append(vector([0, 1, 2]))
    vs.append(vector([3, 4, 5]))
    vs.append(vector([6, 7, 8]))
    assert len(vs) == 3
    return vs


def generate_test_vectorss():
    vss = vectorss()
    vss.append(generate_test_vectors())
    vss.append(generate_test_vectors())
    vss.append(generate_test_vectors())
    assert len(vss) == 3
    return vss


def test_vectors_serialization():
    vs = generate_test_vectors()
    ser = pickle.dumps(vs, 2)
    deser = pickle.loads(ser)
    assert vs == deser


def test_vectors_clear():
    vs = generate_test_vectors()
    vs.clear()
    assert len(vs) == 0


def test_vectors_resize():
    vs = vectors()
    vs.resize(100)
    assert len(vs) == 100
    for i in range(100):
        assert len(vs[i]) == 0


def test_vectors_extend():
    vs = vectors()
    vs.extend([vector([1, 2, 3]), vector([4, 5, 6])])
    assert len(vs) == 2


def test_vectorss_serialization():
    vss = generate_test_vectorss()
    ser = pickle.dumps(vss, 2)
    deser = pickle.loads(ser)
    assert vss == deser


def test_vectorss_clear():
    vss = generate_test_vectorss()
    vss.clear()
    assert len(vss) == 0


def test_vectorss_resize():
    vss = vectorss()
    vss.resize(100)
    assert len(vss) == 100
    for i in range(100):
        assert len(vss[i]) == 0


def test_vectorss_extend():
    vss = vectorss()
    vss.extend([generate_test_vectors(), generate_test_vectors()])
    assert len(vss) == 2