diff options
author | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-09 13:19:48 +0000 |
---|---|---|
committer | Daniel Baumann <daniel.baumann@progress-linux.org> | 2024-03-09 13:20:02 +0000 |
commit | 58daab21cd043e1dc37024a7f99b396788372918 (patch) | |
tree | 96771e43bb69f7c1c2b0b4f7374cb74d7866d0cb /ml/dlib/dlib/svm/structural_track_association_trainer.h | |
parent | Releasing debian version 1.43.2-1. (diff) | |
download | netdata-58daab21cd043e1dc37024a7f99b396788372918.tar.xz netdata-58daab21cd043e1dc37024a7f99b396788372918.zip |
Merging upstream version 1.44.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'ml/dlib/dlib/svm/structural_track_association_trainer.h')
-rw-r--r-- | ml/dlib/dlib/svm/structural_track_association_trainer.h | 404 |
1 files changed, 404 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/structural_track_association_trainer.h b/ml/dlib/dlib/svm/structural_track_association_trainer.h new file mode 100644 index 000000000..87fb829b2 --- /dev/null +++ b/ml/dlib/dlib/svm/structural_track_association_trainer.h @@ -0,0 +1,404 @@ +// Copyright (C) 2014 Davis E. King (davis@dlib.net) +// License: Boost Software License See LICENSE.txt for the full license. +#ifndef DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_ +#define DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_ + +#include "structural_track_association_trainer_abstract.h" +#include "../algs.h" +#include "svm.h" +#include <utility> +#include "track_association_function.h" +#include "structural_assignment_trainer.h" +#include <map> + +namespace dlib +{ + +// ---------------------------------------------------------------------------------------- + + namespace impl + { + template < + typename detection_type, + typename label_type + > + std::vector<detection_type> get_unlabeled_dets ( + const std::vector<labeled_detection<detection_type,label_type> >& dets + ) + { + std::vector<detection_type> temp; + temp.reserve(dets.size()); + for (unsigned long i = 0; i < dets.size(); ++i) + temp.push_back(dets[i].det); + return temp; + } + + } + +// ---------------------------------------------------------------------------------------- + + class structural_track_association_trainer + { + public: + + structural_track_association_trainer ( + ) + { + set_defaults(); + } + + void set_num_threads ( + unsigned long num + ) + { + num_threads = num; + } + + unsigned long get_num_threads ( + ) const + { + return num_threads; + } + + void set_epsilon ( + double eps_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(eps_ > 0, + "\t void structural_track_association_trainer::set_epsilon()" + << "\n\t eps_ must be greater than 0" + << "\n\t eps_: " << eps_ + << "\n\t this: " << this + ); + + eps = eps_; + } + + double get_epsilon ( + ) const { return eps; } + + void set_max_cache_size ( + unsigned long max_size + ) + { + max_cache_size = max_size; + } + + unsigned long get_max_cache_size ( + ) const + { + return max_cache_size; + } + + void set_loss_per_false_association ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss > 0, + "\t void structural_track_association_trainer::set_loss_per_false_association(loss)" + << "\n\t Invalid inputs were given to this function " + << "\n\t loss: " << loss + << "\n\t this: " << this + ); + + loss_per_false_association = loss; + } + + double get_loss_per_false_association ( + ) const + { + return loss_per_false_association; + } + + void set_loss_per_track_break ( + double loss + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(loss > 0, + "\t void structural_track_association_trainer::set_loss_per_track_break(loss)" + << "\n\t Invalid inputs were given to this function " + << "\n\t loss: " << loss + << "\n\t this: " << this + ); + + loss_per_track_break = loss; + } + + double get_loss_per_track_break ( + ) const + { + return loss_per_track_break; + } + + void be_verbose ( + ) + { + verbose = true; + } + + void be_quiet ( + ) + { + verbose = false; + } + + void set_oca ( + const oca& item + ) + { + solver = item; + } + + const oca get_oca ( + ) const + { + return solver; + } + + void set_c ( + double C_ + ) + { + // make sure requires clause is not broken + DLIB_ASSERT(C_ > 0, + "\t void structural_track_association_trainer::set_c()" + << "\n\t C_ must be greater than 0" + << "\n\t C_: " << C_ + << "\n\t this: " << this + ); + + C = C_; + } + + double get_c ( + ) const + { + return C; + } + + bool learns_nonnegative_weights ( + ) const { return learn_nonnegative_weights; } + + void set_learns_nonnegative_weights ( + bool value + ) + { + learn_nonnegative_weights = value; + } + + template < + typename detection_type, + typename label_type + > + const track_association_function<detection_type> train ( + const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples + ) const + { + // make sure requires clause is not broken + DLIB_ASSERT(is_track_association_problem(samples), + "\t track_association_function structural_track_association_trainer::train()" + << "\n\t invalid inputs were given to this function" + << "\n\t is_track_association_problem(samples): " << is_track_association_problem(samples) + ); + + typedef typename detection_type::track_type track_type; + + const unsigned long num_dims = find_num_dims(samples); + + feature_extractor_track_association<detection_type> fe(num_dims, learn_nonnegative_weights?num_dims:0); + structural_assignment_trainer<feature_extractor_track_association<detection_type> > trainer(fe); + + + if (verbose) + trainer.be_verbose(); + + trainer.set_c(C); + trainer.set_epsilon(eps); + trainer.set_max_cache_size(max_cache_size); + trainer.set_num_threads(num_threads); + trainer.set_oca(solver); + trainer.set_loss_per_missed_association(loss_per_track_break); + trainer.set_loss_per_false_association(loss_per_false_association); + + std::vector<std::pair<std::vector<detection_type>, std::vector<track_type> > > assignment_samples; + std::vector<std::vector<long> > labels; + for (unsigned long i = 0; i < samples.size(); ++i) + convert_dets_to_association_sets(samples[i], assignment_samples, labels); + + + return track_association_function<detection_type>(trainer.train(assignment_samples, labels)); + } + + template < + typename detection_type, + typename label_type + > + const track_association_function<detection_type> train ( + const std::vector<std::vector<labeled_detection<detection_type,label_type> > >& sample + ) const + { + std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > > samples; + samples.push_back(sample); + return train(samples); + } + + private: + + template < + typename detection_type, + typename label_type + > + static unsigned long find_num_dims ( + const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples + ) + { + typedef typename detection_type::track_type track_type; + // find a detection_type object so we can call get_similarity_features() and + // find out how big the feature vectors are. + + // for all detection histories + for (unsigned long i = 0; i < samples.size(); ++i) + { + // for all time instances in the detection history + for (unsigned j = 0; j < samples[i].size(); ++j) + { + if (samples[i][j].size() > 0) + { + track_type new_track; + new_track.update_track(samples[i][j][0].det); + typename track_type::feature_vector_type feats; + new_track.get_similarity_features(samples[i][j][0].det, feats); + return feats.size(); + } + } + } + + DLIB_CASSERT(false, + "No detection objects were given in the call to dlib::structural_track_association_trainer::train()"); + } + + template < + typename detections_at_single_time_step, + typename detection_type, + typename track_type + > + static void convert_dets_to_association_sets ( + const std::vector<detections_at_single_time_step>& det_history, + std::vector<std::pair<std::vector<detection_type>, std::vector<track_type> > >& data, + std::vector<std::vector<long> >& labels + ) + { + if (det_history.size() < 1) + return; + + typedef typename detections_at_single_time_step::value_type::label_type label_type; + std::vector<track_type> tracks; + // track_labels maps from detection labels to the index in tracks. So track + // with detection label X is at tracks[track_labels[X]]. + std::map<label_type,unsigned long> track_labels; + add_dets_to_tracks(tracks, track_labels, det_history[0]); + + using namespace impl; + for (unsigned long i = 1; i < det_history.size(); ++i) + { + data.push_back(std::make_pair(get_unlabeled_dets(det_history[i]), tracks)); + labels.push_back(get_association_labels(det_history[i], track_labels)); + add_dets_to_tracks(tracks, track_labels, det_history[i]); + } + } + + template < + typename labeled_detection, + typename label_type + > + static std::vector<long> get_association_labels( + const std::vector<labeled_detection>& dets, + const std::map<label_type,unsigned long>& track_labels + ) + { + std::vector<long> assoc(dets.size(),-1); + // find out which detections associate to what tracks + for (unsigned long i = 0; i < dets.size(); ++i) + { + typename std::map<label_type,unsigned long>::const_iterator j; + j = track_labels.find(dets[i].label); + // If this detection matches one of the tracks then record which track it + // matched with. + if (j != track_labels.end()) + assoc[i] = j->second; + } + return assoc; + } + + template < + typename track_type, + typename label_type, + typename labeled_detection + > + static void add_dets_to_tracks ( + std::vector<track_type>& tracks, + std::map<label_type,unsigned long>& track_labels, + const std::vector<labeled_detection>& dets + ) + { + std::vector<bool> updated_track(tracks.size(), false); + + // first assign the dets to the tracks + for (unsigned long i = 0; i < dets.size(); ++i) + { + const label_type& label = dets[i].label; + if (track_labels.count(label)) + { + const unsigned long track_idx = track_labels[label]; + tracks[track_idx].update_track(dets[i].det); + updated_track[track_idx] = true; + } + else + { + // this detection creates a new track + track_type new_track; + new_track.update_track(dets[i].det); + tracks.push_back(new_track); + track_labels[label] = tracks.size()-1; + } + + } + + // Now propagate all the tracks that didn't get any detections. + for (unsigned long i = 0; i < updated_track.size(); ++i) + { + if (!updated_track[i]) + tracks[i].propagate_track(); + } + } + + double C; + oca solver; + double eps; + bool verbose; + unsigned long num_threads; + unsigned long max_cache_size; + bool learn_nonnegative_weights; + double loss_per_track_break; + double loss_per_false_association; + + void set_defaults () + { + C = 100; + verbose = false; + eps = 0.001; + num_threads = 2; + max_cache_size = 5; + learn_nonnegative_weights = false; + loss_per_track_break = 1; + loss_per_false_association = 1; + } + }; + +} + +#endif // DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_ + |