summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/svm/structural_track_association_trainer.h
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-09 13:19:48 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-09 13:20:02 +0000
commit58daab21cd043e1dc37024a7f99b396788372918 (patch)
tree96771e43bb69f7c1c2b0b4f7374cb74d7866d0cb /ml/dlib/dlib/svm/structural_track_association_trainer.h
parentReleasing debian version 1.43.2-1. (diff)
downloadnetdata-58daab21cd043e1dc37024a7f99b396788372918.tar.xz
netdata-58daab21cd043e1dc37024a7f99b396788372918.zip
Merging upstream version 1.44.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'ml/dlib/dlib/svm/structural_track_association_trainer.h')
-rw-r--r--ml/dlib/dlib/svm/structural_track_association_trainer.h404
1 files changed, 404 insertions, 0 deletions
diff --git a/ml/dlib/dlib/svm/structural_track_association_trainer.h b/ml/dlib/dlib/svm/structural_track_association_trainer.h
new file mode 100644
index 000000000..87fb829b2
--- /dev/null
+++ b/ml/dlib/dlib/svm/structural_track_association_trainer.h
@@ -0,0 +1,404 @@
+// Copyright (C) 2014 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_
+#define DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_
+
+#include "structural_track_association_trainer_abstract.h"
+#include "../algs.h"
+#include "svm.h"
+#include <utility>
+#include "track_association_function.h"
+#include "structural_assignment_trainer.h"
+#include <map>
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ namespace impl
+ {
+ template <
+ typename detection_type,
+ typename label_type
+ >
+ std::vector<detection_type> get_unlabeled_dets (
+ const std::vector<labeled_detection<detection_type,label_type> >& dets
+ )
+ {
+ std::vector<detection_type> temp;
+ temp.reserve(dets.size());
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ temp.push_back(dets[i].det);
+ return temp;
+ }
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class structural_track_association_trainer
+ {
+ public:
+
+ structural_track_association_trainer (
+ )
+ {
+ set_defaults();
+ }
+
+ void set_num_threads (
+ unsigned long num
+ )
+ {
+ num_threads = num;
+ }
+
+ unsigned long get_num_threads (
+ ) const
+ {
+ return num_threads;
+ }
+
+ void set_epsilon (
+ double eps_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(eps_ > 0,
+ "\t void structural_track_association_trainer::set_epsilon()"
+ << "\n\t eps_ must be greater than 0"
+ << "\n\t eps_: " << eps_
+ << "\n\t this: " << this
+ );
+
+ eps = eps_;
+ }
+
+ double get_epsilon (
+ ) const { return eps; }
+
+ void set_max_cache_size (
+ unsigned long max_size
+ )
+ {
+ max_cache_size = max_size;
+ }
+
+ unsigned long get_max_cache_size (
+ ) const
+ {
+ return max_cache_size;
+ }
+
+ void set_loss_per_false_association (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss > 0,
+ "\t void structural_track_association_trainer::set_loss_per_false_association(loss)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this
+ );
+
+ loss_per_false_association = loss;
+ }
+
+ double get_loss_per_false_association (
+ ) const
+ {
+ return loss_per_false_association;
+ }
+
+ void set_loss_per_track_break (
+ double loss
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(loss > 0,
+ "\t void structural_track_association_trainer::set_loss_per_track_break(loss)"
+ << "\n\t Invalid inputs were given to this function "
+ << "\n\t loss: " << loss
+ << "\n\t this: " << this
+ );
+
+ loss_per_track_break = loss;
+ }
+
+ double get_loss_per_track_break (
+ ) const
+ {
+ return loss_per_track_break;
+ }
+
+ void be_verbose (
+ )
+ {
+ verbose = true;
+ }
+
+ void be_quiet (
+ )
+ {
+ verbose = false;
+ }
+
+ void set_oca (
+ const oca& item
+ )
+ {
+ solver = item;
+ }
+
+ const oca get_oca (
+ ) const
+ {
+ return solver;
+ }
+
+ void set_c (
+ double C_
+ )
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(C_ > 0,
+ "\t void structural_track_association_trainer::set_c()"
+ << "\n\t C_ must be greater than 0"
+ << "\n\t C_: " << C_
+ << "\n\t this: " << this
+ );
+
+ C = C_;
+ }
+
+ double get_c (
+ ) const
+ {
+ return C;
+ }
+
+ bool learns_nonnegative_weights (
+ ) const { return learn_nonnegative_weights; }
+
+ void set_learns_nonnegative_weights (
+ bool value
+ )
+ {
+ learn_nonnegative_weights = value;
+ }
+
+ template <
+ typename detection_type,
+ typename label_type
+ >
+ const track_association_function<detection_type> train (
+ const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples
+ ) const
+ {
+ // make sure requires clause is not broken
+ DLIB_ASSERT(is_track_association_problem(samples),
+ "\t track_association_function structural_track_association_trainer::train()"
+ << "\n\t invalid inputs were given to this function"
+ << "\n\t is_track_association_problem(samples): " << is_track_association_problem(samples)
+ );
+
+ typedef typename detection_type::track_type track_type;
+
+ const unsigned long num_dims = find_num_dims(samples);
+
+ feature_extractor_track_association<detection_type> fe(num_dims, learn_nonnegative_weights?num_dims:0);
+ structural_assignment_trainer<feature_extractor_track_association<detection_type> > trainer(fe);
+
+
+ if (verbose)
+ trainer.be_verbose();
+
+ trainer.set_c(C);
+ trainer.set_epsilon(eps);
+ trainer.set_max_cache_size(max_cache_size);
+ trainer.set_num_threads(num_threads);
+ trainer.set_oca(solver);
+ trainer.set_loss_per_missed_association(loss_per_track_break);
+ trainer.set_loss_per_false_association(loss_per_false_association);
+
+ std::vector<std::pair<std::vector<detection_type>, std::vector<track_type> > > assignment_samples;
+ std::vector<std::vector<long> > labels;
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ convert_dets_to_association_sets(samples[i], assignment_samples, labels);
+
+
+ return track_association_function<detection_type>(trainer.train(assignment_samples, labels));
+ }
+
+ template <
+ typename detection_type,
+ typename label_type
+ >
+ const track_association_function<detection_type> train (
+ const std::vector<std::vector<labeled_detection<detection_type,label_type> > >& sample
+ ) const
+ {
+ std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > > samples;
+ samples.push_back(sample);
+ return train(samples);
+ }
+
+ private:
+
+ template <
+ typename detection_type,
+ typename label_type
+ >
+ static unsigned long find_num_dims (
+ const std::vector<std::vector<std::vector<labeled_detection<detection_type,label_type> > > >& samples
+ )
+ {
+ typedef typename detection_type::track_type track_type;
+ // find a detection_type object so we can call get_similarity_features() and
+ // find out how big the feature vectors are.
+
+ // for all detection histories
+ for (unsigned long i = 0; i < samples.size(); ++i)
+ {
+ // for all time instances in the detection history
+ for (unsigned j = 0; j < samples[i].size(); ++j)
+ {
+ if (samples[i][j].size() > 0)
+ {
+ track_type new_track;
+ new_track.update_track(samples[i][j][0].det);
+ typename track_type::feature_vector_type feats;
+ new_track.get_similarity_features(samples[i][j][0].det, feats);
+ return feats.size();
+ }
+ }
+ }
+
+ DLIB_CASSERT(false,
+ "No detection objects were given in the call to dlib::structural_track_association_trainer::train()");
+ }
+
+ template <
+ typename detections_at_single_time_step,
+ typename detection_type,
+ typename track_type
+ >
+ static void convert_dets_to_association_sets (
+ const std::vector<detections_at_single_time_step>& det_history,
+ std::vector<std::pair<std::vector<detection_type>, std::vector<track_type> > >& data,
+ std::vector<std::vector<long> >& labels
+ )
+ {
+ if (det_history.size() < 1)
+ return;
+
+ typedef typename detections_at_single_time_step::value_type::label_type label_type;
+ std::vector<track_type> tracks;
+ // track_labels maps from detection labels to the index in tracks. So track
+ // with detection label X is at tracks[track_labels[X]].
+ std::map<label_type,unsigned long> track_labels;
+ add_dets_to_tracks(tracks, track_labels, det_history[0]);
+
+ using namespace impl;
+ for (unsigned long i = 1; i < det_history.size(); ++i)
+ {
+ data.push_back(std::make_pair(get_unlabeled_dets(det_history[i]), tracks));
+ labels.push_back(get_association_labels(det_history[i], track_labels));
+ add_dets_to_tracks(tracks, track_labels, det_history[i]);
+ }
+ }
+
+ template <
+ typename labeled_detection,
+ typename label_type
+ >
+ static std::vector<long> get_association_labels(
+ const std::vector<labeled_detection>& dets,
+ const std::map<label_type,unsigned long>& track_labels
+ )
+ {
+ std::vector<long> assoc(dets.size(),-1);
+ // find out which detections associate to what tracks
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ {
+ typename std::map<label_type,unsigned long>::const_iterator j;
+ j = track_labels.find(dets[i].label);
+ // If this detection matches one of the tracks then record which track it
+ // matched with.
+ if (j != track_labels.end())
+ assoc[i] = j->second;
+ }
+ return assoc;
+ }
+
+ template <
+ typename track_type,
+ typename label_type,
+ typename labeled_detection
+ >
+ static void add_dets_to_tracks (
+ std::vector<track_type>& tracks,
+ std::map<label_type,unsigned long>& track_labels,
+ const std::vector<labeled_detection>& dets
+ )
+ {
+ std::vector<bool> updated_track(tracks.size(), false);
+
+ // first assign the dets to the tracks
+ for (unsigned long i = 0; i < dets.size(); ++i)
+ {
+ const label_type& label = dets[i].label;
+ if (track_labels.count(label))
+ {
+ const unsigned long track_idx = track_labels[label];
+ tracks[track_idx].update_track(dets[i].det);
+ updated_track[track_idx] = true;
+ }
+ else
+ {
+ // this detection creates a new track
+ track_type new_track;
+ new_track.update_track(dets[i].det);
+ tracks.push_back(new_track);
+ track_labels[label] = tracks.size()-1;
+ }
+
+ }
+
+ // Now propagate all the tracks that didn't get any detections.
+ for (unsigned long i = 0; i < updated_track.size(); ++i)
+ {
+ if (!updated_track[i])
+ tracks[i].propagate_track();
+ }
+ }
+
+ double C;
+ oca solver;
+ double eps;
+ bool verbose;
+ unsigned long num_threads;
+ unsigned long max_cache_size;
+ bool learn_nonnegative_weights;
+ double loss_per_track_break;
+ double loss_per_false_association;
+
+ void set_defaults ()
+ {
+ C = 100;
+ verbose = false;
+ eps = 0.001;
+ num_threads = 2;
+ max_cache_size = 5;
+ learn_nonnegative_weights = false;
+ loss_per_track_break = 1;
+ loss_per_false_association = 1;
+ }
+ };
+
+}
+
+#endif // DLIB_STRUCTURAL_TRACK_ASSOCIATION_TRAnER_Hh_
+