summaryrefslogtreecommitdiffstats
path: root/ml/dlib/tools/python/test/test_vector.py
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/tools/python/test/test_vector.py')
-rw-r--r--ml/dlib/tools/python/test/test_vector.py170
1 files changed, 170 insertions, 0 deletions
diff --git a/ml/dlib/tools/python/test/test_vector.py b/ml/dlib/tools/python/test/test_vector.py
new file mode 100644
index 000000000..ff79ab339
--- /dev/null
+++ b/ml/dlib/tools/python/test/test_vector.py
@@ -0,0 +1,170 @@
+from dlib import vector, vectors, vectorss, dot
+try:
+ import cPickle as pickle # Use cPickle on Python 2.7
+except ImportError:
+ import pickle
+from pytest import raises
+
+
+def test_vector_empty_init():
+ v = vector()
+ assert len(v) == 0
+ assert v.shape == (0, 1)
+ assert str(v) == ""
+ assert repr(v) == "dlib.vector([])"
+
+
+def test_vector_init_with_number():
+ v = vector(3)
+ assert len(v) == 3
+ assert v.shape == (3, 1)
+ assert str(v) == "0\n0\n0"
+ assert repr(v) == "dlib.vector([0, 0, 0])"
+
+
+def test_vector_set_size():
+ v = vector(3)
+
+ v.set_size(0)
+ assert len(v) == 0
+ assert v.shape == (0, 1)
+
+ v.resize(10)
+ assert len(v) == 10
+ assert v.shape == (10, 1)
+ for i in range(10):
+ assert v[i] == 0
+
+
+def test_vector_init_with_list():
+ v = vector([1, 2, 3])
+ assert len(v) == 3
+ assert v.shape == (3, 1)
+ assert str(v) == "1\n2\n3"
+ assert repr(v) == "dlib.vector([1, 2, 3])"
+
+
+def test_vector_getitem():
+ v = vector([1, 2, 3])
+ assert v[0] == 1
+ assert v[-1] == 3
+ assert v[1] == v[-2]
+
+
+def test_vector_slice():
+ v = vector([1, 2, 3, 4, 5])
+ v_slice = v[1:4]
+ assert len(v_slice) == 3
+ for idx, val in enumerate([2, 3, 4]):
+ assert v_slice[idx] == val
+
+ v_slice = v[-3:-1]
+ assert len(v_slice) == 2
+ for idx, val in enumerate([3, 4]):
+ assert v_slice[idx] == val
+
+ v_slice = v[1:-2]
+ assert len(v_slice) == 2
+ for idx, val in enumerate([2, 3]):
+ assert v_slice[idx] == val
+
+
+def test_vector_invalid_getitem():
+ v = vector([1, 2, 3])
+ with raises(IndexError):
+ v[-4]
+ with raises(IndexError):
+ v[3]
+
+
+def test_vector_init_with_negative_number():
+ with raises(Exception):
+ vector(-3)
+
+
+def test_dot():
+ v1 = vector([1, 0])
+ v2 = vector([0, 1])
+ v3 = vector([-1, 0])
+ assert dot(v1, v1) == 1
+ assert dot(v1, v2) == 0
+ assert dot(v1, v3) == -1
+
+
+def test_vector_serialization():
+ v = vector([1, 2, 3])
+ ser = pickle.dumps(v, 2)
+ deser = pickle.loads(ser)
+ assert str(v) == str(deser)
+
+
+def generate_test_vectors():
+ vs = vectors()
+ vs.append(vector([0, 1, 2]))
+ vs.append(vector([3, 4, 5]))
+ vs.append(vector([6, 7, 8]))
+ assert len(vs) == 3
+ return vs
+
+
+def generate_test_vectorss():
+ vss = vectorss()
+ vss.append(generate_test_vectors())
+ vss.append(generate_test_vectors())
+ vss.append(generate_test_vectors())
+ assert len(vss) == 3
+ return vss
+
+
+def test_vectors_serialization():
+ vs = generate_test_vectors()
+ ser = pickle.dumps(vs, 2)
+ deser = pickle.loads(ser)
+ assert vs == deser
+
+
+def test_vectors_clear():
+ vs = generate_test_vectors()
+ vs.clear()
+ assert len(vs) == 0
+
+
+def test_vectors_resize():
+ vs = vectors()
+ vs.resize(100)
+ assert len(vs) == 100
+ for i in range(100):
+ assert len(vs[i]) == 0
+
+
+def test_vectors_extend():
+ vs = vectors()
+ vs.extend([vector([1, 2, 3]), vector([4, 5, 6])])
+ assert len(vs) == 2
+
+
+def test_vectorss_serialization():
+ vss = generate_test_vectorss()
+ ser = pickle.dumps(vss, 2)
+ deser = pickle.loads(ser)
+ assert vss == deser
+
+
+def test_vectorss_clear():
+ vss = generate_test_vectorss()
+ vss.clear()
+ assert len(vss) == 0
+
+
+def test_vectorss_resize():
+ vss = vectorss()
+ vss.resize(100)
+ assert len(vss) == 100
+ for i in range(100):
+ assert len(vss[i]) == 0
+
+
+def test_vectorss_extend():
+ vss = vectorss()
+ vss.extend([generate_test_vectors(), generate_test_vectors()])
+ assert len(vss) == 2