diff options
Diffstat (limited to 'ml/dlib/tools/python/test/test_matrix.py')
-rw-r--r-- | ml/dlib/tools/python/test/test_matrix.py | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/ml/dlib/tools/python/test/test_matrix.py b/ml/dlib/tools/python/test/test_matrix.py new file mode 100644 index 00000000..cdd9bed1 --- /dev/null +++ b/ml/dlib/tools/python/test/test_matrix.py @@ -0,0 +1,100 @@ +from dlib import matrix +try: + import cPickle as pickle # Use cPickle on Python 2.7 +except ImportError: + import pickle +from pytest import raises + +try: + import numpy + have_numpy = True +except ImportError: + have_numpy = False + + +def test_matrix_empty_init(): + m = matrix() + assert m.nr() == 0 + assert m.nc() == 0 + assert m.shape == (0, 0) + assert len(m) == 0 + assert repr(m) == "< dlib.matrix containing: >" + assert str(m) == "" + + +def test_matrix_from_list(): + m = matrix([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + assert m.nr() == 3 + assert m.nc() == 3 + assert m.shape == (3, 3) + assert len(m) == 3 + assert repr(m) == "< dlib.matrix containing: \n0 1 2 \n3 4 5 \n6 7 8 >" + assert str(m) == "0 1 2 \n3 4 5 \n6 7 8" + + deser = pickle.loads(pickle.dumps(m, 2)) + + for row in range(3): + for col in range(3): + assert m[row][col] == deser[row][col] + + +def test_matrix_from_list_with_invalid_rows(): + with raises(ValueError): + matrix([[0, 1, 2], + [3, 4], + [5, 6, 7]]) + + +def test_matrix_from_list_as_column_vector(): + m = matrix([0, 1, 2]) + assert m.nr() == 3 + assert m.nc() == 1 + assert m.shape == (3, 1) + assert len(m) == 3 + assert repr(m) == "< dlib.matrix containing: \n0 \n1 \n2 >" + assert str(m) == "0 \n1 \n2" + + +if have_numpy: + def test_matrix_from_object_with_2d_shape(): + m1 = numpy.array([[0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + m = matrix(m1) + assert m.nr() == 3 + assert m.nc() == 3 + assert m.shape == (3, 3) + assert len(m) == 3 + assert repr(m) == "< dlib.matrix containing: \n0 1 2 \n3 4 5 \n6 7 8 >" + assert str(m) == "0 1 2 \n3 4 5 \n6 7 8" + + + def test_matrix_from_object_without_2d_shape(): + with raises(IndexError): + m1 = numpy.array([0, 1, 2]) + matrix(m1) + + +def test_matrix_from_object_without_shape(): + with raises(AttributeError): + matrix("invalid") + + +def test_matrix_set_size(): + m = matrix() + m.set_size(5, 5) + + assert m.nr() == 5 + assert m.nc() == 5 + assert m.shape == (5, 5) + assert len(m) == 5 + assert repr(m) == "< dlib.matrix containing: \n0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0 >" + assert str(m) == "0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0 \n0 0 0 0 0" + + deser = pickle.loads(pickle.dumps(m, 2)) + + for row in range(5): + for col in range(5): + assert m[row][col] == deser[row][col] |