// Copyright (C) 2013 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #include "opaque_types.h" #include #include #include #include "testing_results.h" #include using namespace dlib; using namespace std; namespace py = pybind11; typedef matrix sample_type; // ---------------------------------------------------------------------------------------- namespace dlib { template bool operator== ( const ranking_pair&, const ranking_pair& ) { pyassert(false, "It is illegal to compare ranking pair objects for equality."); return false; } } template void resize(T& v, unsigned long n) { v.resize(n); } // ---------------------------------------------------------------------------------------- template typename trainer_type::trained_function_type train1 ( const trainer_type& trainer, const ranking_pair& sample ) { typedef ranking_pair st; pyassert(is_ranking_problem(std::vector(1, sample)), "Invalid inputs"); return trainer.train(sample); } template typename trainer_type::trained_function_type train2 ( const trainer_type& trainer, const std::vector >& samples ) { pyassert(is_ranking_problem(samples), "Invalid inputs"); return trainer.train(samples); } template void set_epsilon ( trainer_type& trainer, double eps) { pyassert(eps > 0, "epsilon must be > 0"); trainer.set_epsilon(eps); } template double get_epsilon ( const trainer_type& trainer) { return trainer.get_epsilon(); } template void set_c ( trainer_type& trainer, double C) { pyassert(C > 0, "C must be > 0"); trainer.set_c(C); } template double get_c (const trainer_type& trainer) { return trainer.get_c(); } template void add_ranker ( py::module& m, const char* name ) { py::class_(m, name) .def(py::init()) .def_property("epsilon", get_epsilon, set_epsilon) .def_property("c", get_c, set_c) .def_property("max_iterations", &trainer::get_max_iterations, &trainer::set_max_iterations) .def_property("force_last_weight_to_1", &trainer::forces_last_weight_to_1, &trainer::force_last_weight_to_1) .def_property("learns_nonnegative_weights", &trainer::learns_nonnegative_weights, &trainer::set_learns_nonnegative_weights) .def_property_readonly("has_prior", &trainer::has_prior) .def("train", train1) .def("train", train2) .def("set_prior", &trainer::set_prior) .def("be_verbose", &trainer::be_verbose) .def("be_quiet", &trainer::be_quiet); } // ---------------------------------------------------------------------------------------- template < typename trainer_type, typename T > const ranking_test _cross_ranking_validate_trainer ( const trainer_type& trainer, const std::vector >& samples, const unsigned long folds ) { pyassert(is_ranking_problem(samples), "Training data does not make a valid training set."); pyassert(1 < folds && folds <= samples.size(), "Invalid number of folds given."); return cross_validate_ranking_trainer(trainer, samples, folds); } // ---------------------------------------------------------------------------------------- void bind_svm_rank_trainer(py::module& m) { py::class_ >(m, "ranking_pair") .def(py::init()) .def_readwrite("relevant", &ranking_pair::relevant) .def_readwrite("nonrelevant", &ranking_pair::nonrelevant) .def(py::pickle(&getstate>, &setstate>)); py::class_ >(m, "sparse_ranking_pair") .def(py::init()) .def_readwrite("relevant", &ranking_pair::relevant) .def_readwrite("nonrelevant", &ranking_pair::nonrelevant) .def(py::pickle(&getstate>, &setstate>)); py::bind_vector(m, "ranking_pairs") .def("clear", &ranking_pairs::clear) .def("resize", resize) .def("extend", extend_vector_with_python_list>) .def(py::pickle(&getstate, &setstate)); py::bind_vector(m, "sparse_ranking_pairs") .def("clear", &sparse_ranking_pairs::clear) .def("resize", resize) .def("extend", extend_vector_with_python_list>) .def(py::pickle(&getstate, &setstate)); add_ranker > >(m, "svm_rank_trainer"); add_ranker > >(m, "svm_rank_trainer_sparse"); m.def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer< svm_rank_trainer >,sample_type>, py::arg("trainer"), py::arg("samples"), py::arg("folds") ); m.def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer< svm_rank_trainer > ,sparse_vect>, py::arg("trainer"), py::arg("samples"), py::arg("folds") ); }