summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/svm/track_association_function.h
blob: bf5ef36c70ba5d5d6034d9ce7b2a250faca1cf00 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
// Copyright (C) 2014  Davis E. King (davis@dlib.net)
// License: Boost Software License   See LICENSE.txt for the full license.
#ifndef DLIB_TRACK_ASSOCiATION_FUNCTION_Hh_
#define DLIB_TRACK_ASSOCiATION_FUNCTION_Hh_


#include "track_association_function_abstract.h"
#include <vector>
#include <iostream>
#include "../algs.h"
#include "../serialize.h"
#include "assignment_function.h"

namespace dlib
{

// ----------------------------------------------------------------------------------------

    template <
        typename detection_type
        > 
    class feature_extractor_track_association
    {
    public:
        typedef typename detection_type::track_type track_type;
        typedef typename track_type::feature_vector_type feature_vector_type;

        typedef detection_type lhs_element;
        typedef track_type rhs_element;

        feature_extractor_track_association() : num_dims(0), num_nonnegative(0) {}

        explicit feature_extractor_track_association (
            unsigned long num_dims_,
            unsigned long num_nonnegative_
        ) : num_dims(num_dims_), num_nonnegative(num_nonnegative_) {}

        unsigned long num_features(
        ) const { return num_dims; }

        unsigned long num_nonnegative_weights (
        ) const { return num_nonnegative; }

        void get_features (
            const detection_type& det,
            const track_type& track,
            feature_vector_type& feats
        ) const
        {
            track.get_similarity_features(det, feats);
        }

        friend void serialize (const feature_extractor_track_association& item, std::ostream& out) 
        { 
            serialize(item.num_dims, out);
            serialize(item.num_nonnegative, out);
        }

        friend void deserialize (feature_extractor_track_association& item, std::istream& in) 
        {
            deserialize(item.num_dims, in);
            deserialize(item.num_nonnegative, in);
        }

    private:
        unsigned long num_dims;
        unsigned long num_nonnegative;
    };

// ----------------------------------------------------------------------------------------

    template <
        typename detection_type_
        >
    class track_association_function
    {
    public:

        typedef detection_type_ detection_type;
        typedef typename detection_type::track_type track_type;
        typedef assignment_function<feature_extractor_track_association<detection_type> > association_function_type;

        track_association_function() {}

        track_association_function (
            const association_function_type& assoc_
        ) : assoc(assoc_)
        {
        }

        const association_function_type& get_assignment_function (
        ) const
        {
            return assoc;
        }

        void operator() (
            std::vector<track_type>& tracks,
            const std::vector<detection_type>& dets
        ) const
        {
            std::vector<long> assignments = assoc(dets, tracks);
            std::vector<bool> updated_track(tracks.size(), false);
            // now update all the tracks with the detections that associated to them.
            for (unsigned long i = 0; i < assignments.size(); ++i)
            {
                if (assignments[i] != -1)
                {
                    tracks[assignments[i]].update_track(dets[i]);
                    updated_track[assignments[i]] = true;
                }
                else
                {
                    track_type new_track;
                    new_track.update_track(dets[i]);
                    tracks.push_back(new_track);
                }
            }

            // Now propagate all the tracks that didn't get any detections.
            for (unsigned long i = 0; i < updated_track.size(); ++i)
            {
                if (!updated_track[i])
                    tracks[i].propagate_track();
            }
        }

        friend void serialize (const track_association_function& item, std::ostream& out)
        {
            int version = 1;
            serialize(version, out);
            serialize(item.assoc, out);
        }
        friend void deserialize (track_association_function& item, std::istream& in)
        {
            int version = 0;
            deserialize(version, in);
            if (version != 1)
                throw serialization_error("Unexpected version found while deserializing dlib::track_association_function.");

            deserialize(item.assoc, in);
        }

    private:

        assignment_function<feature_extractor_track_association<detection_type> > assoc;
    };

// ----------------------------------------------------------------------------------------

}

#endif // DLIB_TRACK_ASSOCiATION_FUNCTION_Hh_