diff options
Diffstat (limited to 'ml/dlib/tools/python/test/test_svm_c_trainer.py')
-rw-r--r-- | ml/dlib/tools/python/test/test_svm_c_trainer.py | 65 |
1 files changed, 0 insertions, 65 deletions
diff --git a/ml/dlib/tools/python/test/test_svm_c_trainer.py b/ml/dlib/tools/python/test/test_svm_c_trainer.py deleted file mode 100644 index ba9392e08..000000000 --- a/ml/dlib/tools/python/test/test_svm_c_trainer.py +++ /dev/null @@ -1,65 +0,0 @@ -from __future__ import division - -import pytest -from random import Random -from dlib import (vectors, vector, sparse_vectors, sparse_vector, pair, array, - cross_validate_trainer, - svm_c_trainer_radial_basis, - svm_c_trainer_sparse_radial_basis, - svm_c_trainer_histogram_intersection, - svm_c_trainer_sparse_histogram_intersection, - svm_c_trainer_linear, - svm_c_trainer_sparse_linear, - rvm_trainer_radial_basis, - rvm_trainer_sparse_radial_basis, - rvm_trainer_histogram_intersection, - rvm_trainer_sparse_histogram_intersection, - rvm_trainer_linear, - rvm_trainer_sparse_linear) - - -@pytest.fixture -def training_data(): - r = Random(0) - predictors = vectors() - sparse_predictors = sparse_vectors() - response = array() - for i in range(30): - for c in [-1, 1]: - response.append(c) - values = [r.random() + c * 0.5 for _ in range(3)] - predictors.append(vector(values)) - sp = sparse_vector() - for i, v in enumerate(values): - sp.append(pair(i, v)) - sparse_predictors.append(sp) - return predictors, sparse_predictors, response - - -@pytest.mark.parametrize('trainer, class1_accuracy, class2_accuracy', [ - (svm_c_trainer_radial_basis, 1.0, 1.0), - (svm_c_trainer_sparse_radial_basis, 1.0, 1.0), - (svm_c_trainer_histogram_intersection, 1.0, 1.0), - (svm_c_trainer_sparse_histogram_intersection, 1.0, 1.0), - (svm_c_trainer_linear, 1.0, 23 / 30), - (svm_c_trainer_sparse_linear, 1.0, 23 / 30), - (rvm_trainer_radial_basis, 1.0, 1.0), - (rvm_trainer_sparse_radial_basis, 1.0, 1.0), - (rvm_trainer_histogram_intersection, 1.0, 1.0), - (rvm_trainer_sparse_histogram_intersection, 1.0, 1.0), - (rvm_trainer_linear, 1.0, 0.6), - (rvm_trainer_sparse_linear, 1.0, 0.6) -]) -def test_trainers(training_data, trainer, class1_accuracy, class2_accuracy): - predictors, sparse_predictors, response = training_data - if 'sparse' in trainer.__name__: - predictors = sparse_predictors - cv = cross_validate_trainer(trainer(), predictors, response, folds=10) - assert cv.class1_accuracy == pytest.approx(class1_accuracy) - assert cv.class2_accuracy == pytest.approx(class2_accuracy) - - decision_function = trainer().train(predictors, response) - assert decision_function(predictors[2]) < 0 - assert decision_function(predictors[3]) > 0 - if 'linear' in trainer.__name__: - assert len(decision_function.weights) == 3 |