summaryrefslogtreecommitdiffstats
path: root/ml/dlib/dlib/global_optimization/global_function_search.h
diff options
context:
space:
mode:
authorDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-09 13:19:48 +0000
committerDaniel Baumann <daniel.baumann@progress-linux.org>2024-03-09 13:20:02 +0000
commit58daab21cd043e1dc37024a7f99b396788372918 (patch)
tree96771e43bb69f7c1c2b0b4f7374cb74d7866d0cb /ml/dlib/dlib/global_optimization/global_function_search.h
parentReleasing debian version 1.43.2-1. (diff)
downloadnetdata-58daab21cd043e1dc37024a7f99b396788372918.tar.xz
netdata-58daab21cd043e1dc37024a7f99b396788372918.zip
Merging upstream version 1.44.3.
Signed-off-by: Daniel Baumann <daniel.baumann@progress-linux.org>
Diffstat (limited to 'ml/dlib/dlib/global_optimization/global_function_search.h')
-rw-r--r--ml/dlib/dlib/global_optimization/global_function_search.h245
1 files changed, 245 insertions, 0 deletions
diff --git a/ml/dlib/dlib/global_optimization/global_function_search.h b/ml/dlib/dlib/global_optimization/global_function_search.h
new file mode 100644
index 000000000..fa036884a
--- /dev/null
+++ b/ml/dlib/dlib/global_optimization/global_function_search.h
@@ -0,0 +1,245 @@
+// Copyright (C) 2017 Davis E. King (davis@dlib.net)
+// License: Boost Software License See LICENSE.txt for the full license.
+#ifndef DLIB_GLOBAL_FuNCTION_SEARCH_Hh_
+#define DLIB_GLOBAL_FuNCTION_SEARCH_Hh_
+
+#include "global_function_search_abstract.h"
+#include <vector>
+#include "../matrix.h"
+#include <mutex>
+#include "../rand.h"
+#include "upper_bound_function.h"
+#include "../test_for_odr_violations.h"
+
+namespace dlib
+{
+
+// ----------------------------------------------------------------------------------------
+
+ struct function_spec
+ {
+ function_spec(
+ matrix<double,0,1> bound1,
+ matrix<double,0,1> bound2
+ );
+
+ function_spec(
+ matrix<double,0,1> bound1,
+ matrix<double,0,1> bound2,
+ std::vector<bool> is_integer
+ );
+
+ matrix<double,0,1> lower;
+ matrix<double,0,1> upper;
+ std::vector<bool> is_integer_variable;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ namespace gopt_impl
+ {
+ struct outstanding_function_eval_request
+ {
+ size_t request_id = 0; // unique id for this eval request
+ matrix<double,0,1> x; // function x to evaluate
+
+ // trust region specific stuff
+ bool was_trust_region_generated_request = false;
+ double predicted_improvement = std::numeric_limits<double>::quiet_NaN();
+ double anchor_objective_value = std::numeric_limits<double>::quiet_NaN(); // objective value at center of TR step
+
+ bool operator==(const outstanding_function_eval_request& item) const { return request_id == item.request_id; }
+ };
+
+ struct funct_info
+ {
+ funct_info() = delete;
+ funct_info(const funct_info&) = delete;
+ funct_info& operator=(const funct_info&) = delete;
+
+ funct_info(
+ const function_spec& spec,
+ size_t function_idx,
+ const std::shared_ptr<std::mutex>& m
+ ) :
+ spec(spec), function_idx(function_idx), m(m)
+ {
+ best_x = zeros_matrix(spec.lower);
+ }
+
+ upper_bound_function build_upper_bound_with_all_function_evals (
+ ) const;
+
+ static double find_nn (
+ const std::vector<function_evaluation>& evals,
+ const matrix<double,0,1>& x
+ );
+
+
+ function_spec spec;
+ size_t function_idx = 0;
+ std::shared_ptr<std::mutex> m;
+ upper_bound_function ub;
+ std::vector<outstanding_function_eval_request> outstanding_evals;
+ matrix<double,0,1> best_x;
+ double best_objective_value = -std::numeric_limits<double>::infinity();
+ double radius = 0;
+ };
+
+ }
+
+// ----------------------------------------------------------------------------------------
+
+ class function_evaluation_request
+ {
+ public:
+
+ function_evaluation_request() = delete;
+ function_evaluation_request(const function_evaluation_request&) = delete;
+ function_evaluation_request& operator=(const function_evaluation_request&) = delete;
+
+
+ function_evaluation_request(function_evaluation_request&& item);
+ function_evaluation_request& operator=(function_evaluation_request&& item);
+
+ ~function_evaluation_request();
+
+ size_t function_idx (
+ ) const;
+
+ const matrix<double,0,1>& x (
+ ) const;
+
+ bool has_been_evaluated (
+ ) const;
+
+ void set (
+ double y
+ );
+
+ void swap(function_evaluation_request& item);
+
+ private:
+
+ friend class global_function_search;
+
+ explicit function_evaluation_request(
+ const gopt_impl::outstanding_function_eval_request& req,
+ const std::shared_ptr<gopt_impl::funct_info>& info
+ ) : req(req), info(info) {}
+
+ bool m_has_been_evaluated = false;
+ gopt_impl::outstanding_function_eval_request req;
+ std::shared_ptr<gopt_impl::funct_info> info;
+ };
+
+// ----------------------------------------------------------------------------------------
+
+ class global_function_search
+ {
+ public:
+
+ global_function_search() = default;
+
+ explicit global_function_search(
+ const function_spec& function
+ );
+
+ explicit global_function_search(
+ const std::vector<function_spec>& functions_
+ );
+
+ global_function_search(
+ const std::vector<function_spec>& functions_,
+ const std::vector<std::vector<function_evaluation>>& initial_function_evals,
+ const double relative_noise_magnitude = 0.001
+ );
+
+ global_function_search(const global_function_search&) = delete;
+ global_function_search& operator=(const global_function_search& item) = delete;
+
+ global_function_search(global_function_search&& item) = default;
+ global_function_search& operator=(global_function_search&& item) = default;
+
+ size_t num_functions(
+ ) const;
+
+ void set_seed (
+ time_t seed
+ );
+
+ void get_function_evaluations (
+ std::vector<function_spec>& specs,
+ std::vector<std::vector<function_evaluation>>& function_evals
+ ) const;
+
+ void get_best_function_eval (
+ matrix<double,0,1>& x,
+ double& y,
+ size_t& function_idx
+ ) const;
+
+ function_evaluation_request get_next_x (
+ );
+
+ double get_pure_random_search_probability (
+ ) const;
+
+ void set_pure_random_search_probability (
+ double prob
+ );
+
+ double get_solver_epsilon (
+ ) const;
+
+ void set_solver_epsilon (
+ double eps
+ );
+
+ double get_relative_noise_magnitude (
+ ) const;
+
+ void set_relative_noise_magnitude (
+ double value
+ );
+
+ size_t get_monte_carlo_upper_bound_sample_num (
+ ) const;
+
+ void set_monte_carlo_upper_bound_sample_num (
+ size_t num
+ );
+
+ private:
+
+ std::shared_ptr<gopt_impl::funct_info> best_function(
+ ) const;
+
+ std::shared_ptr<gopt_impl::funct_info> best_function(
+ size_t& idx
+ ) const;
+
+ bool has_outstanding_trust_region_request (
+ ) const;
+
+
+ dlib::rand rnd;
+ double pure_random_search_probability = 0.02;
+ double min_trust_region_epsilon = 0;
+ double relative_noise_magnitude = 0.001;
+ size_t num_random_samples = 5000;
+ bool do_trust_region_step = true;
+
+ size_t next_request_id = 1;
+
+ std::vector<std::shared_ptr<gopt_impl::funct_info>> functions;
+ std::shared_ptr<std::mutex> m;
+
+ };
+
+// ----------------------------------------------------------------------------------------
+
+}
+
+#endif // DLIB_GLOBAL_FuNCTION_SEARCH_Hh_
+