// Copyright (C) 2011 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_ASSIGNMENT_FuNCTION_Hh_ #define DLIB_ASSIGNMENT_FuNCTION_Hh_ #include "assignment_function_abstract.h" #include "../matrix.h" #include #include "../optimization/max_cost_assignment.h" namespace dlib { // ---------------------------------------------------------------------------------------- template < typename feature_extractor > class assignment_function { public: typedef typename feature_extractor::lhs_element lhs_element; typedef typename feature_extractor::rhs_element rhs_element; typedef std::pair, std::vector > sample_type; typedef std::vector label_type; typedef label_type result_type; assignment_function() { weights.set_size(fe.num_features()); weights = 0; bias = 0; force_assignment = false; } explicit assignment_function( const matrix& weights_, double bias_ ) : weights(weights_), bias(bias_), force_assignment(false) { // make sure requires clause is not broken DLIB_ASSERT(fe.num_features() == static_cast(weights_.size()), "\t assignment_function::assignment_function(weights_)" << "\n\t These sizes should match" << "\n\t fe.num_features(): " << fe.num_features() << "\n\t weights_.size(): " << weights_.size() << "\n\t this: " << this ); } assignment_function( const matrix& weights_, double bias_, const feature_extractor& fe_ ) : fe(fe_), weights(weights_), bias(bias_), force_assignment(false) { // make sure requires clause is not broken DLIB_ASSERT(fe_.num_features() == static_cast(weights_.size()), "\t assignment_function::assignment_function(weights_,fe_)" << "\n\t These sizes should match" << "\n\t fe_.num_features(): " << fe_.num_features() << "\n\t weights_.size(): " << weights_.size() << "\n\t this: " << this ); } assignment_function( const matrix& weights_, double bias_, const feature_extractor& fe_, bool force_assignment_ ) : fe(fe_), weights(weights_), bias(bias_), force_assignment(force_assignment_) { // make sure requires clause is not broken DLIB_ASSERT(fe_.num_features() == static_cast(weights_.size()), "\t assignment_function::assignment_function(weights_,fe_,force_assignment_)" << "\n\t These sizes should match" << "\n\t fe_.num_features(): " << fe_.num_features() << "\n\t weights_.size(): " << weights_.size() << "\n\t this: " << this ); } const feature_extractor& get_feature_extractor ( ) const { return fe; } const matrix& get_weights ( ) const { return weights; } double get_bias ( ) const { return bias; } bool forces_assignment ( ) const { return force_assignment; } void predict_assignments ( const std::vector& lhs, const std::vector& rhs, result_type& assignment ) const { assignment.clear(); matrix cost; unsigned long size; if (force_assignment) { size = std::max(lhs.size(), rhs.size()); } else { size = rhs.size() + lhs.size(); } cost.set_size(size, size); typedef typename feature_extractor::feature_vector_type feature_vector_type; feature_vector_type feats; // now fill out the cost assignment matrix for (long r = 0; r < cost.nr(); ++r) { for (long c = 0; c < cost.nc(); ++c) { if (r < (long)lhs.size() && c < (long)rhs.size()) { fe.get_features(lhs[r], rhs[c], feats); cost(r,c) = dot(weights, feats) + bias; } else { cost(r,c) = 0; } } } if (cost.size() != 0) { // max_cost_assignment() only works with integer matrices, so convert from // double to integer. const double scale = (std::numeric_limits::max()/1000)/max(abs(cost)); matrix int_cost = matrix_cast(round(cost*scale)); assignment = max_cost_assignment(int_cost); assignment.resize(lhs.size()); } // adjust assignment so that non-assignments have a value of -1 for (unsigned long i = 0; i < assignment.size(); ++i) { if (assignment[i] >= (long)rhs.size()) assignment[i] = -1; } } void predict_assignments ( const sample_type& item, result_type& assignment ) const { predict_assignments(item.first, item.second, assignment); } result_type operator()( const std::vector& lhs, const std::vector& rhs ) const { result_type temp; predict_assignments(lhs,rhs,temp); return temp; } result_type operator() ( const sample_type& item ) const { return (*this)(item.first, item.second); } private: feature_extractor fe; matrix weights; double bias; bool force_assignment; }; // ---------------------------------------------------------------------------------------- template < typename feature_extractor > void serialize ( const assignment_function& item, std::ostream& out ) { int version = 2; serialize(version, out); serialize(item.get_feature_extractor(), out); serialize(item.get_weights(), out); serialize(item.get_bias(), out); serialize(item.forces_assignment(), out); } // ---------------------------------------------------------------------------------------- template < typename feature_extractor > void deserialize ( assignment_function& item, std::istream& in ) { feature_extractor fe; matrix weights; double bias; bool force_assignment; int version = 0; deserialize(version, in); if (version != 2) throw serialization_error("Unexpected version found while deserializing dlib::assignment_function."); deserialize(fe, in); deserialize(weights, in); deserialize(bias, in); deserialize(force_assignment, in); item = assignment_function(weights, bias, fe, force_assignment); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_ASSIGNMENT_FuNCTION_Hh_