// Copyright (C) 2014 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_ #define DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_ #include "structural_track_association_trainer_abstract.h" #include "../algs.h" #include "svm.h" #include #include "track_association_function.h" #include "structural_assignment_trainer.h" #include namespace dlib { // ---------------------------------------------------------------------------------------- namespace impl { template < typename detection_type, typename label_type > std::vector get_unlabeled_dets ( const std::vector >& dets ) { std::vector temp; temp.reserve(dets.size()); for (unsigned long i = 0; i < dets.size(); ++i) temp.push_back(dets[i].det); return temp; } } // ---------------------------------------------------------------------------------------- class structural_track_association_trainer { public: structural_track_association_trainer ( ) { set_defaults(); } void set_num_threads ( unsigned long num ) { num_threads = num; } unsigned long get_num_threads ( ) const { return num_threads; } void set_epsilon ( double eps_ ) { // make sure requires clause is not broken DLIB_ASSERT(eps_ > 0, "\t void structural_track_association_trainer::set_epsilon()" << "\n\t eps_ must be greater than 0" << "\n\t eps_: " << eps_ << "\n\t this: " << this ); eps = eps_; } double get_epsilon ( ) const { return eps; } void set_max_cache_size ( unsigned long max_size ) { max_cache_size = max_size; } unsigned long get_max_cache_size ( ) const { return max_cache_size; } void set_loss_per_false_association ( double loss ) { // make sure requires clause is not broken DLIB_ASSERT(loss > 0, "\t void structural_track_association_trainer::set_loss_per_false_association(loss)" << "\n\t Invalid inputs were given to this function " << "\n\t loss: " << loss << "\n\t this: " << this ); loss_per_false_association = loss; } double get_loss_per_false_association ( ) const { return loss_per_false_association; } void set_loss_per_track_break ( double loss ) { // make sure requires clause is not broken DLIB_ASSERT(loss > 0, "\t void structural_track_association_trainer::set_loss_per_track_break(loss)" << "\n\t Invalid inputs were given to this function " << "\n\t loss: " << loss << "\n\t this: " << this ); loss_per_track_break = loss; } double get_loss_per_track_break ( ) const { return loss_per_track_break; } void be_verbose ( ) { verbose = true; } void be_quiet ( ) { verbose = false; } void set_oca ( const oca& item ) { solver = item; } const oca get_oca ( ) const { return solver; } void set_c ( double C_ ) { // make sure requires clause is not broken DLIB_ASSERT(C_ > 0, "\t void structural_track_association_trainer::set_c()" << "\n\t C_ must be greater than 0" << "\n\t C_: " << C_ << "\n\t this: " << this ); C = C_; } double get_c ( ) const { return C; } bool learns_nonnegative_weights ( ) const { return learn_nonnegative_weights; } void set_learns_nonnegative_weights ( bool value ) { learn_nonnegative_weights = value; } template < typename detection_type, typename label_type > const track_association_function train ( const std::vector > > >& samples ) const { // make sure requires clause is not broken DLIB_ASSERT(is_track_association_problem(samples), "\t track_association_function structural_track_association_trainer::train()" << "\n\t invalid inputs were given to this function" << "\n\t is_track_association_problem(samples): " << is_track_association_problem(samples) ); typedef typename detection_type::track_type track_type; const unsigned long num_dims = find_num_dims(samples); feature_extractor_track_association fe(num_dims, learn_nonnegative_weights?num_dims:0); structural_assignment_trainer > trainer(fe); if (verbose) trainer.be_verbose(); trainer.set_c(C); trainer.set_epsilon(eps); trainer.set_max_cache_size(max_cache_size); trainer.set_num_threads(num_threads); trainer.set_oca(solver); trainer.set_loss_per_missed_association(loss_per_track_break); trainer.set_loss_per_false_association(loss_per_false_association); std::vector, std::vector > > assignment_samples; std::vector > labels; for (unsigned long i = 0; i < samples.size(); ++i) convert_dets_to_association_sets(samples[i], assignment_samples, labels); return track_association_function(trainer.train(assignment_samples, labels)); } template < typename detection_type, typename label_type > const track_association_function train ( const std::vector > >& sample ) const { std::vector > > > samples; samples.push_back(sample); return train(samples); } private: template < typename detection_type, typename label_type > static unsigned long find_num_dims ( const std::vector > > >& samples ) { typedef typename detection_type::track_type track_type; // find a detection_type object so we can call get_similarity_features() and // find out how big the feature vectors are. // for all detection histories for (unsigned long i = 0; i < samples.size(); ++i) { // for all time instances in the detection history for (unsigned j = 0; j < samples[i].size(); ++j) { if (samples[i][j].size() > 0) { track_type new_track; new_track.update_track(samples[i][j][0].det); typename track_type::feature_vector_type feats; new_track.get_similarity_features(samples[i][j][0].det, feats); return feats.size(); } } } DLIB_CASSERT(false, "No detection objects were given in the call to dlib::structural_track_association_trainer::train()"); } template < typename detections_at_single_time_step, typename detection_type, typename track_type > static void convert_dets_to_association_sets ( const std::vector& det_history, std::vector, std::vector > >& data, std::vector >& labels ) { if (det_history.size() < 1) return; typedef typename detections_at_single_time_step::value_type::label_type label_type; std::vector tracks; // track_labels maps from detection labels to the index in tracks. So track // with detection label X is at tracks[track_labels[X]]. std::map track_labels; add_dets_to_tracks(tracks, track_labels, det_history[0]); using namespace impl; for (unsigned long i = 1; i < det_history.size(); ++i) { data.push_back(std::make_pair(get_unlabeled_dets(det_history[i]), tracks)); labels.push_back(get_association_labels(det_history[i], track_labels)); add_dets_to_tracks(tracks, track_labels, det_history[i]); } } template < typename labeled_detection, typename label_type > static std::vector get_association_labels( const std::vector& dets, const std::map& track_labels ) { std::vector assoc(dets.size(),-1); // find out which detections associate to what tracks for (unsigned long i = 0; i < dets.size(); ++i) { typename std::map::const_iterator j; j = track_labels.find(dets[i].label); // If this detection matches one of the tracks then record which track it // matched with. if (j != track_labels.end()) assoc[i] = j->second; } return assoc; } template < typename track_type, typename label_type, typename labeled_detection > static void add_dets_to_tracks ( std::vector& tracks, std::map& track_labels, const std::vector& dets ) { std::vector updated_track(tracks.size(), false); // first assign the dets to the tracks for (unsigned long i = 0; i < dets.size(); ++i) { const label_type& label = dets[i].label; if (track_labels.count(label)) { const unsigned long track_idx = track_labels[label]; tracks[track_idx].update_track(dets[i].det); updated_track[track_idx] = true; } else { // this detection creates a new track track_type new_track; new_track.update_track(dets[i].det); tracks.push_back(new_track); track_labels[label] = tracks.size()-1; } } // Now propagate all the tracks that didn't get any detections. for (unsigned long i = 0; i < updated_track.size(); ++i) { if (!updated_track[i]) tracks[i].propagate_track(); } } double C; oca solver; double eps; bool verbose; unsigned long num_threads; unsigned long max_cache_size; bool learn_nonnegative_weights; double loss_per_track_break; double loss_per_false_association; void set_defaults () { C = 100; verbose = false; eps = 0.001; num_threads = 2; max_cache_size = 5; learn_nonnegative_weights = false; loss_per_track_break = 1; loss_per_false_association = 1; } }; } #endif // DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_