summaryrefslogtreecommitdiffstats
path: root/src/ml/dlib/tools/python/test/test_svm_c_trainer.py
blob: ba9392e08b6943354d54923d91d5f1fa1219b5ce (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from __future__ import division

import pytest
from random import Random
from dlib import (vectors, vector, sparse_vectors, sparse_vector, pair, array,
                  cross_validate_trainer,
                  svm_c_trainer_radial_basis,
                  svm_c_trainer_sparse_radial_basis,
                  svm_c_trainer_histogram_intersection,
                  svm_c_trainer_sparse_histogram_intersection,
                  svm_c_trainer_linear,
                  svm_c_trainer_sparse_linear,
                  rvm_trainer_radial_basis,
                  rvm_trainer_sparse_radial_basis,
                  rvm_trainer_histogram_intersection,
                  rvm_trainer_sparse_histogram_intersection,
                  rvm_trainer_linear,
                  rvm_trainer_sparse_linear)


@pytest.fixture
def training_data():
    r = Random(0)
    predictors = vectors()
    sparse_predictors = sparse_vectors()
    response = array()
    for i in range(30):
        for c in [-1, 1]:
            response.append(c)
            values = [r.random() + c * 0.5 for _ in range(3)]
            predictors.append(vector(values))
            sp = sparse_vector()
            for i, v in enumerate(values):
                sp.append(pair(i, v))
            sparse_predictors.append(sp)
    return predictors, sparse_predictors, response


@pytest.mark.parametrize('trainer, class1_accuracy, class2_accuracy', [
    (svm_c_trainer_radial_basis, 1.0, 1.0),
    (svm_c_trainer_sparse_radial_basis, 1.0, 1.0),
    (svm_c_trainer_histogram_intersection, 1.0, 1.0),
    (svm_c_trainer_sparse_histogram_intersection, 1.0, 1.0),
    (svm_c_trainer_linear, 1.0, 23 / 30),
    (svm_c_trainer_sparse_linear, 1.0, 23 / 30),
    (rvm_trainer_radial_basis, 1.0, 1.0),
    (rvm_trainer_sparse_radial_basis, 1.0, 1.0),
    (rvm_trainer_histogram_intersection, 1.0, 1.0),
    (rvm_trainer_sparse_histogram_intersection, 1.0, 1.0),
    (rvm_trainer_linear, 1.0, 0.6),
    (rvm_trainer_sparse_linear, 1.0, 0.6)
])
def test_trainers(training_data, trainer, class1_accuracy, class2_accuracy):
    predictors, sparse_predictors, response = training_data
    if 'sparse' in trainer.__name__:
        predictors = sparse_predictors
    cv = cross_validate_trainer(trainer(), predictors, response, folds=10)
    assert cv.class1_accuracy == pytest.approx(class1_accuracy)
    assert cv.class2_accuracy == pytest.approx(class2_accuracy)

    decision_function = trainer().train(predictors, response)
    assert decision_function(predictors[2]) < 0
    assert decision_function(predictors[3]) > 0
    if 'linear' in trainer.__name__:
        assert len(decision_function.weights) == 3