summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/test/sequence_segmenter.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'ml/dlib/dlib/test/sequence_segmenter.cpp')
-rw-r--r--ml/dlib/dlib/test/sequence_segmenter.cpp294
1 files changed, 294 insertions, 0 deletions
diff --git a/ml/dlib/dlib/test/sequence_segmenter.cpp b/ml/dlib/dlib/test/sequence_segmenter.cpp
new file mode 100644
index 000000000..acdcd69be
--- /dev/null
+++ b/ml/dlib/dlib/test/sequence_segmenter.cpp
@@ -0,0 +1,294 @@
+// Copyright (C) 2013 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+
+
+#include <sstream>
+#include "tester.h"
+#include <dlib/svm_threaded.h>
+#include <dlib/rand.h>
+
+
+namespace
+{
+ using namespace test;
+ using namespace dlib;
+ using namespace std;
+
+ logger dlog("test.sequence_segmenter");
+
+// ----------------------------------------------------------------------------------------
+
+ dlib::rand rnd;
+
+ template <bool use_BIO_model_, bool use_high_order_features_, bool allow_negative_weights_>
+ class unigram_extractor
+ {
+ public:
+
+ const static bool use_BIO_model = use_BIO_model_;
+ const static bool use_high_order_features = use_high_order_features_;
+ const static bool allow_negative_weights = allow_negative_weights_;
+
+ typedef std::vector<unsigned long> sequence_type;
+
+ std::map<unsigned long, matrix<double,0,1> > feats;
+
+ unigram_extractor()
+ {
+ matrix<double,0,1> v1, v2, v3;
+ v1 = randm(num_features(), 1, rnd);
+ v2 = randm(num_features(), 1, rnd);
+ v3 = randm(num_features(), 1, rnd);
+ v1(0) = 1;
+ v2(1) = 1;
+ v3(2) = 1;
+ v1(3) = -1;
+ v2(4) = -1;
+ v3(5) = -1;
+ for (unsigned long i = 0; i < num_features(); ++i)
+ {
+ if ( i < 3)
+ feats[i] = v1;
+ else if (i < 6)
+ feats[i] = v2;
+ else
+ feats[i] = v3;
+ }
+ }
+
+ unsigned long num_features() const { return 10; }
+ unsigned long window_size() const { return 3; }
+
+ template <typename feature_setter>
+ void get_features (
+ feature_setter& set_feature,
+ const sequence_type& x,
+ unsigned long position
+ ) const
+ {
+ const matrix<double,0,1>& m = feats.find(x[position])->second;
+ for (unsigned long i = 0; i < num_features(); ++i)
+ {
+ set_feature(i, m(i));
+ }
+ }
+
+ };
+
+ template <bool use_BIO_model_, bool use_high_order_features_, bool neg>
+ void serialize(const unigram_extractor<use_BIO_model_,use_high_order_features_,neg>& item , std::ostream& out )
+ {
+ serialize(item.feats, out);
+ }
+
+ template <bool use_BIO_model_, bool use_high_order_features_, bool neg>
+ void deserialize(unigram_extractor<use_BIO_model_,use_high_order_features_,neg>& item, std::istream& in)
+ {
+ deserialize(item.feats, in);
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void make_dataset (
+ std::vector<std::vector<unsigned long> >& samples,
+ std::vector<std::vector<unsigned long> >& labels,
+ unsigned long dataset_size
+ )
+ {
+ samples.clear();
+ labels.clear();
+
+ samples.resize(dataset_size);
+ labels.resize(dataset_size);
+
+
+ unigram_extractor<true,true,true> fe;
+ dlib::rand rnd;
+
+ for (unsigned long iter = 0; iter < dataset_size; ++iter)
+ {
+
+ samples[iter].resize(10);
+ labels[iter].resize(10);
+
+ for (unsigned long i = 0; i < samples[iter].size(); ++i)
+ {
+ samples[iter][i] = rnd.get_random_32bit_number()%fe.num_features();
+ if (samples[iter][i] < 3)
+ {
+ labels[iter][i] = impl_ss::BEGIN;
+ }
+ else if (samples[iter][i] < 6)
+ {
+ labels[iter][i] = impl_ss::INSIDE;
+ }
+ else
+ {
+ labels[iter][i] = impl_ss::OUTSIDE;
+ }
+
+ if (i != 0)
+ {
+ // do rejection sampling to avoid impossible labels
+ if (labels[iter][i] == impl_ss::INSIDE &&
+ labels[iter][i-1] == impl_ss::OUTSIDE)
+ {
+ --i;
+ }
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ void make_dataset2 (
+ std::vector<std::vector<unsigned long> >& samples,
+ std::vector<std::vector<std::pair<unsigned long, unsigned long> > >& segments,
+ unsigned long dataset_size
+ )
+ {
+ segments.clear();
+ std::vector<std::vector<unsigned long> > labels;
+ make_dataset(samples, labels, dataset_size);
+ segments.resize(samples.size());
+
+ // Convert from BIO tagging to the explicit segments representation.
+ for (unsigned long k = 0; k < labels.size(); ++k)
+ {
+ for (unsigned long i = 0; i < labels[k].size(); ++i)
+ {
+ if (labels[k][i] == impl_ss::BEGIN)
+ {
+ const unsigned long begin = i;
+ ++i;
+ while (i < labels[k].size() && labels[k][i] == impl_ss::INSIDE)
+ ++i;
+
+ segments[k].push_back(std::make_pair(begin, i));
+ --i;
+ }
+ }
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ template <bool use_BIO_model, bool use_high_order_features, bool allow_negative_weights>
+ void do_test()
+ {
+ dlog << LINFO << "use_BIO_model: "<< use_BIO_model;
+ dlog << LINFO << "use_high_order_features: "<< use_high_order_features;
+ dlog << LINFO << "allow_negative_weights: "<< allow_negative_weights;
+
+ std::vector<std::vector<unsigned long> > samples;
+ std::vector<std::vector<std::pair<unsigned long,unsigned long> > > segments;
+ make_dataset2( samples, segments, 100);
+
+ print_spinner();
+ typedef unigram_extractor<use_BIO_model,use_high_order_features,allow_negative_weights> fe_type;
+
+ fe_type fe_temp;
+ fe_type fe_temp2;
+ structural_sequence_segmentation_trainer<fe_type> trainer(fe_temp2);
+ trainer.set_c(5);
+ trainer.set_num_threads(1);
+
+
+ sequence_segmenter<fe_type> labeler = trainer.train(samples, segments);
+
+ print_spinner();
+
+ const std::vector<std::pair<unsigned long, unsigned long> > predicted_labels = labeler(samples[1]);
+ const std::vector<std::pair<unsigned long, unsigned long> > true_labels = segments[1];
+ /*
+ for (unsigned long i = 0; i < predicted_labels.size(); ++i)
+ cout << "["<<predicted_labels[i].first<<","<<predicted_labels[i].second<<") ";
+ cout << endl;
+ for (unsigned long i = 0; i < true_labels.size(); ++i)
+ cout << "["<<true_labels[i].first<<","<<true_labels[i].second<<") ";
+ cout << endl;
+ */
+
+ DLIB_TEST(predicted_labels.size() > 0);
+ DLIB_TEST(predicted_labels.size() == true_labels.size());
+ for (unsigned long i = 0; i < predicted_labels.size(); ++i)
+ {
+ DLIB_TEST(predicted_labels[i].first == true_labels[i].first);
+ DLIB_TEST(predicted_labels[i].second == true_labels[i].second);
+ }
+
+
+ matrix<double> res;
+
+ res = cross_validate_sequence_segmenter(trainer, samples, segments, 3);
+ dlog << LINFO << "cv res: "<< res;
+ DLIB_TEST(min(res) > 0.98);
+ make_dataset2( samples, segments, 100);
+ res = test_sequence_segmenter(labeler, samples, segments);
+ dlog << LINFO << "test res: "<< res;
+ DLIB_TEST(min(res) > 0.98);
+
+ print_spinner();
+
+ ostringstream sout;
+ serialize(labeler, sout);
+ istringstream sin(sout.str());
+ sequence_segmenter<fe_type> labeler2;
+ deserialize(labeler2, sin);
+
+ res = test_sequence_segmenter(labeler2, samples, segments);
+ dlog << LINFO << "test res2: "<< res;
+ DLIB_TEST(min(res) > 0.98);
+
+ long N;
+ if (use_BIO_model)
+ N = 3*3+3;
+ else
+ N = 5*5+5;
+ const double min_normal_weight = min(colm(labeler2.get_weights(), 0, labeler2.get_weights().size()-N));
+ const double min_trans_weight = min(labeler2.get_weights());
+ dlog << LINFO << "min_normal_weight: " << min_normal_weight;
+ dlog << LINFO << "min_trans_weight: " << min_trans_weight;
+ if (allow_negative_weights)
+ {
+ DLIB_TEST(min_normal_weight < 0);
+ DLIB_TEST(min_trans_weight < 0);
+ }
+ else
+ {
+ DLIB_TEST(min_normal_weight == 0);
+ DLIB_TEST(min_trans_weight < 0);
+ }
+ }
+
+// ----------------------------------------------------------------------------------------
+
+
+ class unit_test_sequence_segmenter : public tester
+ {
+ public:
+ unit_test_sequence_segmenter (
+ ) :
+ tester ("test_sequence_segmenter",
+ "Runs tests on the sequence segmenting code.")
+ {}
+
+ void perform_test (
+ )
+ {
+ do_test<true,true,false>();
+ do_test<true,false,false>();
+ do_test<false,true,false>();
+ do_test<false,false,false>();
+ do_test<true,true,true>();
+ do_test<true,false,true>();
+ do_test<false,true,true>();
+ do_test<false,false,true>();
+ }
+ } a;
+
+}
+
+
+